Line data Source code
1 : #include "DnnlTrainableLayer.h" 2 : #include "TypeCasts.hpp" 3 : 4 : namespace elsa::ml 5 : { 6 : namespace detail 7 : { 8 : template <typename data_t> 9 0 : DnnlTrainableLayer<data_t>::DnnlTrainableLayer(const VolumeDescriptor& inputDescriptor, 10 : const VolumeDescriptor& outputDescriptor, 11 : const VolumeDescriptor& weightsDescriptor, 12 : Initializer initializer) 13 : : DnnlLayer<data_t>(inputDescriptor, outputDescriptor, "DnnlTrainableLayer"), 14 0 : _weightsGradientAcc(asUnsigned(weightsDescriptor.getNumberOfCoefficients())), 15 : _weightsDescriptor(weightsDescriptor.clone()), 16 0 : _initializer(initializer) 17 : { 18 : 19 0 : _input.front().canBeReordered = true; 20 : 21 : // Set the layer's fan-in and fan-out. This is needed for random initialization of 22 : // weights and biases 23 0 : _fanInOut.first = inputDescriptor.getNumberOfCoefficients(); 24 0 : _fanInOut.second = outputDescriptor.getNumberOfCoefficients(); 25 : 26 : // Set weights meta information 27 0 : for (const auto& dim : weightsDescriptor.getNumberOfCoefficientsPerDimension()) 28 0 : _weights.dimensions.push_back(dim); 29 : 30 0 : _weightsGradient.dimensions = _weights.dimensions; 31 : 32 0 : _weights.formatTag = 33 0 : BaseType::dataDescriptorToDnnlMemoryFormatTag(weightsDescriptor, 34 : /* No input but weights tag */ false); 35 : 36 0 : _weightsGradient.formatTag = _weights.formatTag; 37 : 38 0 : _weights.descriptor = 39 0 : dnnl::memory::desc({_weights.dimensions}, _typeTag, dnnl::memory::format_tag::any); 40 : 41 0 : _weightsGradient.descriptor = _weights.descriptor; 42 : 43 0 : _weightsGradientAcc.setZero(); 44 : 45 0 : IndexVector_t biasVec(1); 46 0 : biasVec << _weights.dimensions[0]; 47 : 48 0 : _biasDescriptor = VolumeDescriptor(biasVec).clone(); 49 : 50 : // Set weights bias information 51 0 : _bias.dimensions.push_back(_weights.dimensions[0]); 52 : 53 0 : _biasGradient.dimensions = _bias.dimensions; 54 : 55 0 : _bias.descriptor = 56 0 : dnnl::memory::desc({_bias.dimensions}, _typeTag, dnnl::memory::format_tag::any); 57 : 58 0 : _biasGradient.descriptor = _bias.descriptor; 59 : 60 0 : _bias.formatTag = dnnl::memory::format_tag::x; 61 : 62 0 : _biasGradient.formatTag = _bias.formatTag; 63 : 64 0 : _biasGradientAcc.setZero(_biasDescriptor->getNumberOfCoefficients()); 65 : 66 0 : initialize(); 67 0 : } 68 : 69 : template <typename data_t> 70 0 : void DnnlTrainableLayer<data_t>::setWeights(const DataContainer<data_t>& weights) 71 : { 72 0 : this->writeToDnnlMemory(weights, *_weights.describedMemory); 73 0 : } 74 : 75 : template <typename data_t> 76 0 : void DnnlTrainableLayer<data_t>::setBias(const DataContainer<data_t>& bias) 77 : { 78 0 : this->writeToDnnlMemory(bias, *_bias.describedMemory); 79 0 : } 80 : 81 : template <typename data_t> 82 0 : void DnnlTrainableLayer<data_t>::initialize() 83 : { 84 : // Construct weights memory and initialize it 85 0 : auto weightsDesc = 86 0 : dnnl::memory::desc({_weights.dimensions}, _typeTag, _weights.formatTag); 87 0 : _weights.describedMemory = std::make_shared<dnnl::memory>(weightsDesc, *_engine); 88 : 89 0 : InitializerImpl<data_t>::initialize( 90 0 : static_cast<data_t*>(_weights.describedMemory->get_data_handle()), 91 0 : _weightsDescriptor->getNumberOfCoefficients(), _initializer, _fanInOut); 92 : 93 : // Construct bias memory and initialize it with zero 94 0 : auto biasDesc = dnnl::memory::desc({_bias.dimensions}, _typeTag, _bias.formatTag); 95 0 : _bias.describedMemory = std::make_shared<dnnl::memory>(biasDesc, *_engine); 96 : 97 0 : InitializerImpl<data_t>::initialize( 98 0 : static_cast<data_t*>(_bias.describedMemory->get_data_handle()), _bias.dimensions[0], 99 0 : Initializer::Zeros, _fanInOut); 100 : 101 : // Bias can never be reordered 102 0 : _bias.effectiveMemory = _bias.describedMemory; 103 0 : } 104 : 105 : template <typename data_t> 106 0 : void DnnlTrainableLayer<data_t>::compileForwardStream() 107 : { 108 0 : BaseType::compileForwardStream(); 109 0 : } 110 : 111 : template <typename data_t> 112 0 : void DnnlTrainableLayer<data_t>::compileBackwardStream() 113 : { 114 0 : BaseType::compileBackwardStream(); 115 : 116 : // Construct weights memory descriptor and allocate weights memory 117 0 : auto weightsDesc = dnnl::memory::desc({_weightsGradient.dimensions}, _typeTag, 118 : _weightsGradient.formatTag); 119 0 : _weightsGradient.describedMemory = 120 0 : std::make_shared<dnnl::memory>(weightsDesc, *_engine); 121 : 122 : // Construct bias memory descriptor and allocate bias memory 123 0 : auto biasDesc = 124 0 : dnnl::memory::desc({_biasGradient.dimensions}, _typeTag, _biasGradient.formatTag); 125 0 : _biasGradient.describedMemory = std::make_shared<dnnl::memory>(biasDesc, *_engine); 126 : 127 : // Bias can never be reordered 128 0 : _biasGradient.effectiveMemory = _biasGradient.describedMemory; 129 0 : } 130 : 131 : template <typename data_t> 132 0 : DataContainer<data_t> DnnlTrainableLayer<data_t>::getGradientWeights() const 133 : { 134 0 : DataContainer<data_t> output(*_weightsDescriptor); 135 0 : this->readFromDnnlMemory(output, *_weightsGradient.effectiveMemory); 136 0 : return output; 137 : } 138 : 139 : template <typename data_t> 140 0 : DataContainer<data_t> DnnlTrainableLayer<data_t>::getGradientBias() const 141 : { 142 0 : DataContainer<data_t> output(*_biasDescriptor); 143 0 : this->readFromDnnlMemory(output, *_biasGradient.effectiveMemory); 144 0 : return output; 145 : } 146 : 147 : template <typename data_t> 148 0 : void DnnlTrainableLayer<data_t>::updateTrainableParameters() 149 : { 150 0 : index_t batchSize = 151 0 : this->_inputDescriptor.front()->getNumberOfCoefficientsPerDimension()[0]; 152 : 153 : // Update weights 154 0 : _weightsGradientAcc /= static_cast<data_t>(batchSize); 155 0 : weightsOptimizer_->updateParameter( 156 : _weightsGradientAcc.data(), batchSize, 157 0 : static_cast<data_t*>(_weights.effectiveMemory->get_data_handle())); 158 : 159 : // Update bias 160 0 : _biasGradientAcc /= static_cast<data_t>(batchSize); 161 0 : biasOptimizer_->updateParameter( 162 : _biasGradientAcc.data(), batchSize, 163 0 : static_cast<data_t*>(_bias.effectiveMemory->get_data_handle())); 164 : 165 : // Reset accumulated gradient 166 0 : _weightsGradientAcc.setZero(); 167 0 : _biasGradientAcc.setZero(); 168 0 : } 169 : 170 : template <typename data_t> 171 0 : void DnnlTrainableLayer<data_t>::accumulatedGradients() 172 : { 173 : // Accumulate weights 174 0 : Eigen::Map<Eigen::ArrayX<data_t>> weightsGradientMem( 175 0 : static_cast<data_t*>(_weightsGradient.effectiveMemory->get_data_handle()), 176 : _weightsDescriptor->getNumberOfCoefficients()); 177 : 178 0 : assert(_weightsGradientAcc.size() == weightsGradientMem.size() 179 : && "Size of accumulated weigths must match size of weights"); 180 : 181 0 : _weightsGradientAcc += weightsGradientMem; 182 : 183 : // Accumulate bias 184 0 : Eigen::Map<Eigen::ArrayX<data_t>> biasGradientMem( 185 0 : static_cast<data_t*>(_biasGradient.effectiveMemory->get_data_handle()), 186 : _biasDescriptor->getNumberOfCoefficients()); 187 : 188 0 : assert(_biasGradientAcc.size() == biasGradientMem.size() 189 : && "Size of accumulated bias must match size of bias"); 190 0 : _biasGradientAcc += biasGradientMem; 191 0 : } 192 : 193 : template <typename data_t> 194 0 : void DnnlTrainableLayer<data_t>::backwardPropagate(dnnl::stream& executionStream) 195 : { 196 : // Backward propagate as usual 197 0 : BaseType::backwardPropagate(executionStream); 198 : 199 : // Accumulate propagated gradients 200 0 : accumulatedGradients(); 201 0 : } 202 : 203 : template <typename data_t> 204 0 : bool DnnlTrainableLayer<data_t>::isTrainable() const 205 : { 206 0 : return true; 207 : } 208 : 209 : template class DnnlTrainableLayer<float>; 210 : 211 : } // namespace detail 212 : } // namespace elsa::ml