Line data Source code
1 : #include "Optimizer.h" 2 : 3 : namespace elsa 4 : { 5 : namespace ml 6 : { 7 : template <typename data_t> 8 0 : Optimizer<data_t>::Optimizer(OptimizerType optimizerType, data_t learningRate) 9 0 : : optimizerType_(optimizerType), learningRate_(learningRate) 10 : { 11 0 : } 12 : 13 : template <typename data_t> 14 0 : OptimizerType Optimizer<data_t>::getOptimizerType() const 15 : { 16 0 : return optimizerType_; 17 : } 18 : 19 : template <typename data_t> 20 0 : SGD<data_t>::SGD(data_t learningRate, data_t momentum, bool nesterov) 21 : : Optimizer<data_t>(OptimizerType::SGD, learningRate), 22 : momentum_(momentum), 23 0 : nesterov_(nesterov) 24 : { 25 0 : } 26 : 27 : // template <typename data_t> 28 : // std::pair<Eigen::ArrayX<data_t>, Eigen::ArrayX<data_t>> 29 : // SGD<data_t>::getParameterUpdates(const Eigen::ArrayX<data_t>& weights, 30 : // const Eigen::ArrayX<data_t>& bias) 31 : // { 32 : // if (!isInitialized_) { 33 : // weightsVelocity_.setZero(weights.size()); 34 : // biasVelocity_.setZero(bias.size()); 35 : // isInitialized_ = true; 36 : // } 37 : 38 : // // Do we have momentum 39 : // if (std::abs(momentum_) <= 0) { 40 : // // Do we use Nesterov-momentum? 41 : // if (nesterov_) { 42 : // weightsVelocity_ = momentum_ * weightsVelocity_ - learningRate_ * weights; 43 : // biasVelocity_ = momentum_ * biasVelocity_ - learningRate_ * bias; 44 : // return std::make_pair<Eigen::ArrayX<data_t>, Eigen::ArrayX<data_t>>( 45 : // data_t(-1) * momentum_ * weightsVelocity_ - learningRate_ * weights, 46 : // data_t(-1) * momentum_ * weightsVelocity_ - learningRate_ * bias); 47 : // } else { 48 : // return std::make_pair<Eigen::ArrayX<data_t>, Eigen::ArrayX<data_t>>( 49 : // data_t(-1) * momentum_ * weightsVelocity_ + learningRate_ * weights, 50 : // data_t(-1) * momentum_ * weightsVelocity_ + learningRate_ * bias); 51 : // } 52 : // } else { 53 : // // vanilla stochastic gradient descent 54 : // return std::make_pair<Eigen::ArrayX<data_t>, Eigen::ArrayX<data_t>>( 55 : // learningRate_ * weights, learningRate_ * bias); 56 : // } 57 : // } 58 : 59 : template <typename data_t> 60 0 : data_t SGD<data_t>::getMomentum() const 61 : { 62 0 : return momentum_; 63 : } 64 : 65 : template <typename data_t> 66 0 : bool SGD<data_t>::useNesterov() const 67 : { 68 0 : return nesterov_; 69 : } 70 : 71 : template <typename data_t> 72 0 : Adam<data_t>::Adam(data_t learningRate, data_t beta1, data_t beta2, data_t epsilon) 73 : : Optimizer<data_t>(OptimizerType::Adam, learningRate), 74 : beta1_(beta1), 75 : beta2_(beta2), 76 0 : epsilon_(epsilon) 77 : { 78 0 : } 79 : 80 : template <typename data_t> 81 0 : data_t Adam<data_t>::getBeta1() const 82 : { 83 0 : return beta1_; 84 : } 85 : 86 : template <typename data_t> 87 0 : data_t Adam<data_t>::getBeta2() const 88 : { 89 0 : return beta2_; 90 : } 91 : 92 : template <typename data_t> 93 0 : data_t Adam<data_t>::getEpsilon() const 94 : { 95 0 : return epsilon_; 96 : } 97 : 98 : template class Adam<float>; 99 : template class SGD<float>; 100 : } // namespace ml 101 : } // namespace elsa