LCOV - code coverage report
Current view: top level - elsa/core/Utilities - StridedIterator.h (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 19 20 95.0 %
Date: 2024-05-15 03:55:36 Functions: 35 100 35.0 %

          Line data    Source code
       1             : #pragma once
       2             : #include <thrust/iterator/counting_iterator.h>
       3             : #include <thrust/iterator/transform_iterator.h>
       4             : #include <thrust/iterator/permutation_iterator.h>
       5             : #include <thrust/functional.h>
       6             : #include <thrust/fill.h>
       7             : #include <thrust/copy.h>
       8             : #include <iostream>
       9             : 
      10             : namespace elsa
      11             : {
      12             :     template <typename diff_t>
      13             :     struct DimData {
      14             :         diff_t shape{0};
      15             :         diff_t stride{0};
      16             :     };
      17             : 
      18             :     /// @brief Iterable range, in which elements traversed in canonical order
      19             :     /// @tparam Iterator
      20             :     /// @see elsa::NdView::is_canonical()
      21             :     template <typename Iterator>
      22             :     class StridedRange
      23             :     {
      24             :     public:
      25             :         typedef typename thrust::iterator_difference<Iterator>::type difference_type;
      26             : 
      27             :         struct StrideFunctor : public thrust::unary_function<difference_type, difference_type> {
      28             :             size_t _dimCount{0};
      29             :             DimData<difference_type>* _dimData{nullptr};
      30             : 
      31             :             StrideFunctor(DimData<difference_type>* dimData, size_t dimCount)
      32             :                 : _dimCount{dimCount}, _dimData{dimData}
      33        5636 :             {
      34        5636 :             }
      35             : 
      36             :             ~StrideFunctor() = default;
      37             : 
      38             :             __host__ __device__ difference_type operator()(const difference_type& i) const
      39       12935 :             {
      40             :                 /* special case for zero-dimensional volumes */
      41       12935 :                 if (_dimCount == 0)
      42           0 :                     return i;
      43             : 
      44       12935 :                 difference_type out = 0;
      45       12935 :                 difference_type alongDim = i;
      46       51689 :                 for (size_t dim = 0; dim < _dimCount; dim++) {
      47       38754 :                     out += (alongDim % _dimData[dim].shape) * _dimData[dim].stride;
      48       38754 :                     alongDim = alongDim / _dimData[dim].shape;
      49       38754 :                 }
      50       12935 :                 return out;
      51       12935 :             }
      52             :         };
      53             : 
      54             :         typedef typename thrust::counting_iterator<difference_type> CountingIterator;
      55             :         typedef typename thrust::transform_iterator<StrideFunctor, CountingIterator>
      56             :             TransformIterator;
      57             :         typedef typename thrust::permutation_iterator<Iterator, TransformIterator>
      58             :             PermutationIterator;
      59             : 
      60             :         typedef PermutationIterator iterator;
      61             : 
      62             :         /// @param first     iterator to the first element of the range
      63             :         /// @param count     number of elements in the range
      64             :         /// @param dimData   non-owning pointer to the dimensions and stride data.
      65             :         ///                  Make sure the strided range does outlive this pointer!
      66             :         ///                  If the Iterator type is a device iterator, dimData must be
      67             :         ///                  accessible on the device!
      68             :         /// @param dimCount  dimensionality of the data
      69             :         StridedRange(Iterator first, difference_type count, DimData<difference_type>* dimData,
      70             :                      size_t dimCount)
      71             :             : _first{first}, _count{count}, _dimCount{dimCount}, _dimData{dimData}
      72         501 :         {
      73         501 :         }
      74             : 
      75             :         iterator begin()
      76        5636 :         {
      77        5636 :             return PermutationIterator(
      78        5636 :                 _first, TransformIterator(CountingIterator(0), StrideFunctor(_dimData, _dimCount)));
      79        5636 :         }
      80             : 
      81        5117 :         iterator end() { return begin() + _count; }
      82             : 
      83             :         iterator begin() const
      84             :         {
      85             :             return PermutationIterator(
      86             :                 _first, TransformIterator(CountingIterator(0), StrideFunctor(_dimData, _dimCount)));
      87             :         }
      88             : 
      89             :         iterator end() const { return begin() + _count; }
      90             : 
      91             :     protected:
      92             :         Iterator _first;
      93             :         difference_type _count;
      94             :         size_t _dimCount;
      95             :         DimData<difference_type>* _dimData;
      96             :     };
      97             : } // namespace elsa

Generated by: LCOV version 1.14