LCOV - code coverage report
Current view: top level - elsa/solvers - Solver.h (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 6 6 100.0 %
Date: 2024-05-16 04:22:26 Functions: 8 10 80.0 %

          Line data    Source code
       1             : #pragma once
       2             : 
       3             : #include "DataDescriptor.h"
       4             : #include "elsaDefines.h"
       5             : #include "DataContainer.h"
       6             : #include "Cloneable.h"
       7             : #include <chrono>
       8             : #include <optional>
       9             : 
      10             : namespace elsa
      11             : {
      12             : 
      13             :     /**
      14             :      * @brief Base class representing a solver for an optimization problem.
      15             :      *
      16             :      * This class represents abstract (typically iterative) solvers acting on optimization problems.
      17             :      *
      18             :      * @author
      19             :      * - Matthias Wieczorek - initial code
      20             :      * - Maximilian Hornung - modularization
      21             :      * - Tobias Lasser - rewrite, modernization
      22             :      *
      23             :      * @tparam data_t data type for the domain and range of the problem, defaulting to real_t
      24             :      */
      25             :     template <typename data_t = real_t>
      26             :     class Solver : public Cloneable<Solver<data_t>>
      27             :     {
      28             :     public:
      29             :         /// Scalar alias
      30             :         using Scalar = data_t;
      31             : 
      32         189 :         Solver() = default;
      33             : 
      34             :         /// default destructor
      35         189 :         ~Solver() override = default;
      36             : 
      37             :         /**
      38             :          * @brief Solve the optimization problem from a zero starting point. This
      39             :          * will setup the iterative algorithm and then run the algorithm for
      40             :          * the specified number of iterations (assuming the algorithm doesn't convergence before).
      41             :          *
      42             :          * @param[in] iterations number of iterations to execute
      43             :          * @param[in] x0 optional initial solution, initial solution set to zero if not present
      44             :          *
      45             :          * @returns the current solution (after solving)
      46             :          */
      47             :         virtual DataContainer<data_t> solve(index_t iterations,
      48             :                                             std::optional<DataContainer<data_t>> x0 = std::nullopt);
      49             : 
      50             :         /**
      51             :          * @brief Setup and reset the solver. This function should set all values,
      52             :          * vectors and temporaries necessary to run the iterative reconstruction
      53             :          * algorithm.
      54             :          *
      55             :          * @param[in] x optional to initial solution, if optional is empty, some defaulted (e.g. 0
      56             :          * filled) initial solution should be returned
      57             :          *
      58             :          * @returns initial solution
      59             :          */
      60             :         virtual DataContainer<data_t> setup(std::optional<DataContainer<data_t>> x);
      61             : 
      62             :         /**
      63             :          * @brief Perform a single step of the iterative reconstruction algorithm.
      64             :          *
      65             :          * @param[in] x the current solution
      66             :          *
      67             :          * @returns the updated solution estimate
      68             :          */
      69             :         virtual DataContainer<data_t> step(DataContainer<data_t> x);
      70             : 
      71             :         /**
      72             :          * @brief Run the iterative reconstruction algorithm for the
      73             :          * given number of iterations on the initial starting point. This fun
      74             :          *
      75             :          * @param[in] iterations the number of iterations to execute
      76             :          * @param[in] x0 the initial starting point of the run
      77             :          * @param[in] show if the algorithm should print
      78             :          *
      79             :          * @returns the current solution estimate
      80             :          */
      81             :         virtual DataContainer<data_t> run(index_t iterations, const DataContainer<data_t>& x0,
      82             :                                           bool show = false);
      83             : 
      84             :         /**
      85             :          * @brief Function to determine when to stop. This function should implement
      86             :          * algorithm specific stopping criterions.
      87             :          *
      88             :          * @returns true if the algorithm should stop
      89             :          */
      90             :         virtual bool shouldStop() const;
      91             : 
      92             :         /**
      93             :          * @brief Format the header string for the iterative reconstruction algorithm
      94             :          */
      95             :         virtual std::string formatHeader() const;
      96             : 
      97             :         /**
      98             :          * @brief Print the header of the iterative reconstruction algorithm
      99             :          */
     100             :         virtual void printHeader() const;
     101             : 
     102             :         /**
     103             :          * @brief Format the step string for the iterative reconstruction algorithm
     104             :          */
     105             :         virtual std::string formatStep(const DataContainer<data_t>& x) const;
     106             : 
     107             :         /**
     108             :          * @brief Print a step of the iterative reconstruction algorithm
     109             :          */
     110             :         virtual void printStep(const DataContainer<data_t>& x, index_t curiter,
     111             :                                std::chrono::duration<double> steptime,
     112             :                                std::chrono::duration<double> elapsed) const;
     113             : 
     114             :         /**
     115             :          * @brief Set the maximum number of iterations
     116             :          */
     117             :         void setMaxiters(index_t maxiters);
     118             : 
     119             :         /**
     120             :          * @brief Set the callback function
     121             :          *
     122             :          * An example usage to later plot the loss (assuming some ground truth data):
     123             :          * ```cpp
     124             :          * auto phantom = getSomePhantom();
     125             :          * auto solver = SomeSolver(...);
     126             :          *
     127             :          * // compute mean square error for each iteration
     128             :          * std::vector<float> msre;
     129             :          * solver.setCallback([&msre](const DataContainer<data_t>& x){
     130             :          *     msre.push_back(square(phantom - x).sum());
     131             :          * });
     132             :          * ```
     133             :          * Similarly from Python:
     134             :          * ```py
     135             :          * phantom = getSomePhantom()
     136             :          * solver = SomeSolver(...)
     137             :          *
     138             :          * // compute mean square error for each iteration
     139             :          * msre = []
     140             :          * solver.setCallback(lambda x: msre.append(square(phantom - x).sum()))
     141             :          * ```
     142             :          */
     143             :         void setCallback(const std::function<void(const DataContainer<data_t>&)>& callback);
     144             : 
     145             :         /**
     146             :          * @brief Set variable to print only every n steps of the iterative reconstruction algorithm
     147             :          */
     148             :         void printEvery(index_t printEvery);
     149             : 
     150             :     protected:
     151             :         /// Is the solver already configured (no need to call `setup`)
     152             :         bool configured_ = false;
     153             : 
     154             :         /// Logging name of the iterative reconstruction algorithm
     155             :         std::string name_ = "Solver";
     156             : 
     157             :         /// Current iteration
     158             :         index_t curiter_ = 0;
     159             : 
     160             :     private:
     161             :         /// Max iteration to run in total
     162             :         index_t maxiters_ = 10000;
     163             : 
     164             :         index_t printEvery_ = 1;
     165             : 
     166             :         /// Callback function to call each iteration, by default, do nothing
     167         759 :         std::function<void(const DataContainer<data_t>&)> callback_ = [](auto) {};
     168             :     };
     169             : 
     170             :     /// Extract the default value from the optional if present, else create a
     171             :     /// new DataContainer with the given Descriptor and initial value
     172             :     template <class data_t>
     173             :     DataContainer<data_t> extract_or(std::optional<DataContainer<data_t>> x0,
     174             :                                      const DataDescriptor& domain,
     175             :                                      SelfType_t<data_t> val = data_t{0})
     176         140 :     {
     177         140 :         return x0.has_value() ? x0.value() : full<data_t>(domain, val);
     178         140 :     }
     179             : } // namespace elsa

Generated by: LCOV version 1.14