Line data Source code
1 : #pragma once 2 : 3 : #include <memory> 4 : #include <optional> 5 : 6 : #include "Solver.h" 7 : #include "Functional.h" 8 : #include "LineSearchMethod.h" 9 : 10 : namespace elsa 11 : { 12 : /** 13 : * @brief Class implementing Nonlinear Conjugate Gradients with customizable line search and 14 : * beta calculation 15 : * 16 : * @author Eddie Groh - initial code 17 : * 18 : * This Nonlinear CG can minimize any continuous function f for which the the first and second 19 : * derivative can be computed or approximated. By this usage of the Gradient and Hessian 20 : * respectively, it will converge to a local minimum near the starting point. 21 : * 22 : * Because CG can only generate n conjugate vectors, if the problem has dimension n, it improves 23 : * convergence to reset the search direction every n iterations, especially for small n. 24 : * Restarting means that the search direction is "forgotten" and CG is started again in the 25 : * direction of the steepest descent 26 : * 27 : * Convergence is considered reached when \f$ \| f'(x) \| \leq \epsilon \| f'(x_0)} \| \f$ 28 : * satisfied for some small \f$ \epsilon > 0\f$. Here \f$ x \f$ denotes the solution 29 : * obtained in the last step, and \f$ x_0 \f$ denotes the initial guess. 30 : * 31 : * References: 32 : * https://www.cs.cmu.edu/~quake-papers/painless-conjugate-gradient.pdf 33 : */ 34 : template <typename data_t = real_t> 35 : class CGNL : public Solver<data_t> 36 : { 37 : public: 38 : /// Scalar alias 39 : using Scalar = typename Solver<data_t>::Scalar; 40 : 41 : /** 42 : * @brief Function Object which calculates a beta value based on the direction 43 : * vector and residual vector 44 : * 45 : * @param[in] dVector the vector representing the direction of the current CGNL step 46 : * @param[in] rVector the residual vector representing the negative gradient 47 : * 48 : * @return[out] a pair consisting of the calculated beta and the deltaNew 49 : */ 50 : using BetaFunction = std::function<std::pair<data_t, data_t>( 51 : const DataContainer<data_t>& dVector, const DataContainer<data_t>& rVector, 52 : data_t deltaNew)>; 53 : 54 : /** 55 : * @brief Constructor for CGNL, accepting an optimization problem and, optionally, a 56 : * value for epsilon 57 : * 58 : * @param[in] problem the problem that is supposed to be solved 59 : * @param[in] lineSearch function which will be evaluated each 60 : */ 61 : CGNL(const Functional<data_t>& functional, const LineSearchMethod<data_t>& lineSearch); 62 : 63 : /** 64 : * @brief Constructor for CGNL, accepting an optimization problem and, optionally, a 65 : * value for epsilon 66 : * 67 : * @param[in] problem the problem that is supposed to be solved 68 : * @param[in] line_search function which will be evaluated each 69 : * @param[in] beta_function affects the stopping condition 70 : */ 71 : CGNL(const Functional<data_t>& functional, const LineSearchMethod<data_t>& line_search, 72 : const BetaFunction& beta_function); 73 : 74 : /// make copy constructor deletion explicit 75 : CGNL(const CGNL<data_t>&) = delete; 76 : 77 : /// default destructor 78 8 : ~CGNL() override = default; 79 : 80 : DataContainer<data_t> setup(std::optional<DataContainer<data_t>> x) override; 81 : 82 : DataContainer<data_t> step(DataContainer<data_t> x) override; 83 : 84 : bool shouldStop() const override; 85 : 86 : std::string formatHeader() const override; 87 : 88 : std::string formatStep(const DataContainer<data_t>& x) const override; 89 : 90 : /// beta calculation Polak-RibieĢre 91 : static const inline BetaFunction betaPolakRibiere = 92 : [](const DataContainer<data_t>& dVector, const DataContainer<data_t>& rVector, 93 83 : data_t deltaNew) -> std::pair<data_t, data_t> { 94 : // deltaOld <= deltaNew 95 83 : auto deltaOld = deltaNew; 96 : // deltaMid <= r^T * d 97 83 : auto deltaMid = rVector.dot(dVector); 98 : // deltaNew <= r^T * r 99 83 : deltaNew = rVector.dot(rVector); 100 : 101 : // beta <= (deltaNew - deltaMid) / deltaOld 102 83 : auto beta = (deltaNew - deltaMid) / deltaOld; 103 83 : return {beta, deltaNew}; 104 83 : }; 105 : 106 : private: 107 : /// implement the polymorphic clone operation 108 : CGNL<data_t>* cloneImpl() const override; 109 : 110 : /// implement the polymorphic comparison operation 111 : bool isEqual(const Solver<data_t>& other) const override; 112 : 113 : /// the differentiable optimization problem 114 : std::unique_ptr<Functional<data_t>> f_; 115 : 116 : DataContainer<data_t> r_; 117 : 118 : DataContainer<data_t> d_; 119 : 120 : data_t delta_; 121 : 122 : data_t deltaZero_; 123 : 124 : data_t beta_; 125 : 126 : data_t alpha_; 127 : 128 : index_t restart_ = 0; 129 : 130 : /// pointer to line search function (e.g. Armijo) 131 : std::unique_ptr<LineSearchMethod<data_t>> lineSearch_; 132 : 133 : /// Function to evaluate beta 134 : BetaFunction beta_function_; 135 : 136 : /// variable affecting the stopping condition 137 : data_t epsilon_ = data_t{1e-10}; 138 : }; 139 : } // namespace elsa