Line data Source code
1 : #pragma once 2 : 3 : #include <memory> 4 : #include <optional> 5 : #include <vector> 6 : 7 : #include "LeastSquares.h" 8 : #include "LinearOperator.h" 9 : #include "Solver.h" 10 : 11 : namespace elsa 12 : { 13 : /** 14 : * @brief Class representing an SQS Solver. 15 : * 16 : * This class implements an SQS solver with multiple options for momentum acceleration and 17 : * ordered subsets. 18 : * 19 : * No particular stopping rule is currently implemented (only a fixed number of iterations, 20 : * default to 100). 21 : * 22 : * Currently, this is limited to least square problems. 23 : * 24 : * References: 25 : * https://doi.org/10.1109/TMI.2014.2350962 26 : * 27 : * @author Michael Loipführer - initial code 28 : * 29 : * @tparam data_t data type for the domain and range of the problem, defaulting to real_t 30 : */ 31 : template <typename data_t = real_t> 32 : class SQS : public Solver<data_t> 33 : { 34 : public: 35 : /// Scalar alias 36 : using Scalar = typename Solver<data_t>::Scalar; 37 : 38 : SQS(const LeastSquares<data_t>& problem, 39 : std::vector<std::unique_ptr<LeastSquares<data_t>>>&& subsets, 40 : bool momentumAcceleration = true, 41 : data_t epsilon = std::numeric_limits<data_t>::epsilon()); 42 : 43 : SQS(const LeastSquares<data_t>& problem, 44 : std::vector<std::unique_ptr<LeastSquares<data_t>>>&& subsets, 45 : const LinearOperator<data_t>& preconditioner, bool momentumAcceleration = true, 46 : data_t epsilon = std::numeric_limits<data_t>::epsilon()); 47 : 48 : SQS(const LeastSquares<data_t>& problem, bool momentumAcceleration = true, 49 : data_t epsilon = std::numeric_limits<data_t>::epsilon()); 50 : 51 : SQS(const LeastSquares<data_t>& problem, const LinearOperator<data_t>& preconditioner, 52 : bool momentumAcceleration = true, 53 : data_t epsilon = std::numeric_limits<data_t>::epsilon()); 54 : 55 : /// make copy constructor deletion explicit 56 : SQS(const SQS<data_t>&) = delete; 57 : 58 : /// default destructor 59 20 : ~SQS() override = default; 60 : 61 : /** 62 : * @brief Solve the optimization problem, i.e. apply iterations number of iterations of 63 : * gradient descent 64 : * 65 : * @param[in] iterations number of iterations to execute 66 : * @param[in] x0 optional initial solution, initial solution set to zero if not present 67 : * 68 : * @returns the approximated solution 69 : */ 70 : DataContainer<data_t> 71 : solve(index_t iterations, 72 : std::optional<DataContainer<data_t>> x0 = std::nullopt) override; 73 : 74 : private: 75 : /// The optimization to solve 76 : /// TODO: This might be a nice variant? 77 : std::unique_ptr<LeastSquares<data_t>> fullProblem_; 78 : 79 : std::vector<std::unique_ptr<LeastSquares<data_t>>> subsets_{}; 80 : 81 : /// variable affecting the stopping condition 82 : data_t epsilon_; 83 : 84 : /// the preconditioner (if supplied) 85 : std::unique_ptr<LinearOperator<data_t>> preconditioner_{}; 86 : 87 : /// whether to enable momentum acceleration 88 : bool momentumAcceleration_; 89 : 90 : /// whether to operate in subset based mode 91 : bool subsetMode_; 92 : 93 : /// implement the polymorphic clone operation 94 : SQS<data_t>* cloneImpl() const override; 95 : 96 : /// implement the polymorphic comparison operation 97 : bool isEqual(const Solver<data_t>& other) const override; 98 : }; 99 : } // namespace elsa