Line data Source code
1 : #pragma once 2 : 3 : #include <optional> 4 : 5 : #include "Solver.h" 6 : #include "Functional.h" 7 : #include "MaybeUninitialized.hpp" 8 : #include "LineSearchMethod.h" 9 : #include "FixedStepSize.h" 10 : 11 : namespace elsa 12 : { 13 : 14 : /** 15 : * @brief Class representing a simple gradient descent solver with a fixed, given step size. 16 : * 17 : * This class implements a simple gradient descent iterative solver with a fixed, given step 18 : * size. No particular stopping rule is currently implemented (only a fixed number of 19 : * iterations, default to 100). 20 : * 21 : * @tparam data_t data type for the domain and range of the problem, defaulting to real_t 22 : * 23 : * @author 24 : * - Tobias Lasser - initial code 25 : */ 26 : template <typename data_t = real_t> 27 : class GradientDescent : public Solver<data_t> 28 : { 29 : public: 30 : /// Scalar alias 31 : using Scalar = typename Solver<data_t>::Scalar; 32 : 33 : /** 34 : * @brief Constructor for gradient descent, accepting a problem and a fixed step size 35 : * 36 : * @param[in] problem the problem that is supposed to be solved 37 : * @param[in] stepSize the fixed step size to be used while solving 38 : */ 39 : GradientDescent(const Functional<data_t>& problem, data_t stepSize); 40 : 41 : /** 42 : * @brief Constructor for gradient descent, accepting a problem. The step size will be 43 : * computed as \f$ 1 \over L \f$ with \f$ L \f$ being the Lipschitz constant of the 44 : * function. 45 : * 46 : * @param[in] problem the problem that is supposed to be solved 47 : */ 48 : explicit GradientDescent(const Functional<data_t>& problem); 49 : 50 : /** 51 : * @brief Constructor for gradient descent, accepting a problem and a line 52 : * search method 53 : * 54 : * @param[in] problem the problem that is supposed to be solved 55 : * @param[in] lineSearchMethod the line search method to find the step size at 56 : * each iteration 57 : */ 58 : GradientDescent(const Functional<data_t>& problem, 59 : const LineSearchMethod<data_t>& lineSearchMethod); 60 : 61 : /// make copy constructor deletion explicit 62 : GradientDescent(const GradientDescent<data_t>&) = delete; 63 : 64 : /// default destructor 65 14 : ~GradientDescent() override = default; 66 : 67 : /** 68 : * @brief Solve the optimization problem, i.e. apply iterations number of iterations of 69 : * gradient descent 70 : * 71 : * @param[in] iterations number of iterations to execute 72 : * 73 : * @returns the approximated solution 74 : */ 75 : DataContainer<data_t> 76 : solve(index_t iterations, 77 : std::optional<DataContainer<data_t>> x0 = std::nullopt) override; 78 : 79 : private: 80 : /// the differentiable optimizaion problem 81 : std::unique_ptr<Functional<data_t>> _problem; 82 : 83 : /// the line search method 84 : std::unique_ptr<LineSearchMethod<data_t>> _lineSearchMethod; 85 : 86 : /// implement the polymorphic clone operation 87 : GradientDescent<data_t>* cloneImpl() const override; 88 : 89 : /// implement the polymorphic comparison operation 90 : bool isEqual(const Solver<data_t>& other) const override; 91 : }; 92 : } // namespace elsa