Line data Source code
1 : #pragma once 2 : 3 : #include "Solver.h" 4 : 5 : namespace elsa 6 : { 7 : /** 8 : * @brief Class representing an SQS Solver. 9 : * 10 : * @author Michael Loipführer - initial code 11 : * 12 : * @tparam data_t data type for the domain and range of the problem, defaulting to real_t 13 : * 14 : * This class implements an SQS solver with multiple options for momentum acceleration and 15 : * ordered subsets. 16 : * 17 : * No particular stopping rule is currently implemented (only a fixed number of iterations, 18 : * default to 100). 19 : * 20 : * References: 21 : * https://doi.org/10.1109/TMI.2014.2350962 22 : */ 23 : template <typename data_t = real_t> 24 : class SQS : public Solver<data_t> 25 : { 26 : public: 27 : /// Scalar alias 28 : using Scalar = typename Solver<data_t>::Scalar; 29 : 30 : /** 31 : * @brief Constructor for SQS, accepting an optimization problem and, optionally, a value 32 : * for epsilon. If the problem passed to the constructor is a SubsetProblem SQS will operate 33 : * in ordered subset mode. 34 : * 35 : * @param[in] problem the problem that is supposed to be solved 36 : * @param[in] momentumAcceleration whether to enable Nesterov's momentum acceleration 37 : * @param[in] epsilon affects the stopping condition 38 : */ 39 : SQS(const Problem<data_t>& problem, bool momentumAcceleration = true, 40 : data_t epsilon = std::numeric_limits<data_t>::epsilon()); 41 : 42 : /** 43 : * @brief Constructor for SQS, accepting an optimization problem, the inverse of the 44 : * preconditioner and, optionally, a value for epsilon. If the problem passed to the 45 : * constructor is a SubsetProblem SQS will operate in ordered subset mode. 46 : * 47 : * @param[in] problem the problem that is supposed to be solved 48 : * @param[in] preconditioner a preconditioner for the problem at hand 49 : * @param[in] momentumAcceleration whether or not to enable momentum acceleration 50 : * @param[in] epsilon affects the stopping condition 51 : */ 52 : SQS(const Problem<data_t>& problem, const LinearOperator<data_t>& preconditioner, 53 : bool momentumAcceleration = true, 54 : data_t epsilon = std::numeric_limits<data_t>::epsilon()); 55 : 56 : /// make copy constructor deletion explicit 57 : SQS(const SQS<data_t>&) = delete; 58 : 59 : /// default destructor 60 0 : ~SQS() override = default; 61 : 62 : /// lift the base class method getCurrentSolution 63 : using Solver<data_t>::getCurrentSolution; 64 : 65 : private: 66 : /// the default number of iterations 67 : const index_t _defaultIterations{100}; 68 : 69 : /// variable affecting the stopping condition 70 : data_t _epsilon; 71 : 72 : /// the preconditioner (if supplied) 73 : std::unique_ptr<LinearOperator<data_t>> _preconditioner{}; 74 : 75 : /// whether to enable momentum acceleration 76 : bool _momentumAcceleration; 77 : 78 : /// whether to operate in subset based mode 79 : bool _subsetMode{false}; 80 : 81 : /// lift the base class variable _problem 82 : using Solver<data_t>::_problem; 83 : 84 : /** 85 : * @brief Solve the optimization problem, i.e. apply iterations number of iterations of 86 : * gradient descent 87 : * 88 : * @param[in] iterations number of iterations to execute (the default 0 value executes 89 : * _defaultIterations of iterations) 90 : * 91 : * @returns a reference to the current solution 92 : */ 93 : DataContainer<data_t>& solveImpl(index_t iterations) override; 94 : 95 : /// implement the polymorphic clone operation 96 : SQS<data_t>* cloneImpl() const override; 97 : 98 : /// implement the polymorphic comparison operation 99 : bool isEqual(const Solver<data_t>& other) const override; 100 : }; 101 : } // namespace elsa