LCOV - code coverage report
Current view: top level - elsa/solvers - CGLS.cpp (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 56 61 91.8 %
Date: 2024-05-16 04:22:26 Functions: 18 18 100.0 %

          Line data    Source code
       1             : #include "CGLS.h"
       2             : #include "DataContainer.h"
       3             : #include "Error.h"
       4             : #include "LinearOperator.h"
       5             : #include "LinearResidual.h"
       6             : #include "TypeCasts.hpp"
       7             : #include "spdlog/fmt/bundled/core.h"
       8             : #include "spdlog/stopwatch.h"
       9             : #include "Logger.h"
      10             : 
      11             : namespace elsa
      12             : {
      13             :     template <class data_t>
      14             :     CGLS<data_t>::CGLS(const LinearOperator<data_t>& A, const DataContainer<data_t>& b,
      15             :                        SelfType_t<data_t> eps, SelfType_t<data_t> tol)
      16             :         : A_(A.clone()),
      17             :           b_(b),
      18             :           r_(empty<data_t>(A.getDomainDescriptor())),
      19             :           s_(empty<data_t>(A.getRangeDescriptor())),
      20             :           c_(empty<data_t>(A.getDomainDescriptor())),
      21             :           q_(empty<data_t>(A.getRangeDescriptor())),
      22             :           k_(100000000),
      23             :           kold_(100000000),
      24             :           damp_(eps * eps),
      25             :           tol_(tol)
      26          18 :     {
      27          18 :         this->name_ = "CGLS";
      28          18 :     }
      29             : 
      30             :     template <typename data_t>
      31             :     DataContainer<data_t> CGLS<data_t>::setup(std::optional<DataContainer<data_t>> x0)
      32          18 :     {
      33          18 :         auto x = extract_or(x0, A_->getDomainDescriptor());
      34             : 
      35          18 :         if (x0.has_value()) {
      36           0 :             x = *x0;
      37             : 
      38             :             // s = b_ - A_->applx(x)
      39           0 :             A_->apply(x, s_);
      40           0 :             lincomb(1, b_, -1, s_, s_);
      41             : 
      42           0 :             A_->applyAdjoint(s_, r_);
      43           0 :             lincomb(1, r_, -damp_, x, r_);
      44          18 :         } else {
      45          18 :             x = 0;
      46          18 :             s_ = b_;
      47          18 :             A_->applyAdjoint(s_, r_);
      48          18 :         }
      49             : 
      50          18 :         c_ = r_;
      51             : 
      52          18 :         k_ = r_.squaredL2Norm();
      53          18 :         kold_ = k_;
      54             : 
      55             :         // setup done!
      56          18 :         this->configured_ = true;
      57             : 
      58          18 :         return x;
      59          18 :     }
      60             : 
      61             :     template <typename data_t>
      62             :     DataContainer<data_t> CGLS<data_t>::step(DataContainer<data_t> x)
      63          78 :     {
      64          78 :         A_->apply(c_, q_);
      65             : 
      66          78 :         auto delta = q_.squaredL2Norm();
      67          78 :         auto alpha = [&]() {
      68          78 :             if (damp_ == 0) {
      69          44 :                 return kold_ / delta;
      70          44 :             } else {
      71          34 :                 return kold_ / (delta + damp_ * c_.squaredL2Norm());
      72          34 :             }
      73          78 :         }();
      74             : 
      75          78 :         lincomb(1, x, alpha, c_, x);
      76          78 :         s_ -= alpha * q_;
      77             : 
      78          78 :         A_->applyAdjoint(s_, r_);
      79          78 :         if (damp_ != 0.0) {
      80          34 :             r_ -= damp_ * x;
      81          34 :         }
      82             : 
      83          78 :         k_ = r_.squaredL2Norm();
      84          78 :         auto beta = k_ / kold_;
      85             : 
      86             :         // c = r + beta * c;
      87          78 :         lincomb(1, r_, beta, c_, c_);
      88          78 :         kold_ = k_;
      89          78 :         return x;
      90          78 :     }
      91             : 
      92             :     template <typename data_t>
      93             :     bool CGLS<data_t>::shouldStop() const
      94          96 :     {
      95          96 :         return kold_ < tol_;
      96          96 :     }
      97             : 
      98             :     template <typename data_t>
      99             :     std::string CGLS<data_t>::formatHeader() const
     100          18 :     {
     101          18 :         return fmt::format("{:^15} | {:^15}", "|| x ||_2", "|| s ||_2");
     102          18 :     }
     103             : 
     104             :     template <typename data_t>
     105             :     std::string CGLS<data_t>::formatStep(const DataContainer<data_t>& x) const
     106          78 :     {
     107          78 :         return fmt::format("{:>15.10} | {:>15.10}", x.l2Norm(), s_.l2Norm());
     108          78 :     }
     109             : 
     110             :     template <class data_t>
     111             :     CGLS<data_t>* CGLS<data_t>::cloneImpl() const
     112           8 :     {
     113           8 :         return new CGLS(*A_, b_, std::sqrt(damp_));
     114           8 :     }
     115             : 
     116             :     template <class data_t>
     117             :     bool CGLS<data_t>::isEqual(const Solver<data_t>& other) const
     118           8 :     {
     119           8 :         auto otherCGLS = downcast_safe<CGLS>(&other);
     120             : 
     121           8 :         return otherCGLS && *otherCGLS->A_ == *A_ && otherCGLS->b_ == b_
     122           8 :                && otherCGLS->damp_ == damp_;
     123           8 :     }
     124             : 
     125             :     // ------------------------------------------
     126             :     // explicit template instantiation
     127             :     template class CGLS<float>;
     128             :     template class CGLS<double>;
     129             : 
     130             : } // namespace elsa

Generated by: LCOV version 1.14