LCOV - code coverage report
Current view: top level - elsa/solvers - ADMML2.cpp (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 32 68 47.1 %
Date: 2024-12-21 07:37:52 Functions: 4 10 40.0 %

          Line data    Source code
       1             : #include "ADMML2.h"
       2             : 
       3             : #include "DataContainer.h"
       4             : #include "LinearOperator.h"
       5             : #include "Solver.h"
       6             : #include "ProximalOperator.h"
       7             : #include "TypeCasts.hpp"
       8             : #include "elsaDefines.h"
       9             : #include "Logger.h"
      10             : #include "RegularizedInversion.h"
      11             : #include "PowerIterations.h"
      12             : 
      13             : #include <cmath>
      14             : #include <memory>
      15             : #include <optional>
      16             : #include <vector>
      17             : 
      18             : namespace elsa
      19             : {
      20             :     template <class data_t>
      21             :     ADMML2<data_t>::ADMML2(const LinearOperator<data_t>& op, const DataContainer<data_t>& b,
      22             :                            const LinearOperator<data_t>& A, const ProximalOperator<data_t>& proxg,
      23             :                            std::optional<data_t> tau, index_t ninneriters)
      24             :         : Solver<data_t>(),
      25             :           op_(op.clone()),
      26             :           b_(b),
      27             :           A_(A.clone()),
      28             :           proxg_(proxg),
      29             :           tau_(0),
      30             :           ninneriters_(ninneriters)
      31           2 :     {
      32           2 :         auto eigenval = data_t{1} / powerIterations(adjoint(A) * A);
      33             : 
      34           2 :         if (tau.has_value()) {
      35           2 :             tau_ = *tau;
      36           2 :             if (tau_ < 0 || tau_ > eigenval) {
      37           0 :                 Logger::get("ADMML2")->info("tau ({:8.5}), should be between 0 and {:8.5}", tau_,
      38           0 :                                             eigenval);
      39           0 :             }
      40           2 :         } else {
      41           0 :             tau_ = 0.9 * eigenval;
      42           0 :             Logger::get("ADMML2")->info("tau is chosen {}", tau_, eigenval);
      43           0 :         }
      44           2 :     }
      45             : 
      46             :     template <class data_t>
      47             :     ADMML2<data_t>::ADMML2(const LinearOperator<data_t>& op, const DataContainer<data_t>& b,
      48             :                            const DataContainer<data_t>& W, const LinearOperator<data_t>& A,
      49             :                            const ProximalOperator<data_t>& proxg, std::optional<data_t> tau,
      50             :                            index_t ninneriters)
      51             :         : Solver<data_t>(),
      52             :           op_(op.clone()),
      53             :           b_(b),
      54             :           A_(A.clone()),
      55             :           W_(W),
      56             :           proxg_(proxg),
      57             :           tau_(0),
      58             :           ninneriters_(ninneriters)
      59           0 :     {
      60           0 :         auto eigenval = data_t{1} / powerIterations(adjoint(A) * A);
      61             : 
      62           0 :         if (tau.has_value()) {
      63           0 :             tau_ = *tau;
      64           0 :             if (tau_ < 0 || tau_ > eigenval) {
      65           0 :                 Logger::get("ADMML2")->info("tau ({:8.5}), should be between 0 and {:8.5}", tau_,
      66           0 :                                             eigenval);
      67           0 :             }
      68           0 :         } else {
      69           0 :             tau_ = 0.9 * eigenval;
      70           0 :             Logger::get("ADMML2")->info("tau is chosen {}", tau_, eigenval);
      71           0 :         }
      72           0 :     }
      73             : 
      74             :     template <class data_t>
      75             :     DataContainer<data_t> ADMML2<data_t>::solve(index_t iterations,
      76             :                                                 std::optional<DataContainer<data_t>> x0)
      77           2 :     {
      78           2 :         auto x = extract_or(x0, op_->getDomainDescriptor());
      79             : 
      80           2 :         const auto& range = A_->getRangeDescriptor();
      81           2 :         auto z = zeros<data_t>(range);
      82           2 :         auto u = zeros<data_t>(range);
      83             : 
      84           2 :         auto Ax = empty<data_t>(range);
      85           2 :         auto tmp = empty<data_t>(range);
      86             : 
      87           2 :         auto sqrttau = data_t{1} / std::sqrt(tau_);
      88             : 
      89           2 :         auto loglevel = Logger::getLevel();
      90           2 :         Logger::get("ADMML2")->info("| {:^4} | {:^12} | {:^12} | {:^12} |", "iter", "f", "z", "u");
      91          22 :         for (index_t iter = 0; iter < iterations; ++iter) {
      92          20 :             Logger::setLevel(Logger::LogLevel::ERR);
      93             : 
      94             :             // x_{k+1} = \min_x 0.5 ||Op x - b||_2^2 + \frac{1}{2\tau}||Ax - z_k + u_k||_2^2
      95          20 :             x = reguarlizedInversion<data_t>(*op_, b_, *A_, z - u, sqrttau, ninneriters_, W_, x);
      96             : 
      97          20 :             Logger::setLevel(loglevel);
      98             : 
      99          20 :             A_->apply(x, Ax);
     100             : 
     101             :             // Ax_{k+1} + u_k
     102          20 :             lincomb(1, Ax, 1, u, tmp);
     103             : 
     104             :             // z_{k+1} = prox_{\tau * g}(Ax_{k+1} + u_k)
     105          20 :             z = proxg_.apply(tmp, tau_);
     106             : 
     107             :             // u_{k+1} = u_k + Ax_{k+1} - z_{k+1}
     108          20 :             u += Ax;
     109          20 :             u -= z;
     110             : 
     111          20 :             Logger::get("ADMML2")->info("| {:>4} | {:12.7} | {:12.7} | {:12.7} |", iter,
     112          20 :                                         0.5 * (op_->apply(x) - b_).l2Norm(), z.l2Norm(),
     113          20 :                                         u.l2Norm());
     114          20 :         }
     115             : 
     116           2 :         return x;
     117           2 :     }
     118             : 
     119             :     template <class data_t>
     120             :     ADMML2<data_t>* ADMML2<data_t>::cloneImpl() const
     121           0 :     {
     122           0 :         return new ADMML2(*op_, b_, *A_, proxg_, tau_, ninneriters_);
     123           0 :     }
     124             : 
     125             :     template <class data_t>
     126             :     bool ADMML2<data_t>::isEqual(const Solver<data_t>& other) const
     127           0 :     {
     128           0 :         auto otherADMM = downcast_safe<ADMML2>(&other);
     129           0 :         if (!otherADMM)
     130           0 :             return false;
     131             : 
     132           0 :         if (*op_ != *otherADMM->op_)
     133           0 :             return false;
     134             : 
     135           0 :         if (*A_ != *otherADMM->A_)
     136           0 :             return false;
     137             : 
     138           0 :         if (tau_ != otherADMM->tau_)
     139           0 :             return false;
     140             : 
     141           0 :         if (ninneriters_ != otherADMM->ninneriters_)
     142           0 :             return false;
     143             : 
     144           0 :         return true;
     145           0 :     }
     146             : 
     147             :     // ------------------------------------------
     148             :     // explicit template instantiation
     149             :     template class ADMML2<float>;
     150             :     template class ADMML2<double>;
     151             : } // namespace elsa

Generated by: LCOV version 1.14