LCOV - code coverage report
Current view: top level - elsa/solvers - AB_GMRES.cpp (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 58 65 89.2 %
Date: 2024-05-16 04:22:26 Functions: 9 18 50.0 %

          Line data    Source code
       1             : #include "AB_GMRES.h"
       2             : #include "GMRES_common.h"
       3             : #include "TypeCasts.hpp"
       4             : #include "spdlog/stopwatch.h"
       5             : 
       6             : namespace elsa
       7             : {
       8             :     template <typename data_t>
       9             :     AB_GMRES<data_t>::AB_GMRES(const LinearOperator<data_t>& projector,
      10             :                                const DataContainer<data_t>& sinogram, data_t epsilon)
      11             :         : Solver<data_t>(),
      12             :           _A{projector.clone()},
      13             :           _B{adjoint(projector).clone()},
      14             :           _b{sinogram},
      15             :           _epsilon{epsilon}
      16           3 :     {
      17           3 :     }
      18             : 
      19             :     template <typename data_t>
      20             :     AB_GMRES<data_t>::AB_GMRES(const LinearOperator<data_t>& projector,
      21             :                                const LinearOperator<data_t>& backprojector,
      22             :                                const DataContainer<data_t>& sinogram, data_t epsilon)
      23             :         : Solver<data_t>(),
      24             :           _A{projector.clone()},
      25             :           _B{backprojector.clone()},
      26             :           _b{sinogram},
      27             :           _epsilon{epsilon}
      28           1 :     {
      29           1 :     }
      30             : 
      31             :     template <typename data_t>
      32             :     DataContainer<data_t> AB_GMRES<data_t>::solveAndRestart(index_t iterations, index_t restarts,
      33             :                                                             std::optional<DataContainer<data_t>> x0)
      34           1 :     {
      35           1 :         auto x = extract_or(x0, _A->getDomainDescriptor());
      36             : 
      37           4 :         for (index_t k = 0; k < restarts; k++) {
      38           3 :             x = solve(iterations, x);
      39           3 :         }
      40             : 
      41           1 :         return x;
      42           1 :     }
      43             : 
      44             :     template <typename data_t>
      45             :     DataContainer<data_t> AB_GMRES<data_t>::solve(index_t iterations,
      46             :                                                   std::optional<DataContainer<data_t>> x0)
      47           5 :     {
      48           5 :         detail::CalcRFn<data_t> calc_r0 =
      49           5 :             [](const LinearOperator<data_t>& A, const LinearOperator<data_t>&,
      50           5 :                const DataContainer<data_t>& b,
      51           5 :                const DataContainer<data_t>& x) -> DataContainer<data_t> {
      52           5 :             auto Ax = A.apply(x);
      53           5 :             auto r0 = b - Ax;
      54           5 :             return r0;
      55           5 :         };
      56             : 
      57           5 :         detail::CalcQFn<data_t> calc_q =
      58           5 :             [](const LinearOperator<data_t>& A, const LinearOperator<data_t>& B,
      59          34 :                const DataContainer<data_t>& w_k) -> DataContainer<data_t> {
      60          34 :             auto Bw_k = B.apply(w_k);
      61          34 :             auto q = A.apply(Bw_k);
      62          34 :             return q;
      63          34 :         };
      64             : 
      65           5 :         detail::CalcXFn<data_t> calc_x =
      66           5 :             [](const LinearOperator<data_t>& B, const DataContainer<data_t>& x,
      67          34 :                const DataContainer<data_t>& wy) -> DataContainer<data_t> {
      68          34 :             auto x_k = x + B.apply(wy);
      69          34 :             return x_k;
      70          34 :         };
      71             : 
      72           5 :         auto x = DataContainer<data_t>(_A->getDomainDescriptor());
      73           5 :         if (x0.has_value()) {
      74           3 :             x = *x0;
      75           3 :         } else {
      76           2 :             x = 0;
      77           2 :         }
      78             : 
      79           5 :         return detail::gmres("AB_GMRES", _A, _B, _b, _epsilon, x, iterations, calc_r0, calc_q,
      80           5 :                              calc_x);
      81           5 :     }
      82             : 
      83             :     template <typename data_t>
      84             :     AB_GMRES<data_t>* AB_GMRES<data_t>::cloneImpl() const
      85           1 :     {
      86           1 :         return new AB_GMRES(*_A, *_B, _b, _epsilon);
      87           1 :     }
      88             : 
      89             :     template <typename data_t>
      90             :     bool AB_GMRES<data_t>::isEqual(const Solver<data_t>& other) const
      91           1 :     {
      92             :         // This is basically stolen from CG
      93             : 
      94           1 :         auto otherGMRES = downcast_safe<AB_GMRES>(&other);
      95             : 
      96           1 :         if (!otherGMRES)
      97           0 :             return false;
      98             : 
      99           1 :         if (_epsilon != otherGMRES->_epsilon)
     100           0 :             return false;
     101             : 
     102           1 :         if ((_A && !otherGMRES->_A) || (!_A && otherGMRES->_A))
     103           0 :             return false;
     104             : 
     105           1 :         if (_A && otherGMRES->_A)
     106           1 :             if (*_A != *otherGMRES->_A)
     107           0 :                 return false;
     108             : 
     109           1 :         if ((_B && !otherGMRES->_B) || (!_B && otherGMRES->_B))
     110           0 :             return false;
     111             : 
     112           1 :         if (_B && otherGMRES->_B)
     113           1 :             if (*_B != *otherGMRES->_B)
     114           0 :                 return false;
     115             : 
     116           1 :         if (_b != otherGMRES->_b)
     117           0 :             return false;
     118             : 
     119           1 :         return true;
     120           1 :     }
     121             : 
     122             :     // ------------------------------------------
     123             :     // explicit template instantiation
     124             :     template class AB_GMRES<float>;
     125             :     template class AB_GMRES<double>;
     126             : 
     127             : } // namespace elsa

Generated by: LCOV version 1.14