LCOV - code coverage report
Current view: top level - ml - BackendAdaptor.h (source / functions) Hit Total Coverage
Test: test_coverage.info.cleaned Lines: 0 24 0.0 %
Date: 2022-07-06 02:47:47 Functions: 0 8 0.0 %

          Line data    Source code
       1             : #pragma once
       2             : 
       3             : #include "Common.h"
       4             : #include "DataContainer.h"
       5             : #include "TypeCasts.hpp"
       6             : 
       7             : namespace elsa::ml
       8             : {
       9             :     template <typename data_t, MlBackend Backend>
      10             :     class Model;
      11             : 
      12             :     namespace detail
      13             :     {
      14             :         template <typename data_t, MlBackend Backend, LayerType Layer>
      15             :         struct BackendSelector {
      16             :             using Type = std::false_type;
      17             :         };
      18             : 
      19             : #define ELSA_ML_MAKE_BACKEND_LAYER_SELECTOR(Backend, Layer, BackendLayer)  \
      20             :     template <typename data_t>                                             \
      21             :     struct BackendSelector<data_t, MlBackend::Backend, LayerType::Layer> { \
      22             :         using Type = BackendLayer<data_t>;                                 \
      23             :     }
      24             : 
      25             :         /// Generic BackendAdaptor.
      26             :         ///
      27             :         /// Backend-specific logic is implemented in template specializations
      28             :         /// of this struct.
      29             :         template <typename data_t, MlBackend Backend>
      30             :         struct BackendAdaptor {
      31           0 :             static void constructBackendGraph([[maybe_unused]] Model<data_t, Backend>*)
      32             :             {
      33           0 :                 throw std::logic_error("No elsa ML backend available");
      34             :             }
      35             : 
      36           0 :             static DataContainer<data_t> predict([[maybe_unused]] Model<data_t, Backend>*,
      37             :                                                  [[maybe_unused]] const DataContainer<data_t>&)
      38             :             {
      39           0 :                 throw std::logic_error("No elsa ML backend available");
      40             :             }
      41             : 
      42             :             static typename Model<data_t, Backend>::History
      43           0 :                 fit([[maybe_unused]] Model<data_t, Backend>*,
      44             :                     [[maybe_unused]] const std::vector<DataContainer<data_t>>&,
      45             :                     [[maybe_unused]] const std::vector<DataContainer<data_t>>&, index_t)
      46             :             {
      47           0 :                 throw std::logic_error("No elsa ML backend available");
      48             :             }
      49             :         };
      50             : 
      51             :         /// Attach batch-size to a volume-descriptor
      52             :         ///
      53             :         /// If a we have a descriptor
      54             :         ///   {w, h, c}
      55             :         /// this creates a new descriptor
      56             :         ///   {w, h, c, n}.
      57             :         static inline VolumeDescriptor
      58           0 :             attachBatchSizeToVolumeDescriptor(index_t batchSize, const VolumeDescriptor& desc)
      59             :         {
      60           0 :             IndexVector_t dims(desc.getNumberOfDimensions() + 1);
      61           0 :             dims.head(desc.getNumberOfDimensions()) = desc.getNumberOfCoefficientsPerDimension();
      62           0 :             dims.tail(1)[asUnsigned(0)] = batchSize;
      63           0 :             return VolumeDescriptor(dims);
      64             :         }
      65             : 
      66             :         /// Reverse a volume-descriptor
      67             :         ///
      68             :         /// If we have a descriptor
      69             :         ///   {w, h, c, n}
      70             :         /// this creates a descriptor
      71             :         ///   {n, c, h, w}.
      72           0 :         static inline VolumeDescriptor reverseVolumeDescriptor(const VolumeDescriptor& desc)
      73             :         {
      74           0 :             IndexVector_t dims = desc.getNumberOfCoefficientsPerDimension().reverse();
      75           0 :             return VolumeDescriptor(dims);
      76             :         }
      77             : 
      78             :         template <typename T>
      79           0 :         static auto getCheckedLayerPtr(T&& node)
      80             :         {
      81           0 :             auto layer = node.getData();
      82           0 :             assert(layer != nullptr && "Pointer to backend-layer is null");
      83           0 :             return layer;
      84             :         }
      85             : 
      86             :         template <typename T>
      87             :         static auto getCheckedLayerPtr(T* node)
      88             :         {
      89             :             auto layer = node->getData();
      90             :             assert(layer != nullptr && "Pointer to backend-layer is null");
      91             :             return layer;
      92             :         }
      93             : 
      94             :         template <typename GraphType>
      95           0 :         void setNumberOfOutputGradients(GraphType* backendGraph)
      96             :         {
      97           0 :             for (auto&& node : backendGraph->getNodes()) {
      98           0 :                 auto layer = getCheckedLayerPtr(node.second);
      99           0 :                 layer->setNumberOfOutputGradients(
     100           0 :                     asSigned(backendGraph->getOutgoingEdges(node.first).size()));
     101             :             }
     102           0 :         }
     103             :     } // namespace detail
     104             : 
     105             : } // namespace elsa::ml
     106             : 
     107             : #ifdef ELSA_HAS_DNNL_BACKEND
     108             : #include "DnnlBackendAdaptor.h"
     109             : #endif
     110             : 
     111             : #ifdef ELSA_HAS_CUDNN_BACKEND
     112             : #include "CudnnBackendAdaptor.h"
     113             : #endif

Generated by: LCOV version 1.15