LCOV - code coverage report
Current view: top level - elsa/solvers - SIRT.cpp (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 23 29 79.3 %
Date: 2024-12-21 07:37:52 Functions: 8 10 80.0 %

          Line data    Source code
       1             : #include "SIRT.h"
       2             : #include "Scaling.h"
       3             : #include "elsaDefines.h"
       4             : 
       5             : namespace elsa
       6             : {
       7             :     template <typename data_t>
       8             :     SIRT<data_t>::SIRT(const LinearOperator<data_t>& A, const DataContainer<data_t>& b,
       9             :                        SelfType_t<data_t> stepSize)
      10             :         : LandweberIteration<data_t>(A, b, stepSize)
      11           4 :     {
      12           4 :         this->name_ = "SIRT";
      13           4 :     }
      14             : 
      15             :     template <typename data_t>
      16             :     SIRT<data_t>::SIRT(const LinearOperator<data_t>& A, const DataContainer<data_t>& b)
      17             :         : LandweberIteration<data_t>(A, b)
      18           0 :     {
      19           0 :         this->name_ = "SIRT";
      20           0 :     }
      21             : 
      22             :     template <class data_t>
      23             :     std::unique_ptr<LinearOperator<data_t>>
      24             :         SIRT<data_t>::setupOperators(const LinearOperator<data_t>& A) const
      25           2 :     {
      26           2 :         const auto& domain = A.getDomainDescriptor();
      27           2 :         const auto& range = A.getRangeDescriptor();
      28             : 
      29           2 :         auto rowsum = A.apply(domain.template element<data_t>().one());
      30             : 
      31             :         // Prevent division by zero (and hence NaNs), by slightly lifting everything up a touch
      32           2 :         rowsum += data_t{1e-10f};
      33             : 
      34           2 :         Scaling<data_t> M(data_t{1.} / rowsum);
      35             : 
      36           2 :         auto colsum = A.applyAdjoint(range.template element<data_t>().one());
      37           2 :         Scaling<data_t> T(data_t{1.} / colsum);
      38             : 
      39           2 :         return (T * adjoint(A) * M).clone();
      40           2 :     }
      41             : 
      42             :     template <typename data_t>
      43             :     bool SIRT<data_t>::isEqual(const Solver<data_t>& other) const
      44           2 :     {
      45           2 :         if (!LandweberIteration<data_t>::isEqual(other))
      46           0 :             return false;
      47             : 
      48           2 :         auto otherSolver = downcast_safe<SIRT<data_t>>(&other);
      49           2 :         return static_cast<bool>(otherSolver);
      50           2 :     }
      51             : 
      52             :     template <typename data_t>
      53             :     SIRT<data_t>* SIRT<data_t>::cloneImpl() const
      54           2 :     {
      55           2 :         if (this->stepSize_.isInitialized()) {
      56           2 :             return new SIRT(*this->A_, this->b_, *this->stepSize_);
      57           2 :         } else {
      58           0 :             return new SIRT(*this->A_, this->b_);
      59           0 :         }
      60           2 :     }
      61             : 
      62             :     // ------------------------------------------
      63             :     // explicit template instantiation
      64             :     template class SIRT<float>;
      65             :     template class SIRT<double>;
      66             : } // namespace elsa

Generated by: LCOV version 1.14