Line data Source code
1 : #pragma once 2 : #include <thrust/iterator/counting_iterator.h> 3 : #include <thrust/iterator/transform_iterator.h> 4 : #include <thrust/iterator/permutation_iterator.h> 5 : #include <thrust/functional.h> 6 : #include <thrust/fill.h> 7 : #include <thrust/copy.h> 8 : #include <iostream> 9 : 10 : namespace elsa 11 : { 12 : template <typename diff_t> 13 : struct DimData { 14 : diff_t shape{0}; 15 : diff_t stride{0}; 16 : }; 17 : 18 : /// @brief Iterable range, in which elements traversed in canonical order 19 : /// @tparam Iterator 20 : /// @see elsa::NdView::is_canonical() 21 : template <typename Iterator> 22 : class StridedRange 23 : { 24 : public: 25 : typedef typename thrust::iterator_difference<Iterator>::type difference_type; 26 : 27 : struct StrideFunctor : public thrust::unary_function<difference_type, difference_type> { 28 : size_t _dimCount{0}; 29 : DimData<difference_type>* _dimData{nullptr}; 30 : 31 : StrideFunctor(DimData<difference_type>* dimData, size_t dimCount) 32 : : _dimCount{dimCount}, _dimData{dimData} 33 5636 : { 34 5636 : } 35 : 36 : ~StrideFunctor() = default; 37 : 38 : __host__ __device__ difference_type operator()(const difference_type& i) const 39 12934 : { 40 : /* special case for zero-dimensional volumes */ 41 12934 : if (_dimCount == 0) 42 0 : return i; 43 : 44 12934 : difference_type out = 0; 45 12934 : difference_type alongDim = i; 46 51688 : for (size_t dim = 0; dim < _dimCount; dim++) { 47 38754 : out += (alongDim % _dimData[dim].shape) * _dimData[dim].stride; 48 38754 : alongDim = alongDim / _dimData[dim].shape; 49 38754 : } 50 12934 : return out; 51 12934 : } 52 : }; 53 : 54 : typedef typename thrust::counting_iterator<difference_type> CountingIterator; 55 : typedef typename thrust::transform_iterator<StrideFunctor, CountingIterator> 56 : TransformIterator; 57 : typedef typename thrust::permutation_iterator<Iterator, TransformIterator> 58 : PermutationIterator; 59 : 60 : typedef PermutationIterator iterator; 61 : 62 : /// @param first iterator to the first element of the range 63 : /// @param count number of elements in the range 64 : /// @param dimData non-owning pointer to the dimensions and stride data. 65 : /// Make sure the strided range does outlive this pointer! 66 : /// If the Iterator type is a device iterator, dimData must be 67 : /// accessible on the device! 68 : /// @param dimCount dimensionality of the data 69 : StridedRange(Iterator first, difference_type count, DimData<difference_type>* dimData, 70 : size_t dimCount) 71 : : _first{first}, _count{count}, _dimCount{dimCount}, _dimData{dimData} 72 501 : { 73 501 : } 74 : 75 : iterator begin() 76 5636 : { 77 5636 : return PermutationIterator( 78 5636 : _first, TransformIterator(CountingIterator(0), StrideFunctor(_dimData, _dimCount))); 79 5636 : } 80 : 81 5117 : iterator end() { return begin() + _count; } 82 : 83 : iterator begin() const 84 : { 85 : return PermutationIterator( 86 : _first, TransformIterator(CountingIterator(0), StrideFunctor(_dimData, _dimCount))); 87 : } 88 : 89 : iterator end() const { return begin() + _count; } 90 : 91 : protected: 92 : Iterator _first; 93 : difference_type _count; 94 : size_t _dimCount; 95 : DimData<difference_type>* _dimData; 96 : }; 97 : } // namespace elsa