Line data Source code
1 : #pragma once 2 : 3 : #include "Solver.h" 4 : 5 : namespace elsa 6 : { 7 : /** 8 : * @brief Class representing the Optimized Gradient Method. 9 : * 10 : * This class implements the Optimized Gradient Method as introduced by Kim and Fessler in 2016. 11 : * OGM is a first order method to efficiently optimize convex functions with 12 : * Lipschitz-Continuous gradients. It can be seen as an improvement on Nesterov's Fast Gradient 13 : * Method. 14 : * 15 : * @details 16 : * # Optimized Gradient method algorithm overview # 17 : * The algorithm repeats the following update steps for \f$i = 0, \dots, N-1\f$ 18 : * \f{align*}{ 19 : * y_{i+1} &= x_i - \frac{1}{L} f'(x_i) \\ 20 : * \Theta_{i+1} &= 21 : * \begin{cases} 22 : * \frac{1 + \sqrt{1 + 4 \Theta_i^2}}{2} & i \leq N - 2 \\ 23 : * \frac{1 + \sqrt{1 + 8 \Theta_i^2}}{2} & i \leq N - 1 \\ 24 : * \end{cases} \\ 25 : * x_{i+1} &= y_{i} + \frac{\Theta_i - 1}{\Theta_{i+1}}(y_{i+1} - y_i) + 26 : * \frac{\Theta_i}{\Theta_{i+1}}(y_{i+1} - x_i) 27 : * \f} 28 : * The inputs are \f$f \in C_{L}^{1, 1}(\mathbb{R}^d)\f$, \f$x_0 \in \mathbb{R}^d\f$, 29 : * \f$y_0 = x_0\f$, \f$t_0 = 1\f$ 30 : * 31 : * ## Comparison to Nesterov's Fast Gradient ## 32 : * The presented algorithm accelerates FGM by introducing an additional momentum term, which 33 : * doesn't add a great computational amount. 34 : * 35 : * ## OGM References ## 36 : * - Kim, D., Fessler, J.A. _Optimized first-order methods for smooth convex minimization_ 37 : (2016) https://doi.org/10.1007/s10107-015-0949-3 38 : * 39 : * @tparam data_t data type for the domain and range of the problem, defaulting to real_t 40 : * 41 : * @see \verbatim embed:rst 42 : For a basic introduction and problem statement of first-order methods, see 43 : :ref:`here <elsa-first-order-methods-doc>` \endverbatim 44 : * 45 : * @author 46 : * - Michael Loipführer - initial code 47 : * - David Frank - Detailed Documentation 48 : */ 49 : template <typename data_t = real_t> 50 : class OGM : public Solver<data_t> 51 : { 52 : public: 53 : /// Scalar alias 54 : using Scalar = typename Solver<data_t>::Scalar; 55 : 56 : /** 57 : * @brief Constructor for OGM, accepting an optimization problem and, optionally, a value 58 : * for epsilon 59 : * 60 : * @param[in] problem the problem that is supposed to be solved 61 : * @param[in] epsilon affects the stopping condition 62 : */ 63 : OGM(const Problem<data_t>& problem, 64 : data_t epsilon = std::numeric_limits<data_t>::epsilon()); 65 : 66 : /** 67 : * @brief Constructor for OGM, accepting an optimization problem the inverse of a 68 : * preconditioner and, optionally, a value for epsilon 69 : * 70 : * @param[in] problem the problem that is supposed to be solved 71 : * @param[in] preconditionerInverse the inverse of the preconditioner 72 : * @param[in] epsilon affects the stopping condition 73 : */ 74 : OGM(const Problem<data_t>& problem, const LinearOperator<data_t>& preconditionerInverse, 75 : data_t epsilon = std::numeric_limits<data_t>::epsilon()); 76 : 77 : /// make copy constructor deletion explicit 78 : OGM(const OGM<data_t>&) = delete; 79 : 80 : /// default destructor 81 0 : ~OGM() override = default; 82 : 83 : /// lift the base class method getCurrentSolution 84 : using Solver<data_t>::getCurrentSolution; 85 : 86 : private: 87 : /// the default number of iterations 88 : const index_t _defaultIterations{100}; 89 : 90 : /// variable affecting the stopping condition 91 : data_t _epsilon; 92 : 93 : /// the inverse of the preconditioner (if supplied) 94 : std::unique_ptr<LinearOperator<data_t>> _preconditionerInverse{}; 95 : 96 : /// lift the base class variable _problem 97 : using Solver<data_t>::_problem; 98 : 99 : /** 100 : * @brief Solve the optimization problem, i.e. apply iterations number of iterations of 101 : * gradient descent 102 : * 103 : * @param[in] iterations number of iterations to execute (the default 0 value executes 104 : * _defaultIterations of iterations) 105 : * 106 : * @returns a reference to the current solution 107 : */ 108 : DataContainer<data_t>& solveImpl(index_t iterations) override; 109 : 110 : /// implement the polymorphic clone operation 111 : OGM<data_t>* cloneImpl() const override; 112 : 113 : /// implement the polymorphic comparison operation 114 : bool isEqual(const Solver<data_t>& other) const override; 115 : }; 116 : } // namespace elsa