Line data Source code
1 : #pragma once 2 : 3 : #include <memory> 4 : 5 : #include "Solver.h" 6 : #include "Problem.h" 7 : #include "DataContainer.h" 8 : 9 : namespace elsa 10 : { 11 : /** 12 : * @brief Class representing the Optimized Gradient Method. 13 : * 14 : * This class implements the Optimized Gradient Method as introduced by Kim and Fessler in 2016. 15 : * OGM is a first order method to efficiently optimize convex functions with 16 : * Lipschitz-Continuous gradients. It can be seen as an improvement on Nesterov's Fast Gradient 17 : * Method. 18 : * 19 : * @details 20 : * # Optimized Gradient method algorithm overview # 21 : * The algorithm repeats the following update steps for \f$i = 0, \dots, N-1\f$ 22 : * \f{align*}{ 23 : * y_{i+1} &= x_i - \frac{1}{L} f'(x_i) \\ 24 : * \Theta_{i+1} &= 25 : * \begin{cases} 26 : * \frac{1 + \sqrt{1 + 4 \Theta_i^2}}{2} & i \leq N - 2 \\ 27 : * \frac{1 + \sqrt{1 + 8 \Theta_i^2}}{2} & i \leq N - 1 \\ 28 : * \end{cases} \\ 29 : * x_{i+1} &= y_{i} + \frac{\Theta_i - 1}{\Theta_{i+1}}(y_{i+1} - y_i) + 30 : * \frac{\Theta_i}{\Theta_{i+1}}(y_{i+1} - x_i) 31 : * \f} 32 : * The inputs are \f$f \in C_{L}^{1, 1}(\mathbb{R}^d)\f$, \f$x_0 \in \mathbb{R}^d\f$, 33 : * \f$y_0 = x_0\f$, \f$t_0 = 1\f$ 34 : * 35 : * ## Comparison to Nesterov's Fast Gradient ## 36 : * The presented algorithm accelerates FGM by introducing an additional momentum term, which 37 : * doesn't add a great computational amount. 38 : * 39 : * ## OGM References ## 40 : * - Kim, D., Fessler, J.A. _Optimized first-order methods for smooth convex minimization_ 41 : (2016) https://doi.org/10.1007/s10107-015-0949-3 42 : * 43 : * @tparam data_t data type for the domain and range of the problem, defaulting to real_t 44 : * 45 : * @see \verbatim embed:rst 46 : For a basic introduction and problem statement of first-order methods, see 47 : :ref:`here <elsa-first-order-methods-doc>` \endverbatim 48 : * 49 : * @author 50 : * - Michael Loipführer - initial code 51 : * - David Frank - Detailed Documentation 52 : */ 53 : template <typename data_t = real_t> 54 : class OGM : public Solver<data_t> 55 : { 56 : public: 57 : /// Scalar alias 58 : using Scalar = typename Solver<data_t>::Scalar; 59 : 60 : /** 61 : * @brief Constructor for OGM, accepting an optimization problem and, optionally, a value 62 : * for epsilon 63 : * 64 : * @param[in] problem the problem that is supposed to be solved 65 : * @param[in] epsilon affects the stopping condition 66 : */ 67 : OGM(const Problem<data_t>& problem, 68 : data_t epsilon = std::numeric_limits<data_t>::epsilon()); 69 : 70 : /** 71 : * @brief Constructor for OGM, accepting an optimization problem the inverse of a 72 : * preconditioner and, optionally, a value for epsilon 73 : * 74 : * @param[in] problem the problem that is supposed to be solved 75 : * @param[in] preconditionerInverse the inverse of the preconditioner 76 : * @param[in] epsilon affects the stopping condition 77 : */ 78 : OGM(const Problem<data_t>& problem, const LinearOperator<data_t>& preconditionerInverse, 79 : data_t epsilon = std::numeric_limits<data_t>::epsilon()); 80 : 81 : /// make copy constructor deletion explicit 82 : OGM(const OGM<data_t>&) = delete; 83 : 84 : /// default destructor 85 18 : ~OGM() override = default; 86 : 87 : private: 88 : /// the differentiable optimizaion problem 89 : std::unique_ptr<Problem<data_t>> _problem; 90 : 91 : /// the default number of iterations 92 : const index_t _defaultIterations{100}; 93 : 94 : /// variable affecting the stopping condition 95 : data_t _epsilon; 96 : 97 : /// the inverse of the preconditioner (if supplied) 98 : std::unique_ptr<LinearOperator<data_t>> _preconditionerInverse{}; 99 : 100 : /** 101 : * @brief Solve the optimization problem, i.e. apply iterations number of iterations of 102 : * gradient descent 103 : * 104 : * @param[in] iterations number of iterations to execute (the default 0 value executes 105 : * _defaultIterations of iterations) 106 : * 107 : * @returns a reference to the current solution 108 : */ 109 : DataContainer<data_t>& solveImpl(index_t iterations) override; 110 : 111 : /// implement the polymorphic clone operation 112 : OGM<data_t>* cloneImpl() const override; 113 : 114 : /// implement the polymorphic comparison operation 115 : bool isEqual(const Solver<data_t>& other) const override; 116 : }; 117 : } // namespace elsa