LCOV - code coverage report
Current view: top level - ml/backend/Dnnl - DnnlOptimizer.h (source / functions) Hit Total Coverage
Test: test_coverage.info.cleaned Lines: 0 8 0.0 %
Date: 2022-07-06 02:47:47 Functions: 0 1 0.0 %

          Line data    Source code
       1             : #pragma once
       2             : 
       3             : #include "elsaDefines.h"
       4             : #include "Optimizer.h"
       5             : 
       6             : #ifdef ELSA_HAS_CUDNN_BACKEND
       7             : #include "CudnnMemory.h"
       8             : #endif
       9             : 
      10             : namespace elsa
      11             : {
      12             :     namespace ml
      13             :     {
      14             :         namespace detail
      15             :         {
      16             :             template <typename data_t>
      17             :             class OptimizerAdamImpl<data_t, MlBackend::Dnnl> : public OptimizerImplBase<data_t>
      18             :             {
      19             :             public:
      20             :                 OptimizerAdamImpl(index_t size, data_t learningRate = data_t(0.001),
      21             :                                   data_t beta1 = data_t(0.9), data_t beta2 = data_t(0.999),
      22             :                                   data_t epsilon = data_t(1e-7));
      23             : 
      24             :                 void updateParameter(const data_t* gradient, index_t batchSize,
      25             :                                      data_t* param) override;
      26             : 
      27             :             private:
      28             :                 /// \copydoc OptimizerImplBase::learningRate_
      29             :                 using OptimizerImplBase<data_t>::learningRate_;
      30             : 
      31             :                 /// \copydoc OptimizerImplBase::step_
      32             :                 using OptimizerImplBase<data_t>::step_;
      33             : 
      34             :                 /// \copydoc OptimizerImplBase::size_
      35             :                 using OptimizerImplBase<data_t>::size_;
      36             : 
      37             :                 /// exponential decay for 1st order momenta
      38             :                 data_t beta1_;
      39             : 
      40             :                 /// exponential decay for 2nd order momenta
      41             :                 data_t beta2_;
      42             : 
      43             :                 /// epsilon-value for numeric stability
      44             :                 data_t epsilon_;
      45             : 
      46             :                 /// 1st momentum
      47             :                 Eigen::ArrayX<data_t> firstMomentum_;
      48             : 
      49             :                 /// 2nd momentum
      50             :                 Eigen::ArrayX<data_t> secondMomentum_;
      51             :             };
      52             : 
      53             :             template <typename data_t>
      54             :             struct OptimizerFactory<data_t, MlBackend::Dnnl> {
      55           0 :                 static std::shared_ptr<OptimizerImplBase<data_t>> run(Optimizer<data_t>* opt,
      56             :                                                                       index_t size)
      57             :                 {
      58           0 :                     switch (opt->getOptimizerType()) {
      59           0 :                         case OptimizerType::Adam: {
      60           0 :                             auto downcastedOpt = downcast<Adam<data_t>>(opt);
      61           0 :                             return std::make_shared<OptimizerAdamImpl<data_t, MlBackend::Dnnl>>(
      62             :                                 size, downcastedOpt->getLearningRate(), downcastedOpt->getBeta1(),
      63           0 :                                 downcastedOpt->getBeta2(), downcastedOpt->getEpsilon());
      64             :                         }
      65           0 :                         default:
      66           0 :                             assert(false && "This execution path should never be reached");
      67             :                     }
      68             :                     return nullptr;
      69             :                 }
      70             :             };
      71             :         } // namespace detail
      72             :     }     // namespace ml
      73             : } // namespace elsa

Generated by: LCOV version 1.15