Line data Source code
1 : #pragma once 2 : 3 : #include <memory> 4 : 5 : #include "Solver.h" 6 : #include "Problem.h" 7 : 8 : namespace elsa 9 : { 10 : /** 11 : * @brief Class representing an SQS Solver. 12 : * 13 : * @author Michael Loipführer - initial code 14 : * 15 : * @tparam data_t data type for the domain and range of the problem, defaulting to real_t 16 : * 17 : * This class implements an SQS solver with multiple options for momentum acceleration and 18 : * ordered subsets. 19 : * 20 : * No particular stopping rule is currently implemented (only a fixed number of iterations, 21 : * default to 100). 22 : * 23 : * References: 24 : * https://doi.org/10.1109/TMI.2014.2350962 25 : */ 26 : template <typename data_t = real_t> 27 : class SQS : public Solver<data_t> 28 : { 29 : public: 30 : /// Scalar alias 31 : using Scalar = typename Solver<data_t>::Scalar; 32 : 33 : /** 34 : * @brief Constructor for SQS, accepting an optimization problem and, optionally, a value 35 : * for epsilon. If the problem passed to the constructor is a SubsetProblem SQS will operate 36 : * in ordered subset mode. 37 : * 38 : * @param[in] problem the problem that is supposed to be solved 39 : * @param[in] momentumAcceleration whether to enable Nesterov's momentum acceleration 40 : * @param[in] epsilon affects the stopping condition 41 : */ 42 : SQS(const Problem<data_t>& problem, bool momentumAcceleration = true, 43 : data_t epsilon = std::numeric_limits<data_t>::epsilon()); 44 : 45 : /** 46 : * @brief Constructor for SQS, accepting an optimization problem, the inverse of the 47 : * preconditioner and, optionally, a value for epsilon. If the problem passed to the 48 : * constructor is a SubsetProblem SQS will operate in ordered subset mode. 49 : * 50 : * @param[in] problem the problem that is supposed to be solved 51 : * @param[in] preconditioner a preconditioner for the problem at hand 52 : * @param[in] momentumAcceleration whether or not to enable momentum acceleration 53 : * @param[in] epsilon affects the stopping condition 54 : */ 55 : SQS(const Problem<data_t>& problem, const LinearOperator<data_t>& preconditioner, 56 : bool momentumAcceleration = true, 57 : data_t epsilon = std::numeric_limits<data_t>::epsilon()); 58 : 59 : /// make copy constructor deletion explicit 60 : SQS(const SQS<data_t>&) = delete; 61 : 62 : /// default destructor 63 22 : ~SQS() override = default; 64 : 65 : private: 66 : /// The optimization to solve 67 : /// TODO: This might be a nice variant? 68 : std::unique_ptr<Problem<data_t>> _problem; 69 : 70 : /// the default number of iterations 71 : const index_t _defaultIterations{100}; 72 : 73 : /// variable affecting the stopping condition 74 : data_t _epsilon; 75 : 76 : /// the preconditioner (if supplied) 77 : std::unique_ptr<LinearOperator<data_t>> _preconditioner{}; 78 : 79 : /// whether to enable momentum acceleration 80 : bool _momentumAcceleration; 81 : 82 : /// whether to operate in subset based mode 83 : bool _subsetMode{false}; 84 : 85 : /** 86 : * @brief Solve the optimization problem, i.e. apply iterations number of iterations of 87 : * gradient descent 88 : * 89 : * @param[in] iterations number of iterations to execute (the default 0 value executes 90 : * _defaultIterations of iterations) 91 : * 92 : * @returns a reference to the current solution 93 : */ 94 : DataContainer<data_t>& solveImpl(index_t iterations) override; 95 : 96 : /// implement the polymorphic clone operation 97 : SQS<data_t>* cloneImpl() const override; 98 : 99 : /// implement the polymorphic comparison operation 100 : bool isEqual(const Solver<data_t>& other) const override; 101 : }; 102 : } // namespace elsa