Line data Source code
1 : #include "ISTA.h" 2 : #include "SoftThresholding.h" 3 : #include "TypeCasts.hpp" 4 : #include "Logger.h" 5 : 6 : #include "spdlog/stopwatch.h" 7 : 8 : namespace elsa 9 : { 10 : template <typename data_t> 11 : ISTA<data_t>::ISTA(const LASSOProblem<data_t>& problem, geometry::Threshold<data_t> mu, 12 : data_t epsilon) 13 : : Solver<data_t>(), _problem(problem), _mu{mu}, _epsilon{epsilon} 14 1 : { 15 1 : } 16 : 17 : template <typename data_t> 18 : ISTA<data_t>::ISTA(const Problem<data_t>& problem, geometry::Threshold<data_t> mu, 19 : data_t epsilon) 20 : : ISTA(LASSOProblem<data_t>(problem), mu, epsilon) 21 0 : { 22 0 : } 23 : 24 : template <typename data_t> 25 : ISTA<data_t>::ISTA(const Problem<data_t>& problem, data_t epsilon) 26 : : ISTA<data_t>(LASSOProblem<data_t>(problem), epsilon) 27 4 : { 28 4 : } 29 : 30 : template <typename data_t> 31 : ISTA<data_t>::ISTA(const LASSOProblem<data_t>& lassoProb, data_t epsilon) 32 : : Solver<data_t>(), 33 : _problem(lassoProb), 34 : _mu{1 / lassoProb.getLipschitzConstant()}, 35 : _epsilon{epsilon} 36 2 : { 37 2 : } 38 : 39 : template <typename data_t> 40 : auto ISTA<data_t>::solveImpl(index_t iterations) -> DataContainer<data_t>& 41 1 : { 42 1 : if (iterations == 0) 43 0 : iterations = _defaultIterations; 44 : 45 1 : spdlog::stopwatch aggregate_time; 46 1 : Logger::get("ISTA")->info("Start preparations..."); 47 : 48 1 : SoftThresholding<data_t> shrinkageOp{_problem.getCurrentSolution().getDataDescriptor()}; 49 : 50 1 : data_t lambda = _problem.getRegularizationTerms()[0].getWeight(); 51 : 52 : // Safe as long as only LinearResidual exits 53 1 : const auto& linResid = 54 1 : downcast<LinearResidual<data_t>>((_problem.getDataTerm()).getResidual()); 55 1 : const LinearOperator<data_t>& A = linResid.getOperator(); 56 1 : const DataContainer<data_t>& b = linResid.getDataVector(); 57 : 58 1 : DataContainer<data_t>& x = _problem.getCurrentSolution(); 59 1 : DataContainer<data_t> Atb = A.applyAdjoint(b); 60 1 : DataContainer<data_t> gradient = A.applyAdjoint(A.apply(x)) - Atb; 61 : 62 1 : Logger::get("ISTA")->info("Preparations done, tooke {}s", aggregate_time); 63 1 : Logger::get("ISTA")->info("{:^6}|{:*^16}|{:*^8}|{:*^8}|", "iter", "gradient", "time", 64 1 : "elapsed"); 65 : 66 1 : auto deltaZero = gradient.squaredL2Norm(); 67 101 : for (index_t iter = 0; iter < iterations; ++iter) { 68 100 : spdlog::stopwatch iter_time; 69 : 70 100 : gradient = A.applyAdjoint(A.apply(x)) - Atb; 71 : 72 100 : x = shrinkageOp.apply(x - _mu * gradient, geometry::Threshold{_mu * lambda}); 73 : 74 100 : Logger::get("ISTA")->info("{:>5} |{:>15} | {:>6.3} |{:>6.3}s |", iter, 75 100 : gradient.squaredL2Norm(), iter_time, aggregate_time); 76 100 : if (gradient.squaredL2Norm() <= _epsilon * _epsilon * deltaZero) { 77 0 : Logger::get("ISTA")->info("SUCCESS: Reached convergence at {}/{} iteration", 78 0 : iter + 1, iterations); 79 0 : return x; 80 0 : } 81 100 : } 82 : 83 1 : Logger::get("ISTA")->warn("Failed to reach convergence at {} iterations", iterations); 84 : 85 1 : return _problem.getCurrentSolution(); 86 1 : } 87 : 88 : template <typename data_t> 89 : auto ISTA<data_t>::cloneImpl() const -> ISTA<data_t>* 90 1 : { 91 1 : return new ISTA(_problem, geometry::Threshold<data_t>{_mu}, _epsilon); 92 1 : } 93 : 94 : template <typename data_t> 95 : auto ISTA<data_t>::isEqual(const Solver<data_t>& other) const -> bool 96 1 : { 97 1 : auto otherISTA = downcast_safe<ISTA>(&other); 98 1 : if (!otherISTA) 99 0 : return false; 100 : 101 1 : if (_mu != otherISTA->_mu) 102 0 : return false; 103 : 104 1 : if (_epsilon != otherISTA->_epsilon) 105 0 : return false; 106 : 107 1 : return true; 108 1 : } 109 : 110 : // ------------------------------------------ 111 : // explicit template instantiation 112 : template class ISTA<float>; 113 : template class ISTA<double>; 114 : } // namespace elsa