Line data Source code
1 : #include "Solver.h" 2 : #include "DataContainer.h" 3 : #include "spdlog/stopwatch.h" 4 : #include <chrono> 5 : 6 : namespace elsa 7 : { 8 : template <class data_t> 9 : DataContainer<data_t> Solver<data_t>::solve(index_t iterations, 10 : std::optional<DataContainer<data_t>> x0) 11 82 : { 12 82 : auto x = setup(x0); 13 82 : x = run(iterations, x, true); 14 82 : return x; 15 82 : } 16 : 17 : template <class data_t> 18 : DataContainer<data_t> Solver<data_t>::setup(std::optional<DataContainer<data_t>>) 19 0 : { 20 0 : throw Error("Solver: setup is not implemented for this solver"); 21 0 : } 22 : 23 : template <class data_t> 24 : DataContainer<data_t> Solver<data_t>::step(DataContainer<data_t>) 25 0 : { 26 0 : throw Error("Solver: step is not implemented for this solver"); 27 0 : } 28 : 29 : template <class data_t> 30 : DataContainer<data_t> Solver<data_t>::run(index_t iterations, const DataContainer<data_t>& x0, 31 : bool show) 32 82 : { 33 82 : auto x = materialize(x0); 34 : 35 82 : if (!configured_) { 36 2 : x = setup(x); 37 2 : } 38 : 39 82 : if (show) { 40 82 : printHeader(); 41 82 : } 42 : 43 82 : spdlog::stopwatch aggregate_time; 44 : 45 82 : auto enditer = std::min(curiter_ + iterations, maxiters_); 46 841 : while (curiter_ < enditer && !shouldStop()) { 47 759 : spdlog::stopwatch steptime; 48 759 : x = step(std::move(x)); 49 759 : auto stepelapsed = steptime.elapsed(); 50 759 : auto totalelapsed = aggregate_time.elapsed(); 51 : 52 759 : callback_(x); 53 : 54 759 : if (show && curiter_ % printEvery_ == 0) { 55 759 : printStep(x, curiter_, stepelapsed, totalelapsed); 56 759 : } 57 759 : ++curiter_; 58 759 : } 59 : 60 82 : return x; 61 82 : } 62 : 63 : template <class data_t> 64 : bool Solver<data_t>::shouldStop() const 65 4 : { 66 4 : return false; 67 4 : } 68 : 69 : template <class data_t> 70 : std::string Solver<data_t>::formatHeader() const 71 0 : { 72 0 : return ""; 73 0 : } 74 : 75 : template <class data_t> 76 : void Solver<data_t>::printHeader() const 77 82 : { 78 82 : auto str = formatHeader(); 79 82 : Logger::get(name_)->info("{:^5} | {:^8} | {:^8} | {} |", "Iters", "time (s)", "elapsed", 80 82 : str); 81 82 : } 82 : 83 : template <class data_t> 84 : std::string Solver<data_t>::formatStep(const DataContainer<data_t>& /* x */) const 85 0 : { 86 0 : return ""; 87 0 : } 88 : 89 : template <class data_t> 90 : void Solver<data_t>::printStep(const DataContainer<data_t>& x, index_t curiter, 91 : std::chrono::duration<double> steptime, 92 : std::chrono::duration<double> elapsed) const 93 759 : { 94 759 : auto str = formatStep(x); 95 759 : Logger::get(name_)->info("{:>5} | {:>8.3} | {:>8.3} | {} |", curiter, steptime.count(), 96 759 : elapsed.count(), str); 97 759 : } 98 : 99 : template <class data_t> 100 : void Solver<data_t>::setMaxiters(index_t maxiters) 101 0 : { 102 0 : maxiters_ = maxiters; 103 0 : } 104 : 105 : template <class data_t> 106 : void Solver<data_t>::setCallback( 107 : const std::function<void(const DataContainer<data_t>&)>& callback) 108 0 : { 109 0 : callback_ = callback; 110 0 : } 111 : 112 : template <class data_t> 113 : void Solver<data_t>::printEvery(index_t printevery) 114 0 : { 115 0 : printEvery_ = printevery; 116 0 : } 117 : 118 : // ------------------------------------------ 119 : // explicit template instantiation 120 : template class Solver<float>; 121 : template class Solver<double>; 122 : template class Solver<complex<float>>; 123 : template class Solver<complex<double>>; 124 : } // namespace elsa