Line data Source code
1 : #include "DnnlOptimizer.h" 2 : 3 : namespace elsa::ml 4 : { 5 : namespace detail 6 : { 7 : template <typename data_t> 8 0 : OptimizerAdamImpl<data_t, MlBackend::Dnnl>::OptimizerAdamImpl(index_t size, 9 : data_t learningRate, 10 : data_t beta1, data_t beta2, 11 : data_t epsilon) 12 : : OptimizerImplBase<data_t>(size, learningRate), 13 : beta1_(beta1), 14 : beta2_(beta2), 15 0 : epsilon_(epsilon) 16 : { 17 0 : firstMomentum_.setZero(size_); 18 0 : secondMomentum_.setZero(size_); 19 0 : } 20 : 21 : template <typename data_t> 22 0 : void OptimizerAdamImpl<data_t, MlBackend::Dnnl>::updateParameter( 23 : const data_t* gradient, [[maybe_unused]] index_t batchSize, data_t* param) 24 : { 25 0 : ++step_; 26 : 27 0 : Eigen::Map<const Eigen::ArrayX<data_t>> gradientMem(gradient, size_); 28 0 : Eigen::Map<Eigen::ArrayX<data_t>> paramMem(param, size_); 29 : 30 : // first momentum 31 0 : firstMomentum_ = beta1_ * firstMomentum_ + (1 - beta1_) * gradientMem; 32 0 : auto correctedFirstMomentum = firstMomentum_ / (1 - std::pow(beta1_, step_)); 33 : 34 : // second momentum 35 0 : secondMomentum_ = beta2_ * secondMomentum_ + (1 - beta2_) * gradientMem * gradientMem; 36 0 : auto correctedSecondMomentum = secondMomentum_ / (1 - std::pow(beta2_, step_)); 37 : 38 0 : paramMem = paramMem 39 0 : - learningRate_ * correctedFirstMomentum 40 0 : / (correctedSecondMomentum.sqrt() + epsilon_); 41 0 : } 42 : 43 : template class OptimizerAdamImpl<float, MlBackend::Dnnl>; 44 : } // namespace detail 45 : } // namespace elsa::ml