LCOV - code coverage report
Current view: top level - ml - Optimizer.h (source / functions) Hit Total Coverage
Test: test_coverage.info.cleaned Lines: 0 5 0.0 %
Date: 2022-07-06 02:47:47 Functions: 0 4 0.0 %

          Line data    Source code
       1             : #pragma once
       2             : 
       3             : #include <utility>
       4             : #include "elsaDefines.h"
       5             : #include "Cloneable.h"
       6             : #include "Common.h"
       7             : 
       8             : namespace elsa
       9             : {
      10             :     namespace ml
      11             :     {
      12             : 
      13             :         enum class OptimizerType { SGD, Adam };
      14             : 
      15             :         template <typename data_t>
      16             :         class Optimizer
      17             :         {
      18             :         public:
      19             :             virtual OptimizerType getOptimizerType() const;
      20             : 
      21           0 :             data_t getLearningRate() { return learningRate_; }
      22             : 
      23             :         protected:
      24             :             /// default constructor
      25             :             explicit Optimizer(OptimizerType optimizerType, data_t learningRate);
      26             : 
      27             :             /// The type of this optimizer
      28             :             OptimizerType optimizerType_;
      29             : 
      30             :             /// learning-rate
      31             :             data_t learningRate_;
      32             :         };
      33             : 
      34             :         /// Gradient descent (with momentum) optimizer.
      35             :         ///
      36             :         /// Update rule for parameter \f$ w \f$ with gradient \f$ g \f$ when momentum is \f$ 0 \f$:
      37             :         ///
      38             :         /// \f[
      39             :         /// w = w - \text{learning_rate} \cdot g
      40             :         /// \f]
      41             :         ///
      42             :         /// Update rule when momentum is larger than \f$ 0 \f$:
      43             :         ///
      44             :         /// \f[
      45             :         /// \begin{eqnarray*}
      46             :         /// \text{velocity} &=& \text{momentum} \cdot \text{velocity} - \text{learning_rate} \cdot g
      47             :         /// \\ w &=& w  \cdot \text{velocity} \end{eqnarray*} \f]
      48             :         ///
      49             :         /// When \p nesterov=True, this rule becomes:
      50             :         ///
      51             :         /// \f[
      52             :         /// \begin{eqnarray*}
      53             :         ///   \text{velocity} & = & \text{momentum} \cdot \text{velocity} - \text{learning_rate}
      54             :         ///   \cdot g \\ w & = & w + \text{momentum} \cdot \text{velocity} - \text{learning_rate}
      55             :         ///   \cdot g
      56             :         /// \end{eqnarray*}
      57             :         /// \f]
      58             :         template <typename data_t = real_t>
      59             :         class SGD : public Optimizer<data_t>
      60             :         {
      61             :         public:
      62             :             /// Construct an SGD optimizer
      63             :             ///
      64             :             /// @param learningRate The learning-rate. This parameter is
      65             :             /// optional and defaults to 0.01.
      66             :             /// @param momentum hyperparameter >= 0 that accelerates gradient
      67             :             /// descent in the relevant direction and dampens oscillations. This
      68             :             /// parameter is optional and defaults to 0, i.e., vanilla gradient
      69             :             /// descent.
      70             :             /// @param nesterov Whether to apply Nesterov momentum. This
      71             :             /// parameter is optional an defaults to false.
      72             :             SGD(data_t learningRate = data_t(0.01), data_t momentum = data_t(0.0),
      73             :                 bool nesterov = false);
      74             : 
      75             :             /// Get momentum.
      76             :             data_t getMomentum() const;
      77             : 
      78             :             /// True if this optimizer applies Nesterov momentum, false otherwise.
      79             :             bool useNesterov() const;
      80             : 
      81             :         private:
      82             :             /// \copydoc Optimizer::learningRate_
      83             :             using Optimizer<data_t>::learningRate_;
      84             : 
      85             :             /// momentum parameter
      86             :             data_t momentum_;
      87             : 
      88             :             /// True if the Nesterov momentum should be used, false otherwise
      89             :             bool nesterov_;
      90             :         };
      91             : 
      92             :         /// Optimizer that implements the Adam algorithm.
      93             :         ///
      94             :         /// Adam optimization is a stochastic gradient descent method that is
      95             :         /// based on adaptive estimation of first-order and second-order
      96             :         /// moments.
      97             :         ///
      98             :         /// According to Kingma et al., 2014, the method is "computationally
      99             :         /// efficient, has little memory requirement, invariant to diagonal
     100             :         /// rescaling of gradients, and is well suited for problems that are
     101             :         /// large in terms of data/parameters".
     102             :         template <typename data_t = real_t>
     103             :         class Adam : public Optimizer<data_t>
     104             :         {
     105             :         public:
     106             :             /// Construct an Adam optimizer.
     107             :             ///
     108             :             /// @param learningRate The learning-rate. This parameter is
     109             :             /// optional and defaults to 0.001.
     110             :             /// @param beta1 The exponential decay rate for the 1st moment
     111             :             /// estimates. This parameter is optional and defaults to 0.9.
     112             :             /// @param beta2 The exponential decay rate for the 2nd moment
     113             :             /// estimates. This parameter is optional and defaults to 0.999.
     114             :             /// @param epsilon A small constant for numerical stability. This
     115             :             /// epsilon is "epsilon hat" in the Kingma and Ba paper (in the
     116             :             /// formula just before Section 2.1), not the epsilon in Algorithm 1
     117             :             /// of the paper. This parameter is optional and defaults to 1e-7.
     118             :             Adam(data_t learningRate = data_t(0.001), data_t beta1 = data_t(0.9),
     119             :                  data_t beta2 = data_t(0.999), data_t epsilon = data_t(1e-7));
     120             : 
     121             :             /// Get beta1.
     122             :             data_t getBeta1() const;
     123             : 
     124             :             /// Get beta2.
     125             :             data_t getBeta2() const;
     126             : 
     127             :             /// Get epsilon.
     128             :             data_t getEpsilon() const;
     129             : 
     130             :         private:
     131             :             /// \copydoc Optimizer::learningRate_
     132             :             using Optimizer<data_t>::learningRate_;
     133             : 
     134             :             /// exponential decay for 1st order momenta
     135             :             data_t beta1_;
     136             : 
     137             :             /// exponential decay for 2nd order momenta
     138             :             data_t beta2_;
     139             : 
     140             :             /// epsilon-value for numeric stability
     141             :             data_t epsilon_;
     142             :         };
     143             : 
     144             :         namespace detail
     145             :         {
     146             :             template <typename data_t>
     147             :             class OptimizerImplBase
     148             :             {
     149             :             public:
     150           0 :                 OptimizerImplBase(index_t size, data_t learningRate)
     151           0 :                     : size_(size), learningRate_(learningRate)
     152             :                 {
     153           0 :                 }
     154             : 
     155           0 :                 virtual ~OptimizerImplBase() = default;
     156             : 
     157             :                 virtual void updateParameter(const data_t* gradient, index_t batchSize,
     158             :                                              data_t* param) = 0;
     159             : 
     160             :             protected:
     161             :                 /// size of weights and gradients
     162             :                 index_t size_;
     163             : 
     164             :                 /// learning-rate
     165             :                 data_t learningRate_;
     166             : 
     167             :                 /// current execution step
     168             :                 index_t step_ = 0;
     169             :             };
     170             : 
     171             :             template <typename data_t, MlBackend Backend>
     172             :             class OptimizerAdamImpl
     173             :             {
     174             :             };
     175             : 
     176             :             template <typename data_t, MlBackend Backend>
     177             :             class OptimizerSGDImpl
     178             :             {
     179             :             };
     180             : 
     181             :             template <typename data_t, MlBackend Backend>
     182             :             struct OptimizerFactory {
     183             :                 static std::shared_ptr<OptimizerImplBase<data_t>> run([
     184             :                     [maybe_unused]] Optimizer<data_t>* opt)
     185             :                 {
     186             :                     throw std::logic_error("No Ml backend available");
     187             :                 }
     188             :             };
     189             : 
     190             :         } // namespace detail
     191             :     }     // namespace ml
     192             : } // namespace elsa

Generated by: LCOV version 1.15