Line data Source code
1 : #pragma once 2 : 3 : #include "Solver.h" 4 : #include "StrongTypes.h" 5 : #include "LASSOProblem.h" 6 : 7 : namespace elsa 8 : { 9 : /** 10 : * @brief Class representing an Iterative Shrinkage-Thresholding Algorithm solver 11 : * 12 : * This class represents an ISTA solver i.e. 13 : * 14 : * - @f$ x_{k+1} = shrinkageOperator(x_k - \mu * A^T (Ax_k - b)) @f$ 15 : * 16 : * in which shrinkageOperator is the SoftThresholding operator defined as @f$ 17 : * shrinkageOperator(z_{k}) = sign(z_{k})ยท(|z_{k}| - \mu*\lambda)_+ @f$. Each iteration of ISTA 18 : * involves a gradient descent update followed by a shrinkage/soft-threshold step: 19 : * 20 : * ISTA has a worst-case complexity result of @f$ O(1/k) @f$. 21 : * 22 : * @author Andi Braimllari - initial code 23 : * 24 : * @tparam data_t data type for the domain and range of the problem, defaulting to real_t 25 : * 26 : * References: 27 : * http://www.cs.cmu.edu/afs/cs/Web/People/airg/readings/2012_02_21_a_fast_iterative_shrinkage-thresholding.pdf 28 : * https://arxiv.org/pdf/2008.02683.pdf 29 : */ 30 : template <typename data_t = real_t> 31 : class ISTA : public Solver<data_t> 32 : { 33 : public: 34 : /// Scalar alias 35 : using Scalar = typename Solver<data_t>::Scalar; 36 : 37 : /** 38 : * @brief Constructor for ISTA, accepting a LASSO problem, a fixed step size and optionally, 39 : * a value for epsilon 40 : * 41 : * @param[in] problem the LASSO problem that is supposed to be solved 42 : * @param[in] mu the fixed step size to be used while solving 43 : * @param[in] epsilon affects the stopping condition 44 : */ 45 : ISTA(const LASSOProblem<data_t>& problem, geometry::Threshold<data_t> mu, 46 : data_t epsilon = std::numeric_limits<data_t>::epsilon()); 47 : 48 : /** 49 : * @brief Constructor for ISTA, accepting a problem, a fixed step size and optionally, a 50 : * value for epsilon 51 : * 52 : * @param[in] problem the problem that is supposed to be solved 53 : * @param[in] mu the fixed step size to be used while solving 54 : * @param[in] epsilon affects the stopping condition 55 : * 56 : * Conversion to a LASSOProblem will be attempted. Throws if conversion fails. See 57 : * LASSOProblem for further details. 58 : */ 59 : ISTA(const Problem<data_t>& problem, geometry::Threshold<data_t> mu, 60 : data_t epsilon = std::numeric_limits<data_t>::epsilon()); 61 : 62 : /** 63 : * @brief Constructor for ISTA, accepting a problem and optionally, a value for 64 : * epsilon 65 : * 66 : * @param[in] problem the problem that is supposed to be solved 67 : * @param[in] epsilon affects the stopping condition 68 : * 69 : * The step size will be computed as @f$ 1 \over L @f$ with @f$ L @f$ being the Lipschitz 70 : * constant of the WLSProblem. 71 : * 72 : * Conversion to a LASSOProblem will be attempted. Throws if conversion fails. See 73 : * LASSOProblem for further details. 74 : */ 75 : ISTA(const Problem<data_t>& problem, 76 : data_t epsilon = std::numeric_limits<data_t>::epsilon()); 77 : 78 : /// make copy constructor deletion explicit 79 : ISTA(const ISTA<data_t>&) = delete; 80 : 81 : /// default destructor 82 3 : ~ISTA() override = default; 83 : 84 : protected: 85 : /** 86 : * @brief Solve the optimization problem, i.e. apply iterations number of iterations of 87 : * ISTA 88 : * 89 : * @param[in] iterations number of iterations to execute (the default 0 value executes 90 : * _defaultIterations of iterations) 91 : * 92 : * @returns a reference to the current solution 93 : */ 94 : auto solveImpl(index_t iterations) -> DataContainer<data_t>& override; 95 : 96 : /// implement the polymorphic clone operation 97 : auto cloneImpl() const -> ISTA<data_t>* override; 98 : 99 : /// implement the polymorphic comparison operation 100 : auto isEqual(const Solver<data_t>& other) const -> bool override; 101 : 102 : private: 103 : /// private constructor called by a public constructor without the step size so that 104 : /// getLipschitzConstant is called by a LASSOProblem and not by a non-converted Problem 105 : ISTA(const LASSOProblem<data_t>& lassoProb, data_t epsilon); 106 : 107 : /// The LASSO optimization problem 108 : LASSOProblem<data_t> _problem; 109 : 110 : /// the default number of iterations 111 : const index_t _defaultIterations{100}; 112 : 113 : /// the step size 114 : data_t _mu; 115 : 116 : /// variable affecting the stopping condition 117 : data_t _epsilon; 118 : }; 119 : } // namespace elsa