LCOV - code coverage report
Current view: top level - solvers - ISTA.cpp (source / functions) Hit Total Coverage
Test: test_coverage.info.cleaned Lines: 0 51 0.0 %
Date: 2022-08-04 03:43:28 Functions: 0 12 0.0 %

          Line data    Source code
       1             : #include "ISTA.h"
       2             : #include "SoftThresholding.h"
       3             : #include "TypeCasts.hpp"
       4             : #include "Logger.h"
       5             : 
       6             : #include "spdlog/stopwatch.h"
       7             : 
       8             : namespace elsa
       9             : {
      10             :     template <typename data_t>
      11           0 :     ISTA<data_t>::ISTA(const Problem<data_t>& problem, geometry::Threshold<data_t> mu,
      12             :                        data_t epsilon)
      13           0 :         : Solver<data_t>(LASSOProblem(problem)), _mu{mu}, _epsilon{epsilon}
      14             :     {
      15           0 :     }
      16             : 
      17             :     template <typename data_t>
      18           0 :     ISTA<data_t>::ISTA(const Problem<data_t>& problem, data_t epsilon)
      19           0 :         : ISTA<data_t>(LASSOProblem(problem), epsilon)
      20             :     {
      21           0 :     }
      22             : 
      23             :     template <typename data_t>
      24           0 :     ISTA<data_t>::ISTA(const LASSOProblem<data_t>& lassoProb, data_t epsilon)
      25           0 :         : Solver<data_t>(lassoProb), _mu{1 / lassoProb.getLipschitzConstant()}, _epsilon{epsilon}
      26             :     {
      27           0 :     }
      28             : 
      29             :     template <typename data_t>
      30           0 :     auto ISTA<data_t>::solveImpl(index_t iterations) -> DataContainer<data_t>&
      31             :     {
      32           0 :         if (iterations == 0)
      33           0 :             iterations = _defaultIterations;
      34             : 
      35           0 :         spdlog::stopwatch aggregate_time;
      36           0 :         Logger::get("ISTA")->info("Start preparations...");
      37             : 
      38           0 :         SoftThresholding<data_t> shrinkageOp{getCurrentSolution().getDataDescriptor()};
      39             : 
      40           0 :         data_t lambda = _problem->getRegularizationTerms()[0].getWeight();
      41             : 
      42             :         // Safe as long as only LinearResidual exits
      43             :         const auto& linResid =
      44           0 :             downcast<LinearResidual<data_t>>((_problem->getDataTerm()).getResidual());
      45           0 :         const LinearOperator<data_t>& A = linResid.getOperator();
      46           0 :         const DataContainer<data_t>& b = linResid.getDataVector();
      47             : 
      48           0 :         DataContainer<data_t>& x = getCurrentSolution();
      49           0 :         DataContainer<data_t> Atb = A.applyAdjoint(b);
      50           0 :         DataContainer<data_t> gradient = A.applyAdjoint(A.apply(x)) - Atb;
      51             : 
      52           0 :         Logger::get("ISTA")->info("Preparations done, tooke {}s", aggregate_time);
      53           0 :         Logger::get("ISTA")->info("{:^6}|{:*^16}|{:*^8}|{:*^8}|", "iter", "gradient", "time",
      54             :                                   "elapsed");
      55             : 
      56           0 :         auto deltaZero = gradient.squaredL2Norm();
      57           0 :         for (index_t iter = 0; iter < iterations; ++iter) {
      58           0 :             spdlog::stopwatch iter_time;
      59             : 
      60           0 :             gradient = A.applyAdjoint(A.apply(x)) - Atb;
      61             : 
      62           0 :             x = shrinkageOp.apply(x - _mu * gradient, geometry::Threshold{_mu * lambda});
      63             : 
      64           0 :             Logger::get("ISTA")->info("{:>5} |{:>15} | {:>6.3} |{:>6.3}s |", iter,
      65           0 :                                       gradient.squaredL2Norm(), iter_time, aggregate_time);
      66           0 :             if (gradient.squaredL2Norm() <= _epsilon * _epsilon * deltaZero) {
      67           0 :                 Logger::get("ISTA")->info("SUCCESS: Reached convergence at {}/{} iteration",
      68           0 :                                           iter + 1, iterations);
      69           0 :                 return x;
      70             :             }
      71             :         }
      72             : 
      73           0 :         Logger::get("ISTA")->warn("Failed to reach convergence at {} iterations", iterations);
      74             : 
      75           0 :         return getCurrentSolution();
      76           0 :     }
      77             : 
      78             :     template <typename data_t>
      79           0 :     auto ISTA<data_t>::cloneImpl() const -> ISTA<data_t>*
      80             :     {
      81           0 :         return new ISTA(*_problem, geometry::Threshold<data_t>{_mu}, _epsilon);
      82             :     }
      83             : 
      84             :     template <typename data_t>
      85           0 :     auto ISTA<data_t>::isEqual(const Solver<data_t>& other) const -> bool
      86             :     {
      87           0 :         if (!Solver<data_t>::isEqual(other))
      88           0 :             return false;
      89             : 
      90           0 :         auto otherISTA = downcast_safe<ISTA>(&other);
      91           0 :         if (!otherISTA)
      92           0 :             return false;
      93             : 
      94           0 :         if (_mu != otherISTA->_mu)
      95           0 :             return false;
      96             : 
      97           0 :         if (_epsilon != otherISTA->_epsilon)
      98           0 :             return false;
      99             : 
     100           0 :         return true;
     101             :     }
     102             : 
     103             :     // ------------------------------------------
     104             :     // explicit template instantiation
     105             :     template class ISTA<float>;
     106             :     template class ISTA<double>;
     107             : } // namespace elsa

Generated by: LCOV version 1.14