LCOV - code coverage report
Current view: top level - elsa/solvers - BA_GMRES.cpp (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 64 72 88.9 %
Date: 2024-05-16 04:22:26 Functions: 9 18 50.0 %

          Line data    Source code
       1             : #include "BA_GMRES.h"
       2             : #include "GMRES_common.h"
       3             : #include "TypeCasts.hpp"
       4             : #include "spdlog/stopwatch.h"
       5             : 
       6             : namespace elsa
       7             : {
       8             :     template <typename data_t>
       9             :     BA_GMRES<data_t>::BA_GMRES(const LinearOperator<data_t>& projector,
      10             :                                const DataContainer<data_t>& sinogram, data_t epsilon)
      11             :         : Solver<data_t>(),
      12             :           _A{projector.clone()},
      13             :           _B{adjoint(projector).clone()},
      14             :           _b{sinogram},
      15             :           _epsilon{epsilon}
      16           3 :     {
      17           3 :     }
      18             : 
      19             :     template <typename data_t>
      20             :     BA_GMRES<data_t>::BA_GMRES(const LinearOperator<data_t>& projector,
      21             :                                const LinearOperator<data_t>& backprojector,
      22             :                                const DataContainer<data_t>& sinogram, data_t epsilon)
      23             :         : Solver<data_t>(),
      24             :           _A{projector.clone()},
      25             :           _B{backprojector.clone()},
      26             :           _b{sinogram},
      27             :           _epsilon{epsilon}
      28           1 :     {
      29           1 :     }
      30             : 
      31             :     template <typename data_t>
      32             :     DataContainer<data_t> BA_GMRES<data_t>::solveAndRestart(index_t iterations, index_t restarts,
      33             :                                                             std::optional<DataContainer<data_t>> x0)
      34           1 :     {
      35           1 :         auto x = DataContainer<data_t>(_A->getDomainDescriptor());
      36           1 :         if (x0.has_value()) {
      37           0 :             x = *x0;
      38           1 :         } else {
      39           1 :             x = 0;
      40           1 :         }
      41             : 
      42           4 :         for (index_t k = 0; k < restarts; k++) {
      43           3 :             x = solve(iterations, x);
      44           3 :         }
      45             : 
      46           1 :         return x;
      47           1 :     }
      48             : 
      49             :     template <typename data_t>
      50             :     DataContainer<data_t> BA_GMRES<data_t>::solve(index_t iterations,
      51             :                                                   std::optional<DataContainer<data_t>> x0)
      52           5 :     {
      53           5 :         auto x = DataContainer<data_t>(_A->getDomainDescriptor());
      54           5 :         if (x0.has_value()) {
      55           3 :             x = *x0;
      56           3 :         } else {
      57           2 :             x = 0;
      58           2 :         }
      59             : 
      60           5 :         detail::CalcRFn<data_t> calc_r0 =
      61           5 :             [](const LinearOperator<data_t>& A, const LinearOperator<data_t>& B,
      62           5 :                const DataContainer<data_t>& b,
      63           5 :                const DataContainer<data_t>& x) -> DataContainer<data_t> {
      64           5 :             auto Bb = B.apply(b);
      65           5 :             auto Ax = A.apply(x);
      66           5 :             auto BAx = B.apply(Ax);
      67             : 
      68           5 :             auto r0 = Bb - BAx;
      69           5 :             return r0;
      70           5 :         };
      71             : 
      72           5 :         detail::CalcQFn<data_t> calc_q =
      73           5 :             [](const LinearOperator<data_t>& A, const LinearOperator<data_t>& B,
      74          34 :                const DataContainer<data_t>& w_k) -> DataContainer<data_t> {
      75          34 :             auto Aw_k = A.apply(w_k);
      76          34 :             auto q = B.apply(Aw_k);
      77          34 :             return q;
      78          34 :         };
      79             : 
      80           5 :         detail::CalcXFn<data_t> calc_x =
      81           5 :             [](const LinearOperator<data_t>&, const DataContainer<data_t>& x,
      82          34 :                const DataContainer<data_t>& wy) -> DataContainer<data_t> {
      83          34 :             auto x_k = x + wy;
      84          34 :             return x_k;
      85          34 :         };
      86             : 
      87           5 :         return detail::gmres("BA_GMRES", _A, _B, _b, _epsilon, x, iterations, calc_r0, calc_q,
      88           5 :                              calc_x);
      89           5 :     }
      90             : 
      91             :     template <typename data_t>
      92             :     BA_GMRES<data_t>* BA_GMRES<data_t>::cloneImpl() const
      93           1 :     {
      94           1 :         return new BA_GMRES(*_A, *_B, _b, _epsilon);
      95           1 :     }
      96             : 
      97             :     template <typename data_t>
      98             :     bool BA_GMRES<data_t>::isEqual(const Solver<data_t>& other) const
      99           1 :     {
     100             :         // This is basically stolen from CG
     101             : 
     102           1 :         auto otherGMRES = downcast_safe<BA_GMRES>(&other);
     103             : 
     104           1 :         if (!otherGMRES)
     105           0 :             return false;
     106             : 
     107           1 :         if (_epsilon != otherGMRES->_epsilon)
     108           0 :             return false;
     109             : 
     110           1 :         if ((_A && !otherGMRES->_A) || (!_A && otherGMRES->_A))
     111           0 :             return false;
     112             : 
     113           1 :         if (_A && otherGMRES->_A)
     114           1 :             if (*_A != *otherGMRES->_A)
     115           0 :                 return false;
     116             : 
     117           1 :         if ((_B && !otherGMRES->_B) || (!_B && otherGMRES->_B))
     118           0 :             return false;
     119             : 
     120           1 :         if (_B && otherGMRES->_B)
     121           1 :             if (*_B != *otherGMRES->_B)
     122           0 :                 return false;
     123             : 
     124           1 :         if (_b != otherGMRES->_b)
     125           0 :             return false;
     126             : 
     127           1 :         return true;
     128           1 :     }
     129             : 
     130             :     // ------------------------------------------
     131             :     // explicit template instantiation
     132             :     template class BA_GMRES<float>;
     133             :     template class BA_GMRES<double>;
     134             : 
     135             : } // namespace elsa

Generated by: LCOV version 1.14