Line data Source code
1 : #pragma once 2 : 3 : #include <utility> 4 : #include "elsaDefines.h" 5 : #include "Cloneable.h" 6 : #include "Common.h" 7 : 8 : namespace elsa 9 : { 10 : namespace ml 11 : { 12 : 13 : enum class OptimizerType { SGD, Adam }; 14 : 15 : template <typename data_t> 16 : class Optimizer 17 : { 18 : public: 19 : virtual OptimizerType getOptimizerType() const; 20 : 21 0 : data_t getLearningRate() { return learningRate_; } 22 : 23 : protected: 24 : /// default constructor 25 : explicit Optimizer(OptimizerType optimizerType, data_t learningRate); 26 : 27 : /// The type of this optimizer 28 : OptimizerType optimizerType_; 29 : 30 : /// learning-rate 31 : data_t learningRate_; 32 : }; 33 : 34 : /// Gradient descent (with momentum) optimizer. 35 : /// 36 : /// Update rule for parameter \f$ w \f$ with gradient \f$ g \f$ when momentum is \f$ 0 \f$: 37 : /// 38 : /// \f[ 39 : /// w = w - \text{learning_rate} \cdot g 40 : /// \f] 41 : /// 42 : /// Update rule when momentum is larger than \f$ 0 \f$: 43 : /// 44 : /// \f[ 45 : /// \begin{eqnarray*} 46 : /// \text{velocity} &=& \text{momentum} \cdot \text{velocity} - \text{learning_rate} \cdot g 47 : /// \\ w &=& w \cdot \text{velocity} \end{eqnarray*} \f] 48 : /// 49 : /// When \p nesterov=True, this rule becomes: 50 : /// 51 : /// \f[ 52 : /// \begin{eqnarray*} 53 : /// \text{velocity} & = & \text{momentum} \cdot \text{velocity} - \text{learning_rate} 54 : /// \cdot g \\ w & = & w + \text{momentum} \cdot \text{velocity} - \text{learning_rate} 55 : /// \cdot g 56 : /// \end{eqnarray*} 57 : /// \f] 58 : template <typename data_t = real_t> 59 : class SGD : public Optimizer<data_t> 60 : { 61 : public: 62 : /// Construct an SGD optimizer 63 : /// 64 : /// @param learningRate The learning-rate. This parameter is 65 : /// optional and defaults to 0.01. 66 : /// @param momentum hyperparameter >= 0 that accelerates gradient 67 : /// descent in the relevant direction and dampens oscillations. This 68 : /// parameter is optional and defaults to 0, i.e., vanilla gradient 69 : /// descent. 70 : /// @param nesterov Whether to apply Nesterov momentum. This 71 : /// parameter is optional an defaults to false. 72 : SGD(data_t learningRate = data_t(0.01), data_t momentum = data_t(0.0), 73 : bool nesterov = false); 74 : 75 : /// Get momentum. 76 : data_t getMomentum() const; 77 : 78 : /// True if this optimizer applies Nesterov momentum, false otherwise. 79 : bool useNesterov() const; 80 : 81 : private: 82 : /// \copydoc Optimizer::learningRate_ 83 : using Optimizer<data_t>::learningRate_; 84 : 85 : /// momentum parameter 86 : data_t momentum_; 87 : 88 : /// True if the Nesterov momentum should be used, false otherwise 89 : bool nesterov_; 90 : }; 91 : 92 : /// Optimizer that implements the Adam algorithm. 93 : /// 94 : /// Adam optimization is a stochastic gradient descent method that is 95 : /// based on adaptive estimation of first-order and second-order 96 : /// moments. 97 : /// 98 : /// According to Kingma et al., 2014, the method is "computationally 99 : /// efficient, has little memory requirement, invariant to diagonal 100 : /// rescaling of gradients, and is well suited for problems that are 101 : /// large in terms of data/parameters". 102 : template <typename data_t = real_t> 103 : class Adam : public Optimizer<data_t> 104 : { 105 : public: 106 : /// Construct an Adam optimizer. 107 : /// 108 : /// @param learningRate The learning-rate. This parameter is 109 : /// optional and defaults to 0.001. 110 : /// @param beta1 The exponential decay rate for the 1st moment 111 : /// estimates. This parameter is optional and defaults to 0.9. 112 : /// @param beta2 The exponential decay rate for the 2nd moment 113 : /// estimates. This parameter is optional and defaults to 0.999. 114 : /// @param epsilon A small constant for numerical stability. This 115 : /// epsilon is "epsilon hat" in the Kingma and Ba paper (in the 116 : /// formula just before Section 2.1), not the epsilon in Algorithm 1 117 : /// of the paper. This parameter is optional and defaults to 1e-7. 118 : Adam(data_t learningRate = data_t(0.001), data_t beta1 = data_t(0.9), 119 : data_t beta2 = data_t(0.999), data_t epsilon = data_t(1e-7)); 120 : 121 : /// Get beta1. 122 : data_t getBeta1() const; 123 : 124 : /// Get beta2. 125 : data_t getBeta2() const; 126 : 127 : /// Get epsilon. 128 : data_t getEpsilon() const; 129 : 130 : private: 131 : /// \copydoc Optimizer::learningRate_ 132 : using Optimizer<data_t>::learningRate_; 133 : 134 : /// exponential decay for 1st order momenta 135 : data_t beta1_; 136 : 137 : /// exponential decay for 2nd order momenta 138 : data_t beta2_; 139 : 140 : /// epsilon-value for numeric stability 141 : data_t epsilon_; 142 : }; 143 : 144 : namespace detail 145 : { 146 : template <typename data_t> 147 : class OptimizerImplBase 148 : { 149 : public: 150 0 : OptimizerImplBase(index_t size, data_t learningRate) 151 0 : : size_(size), learningRate_(learningRate) 152 : { 153 0 : } 154 : 155 0 : virtual ~OptimizerImplBase() = default; 156 : 157 : virtual void updateParameter(const data_t* gradient, index_t batchSize, 158 : data_t* param) = 0; 159 : 160 : protected: 161 : /// size of weights and gradients 162 : index_t size_; 163 : 164 : /// learning-rate 165 : data_t learningRate_; 166 : 167 : /// current execution step 168 : index_t step_ = 0; 169 : }; 170 : 171 : template <typename data_t, MlBackend Backend> 172 : class OptimizerAdamImpl 173 : { 174 : }; 175 : 176 : template <typename data_t, MlBackend Backend> 177 : class OptimizerSGDImpl 178 : { 179 : }; 180 : 181 : template <typename data_t, MlBackend Backend> 182 : struct OptimizerFactory { 183 : static std::shared_ptr<OptimizerImplBase<data_t>> run([ 184 : [maybe_unused]] Optimizer<data_t>* opt) 185 : { 186 : throw std::logic_error("No Ml backend available"); 187 : } 188 : }; 189 : 190 : } // namespace detail 191 : } // namespace ml 192 : } // namespace elsa