Line data Source code
1 : #pragma once 2 : 3 : #include "DataDescriptor.h" 4 : #include "elsaDefines.h" 5 : #include "DataContainer.h" 6 : #include "Cloneable.h" 7 : #include <chrono> 8 : #include <optional> 9 : 10 : namespace elsa 11 : { 12 : 13 : /** 14 : * @brief Base class representing a solver for an optimization problem. 15 : * 16 : * This class represents abstract (typically iterative) solvers acting on optimization problems. 17 : * 18 : * @author 19 : * - Matthias Wieczorek - initial code 20 : * - Maximilian Hornung - modularization 21 : * - Tobias Lasser - rewrite, modernization 22 : * 23 : * @tparam data_t data type for the domain and range of the problem, defaulting to real_t 24 : */ 25 : template <typename data_t = real_t> 26 : class Solver : public Cloneable<Solver<data_t>> 27 : { 28 : public: 29 : /// Scalar alias 30 : using Scalar = data_t; 31 : 32 189 : Solver() = default; 33 : 34 : /// default destructor 35 189 : ~Solver() override = default; 36 : 37 : /** 38 : * @brief Solve the optimization problem from a zero starting point. This 39 : * will setup the iterative algorithm and then run the algorithm for 40 : * the specified number of iterations (assuming the algorithm doesn't convergence before). 41 : * 42 : * @param[in] iterations number of iterations to execute 43 : * @param[in] x0 optional initial solution, initial solution set to zero if not present 44 : * 45 : * @returns the current solution (after solving) 46 : */ 47 : virtual DataContainer<data_t> solve(index_t iterations, 48 : std::optional<DataContainer<data_t>> x0 = std::nullopt); 49 : 50 : /** 51 : * @brief Setup and reset the solver. This function should set all values, 52 : * vectors and temporaries necessary to run the iterative reconstruction 53 : * algorithm. 54 : * 55 : * @param[in] x optional to initial solution, if optional is empty, some defaulted (e.g. 0 56 : * filled) initial solution should be returned 57 : * 58 : * @returns initial solution 59 : */ 60 : virtual DataContainer<data_t> setup(std::optional<DataContainer<data_t>> x); 61 : 62 : /** 63 : * @brief Perform a single step of the iterative reconstruction algorithm. 64 : * 65 : * @param[in] x the current solution 66 : * 67 : * @returns the updated solution estimate 68 : */ 69 : virtual DataContainer<data_t> step(DataContainer<data_t> x); 70 : 71 : /** 72 : * @brief Run the iterative reconstruction algorithm for the 73 : * given number of iterations on the initial starting point. This fun 74 : * 75 : * @param[in] iterations the number of iterations to execute 76 : * @param[in] x0 the initial starting point of the run 77 : * @param[in] show if the algorithm should print 78 : * 79 : * @returns the current solution estimate 80 : */ 81 : virtual DataContainer<data_t> run(index_t iterations, const DataContainer<data_t>& x0, 82 : bool show = false); 83 : 84 : /** 85 : * @brief Function to determine when to stop. This function should implement 86 : * algorithm specific stopping criterions. 87 : * 88 : * @returns true if the algorithm should stop 89 : */ 90 : virtual bool shouldStop() const; 91 : 92 : /** 93 : * @brief Format the header string for the iterative reconstruction algorithm 94 : */ 95 : virtual std::string formatHeader() const; 96 : 97 : /** 98 : * @brief Print the header of the iterative reconstruction algorithm 99 : */ 100 : virtual void printHeader() const; 101 : 102 : /** 103 : * @brief Format the step string for the iterative reconstruction algorithm 104 : */ 105 : virtual std::string formatStep(const DataContainer<data_t>& x) const; 106 : 107 : /** 108 : * @brief Print a step of the iterative reconstruction algorithm 109 : */ 110 : virtual void printStep(const DataContainer<data_t>& x, index_t curiter, 111 : std::chrono::duration<double> steptime, 112 : std::chrono::duration<double> elapsed) const; 113 : 114 : /** 115 : * @brief Set the maximum number of iterations 116 : */ 117 : void setMaxiters(index_t maxiters); 118 : 119 : /** 120 : * @brief Set the callback function 121 : * 122 : * An example usage to later plot the loss (assuming some ground truth data): 123 : * ```cpp 124 : * auto phantom = getSomePhantom(); 125 : * auto solver = SomeSolver(...); 126 : * 127 : * // compute mean square error for each iteration 128 : * std::vector<float> msre; 129 : * solver.setCallback([&msre](const DataContainer<data_t>& x){ 130 : * msre.push_back(square(phantom - x).sum()); 131 : * }); 132 : * ``` 133 : * Similarly from Python: 134 : * ```py 135 : * phantom = getSomePhantom() 136 : * solver = SomeSolver(...) 137 : * 138 : * // compute mean square error for each iteration 139 : * msre = [] 140 : * solver.setCallback(lambda x: msre.append(square(phantom - x).sum())) 141 : * ``` 142 : */ 143 : void setCallback(const std::function<void(const DataContainer<data_t>&)>& callback); 144 : 145 : /** 146 : * @brief Set variable to print only every n steps of the iterative reconstruction algorithm 147 : */ 148 : void printEvery(index_t printEvery); 149 : 150 : protected: 151 : /// Is the solver already configured (no need to call `setup`) 152 : bool configured_ = false; 153 : 154 : /// Logging name of the iterative reconstruction algorithm 155 : std::string name_ = "Solver"; 156 : 157 : /// Current iteration 158 : index_t curiter_ = 0; 159 : 160 : private: 161 : /// Max iteration to run in total 162 : index_t maxiters_ = 10000; 163 : 164 : index_t printEvery_ = 1; 165 : 166 : /// Callback function to call each iteration, by default, do nothing 167 759 : std::function<void(const DataContainer<data_t>&)> callback_ = [](auto) {}; 168 : }; 169 : 170 : /// Extract the default value from the optional if present, else create a 171 : /// new DataContainer with the given Descriptor and initial value 172 : template <class data_t> 173 : DataContainer<data_t> extract_or(std::optional<DataContainer<data_t>> x0, 174 : const DataDescriptor& domain, 175 : SelfType_t<data_t> val = data_t{0}) 176 140 : { 177 140 : return x0.has_value() ? x0.value() : full<data_t>(domain, val); 178 140 : } 179 : } // namespace elsa