LCOV - code coverage report
Current view: top level - elsa/solvers - CGNE.cpp (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 43 43 100.0 %
Date: 2024-05-16 04:22:26 Functions: 16 16 100.0 %

          Line data    Source code
       1             : #include "CGNE.h"
       2             : #include "DataContainer.h"
       3             : #include "Error.h"
       4             : #include "LinearOperator.h"
       5             : #include "LinearResidual.h"
       6             : #include "Solver.h"
       7             : #include "TypeCasts.hpp"
       8             : #include "spdlog/stopwatch.h"
       9             : #include "Logger.h"
      10             : 
      11             : namespace elsa
      12             : {
      13             :     template <class data_t>
      14             :     CGNE<data_t>::CGNE(const LinearOperator<data_t>& A, const DataContainer<data_t>& b,
      15             :                        SelfType_t<data_t> tol)
      16             :         : A_(A.clone()),
      17             :           AtA_((adjoint(*A_) * (*A_)).clone()),
      18             :           b_(b),
      19             :           Atb_(A_->applyAdjoint(b_)),
      20             :           r_(empty<data_t>(AtA_->getRangeDescriptor())),
      21             :           c_(empty<data_t>(AtA_->getRangeDescriptor())),
      22             :           Ac_(empty<data_t>(AtA_->getRangeDescriptor())),
      23             :           kold_(100000),
      24             :           tol_(tol)
      25          40 :     {
      26          40 :         this->name_ = "CGNE";
      27          40 :     }
      28             : 
      29             :     template <typename data_t>
      30             :     DataContainer<data_t> CGNE<data_t>::setup(std::optional<DataContainer<data_t>> x0)
      31          40 :     {
      32          40 :         auto x = extract_or(x0, A_->getDomainDescriptor());
      33             : 
      34             :         // Residual b - A(x)
      35          40 :         r_ = AtA_->apply(x);
      36          40 :         r_ *= data_t{-1};
      37          40 :         r_ += Atb_;
      38             : 
      39          40 :         c_.assign(r_);
      40             : 
      41          40 :         Ac_ = empty<data_t>(AtA_->getRangeDescriptor());
      42             : 
      43             :         // Squared Norm of residual
      44          40 :         kold_ = r_.squaredL2Norm();
      45             : 
      46             :         // setup done!
      47          40 :         this->configured_ = true;
      48             : 
      49          40 :         return x;
      50          40 :     }
      51             : 
      52             :     template <typename data_t>
      53             :     DataContainer<data_t> CGNE<data_t>::step(DataContainer<data_t> x)
      54         116 :     {
      55         116 :         AtA_->apply(c_, Ac_);
      56         116 :         auto cAc = c_.dot(Ac_);
      57             : 
      58         116 :         alpha_ = kold_ / cAc;
      59             : 
      60             :         // Update x and residual
      61         116 :         x += alpha_ * c_;
      62         116 :         r_ -= alpha_ * Ac_;
      63             : 
      64         116 :         auto k = r_.squaredL2Norm();
      65         116 :         beta_ = k / kold_;
      66             : 
      67             :         // c = r + beta * c
      68         116 :         lincomb(1, r_, beta_, c_, c_);
      69             : 
      70             :         // store k for next iteration
      71         116 :         kold_ = k;
      72             : 
      73         116 :         return x;
      74         116 :     }
      75             : 
      76             :     template <typename data_t>
      77             :     bool CGNE<data_t>::shouldStop() const
      78         126 :     {
      79         126 :         return r_.l2Norm() < tol_;
      80         126 :     }
      81             : 
      82             :     template <typename data_t>
      83             :     std::string CGNE<data_t>::formatHeader() const
      84          40 :     {
      85          40 :         return fmt::format("{:^15} | {:^15} | {:^15} | {:^15}", "r", "c", "alpha", "beta");
      86          40 :     }
      87             : 
      88             :     template <typename data_t>
      89             :     std::string CGNE<data_t>::formatStep(const DataContainer<data_t>&) const
      90         116 :     {
      91         116 :         return fmt::format("{:>15.10} | {:>15.10} | {:>15.10} | {:>15.10}", r_.l2Norm(),
      92         116 :                            c_.l2Norm(), alpha_, beta_);
      93         116 :     }
      94             : 
      95             :     template <class data_t>
      96             :     bool CGNE<data_t>::isEqual(const Solver<data_t>& other) const
      97           8 :     {
      98           8 :         auto cgne = downcast_safe<CGNE>(&other);
      99           8 :         return cgne && *cgne->A_ == *A_ && cgne->b_ == b_;
     100           8 :     }
     101             : 
     102             :     template <class data_t>
     103             :     CGNE<data_t>* CGNE<data_t>::cloneImpl() const
     104           8 :     {
     105           8 :         return new CGNE(*A_, b_, tol_);
     106           8 :     }
     107             : 
     108             :     // ------------------------------------------
     109             :     // explicit template instantiation
     110             :     template class CGNE<float>;
     111             :     template class CGNE<double>;
     112             : 
     113             : } // namespace elsa

Generated by: LCOV version 1.14