LCOV - code coverage report
Current view: top level - elsa/solvers - LB.cpp (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 48 58 82.8 %
Date: 2024-05-16 04:22:26 Functions: 8 8 100.0 %

          Line data    Source code
       1             : #include "LB.h"
       2             : #include "DataContainer.h"
       3             : #include "LinearOperator.h"
       4             : #include "TypeCasts.hpp"
       5             : #include "Logger.h"
       6             : 
       7             : #include "spdlog/stopwatch.h"
       8             : #include "PowerIterations.h"
       9             : 
      10             : namespace elsa
      11             : {
      12             :     template <typename data_t>
      13             :     LB<data_t>::LB(const LinearOperator<data_t>& A, const DataContainer<data_t>& b,
      14             :                    const ProximalOperator<data_t>& prox, data_t mu, std::optional<data_t> beta,
      15             :                    data_t epsilon)
      16             :         : A_(A.clone()), b_(b), prox_(prox), mu_(mu), epsilon_(epsilon)
      17           4 :     {
      18           4 :         if (!beta.has_value()) {
      19             : 
      20           0 :             beta_ = data_t{2} / (mu_ * powerIterations(adjoint(*A_) * (*A_)));
      21           0 :             Logger::get("LB")->info("Step length is chosen to be: {:8.5}", beta_);
      22             : 
      23           4 :         } else {
      24           4 :             beta_ = *beta;
      25           4 :         }
      26           4 :     }
      27             : 
      28             :     template <typename data_t>
      29             :     auto LB<data_t>::solve(index_t iterations, std::optional<DataContainer<data_t>> x0)
      30             :         -> DataContainer<data_t>
      31           2 :     {
      32           2 :         spdlog::stopwatch iter_time;
      33             : 
      34           2 :         auto v = DataContainer<data_t>(A_->getDomainDescriptor());
      35           2 :         auto x = DataContainer<data_t>(A_->getDomainDescriptor());
      36             : 
      37           2 :         v = 0;
      38             : 
      39           2 :         if (x0.has_value()) {
      40           0 :             x = *x0;
      41           2 :         } else {
      42           2 :             x = 0;
      43           2 :         }
      44             : 
      45           2 :         auto residual = DataContainer<data_t>(b_.getDataDescriptor());
      46             : 
      47          20 :         for (int i = 0; i < iterations; ++i) {
      48             : 
      49             :             // x^{k+1} = mu * prox(v^k, 1)
      50          20 :             x = mu_ * prox_.apply(v, 1);
      51             : 
      52             :             // residual = b - Ax^{k+1}
      53          20 :             lincomb(1, b_, -1, A_->apply(x), residual);
      54             : 
      55             :             // v^{k+1} += beta * A^{*}(b - Ax^{k+1})
      56          20 :             v += beta_ * A_->applyAdjoint(residual);
      57             : 
      58          20 :             auto error = residual.squaredL2Norm() / b_.squaredL2Norm();
      59          20 :             if (residual.squaredL2Norm() / b_.squaredL2Norm() <= epsilon_) {
      60           2 :                 Logger::get("LB")->info("SUCCESS: Reached convergence at {}/{} iteration", i + 1,
      61           2 :                                         iterations);
      62           2 :                 return x;
      63           2 :             }
      64             : 
      65          18 :             Logger::get("LB")->info(
      66          18 :                 "|iter: {:>6} | x: {:>12} | v: {:>12} | error: {:>12} | time: {:>8.3} |", i,
      67          18 :                 x.squaredL2Norm(), v.squaredL2Norm(), error, iter_time);
      68          18 :         }
      69             : 
      70           2 :         Logger::get("LB")->warn("Failed to reach convergence at {} iterations", iterations);
      71             : 
      72           0 :         return x;
      73           2 :     }
      74             : 
      75             :     template <typename data_t>
      76             :     auto LB<data_t>::cloneImpl() const -> LB<data_t>*
      77           2 :     {
      78           2 :         return new LB<data_t>(*A_, b_, prox_, mu_, beta_, epsilon_);
      79           2 :     }
      80             : 
      81             :     template <typename data_t>
      82             :     auto LB<data_t>::isEqual(const Solver<data_t>& other) const -> bool
      83           2 :     {
      84           2 :         auto otherLb = downcast_safe<LB>(&other);
      85           2 :         if (!otherLb)
      86           0 :             return false;
      87             : 
      88           2 :         if (*A_ != *otherLb->A_)
      89           0 :             return false;
      90             : 
      91           2 :         if (b_ != otherLb->b_)
      92           0 :             return false;
      93             : 
      94           2 :         Logger::get("LB")->info("beta: {}, {}", beta_, otherLb->beta_);
      95           2 :         if (std::abs(beta_ - otherLb->beta_) > 1e-5)
      96           0 :             return false;
      97             : 
      98           2 :         Logger::get("LB")->info("mu: {}, {}", mu_, otherLb->mu_);
      99           2 :         if (mu_ != otherLb->mu_)
     100           0 :             return false;
     101             : 
     102           2 :         Logger::get("LB")->info("epsilon: {}, {}", epsilon_, otherLb->epsilon_);
     103           2 :         if (epsilon_ != otherLb->epsilon_)
     104           0 :             return false;
     105             : 
     106           2 :         return true;
     107           2 :     }
     108             : 
     109             :     // ------------------------------------------
     110             :     // explicit template instantiation
     111             :     template class LB<float>;
     112             :     template class LB<double>;
     113             : } // namespace elsa

Generated by: LCOV version 1.14