Line data Source code
1 : #include "BA_GMRES.h"
2 : #include "GMRES_common.h"
3 : #include "TypeCasts.hpp"
4 : #include "spdlog/stopwatch.h"
5 :
6 : namespace elsa
7 : {
8 : template <typename data_t>
9 : BA_GMRES<data_t>::BA_GMRES(const LinearOperator<data_t>& projector,
10 : const DataContainer<data_t>& sinogram, data_t epsilon)
11 : : Solver<data_t>(),
12 : _A{projector.clone()},
13 : _B{adjoint(projector).clone()},
14 : _b{sinogram},
15 : _epsilon{epsilon}
16 3 : {
17 3 : }
18 :
19 : template <typename data_t>
20 : BA_GMRES<data_t>::BA_GMRES(const LinearOperator<data_t>& projector,
21 : const LinearOperator<data_t>& backprojector,
22 : const DataContainer<data_t>& sinogram, data_t epsilon)
23 : : Solver<data_t>(),
24 : _A{projector.clone()},
25 : _B{backprojector.clone()},
26 : _b{sinogram},
27 : _epsilon{epsilon}
28 1 : {
29 1 : }
30 :
31 : template <typename data_t>
32 : DataContainer<data_t> BA_GMRES<data_t>::solveAndRestart(index_t iterations, index_t restarts,
33 : std::optional<DataContainer<data_t>> x0)
34 1 : {
35 1 : auto x = DataContainer<data_t>(_A->getDomainDescriptor());
36 1 : if (x0.has_value()) {
37 0 : x = *x0;
38 1 : } else {
39 1 : x = 0;
40 1 : }
41 :
42 4 : for (index_t k = 0; k < restarts; k++) {
43 3 : x = solve(iterations, x);
44 3 : }
45 :
46 1 : return x;
47 1 : }
48 :
49 : template <typename data_t>
50 : DataContainer<data_t> BA_GMRES<data_t>::solve(index_t iterations,
51 : std::optional<DataContainer<data_t>> x0)
52 5 : {
53 5 : auto x = DataContainer<data_t>(_A->getDomainDescriptor());
54 5 : if (x0.has_value()) {
55 3 : x = *x0;
56 3 : } else {
57 2 : x = 0;
58 2 : }
59 :
60 5 : detail::CalcRFn<data_t> calc_r0 =
61 5 : [](const LinearOperator<data_t>& A, const LinearOperator<data_t>& B,
62 5 : const DataContainer<data_t>& b,
63 5 : const DataContainer<data_t>& x) -> DataContainer<data_t> {
64 5 : auto Bb = B.apply(b);
65 5 : auto Ax = A.apply(x);
66 5 : auto BAx = B.apply(Ax);
67 :
68 5 : auto r0 = Bb - BAx;
69 5 : return r0;
70 5 : };
71 :
72 5 : detail::CalcQFn<data_t> calc_q =
73 5 : [](const LinearOperator<data_t>& A, const LinearOperator<data_t>& B,
74 34 : const DataContainer<data_t>& w_k) -> DataContainer<data_t> {
75 34 : auto Aw_k = A.apply(w_k);
76 34 : auto q = B.apply(Aw_k);
77 34 : return q;
78 34 : };
79 :
80 5 : detail::CalcXFn<data_t> calc_x =
81 5 : [](const LinearOperator<data_t>&, const DataContainer<data_t>& x,
82 34 : const DataContainer<data_t>& wy) -> DataContainer<data_t> {
83 34 : auto x_k = x + wy;
84 34 : return x_k;
85 34 : };
86 :
87 5 : return detail::gmres("BA_GMRES", _A, _B, _b, _epsilon, x, iterations, calc_r0, calc_q,
88 5 : calc_x);
89 5 : }
90 :
91 : template <typename data_t>
92 : BA_GMRES<data_t>* BA_GMRES<data_t>::cloneImpl() const
93 1 : {
94 1 : return new BA_GMRES(*_A, *_B, _b, _epsilon);
95 1 : }
96 :
97 : template <typename data_t>
98 : bool BA_GMRES<data_t>::isEqual(const Solver<data_t>& other) const
99 1 : {
100 : // This is basically stolen from CG
101 :
102 1 : auto otherGMRES = downcast_safe<BA_GMRES>(&other);
103 :
104 1 : if (!otherGMRES)
105 0 : return false;
106 :
107 1 : if (_epsilon != otherGMRES->_epsilon)
108 0 : return false;
109 :
110 1 : if ((_A && !otherGMRES->_A) || (!_A && otherGMRES->_A))
111 0 : return false;
112 :
113 1 : if (_A && otherGMRES->_A)
114 1 : if (*_A != *otherGMRES->_A)
115 0 : return false;
116 :
117 1 : if ((_B && !otherGMRES->_B) || (!_B && otherGMRES->_B))
118 0 : return false;
119 :
120 1 : if (_B && otherGMRES->_B)
121 1 : if (*_B != *otherGMRES->_B)
122 0 : return false;
123 :
124 1 : if (_b != otherGMRES->_b)
125 0 : return false;
126 :
127 1 : return true;
128 1 : }
129 :
130 : // ------------------------------------------
131 : // explicit template instantiation
132 : template class BA_GMRES<float>;
133 : template class BA_GMRES<double>;
134 :
135 : } // namespace elsa
|