LCOV - code coverage report
Current view: top level - elsa/solvers - CGNL.h (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 8 8 100.0 %
Date: 2024-05-16 04:22:26 Functions: 4 4 100.0 %

          Line data    Source code
       1             : #pragma once
       2             : 
       3             : #include <memory>
       4             : #include <optional>
       5             : 
       6             : #include "Solver.h"
       7             : #include "Functional.h"
       8             : #include "LineSearchMethod.h"
       9             : 
      10             : namespace elsa
      11             : {
      12             :     /**
      13             :      * @brief Class implementing Nonlinear Conjugate Gradients with customizable line search and
      14             :      * beta calculation
      15             :      *
      16             :      * @author Eddie Groh - initial code
      17             :      *
      18             :      * This Nonlinear CG can minimize any continuous function f for which the the first and second
      19             :      * derivative can be computed or approximated. By this usage of the Gradient and Hessian
      20             :      * respectively, it will converge to a local minimum near the starting point.
      21             :      *
      22             :      * Because CG can only generate n conjugate vectors, if the problem has dimension n, it improves
      23             :      * convergence to reset the search direction every n iterations, especially for small n.
      24             :      * Restarting means that the search direction is "forgotten" and CG is started again in the
      25             :      * direction of the steepest descent
      26             :      *
      27             :      * Convergence is considered reached when \f$ \| f'(x) \| \leq \epsilon \| f'(x_0)} \| \f$
      28             :      * satisfied for some small \f$ \epsilon > 0\f$. Here \f$ x \f$ denotes the solution
      29             :      * obtained in the last step, and \f$ x_0 \f$ denotes the initial guess.
      30             :      *
      31             :      * References:
      32             :      * https://www.cs.cmu.edu/~quake-papers/painless-conjugate-gradient.pdf
      33             :      */
      34             :     template <typename data_t = real_t>
      35             :     class CGNL : public Solver<data_t>
      36             :     {
      37             :     public:
      38             :         /// Scalar alias
      39             :         using Scalar = typename Solver<data_t>::Scalar;
      40             : 
      41             :         /**
      42             :          * @brief Function Object which calculates a beta value based on the direction
      43             :          * vector and residual vector
      44             :          *
      45             :          * @param[in] dVector the vector representing the direction of the current CGNL step
      46             :          * @param[in] rVector the residual vector representing the negative gradient
      47             :          *
      48             :          * @return[out] a pair consisting of the calculated beta and the deltaNew
      49             :          */
      50             :         using BetaFunction = std::function<std::pair<data_t, data_t>(
      51             :             const DataContainer<data_t>& dVector, const DataContainer<data_t>& rVector,
      52             :             data_t deltaNew)>;
      53             : 
      54             :         /**
      55             :          * @brief Constructor for CGNL, accepting an optimization problem and, optionally, a
      56             :          * value for epsilon
      57             :          *
      58             :          * @param[in] problem the problem that is supposed to be solved
      59             :          * @param[in] lineSearch function which will be evaluated each
      60             :          */
      61             :         CGNL(const Functional<data_t>& functional, const LineSearchMethod<data_t>& lineSearch);
      62             : 
      63             :         /**
      64             :          * @brief Constructor for CGNL, accepting an optimization problem and, optionally, a
      65             :          * value for epsilon
      66             :          *
      67             :          * @param[in] problem the problem that is supposed to be solved
      68             :          * @param[in] line_search function which will be evaluated each
      69             :          * @param[in] beta_function affects the stopping condition
      70             :          */
      71             :         CGNL(const Functional<data_t>& functional, const LineSearchMethod<data_t>& line_search,
      72             :              const BetaFunction& beta_function);
      73             : 
      74             :         /// make copy constructor deletion explicit
      75             :         CGNL(const CGNL<data_t>&) = delete;
      76             : 
      77             :         /// default destructor
      78           8 :         ~CGNL() override = default;
      79             : 
      80             :         DataContainer<data_t> setup(std::optional<DataContainer<data_t>> x) override;
      81             : 
      82             :         DataContainer<data_t> step(DataContainer<data_t> x) override;
      83             : 
      84             :         bool shouldStop() const override;
      85             : 
      86             :         std::string formatHeader() const override;
      87             : 
      88             :         std::string formatStep(const DataContainer<data_t>& x) const override;
      89             : 
      90             :         /// beta calculation Polak-RibieĢ€re
      91             :         static const inline BetaFunction betaPolakRibiere =
      92             :             [](const DataContainer<data_t>& dVector, const DataContainer<data_t>& rVector,
      93          83 :                data_t deltaNew) -> std::pair<data_t, data_t> {
      94             :             // deltaOld <= deltaNew
      95          83 :             auto deltaOld = deltaNew;
      96             :             // deltaMid <= r^T * d
      97          83 :             auto deltaMid = rVector.dot(dVector);
      98             :             // deltaNew <= r^T * r
      99          83 :             deltaNew = rVector.dot(rVector);
     100             : 
     101             :             // beta <= (deltaNew - deltaMid) / deltaOld
     102          83 :             auto beta = (deltaNew - deltaMid) / deltaOld;
     103          83 :             return {beta, deltaNew};
     104          83 :         };
     105             : 
     106             :     private:
     107             :         /// implement the polymorphic clone operation
     108             :         CGNL<data_t>* cloneImpl() const override;
     109             : 
     110             :         /// implement the polymorphic comparison operation
     111             :         bool isEqual(const Solver<data_t>& other) const override;
     112             : 
     113             :         /// the differentiable optimization problem
     114             :         std::unique_ptr<Functional<data_t>> f_;
     115             : 
     116             :         DataContainer<data_t> r_;
     117             : 
     118             :         DataContainer<data_t> d_;
     119             : 
     120             :         data_t delta_;
     121             : 
     122             :         data_t deltaZero_;
     123             : 
     124             :         data_t beta_;
     125             : 
     126             :         data_t alpha_;
     127             : 
     128             :         index_t restart_ = 0;
     129             : 
     130             :         /// pointer to line search function (e.g. Armijo)
     131             :         std::unique_ptr<LineSearchMethod<data_t>> lineSearch_;
     132             : 
     133             :         /// Function to evaluate beta
     134             :         BetaFunction beta_function_;
     135             : 
     136             :         /// variable affecting the stopping condition
     137             :         data_t epsilon_ = data_t{1e-10};
     138             :     };
     139             : } // namespace elsa

Generated by: LCOV version 1.14