LCOV - code coverage report
Current view: top level - elsa/solvers - FISTA.h (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 1 1 100.0 %
Date: 2022-08-25 03:05:39 Functions: 2 2 100.0 %

          Line data    Source code
       1             : #pragma once
       2             : 
       3             : #include "Solver.h"
       4             : 
       5             : #include "StrongTypes.h"
       6             : #include "LASSOProblem.h"
       7             : 
       8             : namespace elsa
       9             : {
      10             :     /**
      11             :      * @brief Class representing a Fast Iterative Shrinkage-Thresholding Algorithm solver
      12             :      *
      13             :      * This class represents a FISTA solver i.e.
      14             :      *
      15             :      *  - @f$ x_{k} = shrinkageOperator(y_k - \mu * A^T (Ay_k - b)) @f$
      16             :      *  - @f$ t_{k+1} = \frac{1 + \sqrt{1 + 4 * t_{k}^2}}{2} @f$
      17             :      *  - @f$ y_{k+1} = x_{k} + (\frac{t_{k} - 1}{t_{k+1}}) * (x_{k} - x_{k - 1}) @f$
      18             :      *
      19             :      * in which shrinkageOperator is the SoftThresholding operator defined as @f$
      20             :      * shrinkageOperator(z_k) = sign(z_k)ยท(|z_k| - \mu*\lambda)_+ @f$.
      21             :      *
      22             :      * FISTA has a worst-case complexity result of @f$ O(1/k^2) @f$.
      23             :      *
      24             :      * @author Andi Braimllari - initial code
      25             :      *
      26             :      * @tparam data_t data type for the domain and range of the problem, defaulting to real_t
      27             :      *
      28             :      * References:
      29             :      * http://www.cs.cmu.edu/afs/cs/Web/People/airg/readings/2012_02_21_a_fast_iterative_shrinkage-thresholding.pdf
      30             :      * https://arxiv.org/pdf/2008.02683.pdf
      31             :      */
      32             :     template <typename data_t = real_t>
      33             :     class FISTA : public Solver<data_t>
      34             :     {
      35             :     public:
      36             :         /// Scalar alias
      37             :         using Scalar = typename Solver<data_t>::Scalar;
      38             : 
      39             :         /**
      40             :          * @brief Constructor for FISTA, accepting a LASSO problem, a fixed step size and
      41             :          * optionally, a value for epsilon
      42             :          *
      43             :          * @param[in] problem the LASSO problem that is supposed to be solved
      44             :          * @param[in] mu the fixed step size to be used while solving
      45             :          * @param[in] epsilon affects the stopping condition
      46             :          */
      47             :         FISTA(const LASSOProblem<data_t>& problem, geometry::Threshold<data_t> mu,
      48             :               data_t epsilon = std::numeric_limits<data_t>::epsilon());
      49             : 
      50             :         /**
      51             :          * @brief Constructor for FISTA, accepting a problem, a fixed step size and optionally, a
      52             :          * value for epsilon
      53             :          *
      54             :          * @param[in] problem the problem that is supposed to be solved
      55             :          * @param[in] mu the fixed step size to be used while solving
      56             :          * @param[in] epsilon affects the stopping condition
      57             :          *
      58             :          * Conversion to a LASSOProblem will be attempted. Throws if conversion fails. See
      59             :          * LASSOProblem for further details.
      60             :          */
      61             :         FISTA(const Problem<data_t>& problem, geometry::Threshold<data_t> mu,
      62             :               data_t epsilon = std::numeric_limits<data_t>::epsilon());
      63             : 
      64             :         /**
      65             :          * @brief Constructor for FISTA, accepting a problem and optionally, a value for
      66             :          * epsilon
      67             :          *
      68             :          * @param[in] problem the problem that is supposed to be solved
      69             :          * @param[in] epsilon affects the stopping condition
      70             :          *
      71             :          * The step size will be computed as @f$ 1 \over L @f$ with @f$ L @f$ being the Lipschitz
      72             :          * constant of the WLSProblem.
      73             :          *
      74             :          * Conversion to a LASSOProblem will be attempted. Throws if conversion fails. See
      75             :          * LASSOProblem for further details.
      76             :          */
      77             :         FISTA(const Problem<data_t>& problem,
      78             :               data_t epsilon = std::numeric_limits<data_t>::epsilon());
      79             : 
      80             :         /// make copy constructor deletion explicit
      81             :         FISTA(const FISTA<data_t>&) = delete;
      82             : 
      83             :         /// default destructor
      84           5 :         ~FISTA() override = default;
      85             : 
      86             :     protected:
      87             :         /**
      88             :          * @brief Solve the optimization problem, i.e. apply iterations number of iterations of
      89             :          * FISTA
      90             :          *
      91             :          * @param[in] iterations number of iterations to execute (the default 0 value executes
      92             :          * _defaultIterations of iterations)
      93             :          *
      94             :          * @returns a reference to the current solution
      95             :          */
      96             :         auto solveImpl(index_t iterations) -> DataContainer<data_t>& override;
      97             : 
      98             :         /// implement the polymorphic clone operation
      99             :         auto cloneImpl() const -> FISTA<data_t>* override;
     100             : 
     101             :         /// implement the polymorphic comparison operation
     102             :         auto isEqual(const Solver<data_t>& other) const -> bool override;
     103             : 
     104             :     private:
     105             :         /// private constructor called by a public constructor without the step size so that
     106             :         /// getLipschitzConstant is called by a LASSOProblem and not by a non-converted Problem
     107             :         FISTA(const LASSOProblem<data_t>& lassoProb, data_t epsilon);
     108             : 
     109             :         /// The LASSO optimization problem
     110             :         LASSOProblem<data_t> _problem;
     111             : 
     112             :         /// the default number of iterations
     113             :         const index_t _defaultIterations{100};
     114             : 
     115             :         /// the step size
     116             :         data_t _mu;
     117             : 
     118             :         /// variable affecting the stopping condition
     119             :         data_t _epsilon;
     120             :     };
     121             : } // namespace elsa

Generated by: LCOV version 1.14