Line data Source code
1 : #include "Logger.h"
2 : #include "TypeCasts.hpp"
3 : #include "spdlog/stopwatch.h"
4 : #include "LinearOperator.h"
5 : #include <memory>
6 : #include <Eigen/Core>
7 :
8 : namespace elsa::detail
9 : {
10 : template <typename data_t>
11 : using CalcRFn = std::function<DataContainer<data_t>(
12 : const LinearOperator<data_t>&, const LinearOperator<data_t>&, const DataContainer<data_t>&,
13 : const DataContainer<data_t>&)>;
14 :
15 : template <typename data_t>
16 : using CalcQFn = std::function<DataContainer<data_t>(const LinearOperator<data_t>&,
17 : const LinearOperator<data_t>&,
18 : const DataContainer<data_t>&)>;
19 :
20 : template <typename data_t>
21 : using CalcXFn = std::function<DataContainer<data_t>(
22 : const LinearOperator<data_t>&, const DataContainer<data_t>&, const DataContainer<data_t>&)>;
23 :
24 : template <typename data_t>
25 : DataContainer<data_t> gmres(std::string name, std::unique_ptr<LinearOperator<data_t>>& A,
26 : std::unique_ptr<LinearOperator<data_t>>& B,
27 : DataContainer<data_t>& b, data_t _epsilon, DataContainer<data_t> x,
28 : index_t iterations, CalcRFn<data_t> calculate_r0,
29 : CalcQFn<data_t> calculate_q, CalcXFn<data_t> calculate_x)
30 10 : {
31 : // GMRES Implementation
32 10 : using Mat = Eigen::Matrix<data_t, Eigen::Dynamic, Eigen::Dynamic>;
33 :
34 10 : spdlog::stopwatch aggregate_time;
35 10 : Logger::get(name)->info("Start preparations...");
36 :
37 : // setup DataContainer for Return Value which should be like x
38 10 : auto x_k = DataContainer<data_t>(A->getDomainDescriptor());
39 :
40 : // Custom function for AB/BA-GMRES
41 10 : auto r0 = calculate_r0(*A, *B, b, x);
42 :
43 10 : Mat h = Mat::Constant(iterations + 1, iterations, 0);
44 10 : Mat w = Mat::Constant(r0.getSize(), iterations, 0);
45 10 : Vector_t<data_t> e = Vector_t<data_t>::Constant(iterations + 1, 1, 0);
46 :
47 : // Initializing e Vector
48 10 : e(0) = r0.l2Norm();
49 :
50 : // Filling Matrix w with the vector r0/beta at the specified column
51 10 : auto w_i0 = r0 / e(0);
52 10 : w.col(0) = Eigen::Map<Vector_t<data_t>>(thrust::raw_pointer_cast(w_i0.storage().data()),
53 10 : w_i0.getSize());
54 :
55 10 : Logger::get(name)->info("Preparations done, took {}s", aggregate_time);
56 :
57 10 : Logger::get(name)->info("epsilon: {}", _epsilon);
58 10 : Logger::get(name)->info("||r0||: {}", e(0));
59 :
60 10 : Logger::get(name)->info("{:^6}|{:*^16}|{:*^8}|{:*^8}|", "iter", "r", "time", "elapsed");
61 :
62 78 : for (index_t k = 0; k < iterations; k++) {
63 68 : spdlog::stopwatch iter_time;
64 :
65 68 : auto w_k = DataContainer<data_t>(r0.getDataDescriptor(), w.col(k));
66 :
67 : // Custom function for AB/BA-GMRES
68 68 : auto temp = calculate_q(*A, *B, w_k);
69 :
70 : // casting the DataContainer result to an EigenVector for easier calculations
71 68 : auto q_k = Eigen::Map<Vector_t<data_t>>(thrust::raw_pointer_cast(temp.storage().data()),
72 68 : temp.getSize());
73 :
74 652 : for (index_t i = 0; i < iterations; i++) {
75 584 : auto w_i = w.col(i);
76 584 : auto h_ik = q_k.dot(w_i);
77 :
78 584 : h(i, k) = h_ik;
79 584 : q_k -= h_ik * w_i;
80 584 : }
81 :
82 68 : h(k + 1, k) = q_k.norm();
83 :
84 : // Source:
85 : // https://stackoverflow.com/questions/37962271/whats-wrong-with-my-AB_GMRES-implementation
86 : // This rule exists as we fill k+1 column of w and w matrix only has k columns
87 : // another way to implement this would be by having a matrix w with k + 1 columns and
88 : // instead always just getting the slice w0..wk for wy calculation
89 68 : if (k != iterations - 1) {
90 58 : w.col(k + 1) = q_k / h(k + 1, k);
91 58 : }
92 :
93 : // for other options see:
94 : // https://eigen.tuxfamily.org/dox/group__DenseDecompositionBenchmark.html
95 68 : Eigen::ColPivHouseholderQR<Mat> qr(h);
96 68 : Vector_t<data_t> y = qr.solve(e);
97 68 : auto wy = DataContainer<data_t>(r0.getDataDescriptor(), w * y);
98 :
99 : // Custom function for AB/BA-GMRES
100 68 : x_k = calculate_x(*B, x, wy);
101 :
102 : // disable r for faster results ?
103 68 : auto r = b - A->apply(x_k);
104 :
105 68 : Logger::get(name)->info("{:>5}|{:>15}|{:>6.3}|{:>6.3}s|", k, r.l2Norm(), iter_time,
106 68 : aggregate_time);
107 :
108 : // Break Condition via relative residual, there could be more interesting approaches
109 : // used here like NCP Criterion or discrepancy principle
110 68 : if (r.l2Norm() <= _epsilon) {
111 0 : Logger::get(name)->info("||rx|| {}", r.l2Norm());
112 0 : Logger::get(name)->info("SUCCESS: Reached convergence at {}/{} iteration", k + 1,
113 0 : iterations);
114 0 : return x_k;
115 0 : }
116 68 : }
117 :
118 10 : Logger::get(name)->warn("Failed to reach convergence at {} iterations", iterations);
119 10 : return x_k;
120 10 : };
121 : }; // namespace elsa::detail
|