Line data Source code
1 : #include "FISTA.h" 2 : #include "SoftThresholding.h" 3 : #include "Logger.h" 4 : 5 : #include "spdlog/stopwatch.h" 6 : 7 : namespace elsa 8 : { 9 : template <typename data_t> 10 : FISTA<data_t>::FISTA(const LASSOProblem<data_t>& problem, geometry::Threshold<data_t> mu, 11 : data_t epsilon) 12 : : Solver<data_t>(), _problem(problem), _mu{mu}, _epsilon{epsilon} 13 3 : { 14 3 : } 15 : 16 : template <typename data_t> 17 : FISTA<data_t>::FISTA(const Problem<data_t>& problem, geometry::Threshold<data_t> mu, 18 : data_t epsilon) 19 : : FISTA(LASSOProblem(problem), mu, epsilon) 20 0 : { 21 0 : } 22 : 23 : template <typename data_t> 24 : FISTA<data_t>::FISTA(const Problem<data_t>& problem, data_t epsilon) 25 : : FISTA<data_t>(LASSOProblem(problem), epsilon) 26 4 : { 27 4 : } 28 : 29 : template <typename data_t> 30 : FISTA<data_t>::FISTA(const LASSOProblem<data_t>& problem, data_t epsilon) 31 : : Solver<data_t>(), 32 : _problem(LASSOProblem(problem)), 33 : _mu{1 / problem.getLipschitzConstant()}, 34 : _epsilon{epsilon} 35 2 : { 36 2 : } 37 : 38 : template <typename data_t> 39 : auto FISTA<data_t>::solveImpl(index_t iterations) -> DataContainer<data_t>& 40 3 : { 41 3 : if (iterations == 0) 42 0 : iterations = _defaultIterations; 43 : 44 3 : spdlog::stopwatch aggregate_time; 45 3 : Logger::get("FISTA")->info("Start preparations..."); 46 : 47 3 : SoftThresholding<data_t> shrinkageOp{_problem.getCurrentSolution().getDataDescriptor()}; 48 : 49 3 : data_t lambda = _problem.getRegularizationTerms()[0].getWeight(); 50 : 51 : // Safe as long as only LinearResidual exits 52 3 : const auto& linResid = 53 3 : downcast<LinearResidual<data_t>>((_problem.getDataTerm()).getResidual()); 54 3 : const LinearOperator<data_t>& A = linResid.getOperator(); 55 3 : const DataContainer<data_t>& b = linResid.getDataVector(); 56 : 57 3 : DataContainer<data_t>& x = _problem.getCurrentSolution(); 58 3 : DataContainer<data_t> xPrev = _problem.getCurrentSolution(); 59 3 : DataContainer<data_t> y = _problem.getCurrentSolution(); 60 3 : DataContainer<data_t> yPrev = _problem.getCurrentSolution(); 61 3 : data_t t; 62 3 : data_t tPrev = 1; 63 : 64 3 : DataContainer<data_t> Atb = A.applyAdjoint(b); 65 3 : DataContainer<data_t> gradient = A.applyAdjoint(A.apply(yPrev)) - Atb; 66 : 67 3 : Logger::get("FISTA")->info("Preparations done, tooke {}s", aggregate_time); 68 : 69 3 : Logger::get("FISTA")->info("{:^6}|{:*^16}|{:*^8}|{:*^8}|", "iter", "gradient", "time", 70 3 : "elapsed"); 71 : 72 3 : auto deltaZero = gradient.squaredL2Norm(); 73 703 : for (index_t iter = 0; iter < iterations; ++iter) { 74 700 : spdlog::stopwatch iter_time; 75 : 76 700 : gradient = A.applyAdjoint(A.apply(yPrev)) - Atb; 77 700 : x = shrinkageOp.apply(yPrev - _mu * gradient, geometry::Threshold{_mu * lambda}); 78 : 79 700 : t = (1 + std::sqrt(1 + 4 * tPrev * tPrev)) / 2; 80 700 : y = x + ((tPrev - 1) / t) * (x - xPrev); 81 : 82 700 : xPrev = x; 83 700 : yPrev = y; 84 700 : tPrev = t; 85 : 86 700 : Logger::get("FISTA")->info("{:>5} |{:>15} | {:>6.3} |{:>6.3}s |", iter, 87 700 : gradient.squaredL2Norm(), iter_time, aggregate_time); 88 : 89 700 : if (gradient.squaredL2Norm() <= _epsilon * _epsilon * deltaZero) { 90 0 : Logger::get("FISTA")->info("SUCCESS: Reached convergence at {}/{} iteration", 91 0 : iter + 1, iterations); 92 0 : return x; 93 0 : } 94 700 : } 95 : 96 3 : Logger::get("FISTA")->warn("Failed to reach convergence at {} iterations", iterations); 97 : 98 3 : return _problem.getCurrentSolution(); 99 3 : } 100 : 101 : template <typename data_t> 102 : auto FISTA<data_t>::cloneImpl() const -> FISTA<data_t>* 103 1 : { 104 1 : return new FISTA(_problem, geometry::Threshold<data_t>{_mu}, _epsilon); 105 1 : } 106 : 107 : template <typename data_t> 108 : auto FISTA<data_t>::isEqual(const Solver<data_t>& other) const -> bool 109 1 : { 110 1 : auto otherFISTA = downcast_safe<FISTA>(&other); 111 1 : if (!otherFISTA) 112 0 : return false; 113 : 114 1 : if (_mu != otherFISTA->_mu) 115 0 : return false; 116 : 117 1 : if (_epsilon != otherFISTA->_epsilon) 118 0 : return false; 119 : 120 1 : return true; 121 1 : } 122 : 123 : // ------------------------------------------ 124 : // explicit template instantiation 125 : template class FISTA<float>; 126 : template class FISTA<double>; 127 : } // namespace elsa