LCOV - code coverage report
Current view: top level - elsa/projectors - XrayProjector.h (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 35 41 85.4 %
Date: 2025-01-22 07:37:33 Functions: 29 84 34.5 %

          Line data    Source code
       1             : #pragma once
       2             : 
       3             : #include "LinearOperator.h"
       4             : #include "DetectorDescriptor.h"
       5             : #include "VolumeDescriptor.h"
       6             : #include "BoundingBox.h"
       7             : 
       8             : namespace elsa
       9             : {
      10             :     struct ray_driven_tag {
      11             :     };
      12             :     struct voxel_driven_tag {
      13             :     };
      14             :     struct any_projection_tag {
      15             :     };
      16             : 
      17             :     template <typename T>
      18             :     struct XrayProjectorInnerTypes;
      19             : 
      20             :     /**
      21             :      * @brief Interface class for X-ray based projectors.
      22             :      *
      23             :      * For X-ray CT based methods there are mainly two different implementation methods: ray and
      24             :      * voxel driven methods. The first iterates all rays for each pose of the acquisition
      25             :      * trajectory through the volume. Along the way the ray either accumulates each visited voxels
      26             :      * values (forward) or writes to each visited voxels (backward). The second implementation
      27             :      * methods, iterates over all voxels of the volume and calculates their contribution to
      28             :      * a detector cell.
      29             :      *
      30             :      * Basically, the main difference is the outer most loop of the computation. In the ray driven
      31             :      * case, all rays are iterated and for each ray some calculations are performed. For the voxel
      32             :      * driven approach, all voxels are visited and for each some calculation is performed.
      33             :      *
      34             :      * This base class should aid in the implementation of any X-ray based projector. So, if you
      35             :      * want to implement a new projector, you'd first need to derive from this class, then the
      36             :      * following interface is required:
      37             :      * 1. specialize the `XrayProjectorInnerTypes` class with the following values:
      38             :      *   - value_type
      39             :      *   - forward_tag
      40             :      *   - backward_tag
      41             :      * 2. The class needs to implement
      42             :      *   - `_isEqual(const LinearOperator<data_t>&)` (should call `isEqual` of `LinearOperator`
      43             :      * base)
      44             :      *   - `_cloneImpl()`
      45             :      * 3. If `forward_tag` is equal to `ray_driven_tag`, then the class needs to implement:
      46             :      *   - `data_t traverseRayForward(const BoundingBox&, const RealRay_t&, const
      47             :      * DataContainer<data_t>&) const` it traverses a single ray through the given bounding box and
      48             :      * accumulates the voxel
      49             :      * 4. If `backward_tag` is equal to `ray_driven_tag`, then the class needs to implement:
      50             :      *   - `void traverseRayBackward(const BoundingBox&, const RealRay_t&, const value_type&,
      51             :      *
      52             :      * The interface for voxel based projectors is still WIP, therefore not documented here.
      53             :      *
      54             :      * The `any_projection_tag` should serve for methods, which do not fit any of the two others
      55             :      * (e.g. distance driven projectors), plus the legacy projectors, which can easily fit here for
      56             :      * now and be refactored later.
      57             :      *
      58             :      * At the time of writing, both the API is highly experimental and will evolve quickly.
      59             :      */
      60             :     template <typename D>
      61             :     class XrayProjector : public LinearOperator<typename XrayProjectorInnerTypes<D>::value_type>
      62             :     {
      63             :     public:
      64             :         using derived_type = D;
      65             :         using self_type = XrayProjector<D>;
      66             :         using inner_type = XrayProjectorInnerTypes<derived_type>;
      67             : 
      68             :         using value_type = typename inner_type::value_type;
      69             :         using data_t = value_type;
      70             :         using forward_tag = typename inner_type::forward_tag;
      71             :         using backward_tag = typename inner_type::backward_tag;
      72             : 
      73             :         using base_type = LinearOperator<data_t>;
      74             : 
      75             :         XrayProjector() = delete;
      76        1002 :         ~XrayProjector() override = default;
      77             : 
      78             :         /// TODO: This is basically legacy, the projector itself does not need it...
      79             :         XrayProjector(const VolumeDescriptor& domain, const DetectorDescriptor& range)
      80             :             : base_type(domain, range)
      81        1002 :         {
      82        1002 :         }
      83             : 
      84             :     protected:
      85             :         void applyImpl(const DataContainer<data_t>& x, DataContainer<data_t>& Ax) const override
      86        1138 :         {
      87        1138 :             forward(x, Ax, forward_tag{});
      88        1138 :         }
      89             : 
      90             :         void applyAdjointImpl(const DataContainer<data_t>& y,
      91             :                               DataContainer<data_t>& Aty) const override
      92         193 :         {
      93         193 :             backward(y, Aty, backward_tag{});
      94         193 :         }
      95             : 
      96             :         /// implement the polymorphic comparison operation
      97             :         bool isEqual(const LinearOperator<data_t>& other) const override
      98           4 :         {
      99           4 :             return self()._isEqual(other);
     100           4 :         }
     101             : 
     102             :         /// implement the polymorphic comparison operation
     103          25 :         self_type* cloneImpl() const override { return self()._cloneImpl(); }
     104             : 
     105             :         derived_type& self() { return static_cast<derived_type&>(*this); }
     106      106200 :         const derived_type& self() const { return static_cast<const derived_type&>(*this); }
     107             : 
     108             :     private:
     109             :         void forward(const DataContainer<data_t>& x, DataContainer<data_t>& Ax,
     110             :                      ray_driven_tag) const
     111         928 :         {
     112             :             /// the bounding box of the volume
     113         928 :             const BoundingBox aabb(x.getDataDescriptor().getNumberOfCoefficientsPerDimension());
     114         928 :             auto& detectorDesc = downcast<DetectorDescriptor>(Ax.getDataDescriptor());
     115             : 
     116             :             // --> loop either over every voxel that should  updated or every detector
     117             :             // cell that should be calculated
     118         928 : #pragma omp parallel for
     119         928 :             for (index_t rangeIndex = 0; rangeIndex < Ax.getSize(); ++rangeIndex) {
     120             :                 // --> get the current ray to the detector center
     121           0 :                 const auto ray = detectorDesc.computeRayFromDetectorCoord(rangeIndex);
     122             : 
     123           0 :                 Ax[rangeIndex] = self().traverseRayForward(aabb, ray, x);
     124           0 :             }
     125         928 :         }
     126             : 
     127             :         void forward(const DataContainer<data_t>& x, DataContainer<data_t>& Ax,
     128             :                      voxel_driven_tag) const
     129             :         {
     130             :             auto& detectorDesc = downcast<DetectorDescriptor>(Ax.getDataDescriptor());
     131             : 
     132             :             for (index_t domainIndex = 0; domainIndex < x.getSize(); ++domainIndex) {
     133             :                 auto coord = x.getDataDescriptor().getCoordinateFromIndex(domainIndex);
     134             : 
     135             :                 // TODO: Maybe we need a different interface, but I need an implementation for that
     136             :                 self().forwardVoxel(coord, x[domainIndex], Ax);
     137             :             }
     138             :         }
     139             : 
     140             :         /// We can say nothing about it, so let it handle everything
     141             :         void forward(const DataContainer<data_t>& x, DataContainer<data_t>& Ax,
     142             :                      any_projection_tag) const
     143         210 :         {
     144         210 :             const BoundingBox aabb(x.getDataDescriptor().getNumberOfCoefficientsPerDimension());
     145         210 :             self().forward(aabb, x, Ax);
     146         210 :         }
     147             : 
     148             :         void backward(const DataContainer<data_t>& y, DataContainer<data_t>& Aty,
     149             :                       ray_driven_tag) const
     150          49 :         {
     151             :             /// the bounding box of the volume
     152          49 :             const BoundingBox aabb(Aty.getDataDescriptor().getNumberOfCoefficientsPerDimension());
     153          49 :             auto& detectorDesc = downcast<DetectorDescriptor>(y.getDataDescriptor());
     154             : 
     155             :             // Just to be sure, zero out the result
     156          49 :             Aty = 0;
     157             : 
     158             :             // --> loop either over every voxel that should  updated or every detector
     159             :             // cell that should be calculated
     160          49 : #pragma omp parallel for
     161          49 :             for (index_t rangeIndex = 0; rangeIndex < y.getSize(); ++rangeIndex) {
     162             :                 // --> get the current ray to the detector center
     163           0 :                 const auto ray = detectorDesc.computeRayFromDetectorCoord(rangeIndex);
     164             : 
     165           0 :                 self().traverseRayBackward(aabb, ray, y[rangeIndex], Aty);
     166           0 :             }
     167          49 :         }
     168             : 
     169             :         void backward(const DataContainer<data_t>& y, DataContainer<data_t>& Aty,
     170             :                       voxel_driven_tag) const
     171             :         {
     172             :             auto& detectorDesc = downcast<DetectorDescriptor>(y.getDataDescriptor());
     173             : 
     174             :             for (index_t domainIndex = 0; domainIndex < Aty.getSize(); ++domainIndex) {
     175             :                 auto coord = Aty.getDataDescriptor().getCoordinateFromIndex(domainIndex);
     176             : 
     177             :                 // TODO: Maybe we need a different interface, but I need an implementation for that
     178             :                 Aty[domainIndex] = self().backwardVoxel(coord, y);
     179             :             }
     180             :         }
     181             : 
     182             :         void backward(const DataContainer<data_t>& y, DataContainer<data_t>& Aty,
     183             :                       any_projection_tag) const
     184         144 :         {
     185         144 :             const BoundingBox aabb(Aty.getDataDescriptor().getNumberOfCoefficientsPerDimension());
     186         144 :             self().backward(aabb, y, Aty);
     187         144 :         }
     188             :     };
     189             : 
     190             : } // namespace elsa

Generated by: LCOV version 1.14