Line data Source code
1 : #pragma once 2 : 3 : #include <memory> 4 : 5 : #include "DataContainer.h" 6 : #include "LinearOperator.h" 7 : #include "Solver.h" 8 : #include "elsaDefines.h" 9 : 10 : namespace elsa 11 : { 12 : /// @brief Conjugate Gradient via the Normal equation 13 : /// 14 : /// CG solves the system of equations: 15 : /// \[ 16 : /// A x = b 17 : /// \] 18 : /// where \f$A\f$ is symmetric positive definite operator. \f$b\f$ is the measured quantity. 19 : /// 20 : /// In our implementation, we always assume \f$A\f$ is non-symmetric and not positive 21 : /// definite. Hence, we compute the solution to the normal equation 22 : /// \[ 23 : /// A^T A x = A^t b 24 : /// \] 25 : /// 26 : /// References: 27 : /// - An Introduction to the Conjugate Gradient Method Without the Agonizing Pain, by Shewchuk 28 : /// 29 : /// @author David Frank 30 : template <typename data_t = real_t> 31 : class CGNE : public Solver<data_t> 32 : { 33 : public: 34 : /// Scalar alias 35 : using Scalar = typename Solver<data_t>::Scalar; 36 : 37 : /// @brief Construct the necessary form of CGNE using some linear operator and 38 : /// the measured data. 39 : /// 40 : /// @param A linear operator for the problem 41 : /// @param b the measured data 42 : /// @param tol stopping tolerance 43 : CGNE(const LinearOperator<data_t>& A, const DataContainer<data_t>& b, 44 : SelfType_t<data_t> tol = 1e-4); 45 : 46 : /// make copy constructor deletion explicit 47 : CGNE(const CGNE<data_t>&) = delete; 48 : 49 : /// default destructor 50 40 : ~CGNE() override = default; 51 : 52 : DataContainer<data_t> setup(std::optional<DataContainer<data_t>> x) override; 53 : 54 : DataContainer<data_t> step(DataContainer<data_t> x) override; 55 : 56 : bool shouldStop() const override; 57 : 58 : std::string formatHeader() const override; 59 : 60 : std::string formatStep(const DataContainer<data_t>& x) const override; 61 : 62 : private: 63 : std::unique_ptr<LinearOperator<data_t>> A_; 64 : 65 : std::unique_ptr<LinearOperator<data_t>> AtA_; 66 : 67 : DataContainer<data_t> b_; 68 : 69 : DataContainer<data_t> Atb_; 70 : 71 : DataContainer<data_t> r_; 72 : 73 : DataContainer<data_t> c_; 74 : 75 : DataContainer<data_t> Ac_; 76 : 77 : data_t kold_; 78 : 79 : data_t alpha_; 80 : 81 : data_t beta_; 82 : 83 : data_t tol_{0.0001}; 84 : 85 : /// implement the polymorphic clone operation 86 : CGNE<data_t>* cloneImpl() const override; 87 : 88 : /// implement the polymorphic comparison operation 89 : bool isEqual(const Solver<data_t>& other) const override; 90 : }; 91 : } // namespace elsa