Line data Source code
1 : #include "FISTA.h" 2 : #include "SoftThresholding.h" 3 : #include "Logger.h" 4 : 5 : #include "spdlog/stopwatch.h" 6 : 7 : namespace elsa 8 : { 9 : template <typename data_t> 10 0 : FISTA<data_t>::FISTA(const Problem<data_t>& problem, geometry::Threshold<data_t> mu, 11 : data_t epsilon) 12 0 : : Solver<data_t>(LASSOProblem(problem)), _mu{mu}, _epsilon{epsilon} 13 : { 14 0 : } 15 : 16 : template <typename data_t> 17 0 : FISTA<data_t>::FISTA(const Problem<data_t>& problem, data_t epsilon) 18 0 : : FISTA<data_t>(LASSOProblem(problem), epsilon) 19 : { 20 0 : } 21 : 22 : template <typename data_t> 23 0 : FISTA<data_t>::FISTA(const LASSOProblem<data_t>& problem, data_t epsilon) 24 0 : : Solver<data_t>(LASSOProblem(problem)), 25 0 : _mu{1 / problem.getLipschitzConstant()}, 26 0 : _epsilon{epsilon} 27 : { 28 0 : } 29 : 30 : template <typename data_t> 31 0 : auto FISTA<data_t>::solveImpl(index_t iterations) -> DataContainer<data_t>& 32 : { 33 0 : if (iterations == 0) 34 0 : iterations = _defaultIterations; 35 : 36 0 : spdlog::stopwatch aggregate_time; 37 0 : Logger::get("FISTA")->info("Start preparations..."); 38 : 39 0 : SoftThresholding<data_t> shrinkageOp{getCurrentSolution().getDataDescriptor()}; 40 : 41 0 : data_t lambda = _problem->getRegularizationTerms()[0].getWeight(); 42 : 43 : // Safe as long as only LinearResidual exits 44 : const auto& linResid = 45 0 : downcast<LinearResidual<data_t>>((_problem->getDataTerm()).getResidual()); 46 0 : const LinearOperator<data_t>& A = linResid.getOperator(); 47 0 : const DataContainer<data_t>& b = linResid.getDataVector(); 48 : 49 0 : DataContainer<data_t>& x = getCurrentSolution(); 50 0 : DataContainer<data_t> xPrev = getCurrentSolution(); 51 0 : DataContainer<data_t> y = getCurrentSolution(); 52 0 : DataContainer<data_t> yPrev = getCurrentSolution(); 53 : data_t t; 54 0 : data_t tPrev = 1; 55 : 56 0 : DataContainer<data_t> Atb = A.applyAdjoint(b); 57 0 : DataContainer<data_t> gradient = A.applyAdjoint(A.apply(yPrev)) - Atb; 58 : 59 0 : Logger::get("FISTA")->info("Preparations done, tooke {}s", aggregate_time); 60 : 61 0 : Logger::get("FISTA")->info("{:^6}|{:*^16}|{:*^8}|{:*^8}|", "iter", "gradient", "time", 62 : "elapsed"); 63 : 64 0 : auto deltaZero = gradient.squaredL2Norm(); 65 0 : for (index_t iter = 0; iter < iterations; ++iter) { 66 0 : spdlog::stopwatch iter_time; 67 : 68 0 : gradient = A.applyAdjoint(A.apply(yPrev)) - Atb; 69 0 : x = shrinkageOp.apply(yPrev - _mu * gradient, geometry::Threshold{_mu * lambda}); 70 : 71 0 : t = (1 + std::sqrt(1 + 4 * tPrev * tPrev)) / 2; 72 0 : y = x + ((tPrev - 1) / t) * (x - xPrev); 73 : 74 0 : xPrev = x; 75 0 : yPrev = y; 76 0 : tPrev = t; 77 : 78 0 : Logger::get("FISTA")->info("{:>5} |{:>15} | {:>6.3} |{:>6.3}s |", iter, 79 0 : gradient.squaredL2Norm(), iter_time, aggregate_time); 80 : 81 0 : if (gradient.squaredL2Norm() <= _epsilon * _epsilon * deltaZero) { 82 0 : Logger::get("FISTA")->info("SUCCESS: Reached convergence at {}/{} iteration", 83 0 : iter + 1, iterations); 84 0 : return x; 85 : } 86 : } 87 : 88 0 : Logger::get("FISTA")->warn("Failed to reach convergence at {} iterations", iterations); 89 : 90 0 : return getCurrentSolution(); 91 0 : } 92 : 93 : template <typename data_t> 94 0 : auto FISTA<data_t>::cloneImpl() const -> FISTA<data_t>* 95 : { 96 0 : return new FISTA(*_problem, geometry::Threshold<data_t>{_mu}, _epsilon); 97 : } 98 : 99 : template <typename data_t> 100 0 : auto FISTA<data_t>::isEqual(const Solver<data_t>& other) const -> bool 101 : { 102 0 : if (!Solver<data_t>::isEqual(other)) 103 0 : return false; 104 : 105 0 : auto otherFISTA = downcast_safe<FISTA>(&other); 106 0 : if (!otherFISTA) 107 0 : return false; 108 : 109 0 : if (_mu != otherFISTA->_mu) 110 0 : return false; 111 : 112 0 : if (_epsilon != otherFISTA->_epsilon) 113 0 : return false; 114 : 115 0 : return true; 116 : } 117 : 118 : // ------------------------------------------ 119 : // explicit template instantiation 120 : template class FISTA<float>; 121 : template class FISTA<double>; 122 : } // namespace elsa