Line data Source code
1 : #include "CGNE.h" 2 : #include "DataContainer.h" 3 : #include "Error.h" 4 : #include "LinearOperator.h" 5 : #include "LinearResidual.h" 6 : #include "Solver.h" 7 : #include "TypeCasts.hpp" 8 : #include "spdlog/stopwatch.h" 9 : #include "Logger.h" 10 : 11 : namespace elsa 12 : { 13 : template <class data_t> 14 : CGNE<data_t>::CGNE(const LinearOperator<data_t>& A, const DataContainer<data_t>& b, 15 : SelfType_t<data_t> tol) 16 : : A_(A.clone()), 17 : AtA_((adjoint(*A_) * (*A_)).clone()), 18 : b_(b), 19 : Atb_(A_->applyAdjoint(b_)), 20 : r_(empty<data_t>(AtA_->getRangeDescriptor())), 21 : c_(empty<data_t>(AtA_->getRangeDescriptor())), 22 : Ac_(empty<data_t>(AtA_->getRangeDescriptor())), 23 : kold_(100000), 24 : tol_(tol) 25 40 : { 26 40 : this->name_ = "CGNE"; 27 40 : } 28 : 29 : template <typename data_t> 30 : DataContainer<data_t> CGNE<data_t>::setup(std::optional<DataContainer<data_t>> x0) 31 40 : { 32 40 : auto x = extract_or(x0, A_->getDomainDescriptor()); 33 : 34 : // Residual b - A(x) 35 40 : r_ = AtA_->apply(x); 36 40 : r_ *= data_t{-1}; 37 40 : r_ += Atb_; 38 : 39 40 : c_.assign(r_); 40 : 41 40 : Ac_ = empty<data_t>(AtA_->getRangeDescriptor()); 42 : 43 : // Squared Norm of residual 44 40 : kold_ = r_.squaredL2Norm(); 45 : 46 : // setup done! 47 40 : this->configured_ = true; 48 : 49 40 : return x; 50 40 : } 51 : 52 : template <typename data_t> 53 : DataContainer<data_t> CGNE<data_t>::step(DataContainer<data_t> x) 54 116 : { 55 116 : AtA_->apply(c_, Ac_); 56 116 : auto cAc = c_.dot(Ac_); 57 : 58 116 : alpha_ = kold_ / cAc; 59 : 60 : // Update x and residual 61 116 : x += alpha_ * c_; 62 116 : r_ -= alpha_ * Ac_; 63 : 64 116 : auto k = r_.squaredL2Norm(); 65 116 : beta_ = k / kold_; 66 : 67 : // c = r + beta * c 68 116 : lincomb(1, r_, beta_, c_, c_); 69 : 70 : // store k for next iteration 71 116 : kold_ = k; 72 : 73 116 : return x; 74 116 : } 75 : 76 : template <typename data_t> 77 : bool CGNE<data_t>::shouldStop() const 78 126 : { 79 126 : return r_.l2Norm() < tol_; 80 126 : } 81 : 82 : template <typename data_t> 83 : std::string CGNE<data_t>::formatHeader() const 84 40 : { 85 40 : return fmt::format("{:^15} | {:^15} | {:^15} | {:^15}", "r", "c", "alpha", "beta"); 86 40 : } 87 : 88 : template <typename data_t> 89 : std::string CGNE<data_t>::formatStep(const DataContainer<data_t>&) const 90 116 : { 91 116 : return fmt::format("{:>15.10} | {:>15.10} | {:>15.10} | {:>15.10}", r_.l2Norm(), 92 116 : c_.l2Norm(), alpha_, beta_); 93 116 : } 94 : 95 : template <class data_t> 96 : bool CGNE<data_t>::isEqual(const Solver<data_t>& other) const 97 8 : { 98 8 : auto cgne = downcast_safe<CGNE>(&other); 99 8 : return cgne && *cgne->A_ == *A_ && cgne->b_ == b_; 100 8 : } 101 : 102 : template <class data_t> 103 : CGNE<data_t>* CGNE<data_t>::cloneImpl() const 104 8 : { 105 8 : return new CGNE(*A_, b_, tol_); 106 8 : } 107 : 108 : // ------------------------------------------ 109 : // explicit template instantiation 110 : template class CGNE<float>; 111 : template class CGNE<double>; 112 : 113 : } // namespace elsa