LCOV - code coverage report
Current view: top level - elsa/solvers - ISTA.cpp (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 44 54 81.5 %
Date: 2022-08-25 03:05:39 Functions: 6 14 42.9 %

          Line data    Source code
       1             : #include "ISTA.h"
       2             : #include "SoftThresholding.h"
       3             : #include "TypeCasts.hpp"
       4             : #include "Logger.h"
       5             : 
       6             : #include "spdlog/stopwatch.h"
       7             : 
       8             : namespace elsa
       9             : {
      10             :     template <typename data_t>
      11             :     ISTA<data_t>::ISTA(const LASSOProblem<data_t>& problem, geometry::Threshold<data_t> mu,
      12             :                        data_t epsilon)
      13             :         : Solver<data_t>(), _problem(problem), _mu{mu}, _epsilon{epsilon}
      14           1 :     {
      15           1 :     }
      16             : 
      17             :     template <typename data_t>
      18             :     ISTA<data_t>::ISTA(const Problem<data_t>& problem, geometry::Threshold<data_t> mu,
      19             :                        data_t epsilon)
      20             :         : ISTA(LASSOProblem<data_t>(problem), mu, epsilon)
      21           0 :     {
      22           0 :     }
      23             : 
      24             :     template <typename data_t>
      25             :     ISTA<data_t>::ISTA(const Problem<data_t>& problem, data_t epsilon)
      26             :         : ISTA<data_t>(LASSOProblem<data_t>(problem), epsilon)
      27           4 :     {
      28           4 :     }
      29             : 
      30             :     template <typename data_t>
      31             :     ISTA<data_t>::ISTA(const LASSOProblem<data_t>& lassoProb, data_t epsilon)
      32             :         : Solver<data_t>(),
      33             :           _problem(lassoProb),
      34             :           _mu{1 / lassoProb.getLipschitzConstant()},
      35             :           _epsilon{epsilon}
      36           2 :     {
      37           2 :     }
      38             : 
      39             :     template <typename data_t>
      40             :     auto ISTA<data_t>::solveImpl(index_t iterations) -> DataContainer<data_t>&
      41           1 :     {
      42           1 :         if (iterations == 0)
      43           0 :             iterations = _defaultIterations;
      44             : 
      45           1 :         spdlog::stopwatch aggregate_time;
      46           1 :         Logger::get("ISTA")->info("Start preparations...");
      47             : 
      48           1 :         SoftThresholding<data_t> shrinkageOp{_problem.getCurrentSolution().getDataDescriptor()};
      49             : 
      50           1 :         data_t lambda = _problem.getRegularizationTerms()[0].getWeight();
      51             : 
      52             :         // Safe as long as only LinearResidual exits
      53           1 :         const auto& linResid =
      54           1 :             downcast<LinearResidual<data_t>>((_problem.getDataTerm()).getResidual());
      55           1 :         const LinearOperator<data_t>& A = linResid.getOperator();
      56           1 :         const DataContainer<data_t>& b = linResid.getDataVector();
      57             : 
      58           1 :         DataContainer<data_t>& x = _problem.getCurrentSolution();
      59           1 :         DataContainer<data_t> Atb = A.applyAdjoint(b);
      60           1 :         DataContainer<data_t> gradient = A.applyAdjoint(A.apply(x)) - Atb;
      61             : 
      62           1 :         Logger::get("ISTA")->info("Preparations done, tooke {}s", aggregate_time);
      63           1 :         Logger::get("ISTA")->info("{:^6}|{:*^16}|{:*^8}|{:*^8}|", "iter", "gradient", "time",
      64           1 :                                   "elapsed");
      65             : 
      66           1 :         auto deltaZero = gradient.squaredL2Norm();
      67         101 :         for (index_t iter = 0; iter < iterations; ++iter) {
      68         100 :             spdlog::stopwatch iter_time;
      69             : 
      70         100 :             gradient = A.applyAdjoint(A.apply(x)) - Atb;
      71             : 
      72         100 :             x = shrinkageOp.apply(x - _mu * gradient, geometry::Threshold{_mu * lambda});
      73             : 
      74         100 :             Logger::get("ISTA")->info("{:>5} |{:>15} | {:>6.3} |{:>6.3}s |", iter,
      75         100 :                                       gradient.squaredL2Norm(), iter_time, aggregate_time);
      76         100 :             if (gradient.squaredL2Norm() <= _epsilon * _epsilon * deltaZero) {
      77           0 :                 Logger::get("ISTA")->info("SUCCESS: Reached convergence at {}/{} iteration",
      78           0 :                                           iter + 1, iterations);
      79           0 :                 return x;
      80           0 :             }
      81         100 :         }
      82             : 
      83           1 :         Logger::get("ISTA")->warn("Failed to reach convergence at {} iterations", iterations);
      84             : 
      85           1 :         return _problem.getCurrentSolution();
      86           1 :     }
      87             : 
      88             :     template <typename data_t>
      89             :     auto ISTA<data_t>::cloneImpl() const -> ISTA<data_t>*
      90           1 :     {
      91           1 :         return new ISTA(_problem, geometry::Threshold<data_t>{_mu}, _epsilon);
      92           1 :     }
      93             : 
      94             :     template <typename data_t>
      95             :     auto ISTA<data_t>::isEqual(const Solver<data_t>& other) const -> bool
      96           1 :     {
      97           1 :         auto otherISTA = downcast_safe<ISTA>(&other);
      98           1 :         if (!otherISTA)
      99           0 :             return false;
     100             : 
     101           1 :         if (_mu != otherISTA->_mu)
     102           0 :             return false;
     103             : 
     104           1 :         if (_epsilon != otherISTA->_epsilon)
     105           0 :             return false;
     106             : 
     107           1 :         return true;
     108           1 :     }
     109             : 
     110             :     // ------------------------------------------
     111             :     // explicit template instantiation
     112             :     template class ISTA<float>;
     113             :     template class ISTA<double>;
     114             : } // namespace elsa

Generated by: LCOV version 1.14