LCOV - code coverage report
Current view: top level - elsa/solvers - CGNL.cpp (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 45 47 95.7 %
Date: 2025-01-22 07:37:33 Functions: 18 18 100.0 %

          Line data    Source code
       1             : #include "CGNL.h"
       2             : #include "Logger.h"
       3             : #include "TypeCasts.hpp"
       4             : #include "spdlog/stopwatch.h"
       5             : #include "LineSearchMethod.h"
       6             : 
       7             : namespace elsa
       8             : {
       9             :     template <typename data_t>
      10             :     CGNL<data_t>::CGNL(const Functional<data_t>& functional,
      11             :                        const LineSearchMethod<data_t>& line_search_function)
      12             :         : Solver<data_t>(),
      13             :           f_{functional.clone()},
      14             :           r_(empty<data_t>(functional.getDomainDescriptor())),
      15             :           d_(empty<data_t>(functional.getDomainDescriptor())),
      16             :           delta_(0),
      17             :           deltaZero_(0),
      18             :           beta_(0),
      19             :           alpha_(0),
      20             :           lineSearch_{line_search_function.clone()},
      21             :           beta_function_{betaPolakRibiere}
      22           4 :     {
      23           4 :     }
      24             : 
      25             :     template <typename data_t>
      26             :     CGNL<data_t>::CGNL(const Functional<data_t>& functional,
      27             :                        const LineSearchMethod<data_t>& line_search_function,
      28             :                        const BetaFunction& beta_function)
      29             :         : Solver<data_t>(),
      30             :           f_{functional.clone()},
      31             :           r_(empty<data_t>(functional.getDomainDescriptor())),
      32             :           d_(empty<data_t>(functional.getDomainDescriptor())),
      33             :           delta_(0),
      34             :           deltaZero_(0),
      35             :           beta_(0),
      36             :           alpha_(0),
      37             :           lineSearch_{line_search_function.clone()},
      38             :           beta_function_{beta_function}
      39           4 :     {
      40           4 :     }
      41             : 
      42             :     template <typename data_t>
      43             :     DataContainer<data_t> CGNL<data_t>::setup(std::optional<DataContainer<data_t>> x0)
      44           8 :     {
      45           8 :         auto x = extract_or(x0, f_->getDomainDescriptor());
      46             : 
      47             :         // r <= -f'(x)
      48           8 :         r_ = data_t{-1.0} * f_->getGradient(x);
      49           8 :         d_ = r_;
      50             : 
      51             :         // delta <= r^T * d
      52           8 :         delta_ = r_.dot(d_);
      53             : 
      54             :         // deltaZero <= delta
      55           8 :         deltaZero_ = delta_;
      56             : 
      57             :         // restart every n
      58           8 :         restart_ = f_->getDomainDescriptor().getNumberOfCoefficients();
      59             : 
      60           8 :         this->name_ = "CGNL";
      61             : 
      62             :         // setup done!
      63           8 :         this->configured_ = true;
      64             : 
      65           8 :         return x;
      66           8 :     }
      67             : 
      68             :     template <typename data_t>
      69             :     DataContainer<data_t> CGNL<data_t>::step(DataContainer<data_t> x)
      70          83 :     {
      71             :         // line search
      72          83 :         alpha_ = lineSearch_->solve(x, d_);
      73             : 
      74             :         // update x
      75          83 :         x += alpha_ * d_;
      76             : 
      77             :         // beta function
      78             :         // r <= -f'(x)
      79          83 :         r_ = data_t{-1.0} * f_->getGradient(x);
      80             : 
      81          83 :         std::tie(beta_, delta_) = beta_function_(d_, r_, delta_);
      82             : 
      83          83 :         if (this->curiter_ % restart_ == 0 || beta_ <= 0.0) {
      84          23 :             d_ = r_;
      85          60 :         } else {
      86          60 :             d_ = r_ + beta_ * d_;
      87          60 :         }
      88             : 
      89          83 :         return x;
      90          83 :     }
      91             : 
      92             :     template <typename data_t>
      93             :     bool CGNL<data_t>::shouldStop() const
      94          86 :     {
      95          86 :         return delta_ <= epsilon_ * epsilon_ * deltaZero_;
      96          86 :     }
      97             : 
      98             :     template <typename data_t>
      99             :     std::string CGNL<data_t>::formatHeader() const
     100           8 :     {
     101           8 :         return fmt::format("{:^15} | {:^15} | {:^15}", "delta", "beta", "alpha");
     102           8 :     }
     103             : 
     104             :     template <typename data_t>
     105             :     std::string CGNL<data_t>::formatStep(const DataContainer<data_t>&) const
     106          83 :     {
     107          83 :         return fmt::format("{:>15.10} | {:>15.10} | {:>15.10}", delta_, beta_, alpha_);
     108          83 :     }
     109             : 
     110             :     template <typename data_t>
     111             :     CGNL<data_t>* CGNL<data_t>::cloneImpl() const
     112           4 :     {
     113           4 :         return new CGNL(*f_, *lineSearch_, beta_function_);
     114           4 :     }
     115             : 
     116             :     template <typename data_t>
     117             :     bool CGNL<data_t>::isEqual(const Solver<data_t>& other) const
     118           4 :     {
     119           4 :         auto otherCG = downcast_safe<CGNL>(&other);
     120           4 :         if (!otherCG)
     121           0 :             return false;
     122             : 
     123           4 :         if (epsilon_ != otherCG->epsilon_)
     124           0 :             return false;
     125             : 
     126           4 :         return true;
     127           4 :     }
     128             : 
     129             :     // ------------------------------------------
     130             :     // explicit template instantiation
     131             :     template class CGNL<float>;
     132             :     template class CGNL<double>;
     133             : } // namespace elsa

Generated by: LCOV version 1.14