LCOV - code coverage report
Current view: top level - elsa/solvers - FISTA.cpp (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 54 64 84.4 %
Date: 2022-08-25 03:05:39 Functions: 9 14 64.3 %

          Line data    Source code
       1             : #include "FISTA.h"
       2             : #include "SoftThresholding.h"
       3             : #include "Logger.h"
       4             : 
       5             : #include "spdlog/stopwatch.h"
       6             : 
       7             : namespace elsa
       8             : {
       9             :     template <typename data_t>
      10             :     FISTA<data_t>::FISTA(const LASSOProblem<data_t>& problem, geometry::Threshold<data_t> mu,
      11             :                          data_t epsilon)
      12             :         : Solver<data_t>(), _problem(problem), _mu{mu}, _epsilon{epsilon}
      13           3 :     {
      14           3 :     }
      15             : 
      16             :     template <typename data_t>
      17             :     FISTA<data_t>::FISTA(const Problem<data_t>& problem, geometry::Threshold<data_t> mu,
      18             :                          data_t epsilon)
      19             :         : FISTA(LASSOProblem(problem), mu, epsilon)
      20           0 :     {
      21           0 :     }
      22             : 
      23             :     template <typename data_t>
      24             :     FISTA<data_t>::FISTA(const Problem<data_t>& problem, data_t epsilon)
      25             :         : FISTA<data_t>(LASSOProblem(problem), epsilon)
      26           4 :     {
      27           4 :     }
      28             : 
      29             :     template <typename data_t>
      30             :     FISTA<data_t>::FISTA(const LASSOProblem<data_t>& problem, data_t epsilon)
      31             :         : Solver<data_t>(),
      32             :           _problem(LASSOProblem(problem)),
      33             :           _mu{1 / problem.getLipschitzConstant()},
      34             :           _epsilon{epsilon}
      35           2 :     {
      36           2 :     }
      37             : 
      38             :     template <typename data_t>
      39             :     auto FISTA<data_t>::solveImpl(index_t iterations) -> DataContainer<data_t>&
      40           3 :     {
      41           3 :         if (iterations == 0)
      42           0 :             iterations = _defaultIterations;
      43             : 
      44           3 :         spdlog::stopwatch aggregate_time;
      45           3 :         Logger::get("FISTA")->info("Start preparations...");
      46             : 
      47           3 :         SoftThresholding<data_t> shrinkageOp{_problem.getCurrentSolution().getDataDescriptor()};
      48             : 
      49           3 :         data_t lambda = _problem.getRegularizationTerms()[0].getWeight();
      50             : 
      51             :         // Safe as long as only LinearResidual exits
      52           3 :         const auto& linResid =
      53           3 :             downcast<LinearResidual<data_t>>((_problem.getDataTerm()).getResidual());
      54           3 :         const LinearOperator<data_t>& A = linResid.getOperator();
      55           3 :         const DataContainer<data_t>& b = linResid.getDataVector();
      56             : 
      57           3 :         DataContainer<data_t>& x = _problem.getCurrentSolution();
      58           3 :         DataContainer<data_t> xPrev = _problem.getCurrentSolution();
      59           3 :         DataContainer<data_t> y = _problem.getCurrentSolution();
      60           3 :         DataContainer<data_t> yPrev = _problem.getCurrentSolution();
      61           3 :         data_t t;
      62           3 :         data_t tPrev = 1;
      63             : 
      64           3 :         DataContainer<data_t> Atb = A.applyAdjoint(b);
      65           3 :         DataContainer<data_t> gradient = A.applyAdjoint(A.apply(yPrev)) - Atb;
      66             : 
      67           3 :         Logger::get("FISTA")->info("Preparations done, tooke {}s", aggregate_time);
      68             : 
      69           3 :         Logger::get("FISTA")->info("{:^6}|{:*^16}|{:*^8}|{:*^8}|", "iter", "gradient", "time",
      70           3 :                                    "elapsed");
      71             : 
      72           3 :         auto deltaZero = gradient.squaredL2Norm();
      73         703 :         for (index_t iter = 0; iter < iterations; ++iter) {
      74         700 :             spdlog::stopwatch iter_time;
      75             : 
      76         700 :             gradient = A.applyAdjoint(A.apply(yPrev)) - Atb;
      77         700 :             x = shrinkageOp.apply(yPrev - _mu * gradient, geometry::Threshold{_mu * lambda});
      78             : 
      79         700 :             t = (1 + std::sqrt(1 + 4 * tPrev * tPrev)) / 2;
      80         700 :             y = x + ((tPrev - 1) / t) * (x - xPrev);
      81             : 
      82         700 :             xPrev = x;
      83         700 :             yPrev = y;
      84         700 :             tPrev = t;
      85             : 
      86         700 :             Logger::get("FISTA")->info("{:>5} |{:>15} | {:>6.3} |{:>6.3}s |", iter,
      87         700 :                                        gradient.squaredL2Norm(), iter_time, aggregate_time);
      88             : 
      89         700 :             if (gradient.squaredL2Norm() <= _epsilon * _epsilon * deltaZero) {
      90           0 :                 Logger::get("FISTA")->info("SUCCESS: Reached convergence at {}/{} iteration",
      91           0 :                                            iter + 1, iterations);
      92           0 :                 return x;
      93           0 :             }
      94         700 :         }
      95             : 
      96           3 :         Logger::get("FISTA")->warn("Failed to reach convergence at {} iterations", iterations);
      97             : 
      98           3 :         return _problem.getCurrentSolution();
      99           3 :     }
     100             : 
     101             :     template <typename data_t>
     102             :     auto FISTA<data_t>::cloneImpl() const -> FISTA<data_t>*
     103           1 :     {
     104           1 :         return new FISTA(_problem, geometry::Threshold<data_t>{_mu}, _epsilon);
     105           1 :     }
     106             : 
     107             :     template <typename data_t>
     108             :     auto FISTA<data_t>::isEqual(const Solver<data_t>& other) const -> bool
     109           1 :     {
     110           1 :         auto otherFISTA = downcast_safe<FISTA>(&other);
     111           1 :         if (!otherFISTA)
     112           0 :             return false;
     113             : 
     114           1 :         if (_mu != otherFISTA->_mu)
     115           0 :             return false;
     116             : 
     117           1 :         if (_epsilon != otherFISTA->_epsilon)
     118           0 :             return false;
     119             : 
     120           1 :         return true;
     121           1 :     }
     122             : 
     123             :     // ------------------------------------------
     124             :     // explicit template instantiation
     125             :     template class FISTA<float>;
     126             :     template class FISTA<double>;
     127             : } // namespace elsa

Generated by: LCOV version 1.14