Line data Source code
1 : #pragma once 2 : 3 : #include "Solver.h" 4 : 5 : namespace elsa 6 : { 7 : /** 8 : * @brief Class representing a simple gradient descent solver with a fixed, given step size. 9 : * 10 : * This class implements a simple gradient descent iterative solver with a fixed, given step 11 : * size. No particular stopping rule is currently implemented (only a fixed number of 12 : * iterations, default to 100). 13 : * 14 : * @tparam data_t data type for the domain and range of the problem, defaulting to real_t 15 : * 16 : * @see \verbatim embed:rst 17 : For a basic introduction and problem statement of first-order methods, see 18 : :ref:`here <elsa-first-order-methods-doc>` \endverbatim 19 : * 20 : * @author 21 : * - Tobias Lasser - initial code 22 : */ 23 : template <typename data_t = real_t> 24 : class GradientDescent : public Solver<data_t> 25 : { 26 : public: 27 : /// Scalar alias 28 : using Scalar = typename Solver<data_t>::Scalar; 29 : 30 : /** 31 : * @brief Constructor for gradient descent, accepting a problem and a fixed step size 32 : * 33 : * @param[in] problem the problem that is supposed to be solved 34 : * @param[in] stepSize the fixed step size to be used while solving 35 : */ 36 : GradientDescent(const Problem<data_t>& problem, data_t stepSize); 37 : 38 : /** 39 : * @brief Constructor for gradient descent, accepting a problem. The step size will be 40 : * computed as \f$ 1 \over L \f$ with \f$ L \f$ being the Lipschitz constant of the 41 : * function. 42 : * 43 : * @param[in] problem the problem that is supposed to be solved 44 : */ 45 : GradientDescent(const Problem<data_t>& problem); 46 : 47 : /// make copy constructor deletion explicit 48 : GradientDescent(const GradientDescent<data_t>&) = delete; 49 : 50 : /// default destructor 51 0 : ~GradientDescent() override = default; 52 : 53 : /// lift the base class method getCurrentSolution 54 : using Solver<data_t>::getCurrentSolution; 55 : 56 : private: 57 : /// the step size 58 : data_t _stepSize; 59 : 60 : /// the default number of iterations 61 : const index_t _defaultIterations{100}; 62 : 63 : /// lift the base class variable _problem 64 : using Solver<data_t>::_problem; 65 : 66 : /** 67 : * @brief Solve the optimization problem, i.e. apply iterations number of iterations of 68 : * gradient descent 69 : * 70 : * @param[in] iterations number of iterations to execute (the default 0 value executes 71 : * _defaultIterations of iterations) 72 : * 73 : * @returns a reference to the current solution 74 : */ 75 : DataContainer<data_t>& solveImpl(index_t iterations) override; 76 : 77 : /// implement the polymorphic clone operation 78 : GradientDescent<data_t>* cloneImpl() const override; 79 : 80 : /// implement the polymorphic comparison operation 81 : bool isEqual(const Solver<data_t>& other) const override; 82 : }; 83 : } // namespace elsa