Line data Source code
1 : #pragma once 2 : 3 : #include <memory> 4 : 5 : #include "Optimizer.h" 6 : #include "DnnlLayer.h" 7 : #include "DataDescriptor.h" 8 : #include "DataContainer.h" 9 : #include "DnnlOptimizer.h" 10 : #include "Initializer.h" 11 : 12 : #include "dnnl.hpp" 13 : 14 : namespace elsa::ml 15 : { 16 : namespace detail 17 : { 18 : /** 19 : * Trainable Dnnl layer. 20 : * 21 : * This layer is used as a base class for all Dnnl layer's with trainable 22 : * parameters, such as convolutional or dense layers. 23 : * 24 : * @tparam data_t Type of all coefficiens used in the layer 25 : */ 26 : template <typename data_t> 27 : class DnnlTrainableLayer : public DnnlLayer<data_t> 28 : { 29 : public: 30 : /// Type of this layer's base class 31 : using BaseType = DnnlLayer<data_t>; 32 : 33 : /** 34 : * Construct a trainable Dnnl network layer by passing a descriptor for its input, its 35 : * output and weights and an initializer for its weights and biases. 36 : */ 37 : DnnlTrainableLayer(const VolumeDescriptor& inputDescriptor, 38 : const VolumeDescriptor& outputDescriptor, 39 : const VolumeDescriptor& weightsDescriptor, Initializer initializer); 40 : /** 41 : * Set this layer's weights by passing a DataContainer. 42 : * 43 : * \note This functions performs a copy from a DataContainer to Dnnl 44 : * memory and should therefore be used for debugging or testing purposes 45 : * only. The layer is capable of initializing its weights on its own. 46 : */ 47 : void setWeights(const DataContainer<data_t>& weights); 48 : 49 : /** 50 : * Set this layer's biases by passing a DataContainer. 51 : * 52 : * \note This functions performs a copy from a DataContainer to Dnnl 53 : * memory and should therefore be used for debugging or testing purposes 54 : * only. The layer is capable of initializing its biases on its own. 55 : */ 56 : void setBias(const DataContainer<data_t>& bias); 57 : 58 : /// Get this layer's weights gradient as a DataContainer. 59 : DataContainer<data_t> getGradientWeights() const; 60 : 61 : /// Get this layer's bias gradient as a DataContainer. 62 : DataContainer<data_t> getGradientBias() const; 63 : 64 : void updateTrainableParameters(); 65 : 66 : void accumulatedGradients(); 67 : 68 : void initialize() override; 69 : 70 : void backwardPropagate(dnnl::stream& executionStream) override; 71 : 72 : bool isTrainable() const override; 73 : 74 0 : void setOptimizer(Optimizer<data_t>* optimizer) 75 : { 76 0 : weightsOptimizer_ = OptimizerFactory<data_t, MlBackend::Dnnl>::run( 77 : optimizer, _weightsDescriptor->getNumberOfCoefficients()); 78 0 : biasOptimizer_ = OptimizerFactory<data_t, MlBackend::Dnnl>::run( 79 : optimizer, _biasDescriptor->getNumberOfCoefficients()); 80 0 : } 81 : 82 : protected: 83 : /// \copydoc DnnlLayer::compileForwardStream 84 : void compileForwardStream() override; 85 : 86 : /// \copydoc DnnlLayer::compileBackwardStream 87 : void compileBackwardStream() override; 88 : 89 : using DnnlMemory = typename BaseType::DnnlMemory; 90 : 91 : /// \copydoc DnnlLayer::_typeTag 92 : using BaseType::_typeTag; 93 : 94 : /// \copydoc DnnlLayer::_engine 95 : using BaseType::_engine; 96 : 97 : /// \copydoc DnnlLayer::_input 98 : using BaseType::_input; 99 : 100 : /// \copydoc DnnlLayer::_inputGradient 101 : using BaseType::_inputGradient; 102 : 103 : /// \copydoc DnnlLayer::_output 104 : using BaseType::_output; 105 : 106 : /// \copydoc DnnlLayer::_outputGradient 107 : using BaseType::_outputGradient; 108 : 109 : /// \copydoc DnnlLayer::_forwardStream 110 : using BaseType::_forwardStream; 111 : 112 : /// \copydoc DnnlLayer::_backwardStream 113 : using BaseType::_backwardStream; 114 : 115 : /// This layer's weights memory 116 : DnnlMemory _weights; 117 : 118 : /// This layer's weights gradient memory 119 : DnnlMemory _weightsGradient; 120 : 121 : /// This layer's accumulated weights gradient memory 122 : Eigen::ArrayX<data_t> _weightsGradientAcc; 123 : 124 : /// This layer's bias memory 125 : DnnlMemory _bias; 126 : 127 : /// This layer's bias gradient memory 128 : DnnlMemory _biasGradient; 129 : 130 : /// This layer's accumulated bias gradient memory 131 : Eigen::ArrayX<data_t> _biasGradientAcc; 132 : 133 : /// This layer's weights DataDescriptor 134 : std::unique_ptr<DataDescriptor> _weightsDescriptor; 135 : 136 : /// This layer's bias DataDescriptor 137 : std::unique_ptr<DataDescriptor> _biasDescriptor; 138 : 139 : /// This layer's initializer tag 140 : Initializer _initializer; 141 : 142 : /// This layer's fanIn/fanOut pair that is used during random initialization of weights 143 : /// and biases 144 : typename InitializerImpl<data_t>::FanPairType _fanInOut; 145 : 146 : std::shared_ptr<OptimizerImplBase<data_t>> weightsOptimizer_; 147 : std::shared_ptr<OptimizerImplBase<data_t>> biasOptimizer_; 148 : }; 149 : 150 : } // namespace detail 151 : } // namespace elsa::ml