Line data Source code
1 : #pragma once 2 : 3 : #include "Solver.h" 4 : #include "LinearResidual.h" 5 : #include "StrongTypes.h" 6 : #include "LASSOProblem.h" 7 : 8 : namespace elsa 9 : { 10 : /** 11 : * @brief Class representing a Fast Iterative Shrinkage-Thresholding Algorithm solver 12 : * 13 : * This class represents a FISTA solver i.e. 14 : * 15 : * - @f$ x_{k} = shrinkageOperator(y_k - \mu * A^T (Ay_k - b)) @f$ 16 : * - @f$ t_{k+1} = \frac{1 + \sqrt{1 + 4 * t_{k}^2}}{2} @f$ 17 : * - @f$ y_{k+1} = x_{k} + (\frac{t_{k} - 1}{t_{k+1}}) * (x_{k} - x_{k - 1}) @f$ 18 : * 19 : * in which shrinkageOperator is the SoftThresholding operator defined as @f$ 20 : * shrinkageOperator(z_k) = sign(z_k)ยท(|z_k| - \mu*\lambda)_+ @f$. 21 : * 22 : * FISTA has a worst-case complexity result of @f$ O(1/k^2) @f$. 23 : * 24 : * @author Andi Braimllari - initial code 25 : * 26 : * @tparam data_t data type for the domain and range of the problem, defaulting to real_t 27 : * 28 : * References: 29 : * http://www.cs.cmu.edu/afs/cs/Web/People/airg/readings/2012_02_21_a_fast_iterative_shrinkage-thresholding.pdf 30 : * https://arxiv.org/pdf/2008.02683.pdf 31 : */ 32 : template <typename data_t = real_t> 33 : class FISTA : public Solver<data_t> 34 : { 35 : public: 36 : /// Scalar alias 37 : using Scalar = typename Solver<data_t>::Scalar; 38 : 39 : /** 40 : * @brief Constructor for FISTA, accepting a problem, a fixed step size and optionally, a 41 : * value for epsilon 42 : * 43 : * @param[in] problem the problem that is supposed to be solved 44 : * @param[in] mu the fixed step size to be used while solving 45 : * @param[in] epsilon affects the stopping condition 46 : * 47 : * Conversion to a LASSOProblem will be attempted. Throws if conversion fails. See 48 : * LASSOProblem for further details. 49 : */ 50 : FISTA(const Problem<data_t>& problem, geometry::Threshold<data_t> mu, 51 : data_t epsilon = std::numeric_limits<data_t>::epsilon()); 52 : 53 : /** 54 : * @brief Constructor for FISTA, accepting a problem and optionally, a value for 55 : * epsilon 56 : * 57 : * @param[in] problem the problem that is supposed to be solved 58 : * @param[in] epsilon affects the stopping condition 59 : * 60 : * The step size will be computed as @f$ 1 \over L @f$ with @f$ L @f$ being the Lipschitz 61 : * constant of the WLSProblem. 62 : * 63 : * Conversion to a LASSOProblem will be attempted. Throws if conversion fails. See 64 : * LASSOProblem for further details. 65 : */ 66 : FISTA(const Problem<data_t>& problem, 67 : data_t epsilon = std::numeric_limits<data_t>::epsilon()); 68 : 69 : /// make copy constructor deletion explicit 70 : FISTA(const FISTA<data_t>&) = delete; 71 : 72 : /// default destructor 73 0 : ~FISTA() override = default; 74 : 75 : /// lift the base class method getCurrentSolution 76 : using Solver<data_t>::getCurrentSolution; 77 : 78 : protected: 79 : /// lift the base class variable _problem 80 : using Solver<data_t>::_problem; 81 : 82 : /** 83 : * @brief Solve the optimization problem, i.e. apply iterations number of iterations of 84 : * FISTA 85 : * 86 : * @param[in] iterations number of iterations to execute (the default 0 value executes 87 : * _defaultIterations of iterations) 88 : * 89 : * @returns a reference to the current solution 90 : */ 91 : auto solveImpl(index_t iterations) -> DataContainer<data_t>& override; 92 : 93 : /// implement the polymorphic clone operation 94 : auto cloneImpl() const -> FISTA<data_t>* override; 95 : 96 : /// implement the polymorphic comparison operation 97 : auto isEqual(const Solver<data_t>& other) const -> bool override; 98 : 99 : private: 100 : /// private constructor called by a public constructor without the step size so that 101 : /// getLipschitzConstant is called by a LASSOProblem and not by a non-converted Problem 102 : FISTA(const LASSOProblem<data_t>& lassoProb, data_t epsilon); 103 : 104 : /// the default number of iterations 105 : const index_t _defaultIterations{100}; 106 : 107 : /// the step size 108 : data_t _mu; 109 : 110 : /// variable affecting the stopping condition 111 : data_t _epsilon; 112 : }; 113 : } // namespace elsa