Line data Source code
1 : #include "AB_GMRES.h"
2 : #include "GMRES_common.h"
3 : #include "TypeCasts.hpp"
4 : #include "spdlog/stopwatch.h"
5 :
6 : namespace elsa
7 : {
8 : template <typename data_t>
9 : AB_GMRES<data_t>::AB_GMRES(const LinearOperator<data_t>& projector,
10 : const DataContainer<data_t>& sinogram, data_t epsilon)
11 : : Solver<data_t>(),
12 : _A{projector.clone()},
13 : _B{adjoint(projector).clone()},
14 : _b{sinogram},
15 : _epsilon{epsilon}
16 3 : {
17 3 : }
18 :
19 : template <typename data_t>
20 : AB_GMRES<data_t>::AB_GMRES(const LinearOperator<data_t>& projector,
21 : const LinearOperator<data_t>& backprojector,
22 : const DataContainer<data_t>& sinogram, data_t epsilon)
23 : : Solver<data_t>(),
24 : _A{projector.clone()},
25 : _B{backprojector.clone()},
26 : _b{sinogram},
27 : _epsilon{epsilon}
28 1 : {
29 1 : }
30 :
31 : template <typename data_t>
32 : DataContainer<data_t> AB_GMRES<data_t>::solveAndRestart(index_t iterations, index_t restarts,
33 : std::optional<DataContainer<data_t>> x0)
34 1 : {
35 1 : auto x = extract_or(x0, _A->getDomainDescriptor());
36 :
37 4 : for (index_t k = 0; k < restarts; k++) {
38 3 : x = solve(iterations, x);
39 3 : }
40 :
41 1 : return x;
42 1 : }
43 :
44 : template <typename data_t>
45 : DataContainer<data_t> AB_GMRES<data_t>::solve(index_t iterations,
46 : std::optional<DataContainer<data_t>> x0)
47 5 : {
48 5 : detail::CalcRFn<data_t> calc_r0 =
49 5 : [](const LinearOperator<data_t>& A, const LinearOperator<data_t>&,
50 5 : const DataContainer<data_t>& b,
51 5 : const DataContainer<data_t>& x) -> DataContainer<data_t> {
52 5 : auto Ax = A.apply(x);
53 5 : auto r0 = b - Ax;
54 5 : return r0;
55 5 : };
56 :
57 5 : detail::CalcQFn<data_t> calc_q =
58 5 : [](const LinearOperator<data_t>& A, const LinearOperator<data_t>& B,
59 34 : const DataContainer<data_t>& w_k) -> DataContainer<data_t> {
60 34 : auto Bw_k = B.apply(w_k);
61 34 : auto q = A.apply(Bw_k);
62 34 : return q;
63 34 : };
64 :
65 5 : detail::CalcXFn<data_t> calc_x =
66 5 : [](const LinearOperator<data_t>& B, const DataContainer<data_t>& x,
67 34 : const DataContainer<data_t>& wy) -> DataContainer<data_t> {
68 34 : auto x_k = x + B.apply(wy);
69 34 : return x_k;
70 34 : };
71 :
72 5 : auto x = DataContainer<data_t>(_A->getDomainDescriptor());
73 5 : if (x0.has_value()) {
74 3 : x = *x0;
75 3 : } else {
76 2 : x = 0;
77 2 : }
78 :
79 5 : return detail::gmres("AB_GMRES", _A, _B, _b, _epsilon, x, iterations, calc_r0, calc_q,
80 5 : calc_x);
81 5 : }
82 :
83 : template <typename data_t>
84 : AB_GMRES<data_t>* AB_GMRES<data_t>::cloneImpl() const
85 1 : {
86 1 : return new AB_GMRES(*_A, *_B, _b, _epsilon);
87 1 : }
88 :
89 : template <typename data_t>
90 : bool AB_GMRES<data_t>::isEqual(const Solver<data_t>& other) const
91 1 : {
92 : // This is basically stolen from CG
93 :
94 1 : auto otherGMRES = downcast_safe<AB_GMRES>(&other);
95 :
96 1 : if (!otherGMRES)
97 0 : return false;
98 :
99 1 : if (_epsilon != otherGMRES->_epsilon)
100 0 : return false;
101 :
102 1 : if ((_A && !otherGMRES->_A) || (!_A && otherGMRES->_A))
103 0 : return false;
104 :
105 1 : if (_A && otherGMRES->_A)
106 1 : if (*_A != *otherGMRES->_A)
107 0 : return false;
108 :
109 1 : if ((_B && !otherGMRES->_B) || (!_B && otherGMRES->_B))
110 0 : return false;
111 :
112 1 : if (_B && otherGMRES->_B)
113 1 : if (*_B != *otherGMRES->_B)
114 0 : return false;
115 :
116 1 : if (_b != otherGMRES->_b)
117 0 : return false;
118 :
119 1 : return true;
120 1 : }
121 :
122 : // ------------------------------------------
123 : // explicit template instantiation
124 : template class AB_GMRES<float>;
125 : template class AB_GMRES<double>;
126 :
127 : } // namespace elsa
|