LCOV - code coverage report
Current view: top level - solvers - FISTA.cpp (source / functions) Hit Total Coverage
Test: test_coverage.info.cleaned Lines: 0 62 0.0 %
Date: 2022-08-04 03:43:28 Functions: 0 12 0.0 %

          Line data    Source code
       1             : #include "FISTA.h"
       2             : #include "SoftThresholding.h"
       3             : #include "Logger.h"
       4             : 
       5             : #include "spdlog/stopwatch.h"
       6             : 
       7             : namespace elsa
       8             : {
       9             :     template <typename data_t>
      10           0 :     FISTA<data_t>::FISTA(const Problem<data_t>& problem, geometry::Threshold<data_t> mu,
      11             :                          data_t epsilon)
      12           0 :         : Solver<data_t>(LASSOProblem(problem)), _mu{mu}, _epsilon{epsilon}
      13             :     {
      14           0 :     }
      15             : 
      16             :     template <typename data_t>
      17           0 :     FISTA<data_t>::FISTA(const Problem<data_t>& problem, data_t epsilon)
      18           0 :         : FISTA<data_t>(LASSOProblem(problem), epsilon)
      19             :     {
      20           0 :     }
      21             : 
      22             :     template <typename data_t>
      23           0 :     FISTA<data_t>::FISTA(const LASSOProblem<data_t>& problem, data_t epsilon)
      24           0 :         : Solver<data_t>(LASSOProblem(problem)),
      25           0 :           _mu{1 / problem.getLipschitzConstant()},
      26           0 :           _epsilon{epsilon}
      27             :     {
      28           0 :     }
      29             : 
      30             :     template <typename data_t>
      31           0 :     auto FISTA<data_t>::solveImpl(index_t iterations) -> DataContainer<data_t>&
      32             :     {
      33           0 :         if (iterations == 0)
      34           0 :             iterations = _defaultIterations;
      35             : 
      36           0 :         spdlog::stopwatch aggregate_time;
      37           0 :         Logger::get("FISTA")->info("Start preparations...");
      38             : 
      39           0 :         SoftThresholding<data_t> shrinkageOp{getCurrentSolution().getDataDescriptor()};
      40             : 
      41           0 :         data_t lambda = _problem->getRegularizationTerms()[0].getWeight();
      42             : 
      43             :         // Safe as long as only LinearResidual exits
      44             :         const auto& linResid =
      45           0 :             downcast<LinearResidual<data_t>>((_problem->getDataTerm()).getResidual());
      46           0 :         const LinearOperator<data_t>& A = linResid.getOperator();
      47           0 :         const DataContainer<data_t>& b = linResid.getDataVector();
      48             : 
      49           0 :         DataContainer<data_t>& x = getCurrentSolution();
      50           0 :         DataContainer<data_t> xPrev = getCurrentSolution();
      51           0 :         DataContainer<data_t> y = getCurrentSolution();
      52           0 :         DataContainer<data_t> yPrev = getCurrentSolution();
      53             :         data_t t;
      54           0 :         data_t tPrev = 1;
      55             : 
      56           0 :         DataContainer<data_t> Atb = A.applyAdjoint(b);
      57           0 :         DataContainer<data_t> gradient = A.applyAdjoint(A.apply(yPrev)) - Atb;
      58             : 
      59           0 :         Logger::get("FISTA")->info("Preparations done, tooke {}s", aggregate_time);
      60             : 
      61           0 :         Logger::get("FISTA")->info("{:^6}|{:*^16}|{:*^8}|{:*^8}|", "iter", "gradient", "time",
      62             :                                    "elapsed");
      63             : 
      64           0 :         auto deltaZero = gradient.squaredL2Norm();
      65           0 :         for (index_t iter = 0; iter < iterations; ++iter) {
      66           0 :             spdlog::stopwatch iter_time;
      67             : 
      68           0 :             gradient = A.applyAdjoint(A.apply(yPrev)) - Atb;
      69           0 :             x = shrinkageOp.apply(yPrev - _mu * gradient, geometry::Threshold{_mu * lambda});
      70             : 
      71           0 :             t = (1 + std::sqrt(1 + 4 * tPrev * tPrev)) / 2;
      72           0 :             y = x + ((tPrev - 1) / t) * (x - xPrev);
      73             : 
      74           0 :             xPrev = x;
      75           0 :             yPrev = y;
      76           0 :             tPrev = t;
      77             : 
      78           0 :             Logger::get("FISTA")->info("{:>5} |{:>15} | {:>6.3} |{:>6.3}s |", iter,
      79           0 :                                        gradient.squaredL2Norm(), iter_time, aggregate_time);
      80             : 
      81           0 :             if (gradient.squaredL2Norm() <= _epsilon * _epsilon * deltaZero) {
      82           0 :                 Logger::get("FISTA")->info("SUCCESS: Reached convergence at {}/{} iteration",
      83           0 :                                            iter + 1, iterations);
      84           0 :                 return x;
      85             :             }
      86             :         }
      87             : 
      88           0 :         Logger::get("FISTA")->warn("Failed to reach convergence at {} iterations", iterations);
      89             : 
      90           0 :         return getCurrentSolution();
      91           0 :     }
      92             : 
      93             :     template <typename data_t>
      94           0 :     auto FISTA<data_t>::cloneImpl() const -> FISTA<data_t>*
      95             :     {
      96           0 :         return new FISTA(*_problem, geometry::Threshold<data_t>{_mu}, _epsilon);
      97             :     }
      98             : 
      99             :     template <typename data_t>
     100           0 :     auto FISTA<data_t>::isEqual(const Solver<data_t>& other) const -> bool
     101             :     {
     102           0 :         if (!Solver<data_t>::isEqual(other))
     103           0 :             return false;
     104             : 
     105           0 :         auto otherFISTA = downcast_safe<FISTA>(&other);
     106           0 :         if (!otherFISTA)
     107           0 :             return false;
     108             : 
     109           0 :         if (_mu != otherFISTA->_mu)
     110           0 :             return false;
     111             : 
     112           0 :         if (_epsilon != otherFISTA->_epsilon)
     113           0 :             return false;
     114             : 
     115           0 :         return true;
     116             :     }
     117             : 
     118             :     // ------------------------------------------
     119             :     // explicit template instantiation
     120             :     template class FISTA<float>;
     121             :     template class FISTA<double>;
     122             : } // namespace elsa

Generated by: LCOV version 1.14