LCOV - code coverage report
Current view: top level - ml - Model.cpp (source / functions) Hit Total Coverage
Test: test_coverage.info.cleaned Lines: 0 68 0.0 %
Date: 2022-07-06 02:47:47 Functions: 0 32 0.0 %

          Line data    Source code
       1             : #include "Model.h"
       2             : 
       3             : #include <deque>
       4             : #include "TypeCasts.hpp"
       5             : 
       6             : namespace elsa::ml
       7             : {
       8             :     template <typename data_t, MlBackend Backend>
       9           0 :     Model<data_t, Backend>::Model(std::initializer_list<Input<data_t>*> inputs,
      10             :                                   std::initializer_list<Layer<data_t>*> outputs,
      11             :                                   const std::string& name)
      12           0 :         : name_(name), inputs_(inputs), outputs_(outputs)
      13             :     {
      14             :         // Save the batch-size this model uses
      15           0 :         batchSize_ = inputs_.front()->getBatchSize();
      16             : 
      17             :         // Set all input-descriptors by traversing the graph
      18           0 :         setInputDescriptors();
      19           0 :     }
      20             : 
      21             :     template <typename data_t, MlBackend Backend>
      22           0 :     Model<data_t, Backend>::Model(Input<data_t>* input, Layer<data_t>* output,
      23             :                                   const std::string& name)
      24           0 :         : Model({input}, {output}, name)
      25             :     {
      26           0 :     }
      27             : 
      28             :     template <typename data_t, MlBackend Backend>
      29           0 :     index_t Model<data_t, Backend>::getBatchSize() const
      30             :     {
      31           0 :         return batchSize_;
      32             :     }
      33             : 
      34             :     template <typename data_t, MlBackend Backend>
      35           0 :     void Model<data_t, Backend>::setInputDescriptors()
      36             :     {
      37             :         // TODO(tellenbach): Replace by Graph::visit method
      38             : 
      39             :         // Get the graph
      40           0 :         auto& graph = detail::State<data_t>::getGraph();
      41             : 
      42             :         // Get all nodes of the graph, i.e., a map with node-indices as keys
      43             :         // and nodes as values
      44           0 :         auto& nodes = graph.getNodes();
      45             : 
      46             :         // We maintain a list of nodes we've already visited and a call-queue to
      47             :         // ensure the correct order of traversal.
      48           0 :         std::vector<bool> visited(asUnsigned(graph.getNumberOfNodes()));
      49             :         // Note that this queue is in fact a deque, so we can push and pop from
      50             :         // both, the front and back.
      51           0 :         std::deque<index_t> queue;
      52             : 
      53             :         // Perform an iterative depth-first traversal through the graph
      54           0 :         for (auto in : inputs_) {
      55             :             // Push the input-node onto the call-queue
      56           0 :             queue.push_back(in->getGlobalIndex());
      57             : 
      58           0 :             while (!queue.empty()) {
      59             :                 // The current node is the top of the stack and we compute its
      60             :                 // output descriptor
      61           0 :                 index_t s = queue.back();
      62           0 :                 queue.pop_back();
      63             : 
      64           0 :                 if (!visited[static_cast<std::size_t>(s)]) {
      65             :                     // If the current node is a merging layer, its
      66             :                     // output-descriptor can depend on *all* input-descriptors.
      67             :                     // We therefore have to make sure that we really set all
      68             :                     // input-descriptors before attempting to compute a merging
      69             :                     // layer's output-descriptor or attempting to continue the
      70             :                     // traversal.
      71             :                     //
      72             :                     // We do this by checking if the number of edges that reach
      73             :                     // a merging layer is equal to the number of set inputs.
      74             :                     //
      75             :                     //      +---------+
      76             :                     //      | Layer 1 |   +---------+
      77             :                     //      +---------+   | Layer 2 |
      78             :                     //           |        +---------+
      79             :                     //           v             |
      80             :                     //      +---------+        |
      81             :                     //      | Merging |<-------+
      82             :                     //      +---------+
      83             :                     //  (?) Do we have the input from Layer1 *and* Layer2?
      84             :                     //
      85             :                     // We also have to make sure that a merging layer get's
      86             :                     // visited again when all of its inputs are set. Pushing the
      87             :                     // layer on top of the queue again causes an infinite loop
      88             :                     // since we will always visit it again, see that we can't
      89             :                     // compute its output-descriptor yet, visit it again...
      90             :                     //
      91             :                     // To solve this problem we push a merging layer to the
      92             :                     // *front* of the queue such that it get's visited again,
      93             :                     // in a delayed fashion.
      94           0 :                     if (!nodes.at(s).getData()->canMerge()
      95           0 :                         || nodes.at(s).getData()->getNumberOfInputs()
      96           0 :                                == static_cast<index_t>(graph.getIncomingEdges(s).size())) {
      97             :                         // We end up here if we either have no merging layer
      98             :                         // or if we have a merging layer but already gathered
      99             :                         // all of its inputs
     100           0 :                         nodes.at(s).getData()->computeOutputDescriptor();
     101             : 
     102           0 :                         visited[asUnsigned(s)] = true;
     103             :                     } else {
     104             :                         // We end up here if we have a merging layer but haven't
     105             :                         // collected all of its inputs yet. In this case we
     106             :                         // push the layer to the *front* of out call-queue.
     107           0 :                         queue.push_front(s);
     108             : 
     109             :                         // Make sure we don't handle a merging layer's childs
     110             :                         // before handling the layer itself
     111           0 :                         continue;
     112             :                     }
     113             :                 }
     114             : 
     115             :                 // TODO(tellenbach): Stop if we reach one of the model's output
     116             :                 // layers
     117             : 
     118             :                 // Consider all outgoing edges of a node and set their
     119             :                 // input-descriptors to the output-descriptor of their parent
     120             :                 // node
     121           0 :                 for (auto& e : graph.getOutgoingEdges(s)) {
     122           0 :                     auto idx = e.end()->getIndex();
     123             : 
     124             :                     // If we haven't visited this child node yet, add it to the
     125             :                     // call-queue and set its input-descriptor
     126           0 :                     if (!visited[static_cast<std::size_t>(idx)]) {
     127           0 :                         queue.push_back(idx);
     128           0 :                         e.end()->getData()->setInputDescriptor(
     129           0 :                             nodes.at(s).getData()->getOutputDescriptor());
     130             :                     }
     131             :                 }
     132             :             }
     133             :         }
     134           0 :     }
     135             : 
     136             :     template <typename data_t, MlBackend Backend>
     137           0 :     void Model<data_t, Backend>::compile(const Loss<data_t>& loss, Optimizer<data_t>* optimizer)
     138             :     {
     139           0 :         loss_ = loss;
     140           0 :         optimizer_ = optimizer;
     141           0 :         detail::BackendAdaptor<data_t, Backend>::constructBackendGraph(this);
     142           0 :     }
     143             : 
     144             :     template <typename data_t, MlBackend Backend>
     145             :     typename Model<data_t, Backend>::History
     146           0 :         Model<data_t, Backend>::fit(const std::vector<DataContainer<data_t>>& x,
     147             :                                     const std::vector<DataContainer<data_t>>& y, index_t epochs)
     148             :     {
     149             :         // Check if all elements of x have the same data-container
     150           0 :         if (std::adjacent_find(x.begin(), x.end(),
     151           0 :                                [](const auto& dc0, const auto& dc1) {
     152           0 :                                    return dc0.getDataDescriptor() != dc1.getDataDescriptor();
     153             :                                })
     154           0 :             != x.end())
     155           0 :             throw std::invalid_argument("All elements of x must have the same data-descriptor");
     156             : 
     157             :         // Check if all elements of y have the same data-container
     158           0 :         if (std::adjacent_find(y.begin(), y.end(),
     159           0 :                                [](const auto& dc0, const auto& dc1) {
     160           0 :                                    return dc0.getDataDescriptor() != dc1.getDataDescriptor();
     161             :                                })
     162           0 :             != y.end())
     163           0 :             throw std::invalid_argument("All elements of y must have the same data-descriptor");
     164             : 
     165           0 :         return detail::BackendAdaptor<data_t, Backend>::fit(this, x, y, epochs);
     166             :     }
     167             : 
     168             :     template <typename data_t, MlBackend Backend>
     169           0 :     DataContainer<data_t> Model<data_t, Backend>::predict(const DataContainer<data_t>& x)
     170             :     {
     171           0 :         return detail::BackendAdaptor<data_t, Backend>::predict(this, x);
     172             :     }
     173             : 
     174             :     template <typename data_t, MlBackend Backend>
     175           0 :     Optimizer<data_t>* Model<data_t, Backend>::getOptimizer()
     176             :     {
     177           0 :         return optimizer_;
     178             :     }
     179             : 
     180             :     template <typename data_t, MlBackend Backend>
     181           0 :     std::string Model<data_t, Backend>::getName() const
     182             :     {
     183           0 :         return name_;
     184             :     }
     185             : 
     186             :     template <typename data_t, MlBackend Backend>
     187           0 :     std::vector<Input<data_t>*> Model<data_t, Backend>::getInputs()
     188             :     {
     189           0 :         return inputs_;
     190             :     }
     191             : 
     192             :     template <typename data_t, MlBackend Backend>
     193           0 :     std::vector<Layer<data_t>*> Model<data_t, Backend>::getOutputs()
     194             :     {
     195           0 :         return outputs_;
     196             :     }
     197             : 
     198             :     template <typename data_t, MlBackend Backend>
     199             :     detail::Graph<typename detail::BackendSelector<data_t, Backend, LayerType::Undefined>::Type,
     200             :                   false>&
     201           0 :         Model<data_t, Backend>::getBackendGraph()
     202             :     {
     203           0 :         return backendGraph_;
     204             :     }
     205             : 
     206             :     template <typename data_t, MlBackend Backend>
     207             :     const detail::Graph<
     208             :         typename detail::BackendSelector<data_t, Backend, LayerType::Undefined>::Type, false>&
     209           0 :         Model<data_t, Backend>::getBackendGraph() const
     210             :     {
     211           0 :         return backendGraph_;
     212             :     }
     213             : 
     214             :     template <typename data_t, MlBackend Backend>
     215           0 :     const Loss<data_t>& Model<data_t, Backend>::getLoss() const
     216             :     {
     217           0 :         return loss_;
     218             :     }
     219             : 
     220             :     template class Model<float, MlBackend::Dnnl>;
     221             :     template class Model<float, MlBackend::Cudnn>;
     222             : } // namespace elsa::ml

Generated by: LCOV version 1.15