LCOV - code coverage report
Current view: top level - elsa/projectors - BinaryMethod.cpp (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 43 59 72.9 %
Date: 2024-05-15 03:55:36 Functions: 10 30 33.3 %

          Line data    Source code
       1             : #include "BinaryMethod.h"
       2             : #include "Timer.h"
       3             : #include "TraverseAABB.h"
       4             : #include "TypeCasts.hpp"
       5             : 
       6             : #include <stdexcept>
       7             : #include <type_traits>
       8             : 
       9             : namespace elsa
      10             : {
      11             :     template <typename data_t>
      12             :     BinaryMethod<data_t>::BinaryMethod(const VolumeDescriptor& domainDescriptor,
      13             :                                        const DetectorDescriptor& rangeDescriptor)
      14             :         : LinearOperator<data_t>(domainDescriptor, rangeDescriptor),
      15             :           _boundingBox{domainDescriptor.getNumberOfCoefficientsPerDimension()},
      16             :           _detectorDescriptor(static_cast<DetectorDescriptor&>(*_rangeDescriptor)),
      17             :           _volumeDescriptor(static_cast<VolumeDescriptor&>(*_domainDescriptor))
      18          46 :     {
      19             :         // sanity checks
      20          46 :         auto dim = _domainDescriptor->getNumberOfDimensions();
      21          46 :         if (dim < 2 || dim > 3)
      22           0 :             throw InvalidArgumentError("BinaryMethod: only supporting 2d/3d operations");
      23             : 
      24          46 :         if (dim != _rangeDescriptor->getNumberOfDimensions())
      25           0 :             throw InvalidArgumentError("BinaryMethod: domain and range dimension need to match");
      26             : 
      27          46 :         if (_detectorDescriptor.getNumberOfGeometryPoses() == 0)
      28           0 :             throw InvalidArgumentError("BinaryMethod: rangeDescriptor without any geometry");
      29          46 :     }
      30             : 
      31             :     template <typename data_t>
      32             :     void BinaryMethod<data_t>::applyImpl(const DataContainer<data_t>& x,
      33             :                                          DataContainer<data_t>& Ax) const
      34          71 :     {
      35          71 :         Timer t("BinaryMethod", "apply");
      36          71 :         traverseVolume<false>(x, Ax);
      37          71 :     }
      38             : 
      39             :     template <typename data_t>
      40             :     void BinaryMethod<data_t>::applyAdjointImpl(const DataContainer<data_t>& y,
      41             :                                                 DataContainer<data_t>& Aty) const
      42          41 :     {
      43          41 :         Timer t("BinaryMethod", "applyAdjoint");
      44          41 :         traverseVolume<true>(y, Aty);
      45          41 :     }
      46             : 
      47             :     template <typename data_t>
      48             :     BinaryMethod<data_t>* BinaryMethod<data_t>::cloneImpl() const
      49           2 :     {
      50           2 :         return new BinaryMethod(_volumeDescriptor, _detectorDescriptor);
      51           2 :     }
      52             : 
      53             :     template <typename data_t>
      54             :     bool BinaryMethod<data_t>::isEqual(const LinearOperator<data_t>& other) const
      55           0 :     {
      56           0 :         if (!LinearOperator<data_t>::isEqual(other))
      57           0 :             return false;
      58             : 
      59           0 :         auto otherBM = downcast_safe<BinaryMethod>(&other);
      60           0 :         if (!otherBM)
      61           0 :             return false;
      62             : 
      63           0 :         return true;
      64           0 :     }
      65             : 
      66             :     template <typename data_t>
      67             :     template <bool adjoint>
      68             :     void BinaryMethod<data_t>::traverseVolume(const DataContainer<data_t>& vector,
      69             :                                               DataContainer<data_t>& result) const
      70         112 :     {
      71         112 :         if (_domainDescriptor->getNumberOfDimensions() == 2) {
      72          86 :             return doTraverseVolume<adjoint, 2>(vector, result);
      73          86 :         } else if (_domainDescriptor->getNumberOfDimensions() == 3) {
      74          26 :             return doTraverseVolume<adjoint, 3>(vector, result);
      75          26 :         }
      76         112 :     }
      77             : 
      78             :     template <typename data_t>
      79             :     template <bool adjoint, int dim>
      80             :     void BinaryMethod<data_t>::doTraverseVolume(const DataContainer<data_t>& vector,
      81             :                                                 DataContainer<data_t>& result) const
      82         112 :     {
      83         112 :         const index_t maxIterations = adjoint ? vector.getSize() : result.getSize();
      84             : 
      85         112 :         if constexpr (adjoint) {
      86          41 :             result = 0; // initialize volume to 0, because we are not going to hit every voxel!
      87          41 :         }
      88             : 
      89             :         // --> loop either over every voxel that should be updated or every detector
      90             :         // cell that should be calculated
      91         112 : #pragma omp parallel for
      92         112 :         for (index_t rangeIndex = 0; rangeIndex < maxIterations; ++rangeIndex) {
      93             : 
      94             :             // --> get the current ray to the detector center (from reference to DetectorDescriptor)
      95           0 :             auto ray = _detectorDescriptor.computeRayFromDetectorCoord(rangeIndex);
      96             : 
      97             :             // --> setup traversal algorithm
      98           0 :             TraverseAABB<dim> traverse(_boundingBox, ray,
      99           0 :                                        _domainDescriptor->getNumberOfCoefficientsPerDimension());
     100             : 
     101           0 :             if constexpr (!adjoint)
     102        5080 :                 result[rangeIndex] = 0;
     103             : 
     104      557145 :             while (traverse.isInBoundingBox()) {
     105             :                 // --> initial index to access the data vector
     106      557145 :                 auto dataIndexForCurrentVoxel =
     107      557145 :                     _domainDescriptor->getIndexFromCoordinate(traverse.getCurrentVoxel());
     108             : 
     109             :                 // --> update result depending on the operation performed
     110      557145 :                 if constexpr (adjoint)
     111      282865 : #pragma omp atomic
     112      278203 :                     result[dataIndexForCurrentVoxel] += vector[rangeIndex];
     113      278203 :                 else
     114      278203 :                     result[rangeIndex] += vector[dataIndexForCurrentVoxel];
     115             : 
     116      557145 :                 traverse.updateTraverse();
     117      557145 :             }
     118           0 :         } // end for
     119         112 :     }
     120             : 
     121             :     // ------------------------------------------
     122             :     // explicit template instantiation
     123             :     template class BinaryMethod<float>;
     124             :     template class BinaryMethod<double>;
     125             : 
     126             : } // namespace elsa

Generated by: LCOV version 1.14