Line data Source code
1 : #include "ISTA.h" 2 : #include "SoftThresholding.h" 3 : #include "TypeCasts.hpp" 4 : #include "Logger.h" 5 : 6 : #include "spdlog/stopwatch.h" 7 : 8 : namespace elsa 9 : { 10 : template <typename data_t> 11 0 : ISTA<data_t>::ISTA(const Problem<data_t>& problem, geometry::Threshold<data_t> mu, 12 : data_t epsilon) 13 0 : : Solver<data_t>(LASSOProblem(problem)), _mu{mu}, _epsilon{epsilon} 14 : { 15 0 : } 16 : 17 : template <typename data_t> 18 0 : ISTA<data_t>::ISTA(const Problem<data_t>& problem, data_t epsilon) 19 0 : : ISTA<data_t>(LASSOProblem(problem), epsilon) 20 : { 21 0 : } 22 : 23 : template <typename data_t> 24 0 : ISTA<data_t>::ISTA(const LASSOProblem<data_t>& lassoProb, data_t epsilon) 25 0 : : Solver<data_t>(lassoProb), _mu{1 / lassoProb.getLipschitzConstant()}, _epsilon{epsilon} 26 : { 27 0 : } 28 : 29 : template <typename data_t> 30 0 : auto ISTA<data_t>::solveImpl(index_t iterations) -> DataContainer<data_t>& 31 : { 32 0 : if (iterations == 0) 33 0 : iterations = _defaultIterations; 34 : 35 0 : spdlog::stopwatch aggregate_time; 36 0 : Logger::get("ISTA")->info("Start preparations..."); 37 : 38 0 : SoftThresholding<data_t> shrinkageOp{getCurrentSolution().getDataDescriptor()}; 39 : 40 0 : data_t lambda = _problem->getRegularizationTerms()[0].getWeight(); 41 : 42 : // Safe as long as only LinearResidual exits 43 : const auto& linResid = 44 0 : downcast<LinearResidual<data_t>>((_problem->getDataTerm()).getResidual()); 45 0 : const LinearOperator<data_t>& A = linResid.getOperator(); 46 0 : const DataContainer<data_t>& b = linResid.getDataVector(); 47 : 48 0 : DataContainer<data_t>& x = getCurrentSolution(); 49 0 : DataContainer<data_t> Atb = A.applyAdjoint(b); 50 0 : DataContainer<data_t> gradient = A.applyAdjoint(A.apply(x)) - Atb; 51 : 52 0 : Logger::get("ISTA")->info("Preparations done, tooke {}s", aggregate_time); 53 0 : Logger::get("ISTA")->info("{:^6}|{:*^16}|{:*^8}|{:*^8}|", "iter", "gradient", "time", 54 : "elapsed"); 55 : 56 0 : auto deltaZero = gradient.squaredL2Norm(); 57 0 : for (index_t iter = 0; iter < iterations; ++iter) { 58 0 : spdlog::stopwatch iter_time; 59 : 60 0 : gradient = A.applyAdjoint(A.apply(x)) - Atb; 61 : 62 0 : x = shrinkageOp.apply(x - _mu * gradient, geometry::Threshold{_mu * lambda}); 63 : 64 0 : Logger::get("ISTA")->info("{:>5} |{:>15} | {:>6.3} |{:>6.3}s |", iter, 65 0 : gradient.squaredL2Norm(), iter_time, aggregate_time); 66 0 : if (gradient.squaredL2Norm() <= _epsilon * _epsilon * deltaZero) { 67 0 : Logger::get("ISTA")->info("SUCCESS: Reached convergence at {}/{} iteration", 68 0 : iter + 1, iterations); 69 0 : return x; 70 : } 71 : } 72 : 73 0 : Logger::get("ISTA")->warn("Failed to reach convergence at {} iterations", iterations); 74 : 75 0 : return getCurrentSolution(); 76 0 : } 77 : 78 : template <typename data_t> 79 0 : auto ISTA<data_t>::cloneImpl() const -> ISTA<data_t>* 80 : { 81 0 : return new ISTA(*_problem, geometry::Threshold<data_t>{_mu}, _epsilon); 82 : } 83 : 84 : template <typename data_t> 85 0 : auto ISTA<data_t>::isEqual(const Solver<data_t>& other) const -> bool 86 : { 87 0 : if (!Solver<data_t>::isEqual(other)) 88 0 : return false; 89 : 90 0 : auto otherISTA = downcast_safe<ISTA>(&other); 91 0 : if (!otherISTA) 92 0 : return false; 93 : 94 0 : if (_mu != otherISTA->_mu) 95 0 : return false; 96 : 97 0 : if (_epsilon != otherISTA->_epsilon) 98 0 : return false; 99 : 100 0 : return true; 101 : } 102 : 103 : // ------------------------------------------ 104 : // explicit template instantiation 105 : template class ISTA<float>; 106 : template class ISTA<double>; 107 : } // namespace elsa