Line data Source code
1 : #include "CGLS.h"
2 : #include "DataContainer.h"
3 : #include "Error.h"
4 : #include "LinearOperator.h"
5 : #include "LinearResidual.h"
6 : #include "TypeCasts.hpp"
7 : #include "spdlog/fmt/bundled/core.h"
8 : #include "spdlog/stopwatch.h"
9 : #include "Logger.h"
10 :
11 : namespace elsa
12 : {
13 : template <class data_t>
14 : CGLS<data_t>::CGLS(const LinearOperator<data_t>& A, const DataContainer<data_t>& b,
15 : SelfType_t<data_t> eps, SelfType_t<data_t> tol)
16 : : A_(A.clone()),
17 : b_(b),
18 : r_(empty<data_t>(A.getDomainDescriptor())),
19 : s_(empty<data_t>(A.getRangeDescriptor())),
20 : c_(empty<data_t>(A.getDomainDescriptor())),
21 : q_(empty<data_t>(A.getRangeDescriptor())),
22 : k_(100000000),
23 : kold_(100000000),
24 : damp_(eps * eps),
25 : tol_(tol)
26 18 : {
27 18 : this->name_ = "CGLS";
28 18 : }
29 :
30 : template <typename data_t>
31 : DataContainer<data_t> CGLS<data_t>::setup(std::optional<DataContainer<data_t>> x0)
32 18 : {
33 18 : auto x = extract_or(x0, A_->getDomainDescriptor());
34 :
35 18 : if (x0.has_value()) {
36 0 : x = *x0;
37 :
38 : // s = b_ - A_->applx(x)
39 0 : A_->apply(x, s_);
40 0 : lincomb(1, b_, -1, s_, s_);
41 :
42 0 : A_->applyAdjoint(s_, r_);
43 0 : lincomb(1, r_, -damp_, x, r_);
44 18 : } else {
45 18 : x = 0;
46 18 : s_ = b_;
47 18 : A_->applyAdjoint(s_, r_);
48 18 : }
49 :
50 18 : c_ = r_;
51 :
52 18 : k_ = r_.squaredL2Norm();
53 18 : kold_ = k_;
54 :
55 : // setup done!
56 18 : this->configured_ = true;
57 :
58 18 : return x;
59 18 : }
60 :
61 : template <typename data_t>
62 : DataContainer<data_t> CGLS<data_t>::step(DataContainer<data_t> x)
63 78 : {
64 78 : A_->apply(c_, q_);
65 :
66 78 : auto delta = q_.squaredL2Norm();
67 78 : auto alpha = [&]() {
68 78 : if (damp_ == 0) {
69 44 : return kold_ / delta;
70 44 : } else {
71 34 : return kold_ / (delta + damp_ * c_.squaredL2Norm());
72 34 : }
73 78 : }();
74 :
75 78 : lincomb(1, x, alpha, c_, x);
76 78 : s_ -= alpha * q_;
77 :
78 78 : A_->applyAdjoint(s_, r_);
79 78 : if (damp_ != 0.0) {
80 34 : r_ -= damp_ * x;
81 34 : }
82 :
83 78 : k_ = r_.squaredL2Norm();
84 78 : auto beta = k_ / kold_;
85 :
86 : // c = r + beta * c;
87 78 : lincomb(1, r_, beta, c_, c_);
88 78 : kold_ = k_;
89 78 : return x;
90 78 : }
91 :
92 : template <typename data_t>
93 : bool CGLS<data_t>::shouldStop() const
94 96 : {
95 96 : return kold_ < tol_;
96 96 : }
97 :
98 : template <typename data_t>
99 : std::string CGLS<data_t>::formatHeader() const
100 18 : {
101 18 : return fmt::format("{:^15} | {:^15}", "|| x ||_2", "|| s ||_2");
102 18 : }
103 :
104 : template <typename data_t>
105 : std::string CGLS<data_t>::formatStep(const DataContainer<data_t>& x) const
106 78 : {
107 78 : return fmt::format("{:>15.10} | {:>15.10}", x.l2Norm(), s_.l2Norm());
108 78 : }
109 :
110 : template <class data_t>
111 : CGLS<data_t>* CGLS<data_t>::cloneImpl() const
112 8 : {
113 8 : return new CGLS(*A_, b_, std::sqrt(damp_));
114 8 : }
115 :
116 : template <class data_t>
117 : bool CGLS<data_t>::isEqual(const Solver<data_t>& other) const
118 8 : {
119 8 : auto otherCGLS = downcast_safe<CGLS>(&other);
120 :
121 8 : return otherCGLS && *otherCGLS->A_ == *A_ && otherCGLS->b_ == b_
122 8 : && otherCGLS->damp_ == damp_;
123 8 : }
124 :
125 : // ------------------------------------------
126 : // explicit template instantiation
127 : template class CGLS<float>;
128 : template class CGLS<double>;
129 :
130 : } // namespace elsa
|