LCOV - code coverage report
Current view: top level - elsa/projectors - JosephsMethod.cpp (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 109 122 89.3 %
Date: 2024-05-15 03:55:36 Functions: 31 64 48.4 %

          Line data    Source code
       1             : #include "JosephsMethod.h"
       2             : #include "Timer.h"
       3             : #include "DrivingDirectionTraversal.h"
       4             : #include "Error.h"
       5             : #include "TypeCasts.hpp"
       6             : 
       7             : #include <type_traits>
       8             : 
       9             : namespace elsa
      10             : {
      11             :     template <typename data_t>
      12             :     JosephsMethod<data_t>::JosephsMethod(const VolumeDescriptor& domainDescriptor,
      13             :                                          const DetectorDescriptor& rangeDescriptor)
      14             :         : base_type(domainDescriptor, rangeDescriptor)
      15          62 :     {
      16          62 :         auto dim = domainDescriptor.getNumberOfDimensions();
      17          62 :         if (dim != 2 && dim != 3) {
      18           0 :             throw InvalidArgumentError("JosephsMethod:only supporting 2d/3d operations");
      19           0 :         }
      20             : 
      21          62 :         if (dim != rangeDescriptor.getNumberOfDimensions()) {
      22           0 :             throw InvalidArgumentError("JosephsMethod: domain and range dimension need to match");
      23           0 :         }
      24             : 
      25          62 :         if (rangeDescriptor.getNumberOfGeometryPoses() == 0) {
      26           0 :             throw InvalidArgumentError("JosephsMethod: geometry list was empty");
      27           0 :         }
      28          62 :     }
      29             : 
      30             :     template <int dim>
      31             :     bool isInAABB(const IndexArray_t<dim>& indices, const IndexArray_t<dim>& aabbMin,
      32             :                   const IndexArray_t<dim>& aabbMax)
      33     2842480 :     {
      34     2842480 :         return (indices >= aabbMin && indices < aabbMax).all();
      35     2842480 :     }
      36             : 
      37             :     template <int dim>
      38             :     std::pair<RealArray_t<dim>, RealArray_t<dim>>
      39             :         getLinearInterpolationWeights(const RealArray_t<dim>& currentPos,
      40             :                                       const IndexArray_t<dim>& voxelFloor,
      41             :                                       const index_t drivingDirection)
      42     1796424 :     {
      43             :         // subtract 0.5 because the weight calculation assumes indices that refer to the center of
      44             :         // the voxels, while elsa works with the lower corners of the indices.
      45     1796424 :         RealArray_t<dim> complement_weight = currentPos - voxelFloor.template cast<real_t>() - 0.5f;
      46     1796424 :         RealArray_t<dim> weight = RealArray_t<dim>{1} - complement_weight;
      47             :         // set weights along drivingDirection to 1 so that the interpolation does not have to handle
      48             :         // the drivingDirection as a special case
      49     1796424 :         weight(drivingDirection) = 1;
      50     1796424 :         complement_weight(drivingDirection) = 1;
      51     1796424 :         return std::make_pair(weight, complement_weight);
      52     1796424 :     }
      53             : 
      54             :     template <int dim>
      55             :     index_t coord2Idx(IndexArray_t<dim> coord, IndexArray_t<dim> strides)
      56     3467706 :     {
      57     3467706 :         return (coord * strides).sum();
      58     3467706 :     }
      59             : 
      60             :     template <typename data_t, int dim, class Fn>
      61             :     void doInterpolation(const IndexArray_t<dim>& voxelFloor, const IndexArray_t<dim>& voxelCeil,
      62             :                          const RealArray_t<dim>& weight, const RealArray_t<dim>& complement_weight,
      63             :                          const IndexArray_t<dim>& aabbMin, const IndexArray_t<dim>& aabbMax, Fn fn)
      64     1826795 :     {
      65     3521043 :         auto clip = [](auto coord, auto lower, auto upper) { return coord.min(upper).max(lower); };
      66     1826795 :         IndexArray_t<dim> tempIndices;
      67     1826795 :         if constexpr (dim == 2) {
      68     3491057 :             auto interpol = [&](auto v1, auto v2, auto w1, auto w2) {
      69     3491057 :                 tempIndices << v1, v2;
      70             : 
      71     3491057 :                 bool is_in_aab = isInAABB(tempIndices, aabbMin, aabbMax);
      72     3491057 :                 tempIndices = clip(tempIndices, aabbMin, (aabbMax - 1));
      73     3491057 :                 auto weight = is_in_aab * w1 * w2;
      74     3491057 :                 fn(tempIndices, weight);
      75     3491057 :             };
      76             : 
      77          96 :             interpol(voxelFloor[0], voxelCeil[1], weight[0], complement_weight[1]);
      78          96 :             interpol(voxelCeil[0], voxelFloor[1], complement_weight[0], weight[1]);
      79          96 :         } else {
      80         386 :             auto interpol = [&](auto v1, auto v2, auto v3, auto w1, auto w2, auto w3) {
      81         386 :                 tempIndices << v1, v2, v3;
      82             : 
      83         386 :                 bool is_in_aab = isInAABB(tempIndices, aabbMin, aabbMax);
      84         386 :                 tempIndices = clip(tempIndices, aabbMin, (aabbMax - 1));
      85         386 :                 auto weight = is_in_aab * w1 * w2 * w3;
      86         386 :                 fn(tempIndices, weight);
      87         386 :             };
      88             : 
      89          96 :             interpol(voxelFloor[0], voxelFloor[1], voxelFloor[2], weight[0], weight[1], weight[2]);
      90          96 :             interpol(voxelFloor[0], voxelCeil[1], voxelCeil[2], weight[0], complement_weight[1],
      91          96 :                      complement_weight[2]);
      92          96 :             interpol(voxelCeil[0], voxelFloor[1], voxelCeil[2], complement_weight[0], weight[1],
      93          96 :                      complement_weight[2]);
      94          96 :             interpol(voxelCeil[0], voxelCeil[1], voxelFloor[2], complement_weight[0],
      95          96 :                      complement_weight[1], weight[2]);
      96          96 :         }
      97     1826795 :     }
      98             : 
      99             :     template <typename data_t>
     100             :     void JosephsMethod<data_t>::forward(const BoundingBox& aabb, const DataContainer<data_t>& x,
     101             :                                         DataContainer<data_t>& Ax) const
     102         210 :     {
     103         210 :         Timer timeguard("JosephsMethod", "apply");
     104         210 :         if (aabb.dim() == 2) {
     105         194 :             traverseVolume<false, 2>(aabb, x, Ax);
     106         194 :         } else if (aabb.dim() == 3) {
     107          16 :             traverseVolume<false, 3>(aabb, x, Ax);
     108          16 :         }
     109         210 :     }
     110             : 
     111             :     template <typename data_t>
     112             :     void JosephsMethod<data_t>::backward(const BoundingBox& aabb, const DataContainer<data_t>& y,
     113             :                                          DataContainer<data_t>& Aty) const
     114         144 :     {
     115         144 :         Timer timeguard("JosephsMethod", "applyAdjoint");
     116         144 :         if (aabb.dim() == 2) {
     117         135 :             traverseVolume<true, 2>(aabb, y, Aty);
     118         135 :         } else if (aabb.dim() == 3) {
     119           9 :             traverseVolume<true, 3>(aabb, y, Aty);
     120           9 :         }
     121         144 :     }
     122             : 
     123             :     template <typename data_t>
     124             :     JosephsMethod<data_t>* JosephsMethod<data_t>::_cloneImpl() const
     125          23 :     {
     126          23 :         return new self_type(downcast<VolumeDescriptor>(*this->_domainDescriptor),
     127          23 :                              downcast<DetectorDescriptor>(*this->_rangeDescriptor));
     128          23 :     }
     129             : 
     130             :     template <typename data_t>
     131             :     bool JosephsMethod<data_t>::_isEqual(const LinearOperator<data_t>& other) const
     132           4 :     {
     133           4 :         if (!LinearOperator<data_t>::isEqual(other))
     134           0 :             return false;
     135             : 
     136           4 :         auto otherJM = downcast_safe<JosephsMethod>(&other);
     137           4 :         return static_cast<bool>(otherJM);
     138           4 :     }
     139             : 
     140             :     template <typename data_t>
     141             :     template <bool adjoint, int dim>
     142             :     void JosephsMethod<data_t>::traverseVolume(const BoundingBox& aabb,
     143             :                                                const DataContainer<data_t>& vector,
     144             :                                                DataContainer<data_t>& result) const
     145         354 :     {
     146         354 :         if constexpr (adjoint)
     147         144 :             result = 0;
     148             : 
     149         354 :         const auto& domain = adjoint ? result.getDataDescriptor() : vector.getDataDescriptor();
     150         354 :         const auto& range = downcast<DetectorDescriptor>(adjoint ? vector.getDataDescriptor()
     151         354 :                                                                  : result.getDataDescriptor());
     152             : 
     153         354 :         const IndexArray_t<dim> strides = domain.getProductOfCoefficientsPerDimension();
     154         354 :         const auto sizeOfRange = range.getNumberOfCoefficients();
     155             : 
     156         354 :         const IndexArray_t<dim> aabbMin = aabb.min().template cast<index_t>();
     157         354 :         const IndexArray_t<dim> aabbMax = aabb.max().template cast<index_t>();
     158             : 
     159             :         // iterate over all rays
     160         354 : #pragma omp parallel for
     161         354 :         for (index_t ir = 0; ir < sizeOfRange; ir++) {
     162           0 :             const auto ray = range.computeRayFromDetectorCoord(ir);
     163             : 
     164             :             // --> setup traversal algorithm
     165             : 
     166           0 :             DrivingDirectionTraversal<dim> traverse(aabb, ray);
     167           0 :             const index_t drivingDirection = traverse.getDrivingDirection();
     168           0 :             const data_t intersection = traverse.getIntersectionLength();
     169             : 
     170           0 :             if constexpr (!adjoint)
     171       98674 :                 result[ir] = 0;
     172             : 
     173             :             // Make steps through the volume
     174     1818669 :             while (traverse.isInBoundingBox()) {
     175     1818669 :                 const IndexArray_t<dim> voxelFloor = traverse.getCurrentVoxelFloor();
     176     1818669 :                 const IndexArray_t<dim> voxelCeil = traverse.getCurrentVoxelCeil();
     177     1818669 :                 const auto [weight, complement_weight] = getLinearInterpolationWeights(
     178     1818669 :                     traverse.getCurrentPos(), voxelFloor, drivingDirection);
     179             : 
     180     1818669 :                 if constexpr (adjoint) {
     181     1046494 :                     doInterpolation<data_t>(voxelFloor, voxelCeil, weight, complement_weight,
     182     1472642 :                                             aabbMin, aabbMax, [&](const auto& coord, auto wght) {
     183     1472642 : #pragma omp atomic
     184     1472642 :                                                 result[coord2Idx(coord, strides)] +=
     185     1472642 :                                                     vector[ir] * intersection * wght;
     186     1472642 :                                             });
     187     1046494 :                 } else {
     188     1046494 :                     doInterpolation<data_t>(voxelFloor, voxelCeil, weight, complement_weight,
     189     2000701 :                                             aabbMin, aabbMax, [&](const auto& coord, auto wght) {
     190     2000701 :                                                 result[ir] += vector[coord2Idx(coord, strides)]
     191     2000701 :                                                               * intersection * wght;
     192     2000701 :                                             });
     193     1046494 :                 }
     194             :                 // update Traverse
     195     1818669 :                 traverse.updateTraverse();
     196     1818669 :             }
     197           0 :         }
     198         354 :     }
     199             : 
     200             :     // ------------------------------------------
     201             :     // explicit template instantiation
     202             :     template class JosephsMethod<float>;
     203             :     template class JosephsMethod<double>;
     204             : 
     205             : } // namespace elsa

Generated by: LCOV version 1.14