Line data Source code
1 : #pragma once 2 : 3 : #include "Solver.h" 4 : #include "LinearResidual.h" 5 : #include "StrongTypes.h" 6 : #include "LASSOProblem.h" 7 : 8 : namespace elsa 9 : { 10 : /** 11 : * @brief Class representing an Iterative Shrinkage-Thresholding Algorithm solver 12 : * 13 : * This class represents an ISTA solver i.e. 14 : * 15 : * - @f$ x_{k+1} = shrinkageOperator(x_k - \mu * A^T (Ax_k - b)) @f$ 16 : * 17 : * in which shrinkageOperator is the SoftThresholding operator defined as @f$ 18 : * shrinkageOperator(z_{k}) = sign(z_{k})ยท(|z_{k}| - \mu*\lambda)_+ @f$. Each iteration of ISTA 19 : * involves a gradient descent update followed by a shrinkage/soft-threshold step: 20 : * 21 : * ISTA has a worst-case complexity result of @f$ O(1/k) @f$. 22 : * 23 : * @author Andi Braimllari - initial code 24 : * 25 : * @tparam data_t data type for the domain and range of the problem, defaulting to real_t 26 : * 27 : * References: 28 : * http://www.cs.cmu.edu/afs/cs/Web/People/airg/readings/2012_02_21_a_fast_iterative_shrinkage-thresholding.pdf 29 : * https://arxiv.org/pdf/2008.02683.pdf 30 : */ 31 : template <typename data_t = real_t> 32 : class ISTA : public Solver<data_t> 33 : { 34 : public: 35 : /// Scalar alias 36 : using Scalar = typename Solver<data_t>::Scalar; 37 : 38 : /** 39 : * @brief Constructor for ISTA, accepting a problem, a fixed step size and optionally, a 40 : * value for epsilon 41 : * 42 : * @param[in] problem the problem that is supposed to be solved 43 : * @param[in] mu the fixed step size to be used while solving 44 : * @param[in] epsilon affects the stopping condition 45 : * 46 : * Conversion to a LASSOProblem will be attempted. Throws if conversion fails. See 47 : * LASSOProblem for further details. 48 : */ 49 : ISTA(const Problem<data_t>& problem, geometry::Threshold<data_t> mu, 50 : data_t epsilon = std::numeric_limits<data_t>::epsilon()); 51 : 52 : /** 53 : * @brief Constructor for ISTA, accepting a problem and optionally, a value for 54 : * epsilon 55 : * 56 : * @param[in] problem the problem that is supposed to be solved 57 : * @param[in] epsilon affects the stopping condition 58 : * 59 : * The step size will be computed as @f$ 1 \over L @f$ with @f$ L @f$ being the Lipschitz 60 : * constant of the WLSProblem. 61 : * 62 : * Conversion to a LASSOProblem will be attempted. Throws if conversion fails. See 63 : * LASSOProblem for further details. 64 : */ 65 : ISTA(const Problem<data_t>& problem, 66 : data_t epsilon = std::numeric_limits<data_t>::epsilon()); 67 : 68 : /// make copy constructor deletion explicit 69 : ISTA(const ISTA<data_t>&) = delete; 70 : 71 : /// default destructor 72 0 : ~ISTA() override = default; 73 : 74 : /// lift the base class method getCurrentSolution 75 : using Solver<data_t>::getCurrentSolution; 76 : 77 : protected: 78 : /// lift the base class variable _problem 79 : using Solver<data_t>::_problem; 80 : 81 : /** 82 : * @brief Solve the optimization problem, i.e. apply iterations number of iterations of 83 : * ISTA 84 : * 85 : * @param[in] iterations number of iterations to execute (the default 0 value executes 86 : * _defaultIterations of iterations) 87 : * 88 : * @returns a reference to the current solution 89 : */ 90 : auto solveImpl(index_t iterations) -> DataContainer<data_t>& override; 91 : 92 : /// implement the polymorphic clone operation 93 : auto cloneImpl() const -> ISTA<data_t>* override; 94 : 95 : /// implement the polymorphic comparison operation 96 : auto isEqual(const Solver<data_t>& other) const -> bool override; 97 : 98 : private: 99 : /// private constructor called by a public constructor without the step size so that 100 : /// getLipschitzConstant is called by a LASSOProblem and not by a non-converted Problem 101 : ISTA(const LASSOProblem<data_t>& lassoProb, data_t epsilon); 102 : 103 : /// the default number of iterations 104 : const index_t _defaultIterations{100}; 105 : 106 : /// the step size 107 : data_t _mu; 108 : 109 : /// variable affecting the stopping condition 110 : data_t _epsilon; 111 : }; 112 : } // namespace elsa