Line data Source code
1 : #pragma once 2 : 3 : #include "Solver.h" 4 : 5 : #include "StrongTypes.h" 6 : #include "LASSOProblem.h" 7 : 8 : namespace elsa 9 : { 10 : /** 11 : * @brief Class representing a Fast Iterative Shrinkage-Thresholding Algorithm solver 12 : * 13 : * This class represents a FISTA solver i.e. 14 : * 15 : * - @f$ x_{k} = shrinkageOperator(y_k - \mu * A^T (Ay_k - b)) @f$ 16 : * - @f$ t_{k+1} = \frac{1 + \sqrt{1 + 4 * t_{k}^2}}{2} @f$ 17 : * - @f$ y_{k+1} = x_{k} + (\frac{t_{k} - 1}{t_{k+1}}) * (x_{k} - x_{k - 1}) @f$ 18 : * 19 : * in which shrinkageOperator is the SoftThresholding operator defined as @f$ 20 : * shrinkageOperator(z_k) = sign(z_k)ยท(|z_k| - \mu*\lambda)_+ @f$. 21 : * 22 : * FISTA has a worst-case complexity result of @f$ O(1/k^2) @f$. 23 : * 24 : * @author Andi Braimllari - initial code 25 : * 26 : * @tparam data_t data type for the domain and range of the problem, defaulting to real_t 27 : * 28 : * References: 29 : * http://www.cs.cmu.edu/afs/cs/Web/People/airg/readings/2012_02_21_a_fast_iterative_shrinkage-thresholding.pdf 30 : * https://arxiv.org/pdf/2008.02683.pdf 31 : */ 32 : template <typename data_t = real_t> 33 : class FISTA : public Solver<data_t> 34 : { 35 : public: 36 : /// Scalar alias 37 : using Scalar = typename Solver<data_t>::Scalar; 38 : 39 : /** 40 : * @brief Constructor for FISTA, accepting a LASSO problem, a fixed step size and 41 : * optionally, a value for epsilon 42 : * 43 : * @param[in] problem the LASSO problem that is supposed to be solved 44 : * @param[in] mu the fixed step size to be used while solving 45 : * @param[in] epsilon affects the stopping condition 46 : */ 47 : FISTA(const LASSOProblem<data_t>& problem, geometry::Threshold<data_t> mu, 48 : data_t epsilon = std::numeric_limits<data_t>::epsilon()); 49 : 50 : /** 51 : * @brief Constructor for FISTA, accepting a problem, a fixed step size and optionally, a 52 : * value for epsilon 53 : * 54 : * @param[in] problem the problem that is supposed to be solved 55 : * @param[in] mu the fixed step size to be used while solving 56 : * @param[in] epsilon affects the stopping condition 57 : * 58 : * Conversion to a LASSOProblem will be attempted. Throws if conversion fails. See 59 : * LASSOProblem for further details. 60 : */ 61 : FISTA(const Problem<data_t>& problem, geometry::Threshold<data_t> mu, 62 : data_t epsilon = std::numeric_limits<data_t>::epsilon()); 63 : 64 : /** 65 : * @brief Constructor for FISTA, accepting a problem and optionally, a value for 66 : * epsilon 67 : * 68 : * @param[in] problem the problem that is supposed to be solved 69 : * @param[in] epsilon affects the stopping condition 70 : * 71 : * The step size will be computed as @f$ 1 \over L @f$ with @f$ L @f$ being the Lipschitz 72 : * constant of the WLSProblem. 73 : * 74 : * Conversion to a LASSOProblem will be attempted. Throws if conversion fails. See 75 : * LASSOProblem for further details. 76 : */ 77 : FISTA(const Problem<data_t>& problem, 78 : data_t epsilon = std::numeric_limits<data_t>::epsilon()); 79 : 80 : /// make copy constructor deletion explicit 81 : FISTA(const FISTA<data_t>&) = delete; 82 : 83 : /// default destructor 84 5 : ~FISTA() override = default; 85 : 86 : protected: 87 : /** 88 : * @brief Solve the optimization problem, i.e. apply iterations number of iterations of 89 : * FISTA 90 : * 91 : * @param[in] iterations number of iterations to execute (the default 0 value executes 92 : * _defaultIterations of iterations) 93 : * 94 : * @returns a reference to the current solution 95 : */ 96 : auto solveImpl(index_t iterations) -> DataContainer<data_t>& override; 97 : 98 : /// implement the polymorphic clone operation 99 : auto cloneImpl() const -> FISTA<data_t>* override; 100 : 101 : /// implement the polymorphic comparison operation 102 : auto isEqual(const Solver<data_t>& other) const -> bool override; 103 : 104 : private: 105 : /// private constructor called by a public constructor without the step size so that 106 : /// getLipschitzConstant is called by a LASSOProblem and not by a non-converted Problem 107 : FISTA(const LASSOProblem<data_t>& lassoProb, data_t epsilon); 108 : 109 : /// The LASSO optimization problem 110 : LASSOProblem<data_t> _problem; 111 : 112 : /// the default number of iterations 113 : const index_t _defaultIterations{100}; 114 : 115 : /// the step size 116 : data_t _mu; 117 : 118 : /// variable affecting the stopping condition 119 : data_t _epsilon; 120 : }; 121 : } // namespace elsa