Line data Source code
1 : #pragma once 2 : 3 : #include "Common.h" 4 : #include "DataContainer.h" 5 : #include "TypeCasts.hpp" 6 : 7 : namespace elsa::ml 8 : { 9 : template <typename data_t, MlBackend Backend> 10 : class Model; 11 : 12 : namespace detail 13 : { 14 : template <typename data_t, MlBackend Backend, LayerType Layer> 15 : struct BackendSelector { 16 : using Type = std::false_type; 17 : }; 18 : 19 : #define ELSA_ML_MAKE_BACKEND_LAYER_SELECTOR(Backend, Layer, BackendLayer) \ 20 : template <typename data_t> \ 21 : struct BackendSelector<data_t, MlBackend::Backend, LayerType::Layer> { \ 22 : using Type = BackendLayer<data_t>; \ 23 : } 24 : 25 : /// Generic BackendAdaptor. 26 : /// 27 : /// Backend-specific logic is implemented in template specializations 28 : /// of this struct. 29 : template <typename data_t, MlBackend Backend> 30 : struct BackendAdaptor { 31 0 : static void constructBackendGraph([[maybe_unused]] Model<data_t, Backend>*) 32 : { 33 0 : throw std::logic_error("No elsa ML backend available"); 34 : } 35 : 36 0 : static DataContainer<data_t> predict([[maybe_unused]] Model<data_t, Backend>*, 37 : [[maybe_unused]] const DataContainer<data_t>&) 38 : { 39 0 : throw std::logic_error("No elsa ML backend available"); 40 : } 41 : 42 : static typename Model<data_t, Backend>::History 43 0 : fit([[maybe_unused]] Model<data_t, Backend>*, 44 : [[maybe_unused]] const std::vector<DataContainer<data_t>>&, 45 : [[maybe_unused]] const std::vector<DataContainer<data_t>>&, index_t) 46 : { 47 0 : throw std::logic_error("No elsa ML backend available"); 48 : } 49 : }; 50 : 51 : /// Attach batch-size to a volume-descriptor 52 : /// 53 : /// If a we have a descriptor 54 : /// {w, h, c} 55 : /// this creates a new descriptor 56 : /// {w, h, c, n}. 57 : static inline VolumeDescriptor 58 0 : attachBatchSizeToVolumeDescriptor(index_t batchSize, const VolumeDescriptor& desc) 59 : { 60 0 : IndexVector_t dims(desc.getNumberOfDimensions() + 1); 61 0 : dims.head(desc.getNumberOfDimensions()) = desc.getNumberOfCoefficientsPerDimension(); 62 0 : dims.tail(1)[asUnsigned(0)] = batchSize; 63 0 : return VolumeDescriptor(dims); 64 : } 65 : 66 : /// Reverse a volume-descriptor 67 : /// 68 : /// If we have a descriptor 69 : /// {w, h, c, n} 70 : /// this creates a descriptor 71 : /// {n, c, h, w}. 72 0 : static inline VolumeDescriptor reverseVolumeDescriptor(const VolumeDescriptor& desc) 73 : { 74 0 : IndexVector_t dims = desc.getNumberOfCoefficientsPerDimension().reverse(); 75 0 : return VolumeDescriptor(dims); 76 : } 77 : 78 : template <typename T> 79 0 : static auto getCheckedLayerPtr(T&& node) 80 : { 81 0 : auto layer = node.getData(); 82 0 : assert(layer != nullptr && "Pointer to backend-layer is null"); 83 0 : return layer; 84 : } 85 : 86 : template <typename T> 87 : static auto getCheckedLayerPtr(T* node) 88 : { 89 : auto layer = node->getData(); 90 : assert(layer != nullptr && "Pointer to backend-layer is null"); 91 : return layer; 92 : } 93 : 94 : template <typename GraphType> 95 0 : void setNumberOfOutputGradients(GraphType* backendGraph) 96 : { 97 0 : for (auto&& node : backendGraph->getNodes()) { 98 0 : auto layer = getCheckedLayerPtr(node.second); 99 0 : layer->setNumberOfOutputGradients( 100 0 : asSigned(backendGraph->getOutgoingEdges(node.first).size())); 101 : } 102 0 : } 103 : } // namespace detail 104 : 105 : } // namespace elsa::ml 106 : 107 : #ifdef ELSA_HAS_DNNL_BACKEND 108 : #include "DnnlBackendAdaptor.h" 109 : #endif 110 : 111 : #ifdef ELSA_HAS_CUDNN_BACKEND 112 : #include "CudnnBackendAdaptor.h" 113 : #endif