Line data Source code
1 : #pragma once 2 : 3 : #include "elsaDefines.h" 4 : #include "Optimizer.h" 5 : 6 : #ifdef ELSA_HAS_CUDNN_BACKEND 7 : #include "CudnnMemory.h" 8 : #endif 9 : 10 : namespace elsa 11 : { 12 : namespace ml 13 : { 14 : namespace detail 15 : { 16 : template <typename data_t> 17 : class OptimizerAdamImpl<data_t, MlBackend::Dnnl> : public OptimizerImplBase<data_t> 18 : { 19 : public: 20 : OptimizerAdamImpl(index_t size, data_t learningRate = data_t(0.001), 21 : data_t beta1 = data_t(0.9), data_t beta2 = data_t(0.999), 22 : data_t epsilon = data_t(1e-7)); 23 : 24 : void updateParameter(const data_t* gradient, index_t batchSize, 25 : data_t* param) override; 26 : 27 : private: 28 : /// \copydoc OptimizerImplBase::learningRate_ 29 : using OptimizerImplBase<data_t>::learningRate_; 30 : 31 : /// \copydoc OptimizerImplBase::step_ 32 : using OptimizerImplBase<data_t>::step_; 33 : 34 : /// \copydoc OptimizerImplBase::size_ 35 : using OptimizerImplBase<data_t>::size_; 36 : 37 : /// exponential decay for 1st order momenta 38 : data_t beta1_; 39 : 40 : /// exponential decay for 2nd order momenta 41 : data_t beta2_; 42 : 43 : /// epsilon-value for numeric stability 44 : data_t epsilon_; 45 : 46 : /// 1st momentum 47 : Eigen::ArrayX<data_t> firstMomentum_; 48 : 49 : /// 2nd momentum 50 : Eigen::ArrayX<data_t> secondMomentum_; 51 : }; 52 : 53 : template <typename data_t> 54 : struct OptimizerFactory<data_t, MlBackend::Dnnl> { 55 0 : static std::shared_ptr<OptimizerImplBase<data_t>> run(Optimizer<data_t>* opt, 56 : index_t size) 57 : { 58 0 : switch (opt->getOptimizerType()) { 59 0 : case OptimizerType::Adam: { 60 0 : auto downcastedOpt = downcast<Adam<data_t>>(opt); 61 0 : return std::make_shared<OptimizerAdamImpl<data_t, MlBackend::Dnnl>>( 62 : size, downcastedOpt->getLearningRate(), downcastedOpt->getBeta1(), 63 0 : downcastedOpt->getBeta2(), downcastedOpt->getEpsilon()); 64 : } 65 0 : default: 66 0 : assert(false && "This execution path should never be reached"); 67 : } 68 : return nullptr; 69 : } 70 : }; 71 : } // namespace detail 72 : } // namespace ml 73 : } // namespace elsa