LCOV - code coverage report
Current view: top level - projectors - BinaryMethod.cpp (source / functions) Hit Total Coverage
Test: test_coverage.info.cleaned Lines: 23 33 69.7 %
Date: 2022-05-27 02:48:28 Functions: 6 14 42.9 %

          Line data    Source code
       1             : #include "BinaryMethod.h"
       2             : #include "Timer.h"
       3             : #include "TraverseAABB.h"
       4             : #include "TypeCasts.hpp"
       5             : 
       6             : #include <stdexcept>
       7             : #include <type_traits>
       8             : 
       9             : namespace elsa
      10             : {
      11             :     template <typename data_t>
      12          47 :     BinaryMethod<data_t>::BinaryMethod(const VolumeDescriptor& domainDescriptor,
      13             :                                        const DetectorDescriptor& rangeDescriptor)
      14             :         : LinearOperator<data_t>(domainDescriptor, rangeDescriptor),
      15             :           _boundingBox{domainDescriptor.getNumberOfCoefficientsPerDimension()},
      16          47 :           _detectorDescriptor(static_cast<DetectorDescriptor&>(*_rangeDescriptor)),
      17          47 :           _volumeDescriptor(static_cast<VolumeDescriptor&>(*_domainDescriptor))
      18             :     {
      19             :         // sanity checks
      20          47 :         auto dim = _domainDescriptor->getNumberOfDimensions();
      21          47 :         if (dim < 2 || dim > 3)
      22           0 :             throw InvalidArgumentError("BinaryMethod: only supporting 2d/3d operations");
      23             : 
      24          47 :         if (dim != _rangeDescriptor->getNumberOfDimensions())
      25           0 :             throw InvalidArgumentError("BinaryMethod: domain and range dimension need to match");
      26             : 
      27          47 :         if (_detectorDescriptor.getNumberOfGeometryPoses() == 0)
      28           0 :             throw InvalidArgumentError("BinaryMethod: rangeDescriptor without any geometry");
      29          47 :     }
      30             : 
      31             :     template <typename data_t>
      32          72 :     void BinaryMethod<data_t>::applyImpl(const DataContainer<data_t>& x,
      33             :                                          DataContainer<data_t>& Ax) const
      34             :     {
      35         216 :         Timer t("BinaryMethod", "apply");
      36          72 :         traverseVolume<false>(x, Ax);
      37          72 :     }
      38             : 
      39             :     template <typename data_t>
      40          42 :     void BinaryMethod<data_t>::applyAdjointImpl(const DataContainer<data_t>& y,
      41             :                                                 DataContainer<data_t>& Aty) const
      42             :     {
      43         126 :         Timer t("BinaryMethod", "applyAdjoint");
      44          42 :         traverseVolume<true>(y, Aty);
      45          42 :     }
      46             : 
      47             :     template <typename data_t>
      48           2 :     BinaryMethod<data_t>* BinaryMethod<data_t>::cloneImpl() const
      49             :     {
      50           2 :         return new BinaryMethod(_volumeDescriptor, _detectorDescriptor);
      51             :     }
      52             : 
      53             :     template <typename data_t>
      54           0 :     bool BinaryMethod<data_t>::isEqual(const LinearOperator<data_t>& other) const
      55             :     {
      56           0 :         if (!LinearOperator<data_t>::isEqual(other))
      57           0 :             return false;
      58             : 
      59           0 :         auto otherBM = downcast_safe<BinaryMethod>(&other);
      60           0 :         if (!otherBM)
      61           0 :             return false;
      62             : 
      63           0 :         return true;
      64             :     }
      65             : 
      66             :     template <typename data_t>
      67             :     template <bool adjoint>
      68         114 :     void BinaryMethod<data_t>::traverseVolume(const DataContainer<data_t>& vector,
      69             :                                               DataContainer<data_t>& result) const
      70             :     {
      71         114 :         const index_t maxIterations = adjoint ? vector.getSize() : result.getSize();
      72             : 
      73             :         if constexpr (adjoint) {
      74          42 :             result = 0; // initialize volume to 0, because we are not going to hit every voxel!
      75             :         }
      76             : 
      77             :         // --> loop either over every voxel that should be updated or every detector
      78             :         // cell that should be calculated
      79         114 : #pragma omp parallel for
      80             :         for (index_t rangeIndex = 0; rangeIndex < maxIterations; ++rangeIndex) {
      81             : 
      82             :             // --> get the current ray to the detector center (from reference to DetectorDescriptor)
      83             :             auto ray = _detectorDescriptor.computeRayFromDetectorCoord(rangeIndex);
      84             : 
      85             :             // --> setup traversal algorithm
      86             :             TraverseAABB traverse(_boundingBox, ray);
      87             : 
      88             :             if constexpr (!adjoint)
      89             :                 result[rangeIndex] = 0;
      90             : 
      91             :             while (traverse.isInBoundingBox()) {
      92             :                 // --> initial index to access the data vector
      93             :                 auto dataIndexForCurrentVoxel =
      94             :                     _domainDescriptor->getIndexFromCoordinate(traverse.getCurrentVoxel());
      95             : 
      96             :                 // --> update result depending on the operation performed
      97             :                 if constexpr (adjoint)
      98             : #pragma omp atomic
      99             :                     result[dataIndexForCurrentVoxel] += vector[rangeIndex];
     100             :                 else
     101             :                     result[rangeIndex] += vector[dataIndexForCurrentVoxel];
     102             : 
     103             :                 traverse.updateTraverse();
     104             :             }
     105             :         } // end for
     106         114 :     }
     107             : 
     108             :     // ------------------------------------------
     109             :     // explicit template instantiation
     110             :     template class BinaryMethod<float>;
     111             :     template class BinaryMethod<double>;
     112             : 
     113             : } // namespace elsa

Generated by: LCOV version 1.15