LCOV - code coverage report
Current view: top level - elsa/solvers - ALB.cpp (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 51 59 86.4 %
Date: 2024-05-16 04:22:26 Functions: 16 16 100.0 %

          Line data    Source code
       1             : #include "ALB.h"
       2             : #include "DataContainer.h"
       3             : #include "LinearOperator.h"
       4             : #include "TypeCasts.hpp"
       5             : #include "Logger.h"
       6             : 
       7             : #include "spdlog/stopwatch.h"
       8             : #include "PowerIterations.h"
       9             : 
      10             : namespace elsa
      11             : {
      12             :     template <typename data_t>
      13             :     ALB<data_t>::ALB(const LinearOperator<data_t>& A, const DataContainer<data_t>& b,
      14             :                      ProximalOperator<data_t> prox, data_t mu, std::optional<data_t> beta,
      15             :                      data_t epsilon)
      16             :         : A_(A.clone()),
      17             :           b_(b),
      18             :           v_(A_->getDomainDescriptor()),
      19             :           vPrev_(A_->getDomainDescriptor()),
      20             :           vTilda_(A_->getDomainDescriptor()),
      21             :           residual_(A_->getRangeDescriptor()),
      22             :           prox_(prox),
      23             :           mu_(mu),
      24             :           epsilon_(epsilon)
      25           4 :     {
      26           4 :         if (!beta.has_value()) {
      27           0 :             beta_ = data_t{2} / (mu_ * powerIterations(adjoint(*A_) * (*A_)));
      28           0 :             Logger::get("ALB")->info("Step length is chosen to be: {:8.5}", beta_);
      29           4 :         } else {
      30           4 :             beta_ = *beta;
      31           4 :         }
      32           4 :     }
      33             : 
      34             :     template <typename data_t>
      35             :     DataContainer<data_t> ALB<data_t>::setup(std::optional<DataContainer<data_t>> x0)
      36           4 :     {
      37           4 :         auto x = extract_or(x0, A_->getDomainDescriptor());
      38             : 
      39           4 :         v_ = emptylike(x);
      40           4 :         vPrev_ = zeroslike(x);
      41           4 :         vTilda_ = zeroslike(x);
      42             : 
      43           4 :         residual_ = emptylike(b_);
      44             : 
      45           4 :         return x;
      46           4 :     }
      47             : 
      48             :     template <typename data_t>
      49             :     DataContainer<data_t> ALB<data_t>::step(DataContainer<data_t> x)
      50          68 :     {
      51          68 :         vPrev_ = v_;
      52             : 
      53             :         // x^{k+1} = mu * prox(v^k, 1)
      54          68 :         x = mu_ * prox_.apply(vTilda_, 1);
      55             : 
      56             :         // residual = b - Ax^{k+1}
      57          68 :         lincomb(1, b_, -1, A_->apply(x), residual_);
      58             : 
      59             :         // v^{k+1} = v_tilda^k + beta * A^{*}(b - Ax^{k+1})
      60          68 :         lincomb(1, vTilda_, beta_, A_->applyAdjoint(residual_), v_);
      61             : 
      62             :         // a_k = (2i + 3) / (i + 3)
      63          68 :         auto a = static_cast<data_t>(2 * this->curiter_ + 3) / (this->curiter_ + 3);
      64             : 
      65             :         // v_tilda^{k+1} = a_k * v^{k+1} + (1 - a_k) * v^k
      66          68 :         lincomb(a, v_, (1 - a), vPrev_, vTilda_);
      67             : 
      68          68 :         return x;
      69          68 :     }
      70             : 
      71             :     template <typename data_t>
      72             :     bool ALB<data_t>::shouldStop() const
      73          70 :     {
      74          70 :         return this->curiter_ > 1 && residual_.squaredL2Norm() / b_.squaredL2Norm() <= epsilon_;
      75          70 :     }
      76             : 
      77             :     template <typename data_t>
      78             :     std::string ALB<data_t>::formatHeader() const
      79           2 :     {
      80           2 :         return fmt::format("{:^12} | {:^12} | {:^12} | {:^12}", "x-norm", "v-norm", "\\tilde{v}",
      81           2 :                            "error");
      82           2 :     }
      83             : 
      84             :     template <typename data_t>
      85             :     std::string ALB<data_t>::formatStep(const DataContainer<data_t>& x) const
      86          68 :     {
      87          68 :         auto error = residual_.squaredL2Norm() / b_.squaredL2Norm();
      88          68 :         return fmt::format("{:>12} | {:>12} | {:>12} | {:>12}", x.squaredL2Norm(),
      89          68 :                            v_.squaredL2Norm(), vTilda_.squaredL2Norm(), error);
      90          68 :     }
      91             : 
      92             :     template <typename data_t>
      93             :     auto ALB<data_t>::cloneImpl() const -> ALB<data_t>*
      94           2 :     {
      95           2 :         return new ALB<data_t>(*A_, b_, prox_, mu_, beta_, epsilon_);
      96           2 :     }
      97             : 
      98             :     template <typename data_t>
      99             :     auto ALB<data_t>::isEqual(const Solver<data_t>& other) const -> bool
     100           2 :     {
     101           2 :         auto otherAlb = downcast_safe<ALB>(&other);
     102           2 :         if (!otherAlb)
     103           0 :             return false;
     104             : 
     105           2 :         if (*A_ != *otherAlb->A_)
     106           0 :             return false;
     107             : 
     108           2 :         if (b_ != otherAlb->b_)
     109           0 :             return false;
     110             : 
     111           2 :         Logger::get("ALB")->info("beta: {}, {}", beta_, otherAlb->beta_);
     112           2 :         if (std::abs(beta_ - otherAlb->beta_) > 1e-5)
     113           0 :             return false;
     114             : 
     115           2 :         Logger::get("ALB")->info("mu: {}, {}", mu_, otherAlb->mu_);
     116           2 :         if (mu_ != otherAlb->mu_)
     117           0 :             return false;
     118             : 
     119           2 :         Logger::get("ALB")->info("epsilon: {}, {}", epsilon_, otherAlb->epsilon_);
     120           2 :         if (epsilon_ != otherAlb->epsilon_)
     121           0 :             return false;
     122             : 
     123           2 :         return true;
     124           2 :     }
     125             : 
     126             :     // ------------------------------------------
     127             :     // explicit template instantiation
     128             :     template class ALB<float>;
     129             :     template class ALB<double>;
     130             : } // namespace elsa

Generated by: LCOV version 1.14