LCOV - code coverage report
Current view: top level - ml/backend/Dnnl - DnnlTrainableLayer.h (source / functions) Hit Total Coverage
Test: test_coverage.info.cleaned Lines: 0 4 0.0 %
Date: 2022-07-06 02:47:47 Functions: 0 1 0.0 %

          Line data    Source code
       1             : #pragma once
       2             : 
       3             : #include <memory>
       4             : 
       5             : #include "Optimizer.h"
       6             : #include "DnnlLayer.h"
       7             : #include "DataDescriptor.h"
       8             : #include "DataContainer.h"
       9             : #include "DnnlOptimizer.h"
      10             : #include "Initializer.h"
      11             : 
      12             : #include "dnnl.hpp"
      13             : 
      14             : namespace elsa::ml
      15             : {
      16             :     namespace detail
      17             :     {
      18             :         /**
      19             :          * Trainable Dnnl layer.
      20             :          *
      21             :          * This layer is used as a base class for all Dnnl layer's with trainable
      22             :          * parameters, such as convolutional or dense layers.
      23             :          *
      24             :          * @tparam data_t Type of all coefficiens used in the layer
      25             :          */
      26             :         template <typename data_t>
      27             :         class DnnlTrainableLayer : public DnnlLayer<data_t>
      28             :         {
      29             :         public:
      30             :             /// Type of this layer's base class
      31             :             using BaseType = DnnlLayer<data_t>;
      32             : 
      33             :             /**
      34             :              * Construct a trainable Dnnl network layer by passing a descriptor for its input, its
      35             :              * output and weights and an initializer for its weights and biases.
      36             :              */
      37             :             DnnlTrainableLayer(const VolumeDescriptor& inputDescriptor,
      38             :                                const VolumeDescriptor& outputDescriptor,
      39             :                                const VolumeDescriptor& weightsDescriptor, Initializer initializer);
      40             :             /**
      41             :              * Set this layer's weights by passing a DataContainer.
      42             :              *
      43             :              * \note This functions performs a copy from a DataContainer to Dnnl
      44             :              * memory and should therefore be used for debugging or testing purposes
      45             :              * only. The layer is capable of initializing its weights on its own.
      46             :              */
      47             :             void setWeights(const DataContainer<data_t>& weights);
      48             : 
      49             :             /**
      50             :              * Set this layer's biases by passing a DataContainer.
      51             :              *
      52             :              * \note This functions performs a copy from a DataContainer to Dnnl
      53             :              * memory and should therefore be used for debugging or testing purposes
      54             :              * only. The layer is capable of initializing its biases on its own.
      55             :              */
      56             :             void setBias(const DataContainer<data_t>& bias);
      57             : 
      58             :             /// Get this layer's weights gradient as a DataContainer.
      59             :             DataContainer<data_t> getGradientWeights() const;
      60             : 
      61             :             /// Get this layer's bias gradient as a DataContainer.
      62             :             DataContainer<data_t> getGradientBias() const;
      63             : 
      64             :             void updateTrainableParameters();
      65             : 
      66             :             void accumulatedGradients();
      67             : 
      68             :             void initialize() override;
      69             : 
      70             :             void backwardPropagate(dnnl::stream& executionStream) override;
      71             : 
      72             :             bool isTrainable() const override;
      73             : 
      74           0 :             void setOptimizer(Optimizer<data_t>* optimizer)
      75             :             {
      76           0 :                 weightsOptimizer_ = OptimizerFactory<data_t, MlBackend::Dnnl>::run(
      77             :                     optimizer, _weightsDescriptor->getNumberOfCoefficients());
      78           0 :                 biasOptimizer_ = OptimizerFactory<data_t, MlBackend::Dnnl>::run(
      79             :                     optimizer, _biasDescriptor->getNumberOfCoefficients());
      80           0 :             }
      81             : 
      82             :         protected:
      83             :             /// \copydoc DnnlLayer::compileForwardStream
      84             :             void compileForwardStream() override;
      85             : 
      86             :             /// \copydoc DnnlLayer::compileBackwardStream
      87             :             void compileBackwardStream() override;
      88             : 
      89             :             using DnnlMemory = typename BaseType::DnnlMemory;
      90             : 
      91             :             /// \copydoc DnnlLayer::_typeTag
      92             :             using BaseType::_typeTag;
      93             : 
      94             :             /// \copydoc DnnlLayer::_engine
      95             :             using BaseType::_engine;
      96             : 
      97             :             /// \copydoc DnnlLayer::_input
      98             :             using BaseType::_input;
      99             : 
     100             :             /// \copydoc DnnlLayer::_inputGradient
     101             :             using BaseType::_inputGradient;
     102             : 
     103             :             /// \copydoc DnnlLayer::_output
     104             :             using BaseType::_output;
     105             : 
     106             :             /// \copydoc DnnlLayer::_outputGradient
     107             :             using BaseType::_outputGradient;
     108             : 
     109             :             /// \copydoc DnnlLayer::_forwardStream
     110             :             using BaseType::_forwardStream;
     111             : 
     112             :             /// \copydoc DnnlLayer::_backwardStream
     113             :             using BaseType::_backwardStream;
     114             : 
     115             :             /// This layer's weights memory
     116             :             DnnlMemory _weights;
     117             : 
     118             :             /// This layer's weights gradient memory
     119             :             DnnlMemory _weightsGradient;
     120             : 
     121             :             /// This layer's accumulated weights gradient memory
     122             :             Eigen::ArrayX<data_t> _weightsGradientAcc;
     123             : 
     124             :             /// This layer's bias memory
     125             :             DnnlMemory _bias;
     126             : 
     127             :             /// This layer's bias gradient memory
     128             :             DnnlMemory _biasGradient;
     129             : 
     130             :             /// This layer's accumulated bias gradient memory
     131             :             Eigen::ArrayX<data_t> _biasGradientAcc;
     132             : 
     133             :             /// This layer's weights DataDescriptor
     134             :             std::unique_ptr<DataDescriptor> _weightsDescriptor;
     135             : 
     136             :             /// This layer's bias DataDescriptor
     137             :             std::unique_ptr<DataDescriptor> _biasDescriptor;
     138             : 
     139             :             /// This layer's initializer tag
     140             :             Initializer _initializer;
     141             : 
     142             :             /// This layer's fanIn/fanOut pair that is used during random initialization of weights
     143             :             /// and biases
     144             :             typename InitializerImpl<data_t>::FanPairType _fanInOut;
     145             : 
     146             :             std::shared_ptr<OptimizerImplBase<data_t>> weightsOptimizer_;
     147             :             std::shared_ptr<OptimizerImplBase<data_t>> biasOptimizer_;
     148             :         };
     149             : 
     150             :     } // namespace detail
     151             : } // namespace elsa::ml

Generated by: LCOV version 1.15