LCOV - code coverage report
Current view: top level - elsa/solvers - GMRES_common.h (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 45 50 90.0 %
Date: 2024-12-21 07:37:52 Functions: 1 2 50.0 %

          Line data    Source code
       1             : #include "Logger.h"
       2             : #include "TypeCasts.hpp"
       3             : #include "spdlog/stopwatch.h"
       4             : #include "LinearOperator.h"
       5             : #include <memory>
       6             : #include <Eigen/Core>
       7             : 
       8             : namespace elsa::detail
       9             : {
      10             :     template <typename data_t>
      11             :     using CalcRFn = std::function<DataContainer<data_t>(
      12             :         const LinearOperator<data_t>&, const LinearOperator<data_t>&, const DataContainer<data_t>&,
      13             :         const DataContainer<data_t>&)>;
      14             : 
      15             :     template <typename data_t>
      16             :     using CalcQFn = std::function<DataContainer<data_t>(const LinearOperator<data_t>&,
      17             :                                                         const LinearOperator<data_t>&,
      18             :                                                         const DataContainer<data_t>&)>;
      19             : 
      20             :     template <typename data_t>
      21             :     using CalcXFn = std::function<DataContainer<data_t>(
      22             :         const LinearOperator<data_t>&, const DataContainer<data_t>&, const DataContainer<data_t>&)>;
      23             : 
      24             :     template <typename data_t>
      25             :     DataContainer<data_t> gmres(std::string name, std::unique_ptr<LinearOperator<data_t>>& A,
      26             :                                 std::unique_ptr<LinearOperator<data_t>>& B,
      27             :                                 DataContainer<data_t>& b, data_t _epsilon, DataContainer<data_t> x,
      28             :                                 index_t iterations, CalcRFn<data_t> calculate_r0,
      29             :                                 CalcQFn<data_t> calculate_q, CalcXFn<data_t> calculate_x)
      30          10 :     {
      31             :         // GMRES Implementation
      32          10 :         using Mat = Eigen::Matrix<data_t, Eigen::Dynamic, Eigen::Dynamic>;
      33             : 
      34          10 :         spdlog::stopwatch aggregate_time;
      35          10 :         Logger::get(name)->info("Start preparations...");
      36             : 
      37             :         // setup DataContainer for Return Value which should be like x
      38          10 :         auto x_k = DataContainer<data_t>(A->getDomainDescriptor());
      39             : 
      40             :         // Custom function for AB/BA-GMRES
      41          10 :         auto r0 = calculate_r0(*A, *B, b, x);
      42             : 
      43          10 :         Mat h = Mat::Constant(iterations + 1, iterations, 0);
      44          10 :         Mat w = Mat::Constant(r0.getSize(), iterations, 0);
      45          10 :         Vector_t<data_t> e = Vector_t<data_t>::Constant(iterations + 1, 1, 0);
      46             : 
      47             :         // Initializing e Vector
      48          10 :         e(0) = r0.l2Norm();
      49             : 
      50             :         // Filling Matrix w with the vector r0/beta at the specified column
      51          10 :         auto w_i0 = r0 / e(0);
      52          10 :         w.col(0) = Eigen::Map<Vector_t<data_t>>(thrust::raw_pointer_cast(w_i0.storage().data()),
      53          10 :                                                 w_i0.getSize());
      54             : 
      55          10 :         Logger::get(name)->info("Preparations done, took {}s", aggregate_time);
      56             : 
      57          10 :         Logger::get(name)->info("epsilon: {}", _epsilon);
      58          10 :         Logger::get(name)->info("||r0||: {}", e(0));
      59             : 
      60          10 :         Logger::get(name)->info("{:^6}|{:*^16}|{:*^8}|{:*^8}|", "iter", "r", "time", "elapsed");
      61             : 
      62          78 :         for (index_t k = 0; k < iterations; k++) {
      63          68 :             spdlog::stopwatch iter_time;
      64             : 
      65          68 :             auto w_k = DataContainer<data_t>(r0.getDataDescriptor(), w.col(k));
      66             : 
      67             :             // Custom function for AB/BA-GMRES
      68          68 :             auto temp = calculate_q(*A, *B, w_k);
      69             : 
      70             :             // casting the DataContainer result to an EigenVector for easier calculations
      71          68 :             auto q_k = Eigen::Map<Vector_t<data_t>>(thrust::raw_pointer_cast(temp.storage().data()),
      72          68 :                                                     temp.getSize());
      73             : 
      74         652 :             for (index_t i = 0; i < iterations; i++) {
      75         584 :                 auto w_i = w.col(i);
      76         584 :                 auto h_ik = q_k.dot(w_i);
      77             : 
      78         584 :                 h(i, k) = h_ik;
      79         584 :                 q_k -= h_ik * w_i;
      80         584 :             }
      81             : 
      82          68 :             h(k + 1, k) = q_k.norm();
      83             : 
      84             :             // Source:
      85             :             // https://stackoverflow.com/questions/37962271/whats-wrong-with-my-AB_GMRES-implementation
      86             :             // This rule exists as we fill k+1 column of w and w matrix only has k columns
      87             :             // another way to implement this would be by having a matrix w with k + 1 columns and
      88             :             // instead always just getting the slice w0..wk for wy calculation
      89          68 :             if (k != iterations - 1) {
      90          58 :                 w.col(k + 1) = q_k / h(k + 1, k);
      91          58 :             }
      92             : 
      93             :             // for other options see:
      94             :             // https://eigen.tuxfamily.org/dox/group__DenseDecompositionBenchmark.html
      95          68 :             Eigen::ColPivHouseholderQR<Mat> qr(h);
      96          68 :             Vector_t<data_t> y = qr.solve(e);
      97          68 :             auto wy = DataContainer<data_t>(r0.getDataDescriptor(), w * y);
      98             : 
      99             :             // Custom function for AB/BA-GMRES
     100          68 :             x_k = calculate_x(*B, x, wy);
     101             : 
     102             :             // disable r for faster results ?
     103          68 :             auto r = b - A->apply(x_k);
     104             : 
     105          68 :             Logger::get(name)->info("{:>5}|{:>15}|{:>6.3}|{:>6.3}s|", k, r.l2Norm(), iter_time,
     106          68 :                                     aggregate_time);
     107             : 
     108             :             //  Break Condition via relative residual, there could be more interesting approaches
     109             :             //  used here like NCP Criterion or discrepancy principle
     110          68 :             if (r.l2Norm() <= _epsilon) {
     111           0 :                 Logger::get(name)->info("||rx|| {}", r.l2Norm());
     112           0 :                 Logger::get(name)->info("SUCCESS: Reached convergence at {}/{} iteration", k + 1,
     113           0 :                                         iterations);
     114           0 :                 return x_k;
     115           0 :             }
     116          68 :         }
     117             : 
     118          10 :         Logger::get(name)->warn("Failed to reach convergence at {} iterations", iterations);
     119          10 :         return x_k;
     120          10 :     };
     121             : }; // namespace elsa::detail

Generated by: LCOV version 1.14