Line data Source code
1 : #pragma once 2 : 3 : #include "Solver.h" 4 : #include "Problem.h" 5 : 6 : namespace elsa 7 : { 8 : /** 9 : * @brief Class representing a simple gradient descent solver with a fixed, given step size. 10 : * 11 : * This class implements a simple gradient descent iterative solver with a fixed, given step 12 : * size. No particular stopping rule is currently implemented (only a fixed number of 13 : * iterations, default to 100). 14 : * 15 : * @tparam data_t data type for the domain and range of the problem, defaulting to real_t 16 : * 17 : * @see \verbatim embed:rst 18 : For a basic introduction and problem statement of first-order methods, see 19 : :ref:`here <elsa-first-order-methods-doc>` \endverbatim 20 : * 21 : * @author 22 : * - Tobias Lasser - initial code 23 : */ 24 : template <typename data_t = real_t> 25 : class GradientDescent : public Solver<data_t> 26 : { 27 : public: 28 : /// Scalar alias 29 : using Scalar = typename Solver<data_t>::Scalar; 30 : 31 : /** 32 : * @brief Constructor for gradient descent, accepting a problem and a fixed step size 33 : * 34 : * @param[in] problem the problem that is supposed to be solved 35 : * @param[in] stepSize the fixed step size to be used while solving 36 : */ 37 : GradientDescent(const Problem<data_t>& problem, data_t stepSize); 38 : 39 : /** 40 : * @brief Constructor for gradient descent, accepting a problem. The step size will be 41 : * computed as \f$ 1 \over L \f$ with \f$ L \f$ being the Lipschitz constant of the 42 : * function. 43 : * 44 : * @param[in] problem the problem that is supposed to be solved 45 : */ 46 : GradientDescent(const Problem<data_t>& problem); 47 : 48 : /// make copy constructor deletion explicit 49 : GradientDescent(const GradientDescent<data_t>&) = delete; 50 : 51 : /// default destructor 52 14 : ~GradientDescent() override = default; 53 : 54 : private: 55 : /// the differentiable optimizaion problem 56 : std::unique_ptr<Problem<data_t>> _problem; 57 : 58 : /// the step size 59 : data_t _stepSize; 60 : 61 : /// the default number of iterations 62 : const index_t _defaultIterations{100}; 63 : 64 : /** 65 : * @brief Solve the optimization problem, i.e. apply iterations number of iterations of 66 : * gradient descent 67 : * 68 : * @param[in] iterations number of iterations to execute (the default 0 value executes 69 : * _defaultIterations of iterations) 70 : * 71 : * @returns a reference to the current solution 72 : */ 73 : DataContainer<data_t>& solveImpl(index_t iterations) override; 74 : 75 : /// implement the polymorphic clone operation 76 : GradientDescent<data_t>* cloneImpl() const override; 77 : 78 : /// implement the polymorphic comparison operation 79 : bool isEqual(const Solver<data_t>& other) const override; 80 : }; 81 : } // namespace elsa