LCOV - code coverage report
Current view: top level - elsa/projectors - LutProjector.h (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 62 69 89.9 %
Date: 2024-05-16 04:22:26 Functions: 11 20 55.0 %

          Line data    Source code
       1             : #pragma once
       2             : 
       3             : #include "elsaDefines.h"
       4             : #include "Timer.h"
       5             : #include "SliceTraversal.h"
       6             : #include "LinearOperator.h"
       7             : #include "VolumeDescriptor.h"
       8             : #include "DetectorDescriptor.h"
       9             : #include "DataContainer.h"
      10             : #include "BoundingBox.h"
      11             : #include "Logger.h"
      12             : #include "Blobs.h"
      13             : #include "BSplines.h"
      14             : #include "CartesianIndices.h"
      15             : 
      16             : #include "XrayProjector.h"
      17             : 
      18             : #include "spdlog/fmt/fmt.h"
      19             : #include "spdlog/fmt/ostr.h"
      20             : 
      21             : namespace elsa
      22             : {
      23             :     template <typename data_t, typename Derived>
      24             :     class LutProjector;
      25             : 
      26             :     template <typename data_t = real_t>
      27             :     class BlobProjector;
      28             : 
      29             :     template <typename data_t = real_t>
      30             :     class BSplineProjector;
      31             : 
      32             :     template <typename data_t>
      33             :     struct XrayProjectorInnerTypes<BlobProjector<data_t>> {
      34             :         using value_type = data_t;
      35             :         using forward_tag = ray_driven_tag;
      36             :         using backward_tag = ray_driven_tag;
      37             :     };
      38             : 
      39             :     template <typename data_t>
      40             :     struct XrayProjectorInnerTypes<BSplineProjector<data_t>> {
      41             :         using value_type = data_t;
      42             :         using forward_tag = ray_driven_tag;
      43             :         using backward_tag = ray_driven_tag;
      44             :     };
      45             : 
      46             :     template <typename data_t, typename Derived>
      47             :     class LutProjector : public XrayProjector<Derived>
      48             :     {
      49             :     public:
      50             :         using self_type = LutProjector<data_t, Derived>;
      51             :         using base_type = XrayProjector<Derived>;
      52             :         using value_type = typename base_type::value_type;
      53             :         using forward_tag = typename base_type::forward_tag;
      54             :         using backward_tag = typename base_type::backward_tag;
      55             : 
      56             :         LutProjector(const VolumeDescriptor& domainDescriptor,
      57             :                      const DetectorDescriptor& rangeDescriptor)
      58             :             : base_type(domainDescriptor, rangeDescriptor)
      59         871 :         {
      60             :             // sanity checks
      61         871 :             auto dim = domainDescriptor.getNumberOfDimensions();
      62         871 :             if (dim < 2 || dim > 3) {
      63           0 :                 throw InvalidArgumentError("LutProjector: only supporting 2d/3d operations");
      64           0 :             }
      65             : 
      66         871 :             if (dim != rangeDescriptor.getNumberOfDimensions()) {
      67           0 :                 throw InvalidArgumentError(
      68           0 :                     "LutProjector: domain and range dimension need to match");
      69           0 :             }
      70             : 
      71         871 :             if (rangeDescriptor.getNumberOfGeometryPoses() == 0) {
      72           0 :                 throw InvalidArgumentError("LutProjector: rangeDescriptor without any geometry");
      73           0 :             }
      74         871 :         }
      75             : 
      76             :         /// default destructor
      77         871 :         ~LutProjector() override = default;
      78             : 
      79             :     private:
      80             :         /// apply the binary method (i.e. forward projection)
      81             :         data_t traverseRayForward(const BoundingBox& boundingbox, const RealRay_t& ray,
      82             :                                   const DataContainer<data_t>& x) const
      83        1612 :         {
      84        1612 :             const IndexVector_t lower = boundingbox.min().template cast<index_t>();
      85        1612 :             const IndexVector_t upper = boundingbox.max().template cast<index_t>();
      86        1612 :             const auto support = this->self().support();
      87             : 
      88        1612 :             index_t leadingdir = 0;
      89        1612 :             ray.direction().array().cwiseAbs().maxCoeff(&leadingdir);
      90             : 
      91        1612 :             IndexVector_t distvec = IndexVector_t::Constant(lower.size(), support);
      92        1612 :             distvec[leadingdir] = 0;
      93             : 
      94        1612 :             auto rangeVal = data_t(0);
      95             : 
      96             :             // Expand bounding box as rays have larger support now
      97        1612 :             auto aabb = boundingbox;
      98        1612 :             aabb.min().array() -= static_cast<real_t>(support);
      99        1612 :             aabb.min()[leadingdir] += static_cast<real_t>(support);
     100             : 
     101        1612 :             aabb.max().array() += static_cast<real_t>(support);
     102        1612 :             aabb.max()[leadingdir] -= static_cast<real_t>(support);
     103             : 
     104             :             // Keep this here, as it saves us a couple of allocations on clang
     105        1612 :             CartesianIndices neighbours(upper);
     106             : 
     107             :             // --> setup traversal algorithm
     108        1612 :             SliceTraversal traversal(boundingbox, ray);
     109             : 
     110        7867 :             for (const auto& curVoxel : traversal) {
     111        7867 :                 neighbours = neighbours_in_slice(curVoxel, distvec, lower, upper);
     112       49905 :                 for (auto neighbour : neighbours) {
     113             :                     // Correct position, such that the distance is still correct
     114       49905 :                     const auto correctedPos = neighbour.template cast<real_t>().array() + 0.5;
     115       49905 :                     const auto distance = ray.distance(correctedPos);
     116       49905 :                     const auto weight = this->self().weight(distance);
     117             : 
     118       49905 :                     rangeVal += weight * x(neighbour);
     119       49905 :                 }
     120        7867 :             }
     121             : 
     122        1612 :             return rangeVal;
     123        1612 :         }
     124             : 
     125             :         void traverseRayBackward(const BoundingBox& boundingbox, const RealRay_t& ray,
     126             :                                  const value_type& detectorValue, DataContainer<data_t>& Aty) const
     127          15 :         {
     128          15 :             const IndexVector_t lower = boundingbox.min().template cast<index_t>();
     129          15 :             const IndexVector_t upper = boundingbox.max().template cast<index_t>();
     130          15 :             const auto support = this->self().support();
     131             : 
     132          15 :             index_t leadingdir = 0;
     133          15 :             ray.direction().array().cwiseAbs().maxCoeff(&leadingdir);
     134             : 
     135          15 :             IndexVector_t distvec = IndexVector_t::Constant(lower.size(), support);
     136          15 :             distvec[leadingdir] = 0;
     137             : 
     138             :             // Expand bounding box as rays have larger support now
     139          15 :             auto aabb = boundingbox;
     140          15 :             aabb.min().array() -= static_cast<real_t>(support);
     141          15 :             aabb.min()[leadingdir] += static_cast<real_t>(support);
     142             : 
     143          15 :             aabb.max().array() += static_cast<real_t>(support);
     144          15 :             aabb.max()[leadingdir] -= static_cast<real_t>(support);
     145             : 
     146             :             // Keep this here, as it saves us a couple of allocations on clang
     147          15 :             CartesianIndices neighbours(upper);
     148             : 
     149             :             // --> setup traversal algorithm
     150          15 :             SliceTraversal traversal(aabb, ray);
     151             : 
     152          75 :             for (const auto& curVoxel : traversal) {
     153          75 :                 neighbours = neighbours_in_slice(curVoxel, distvec, lower, upper);
     154         357 :                 for (auto neighbour : neighbours) {
     155             :                     // Correct position, such that the distance is still correct
     156         357 :                     const auto correctedPos = neighbour.template cast<real_t>().array() + 0.5;
     157         357 :                     const auto distance = ray.distance(correctedPos);
     158         357 :                     const auto weight = this->self().weight(distance);
     159             : 
     160         357 : #pragma omp atomic
     161         357 :                     Aty(neighbour) += weight * detectorValue;
     162         357 :                 }
     163          75 :             }
     164          15 :         }
     165             : 
     166             :         /// implement the polymorphic clone operation
     167             :         LutProjector<data_t, Derived>* _cloneImpl() const
     168             :         {
     169             :             return new LutProjector(downcast<VolumeDescriptor>(*this->_domainDescriptor),
     170             :                                     downcast<DetectorDescriptor>(*this->_rangeDescriptor));
     171             :         }
     172             : 
     173             :         /// implement the polymorphic comparison operation
     174             :         bool _isEqual(const LinearOperator<data_t>& other) const
     175             :         {
     176             :             if (!LinearOperator<data_t>::isEqual(other))
     177             :                 return false;
     178             : 
     179             :             auto otherOp = downcast_safe<LutProjector>(&other);
     180             :             return static_cast<bool>(otherOp);
     181             :         }
     182             : 
     183             :         friend class XrayProjector<Derived>;
     184             :     };
     185             : 
     186             :     template <typename data_t>
     187             :     class BlobProjector : public LutProjector<data_t, BlobProjector<data_t>>
     188             :     {
     189             :     public:
     190             :         using self_type = BlobProjector<data_t>;
     191             : 
     192             :         BlobProjector(const VolumeDescriptor& domainDescriptor,
     193             :                       const DetectorDescriptor& rangeDescriptor,
     194             :                       data_t radius = blobs::DEFAULT_RADIUS, data_t alpha = blobs::DEFAULT_ALPHA,
     195             :                       index_t order = blobs::DEFAULT_ORDER);
     196             : 
     197       51415 :         data_t weight(data_t distance) const { return blob_.get_lut()(distance); }
     198             : 
     199        1595 :         index_t support() const { return static_cast<index_t>(std::ceil(blob_.radius())); }
     200             : 
     201             :         /// implement the polymorphic clone operation
     202             :         BlobProjector<data_t>* _cloneImpl() const;
     203             : 
     204             :         /// implement the polymorphic comparison operation
     205             :         bool _isEqual(const LinearOperator<data_t>& other) const;
     206             : 
     207             :     private:
     208             :         ProjectedBlob<data_t> blob_;
     209             : 
     210             :         using Base = LutProjector<data_t, BlobProjector<data_t>>;
     211             : 
     212             :         friend class XrayProjector<self_type>;
     213             :     };
     214             : 
     215             :     template <typename data_t>
     216             :     class BSplineProjector : public LutProjector<data_t, BSplineProjector<data_t>>
     217             :     {
     218             :     public:
     219             :         using self_type = BlobProjector<data_t>;
     220             : 
     221             :         BSplineProjector(const VolumeDescriptor& domainDescriptor,
     222             :                          const DetectorDescriptor& rangeDescriptor,
     223             :                          index_t order = bspline::DEFAULT_ORDER);
     224             : 
     225             :         data_t weight(data_t distance) const;
     226             : 
     227             :         index_t support() const;
     228             : 
     229             :         /// implement the polymorphic clone operation
     230             :         BSplineProjector<data_t>* _cloneImpl() const;
     231             : 
     232             :         /// implement the polymorphic comparison operation
     233             :         bool _isEqual(const LinearOperator<data_t>& other) const;
     234             : 
     235             :     private:
     236             :         ProjectedBSpline<data_t> bspline_;
     237             : 
     238             :         using Base = LutProjector<data_t, BSplineProjector<data_t>>;
     239             : 
     240             :         friend class XrayProjector<self_type>;
     241             :     };
     242             : } // namespace elsa

Generated by: LCOV version 1.14