LCOV - code coverage report
Current view: top level - elsa/core - NdView.h (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 363 379 95.8 %
Date: 2025-01-02 06:42:49 Functions: 351 580 60.5 %

          Line data    Source code
       1             : #pragma once
       2             : 
       3             : #include <thrust/complex.h>
       4             : #include "ContiguousStorage.h"
       5             : #include "Utilities/StridedIterator.h"
       6             : #include "elsaDefines.h"
       7             : 
       8             : DISABLE_WARNING_PUSH
       9             : DISABLE_WARNING_SIGN_CONVERSION
      10             : #include <thrust/universal_vector.h>
      11             : #include <thrust/device_ptr.h>
      12             : #include <thrust/functional.h>
      13             : #include <thrust/transform.h>
      14             : #include <thrust/for_each.h>
      15             : #include <thrust/iterator/constant_iterator.h>
      16             : #include <thrust/iterator/zip_iterator.h>
      17             : DISABLE_WARNING_POP
      18             : 
      19             : #include <memory>
      20             : #include <type_traits>
      21             : #include <cinttypes>
      22             : #include <functional>
      23             : #include <exception>
      24             : 
      25             : namespace elsa
      26             : {
      27             :     class NdViewDimError : public std::runtime_error
      28             :     {
      29             :     public:
      30           2 :         NdViewDimError() : std::runtime_error("NdView dimensions do not match") {}
      31             :     };
      32             :     class NdViewUnsupportedShape : public std::runtime_error
      33             :     {
      34             :     public:
      35             :         NdViewUnsupportedShape()
      36             :             : std::runtime_error("NdView one or more dimensions are not positive")
      37           1 :         {
      38           1 :         }
      39             :     };
      40             :     class NdViewDimOutOfBounds : public std::runtime_error
      41             :     {
      42             :     public:
      43           2 :         NdViewDimOutOfBounds() : std::runtime_error("NdView dimension out-of-bounds") {}
      44             :     };
      45             :     class NdViewIndexOutOfBounds : public std::runtime_error
      46             :     {
      47             :     public:
      48             :         NdViewIndexOutOfBounds() : std::runtime_error("NdView index into dimension out-of-bounds")
      49           3 :         {
      50           3 :         }
      51             :     };
      52             :     class NdViewUnavailableIteration : public std::runtime_error
      53             :     {
      54             :     public:
      55             :         NdViewUnavailableIteration()
      56             :             : std::runtime_error("NdView iteration range not available for current data")
      57           2 :         {
      58           2 :         }
      59             :     };
      60             :     class NdViewEmptyView : public std::runtime_error
      61             :     {
      62             :     public:
      63           0 :         NdViewEmptyView() : std::runtime_error("NdView is empty") {}
      64             :     };
      65             : 
      66             :     template <typename data_t, mr::StorageType indexed_tag, mr::StorageType index_tag>
      67             :     class BoolIndexedView;
      68             : 
      69             :     template <class data_t, mr::StorageType tag>
      70             :     class NdViewTagged;
      71             : 
      72             :     namespace detail
      73             :     {
      74             :         template <typename View, typename Functor>
      75             :         auto with_canonical_range(View& view, Functor&& f)
      76             :             -> decltype(std::declval<Functor>()(std::declval<View>().canonical_range()))
      77         111 :         {
      78         111 :             static_assert(std::is_same_v<decltype(std::declval<Functor>()(view.canonical_range())),
      79         111 :                                          decltype(std::declval<Functor>()(view.range()))>,
      80         111 :                           "Functor return value differs, depending on input iterator range "
      81         111 :                           "type!");
      82         111 :             if (view.is_canonical()) {
      83          57 :                 return f(view.canonical_range());
      84          57 :             } else {
      85          54 :                 return f(view.range());
      86          54 :             }
      87         111 :         }
      88             : 
      89             :         /* throws NdViewDimError if the dimension coefficions do not match */
      90             :         bool are_strides_compatible(const IndexVector_t& shape1, const IndexVector_t& strides1,
      91             :                                     const IndexVector_t& shape2, const IndexVector_t& strides2);
      92             : 
      93             :         template <typename size_type>
      94             :         std::pair<IndexVector_t, IndexVector_t>
      95             :             unpack_dim_data(const ContiguousStorage<DimData<size_type>>& shape)
      96             :         {
      97             :             size_t dim_count = shape.size();
      98             :             IndexVector_t out_shape(dim_count);
      99             :             IndexVector_t out_strides(dim_count);
     100             :             for (size_t i = 0; i < dim_count; i++) {
     101             :                 auto dim_data = shape[i];
     102             :                 out_shape(i) = dim_data.shape;
     103             :                 out_strides(i) = dim_data.stride;
     104             :             }
     105             :             return std::make_pair(out_shape, out_strides);
     106             :         }
     107             : 
     108             :         template <typename data_t, mr::StorageType tag>
     109             :         NdViewTagged<data_t, tag> create_uninitialized_owned_view(const IndexVector_t& shape,
     110             :                                                                   const IndexVector_t& strides,
     111             :                                                                   size_t allocation_size)
     112         159 :         {
     113         159 :             auto memres = mr::defaultResource();
     114         159 :             if (std::numeric_limits<size_t>::max() / sizeof(data_t) < allocation_size) {
     115           0 :                 throw std::runtime_error("Overflowing multiplication");
     116           0 :             }
     117         159 :             size_t out_data_size = sizeof(data_t) * allocation_size;
     118         159 :             void* out_data = memres->allocate(out_data_size, alignof(data_t));
     119         159 :             return NdViewTagged<data_t, mr::sysStorageType>(
     120         159 :                 reinterpret_cast<data_t*>(out_data), shape, strides,
     121         159 :                 [=]() { memres->deallocate(out_data, out_data_size, alignof(data_t)); });
     122         159 :         }
     123             : 
     124             :         template <typename data_t, typename Functor>
     125             :         struct binop_to_unop
     126             :             : public thrust::unary_function<data_t,
     127             :                                             decltype(std::declval<Functor>()(
     128             :                                                 std::declval<data_t>(), std::declval<data_t>()))> {
     129             :             data_t other;
     130             :             Functor f;
     131             : 
     132             :             binop_to_unop(data_t other, Functor f) : other{other}, f{f} {}
     133             : 
     134             :             __host__ __device__ auto operator()(data_t e) const -> decltype(f(e, other))
     135             :             {
     136             :                 return f(e, other);
     137             :             }
     138             :         };
     139             : 
     140             :     } // namespace detail
     141             : 
     142             :     /// @brief Represents a non-owning view. It can handle arbitrary strides (including negative).
     143             :     ///        Supports the creation of subviews. Supports iteration in canonical order with a
     144             :     ///        thrust::device compatible iterator, provided the storage type is device accessible.
     145             :     ///        Upon deletion, if no other NdView has a reference to the data, signals to the owner
     146             :     ///        of the data that it may be deleted via a destructor that is passed as
     147             :     ///        constructor parameter.
     148             :     ///        Additionally, NdViewTagged provides elementwise unary and binary operations and
     149             :     ///        filtered assignments.
     150             :     /// @tparam data_t type of the data that the NdView points to
     151             :     /// @tparam tag    storage type/location of the data; Unlike DataContainer, NdView supports
     152             :     ///                non-universal memory, if compiled with CUDA
     153             :     /// @see is_canonical()
     154             :     template <class data_t, mr::StorageType tag>
     155             :     class NdViewTagged
     156             :     {
     157             :     public:
     158             :         using pointer_type = std::conditional_t<
     159             :             tag == mr::StorageType::host, data_t*,
     160             :             std::conditional_t<tag == mr::StorageType::device, thrust::device_ptr<data_t>,
     161             :                                thrust::universal_ptr<data_t>>>;
     162             :         using const_pointer_type = std::conditional_t<
     163             :             tag == mr::StorageType::host, const data_t*,
     164             :             std::conditional_t<tag == mr::StorageType::device, thrust::device_ptr<const data_t>,
     165             :                                thrust::universal_ptr<const data_t>>>;
     166             : 
     167             :         using Scalar = data_t;
     168             :         using self_type = NdViewTagged<data_t, tag>;
     169             :         using value_type = data_t;
     170             :         using iterator = pointer_type;
     171             :         using const_iterator = const_pointer_type;
     172             :         using reference = data_t&;
     173             :         using const_reference = std::add_const_t<reference>;
     174             :         using dim_data_type = decltype(pointer_type() - pointer_type());
     175             :         using dim_data = DimData<dim_data_type>;
     176             : 
     177             :         struct Cleanup {
     178             :             std::function<void()> cleanup;
     179         402 :             Cleanup(std::function<void()>&& cleanup) : cleanup{cleanup} {}
     180         402 :             ~Cleanup() { cleanup(); }
     181             :         };
     182             : 
     183             :         template <class ItType>
     184             :         struct IteratorRange {
     185             :         private:
     186             :             ItType _begin = ItType();
     187             :             ItType _end = ItType();
     188             : 
     189             :         public:
     190         386 :             IteratorRange(const ItType& b, const ItType& e) : _begin(b), _end(e) {}
     191         385 :             ItType begin() const { return _begin; }
     192         173 :             ItType end() const { return _end; }
     193             :         };
     194             : 
     195             :     private:
     196             :         struct Container {
     197             :             std::shared_ptr<Cleanup> cleanup;
     198             :             pointer_type pointer = pointer_type();
     199             :             IndexVector_t shape;
     200             :             IndexVector_t strides;
     201             :             ContiguousStorage<dim_data> dims;
     202             :             pointer_type contiguous_start = pointer_type();
     203             :             size_t size{0};
     204             :             bool is_canonical = false;
     205             :             bool contiguous = false;
     206             :         };
     207             :         std::shared_ptr<Container> _container;
     208             : 
     209             :     public:
     210             :         /// @brief Create a view on raw (possibly non-contiguous) data
     211             :         /// @param cleanup shared encapsulated destructor; to be called, once this NdView
     212             :         ///                (and all of its parent or subviews) have been deleted
     213             :         NdViewTagged(data_t* raw_data, const IndexVector_t& shape, const IndexVector_t& strides,
     214             :                      const std::shared_ptr<Cleanup>& cleanup)
     215             :             : _container{std::make_shared<Container>()}
     216         416 :         {
     217         416 :             size_t dims = shape.size();
     218         416 :             if (static_cast<ssize_t>(dims) != strides.size())
     219           0 :                 throw NdViewDimError();
     220             : 
     221             :             /* check that all shapes are positive and if the overall container describes an
     222             :              *  empty volume (initialize size as 1 as zero-dimensional objects have this size) */
     223         416 :             _container->size = 1;
     224        1233 :             for (auto& x : shape) {
     225        1233 :                 if (x < 0)
     226           1 :                     throw NdViewUnsupportedShape();
     227        1232 :                 if (x == 0)
     228           0 :                     _container->size = 0;
     229        1232 :             }
     230         416 :             if (_container->size == 0)
     231           0 :                 dims = 0;
     232         415 :             _container->cleanup = cleanup;
     233         415 :             _container->pointer = pointer_type(raw_data);
     234         415 :             _container->shape = shape;
     235         415 :             _container->strides = strides;
     236         415 :             _container->dims = ContiguousStorage<dim_data>(static_cast<size_t>(dims));
     237         415 :             _container->contiguous_start = _container->pointer;
     238         415 :             _container->is_canonical = true;
     239         415 :             _container->contiguous = true;
     240             : 
     241             :             /* check if the container is empty or if it is zero-dimensional,
     242             :              * in which case the processing is done here */
     243         415 :             if (_container->size == 0)
     244           0 :                 return;
     245         415 :             if (dims == 0)
     246           1 :                 return;
     247             : 
     248             :             /* iterate over all dimensions and:
     249             :              *  - check if its a canonical layout
     250             :              *  - adjust the start-pointer to the actual first address
     251             :              *      to allow for contiugous iteration (in case of negative strides)
     252             :              *  - create the copy of the strides in order to determine if the range is contiguous */
     253         414 :             std::vector<dim_data> shapeAndStrides(dims);
     254        1646 :             for (size_t i = 0; i < dims; i++) {
     255        1232 :                 _container->dims[i] = {shape(i), strides(i)};
     256        1232 :                 shapeAndStrides[i] = {shape(i), std::abs(strides(i))};
     257             : 
     258             :                 /* check if the stride matches the canonical layout */
     259        1232 :                 if (static_cast<ssize_t>(_container->size) != _container->dims[i].stride)
     260         339 :                     _container->is_canonical = false;
     261             : 
     262             :                 /* check if the stride is negative and adjust the pointer accordingly */
     263        1232 :                 if (_container->dims[i].stride < 0) {
     264             :                     /* move the adjusted pointer back by the negative stides by its dimension */
     265           0 :                     _container->contiguous_start +=
     266           0 :                         _container->dims[i].stride * _container->dims[i].shape;
     267           0 :                 }
     268        1232 :                 _container->size *= shape(i);
     269        1232 :             }
     270             : 
     271             :             /* sort the shapes and strides by the strides and check if every next stride
     272             :              *  is equal to the previous stride times shape (in which case its contiguous) */
     273         414 :             std::sort(shapeAndStrides.begin(), shapeAndStrides.end(),
     274        1383 :                       [](auto& a, auto& b) { return a.stride < b.stride; });
     275             : 
     276             :             /* dims > 0 is ensured above */
     277         414 :             _container->contiguous = (shapeAndStrides[0].stride == 1);
     278        1184 :             for (size_t i = 1; i < dims && _container->contiguous; i++) {
     279         770 :                 _container->contiguous =
     280         770 :                     shapeAndStrides[i].stride
     281         770 :                     == shapeAndStrides[i - 1].stride * shapeAndStrides[i - 1].shape;
     282         770 :             }
     283             : 
     284             :             /* is_canonical will already be implicitly false, but this makes it clearer */
     285         414 :             if (!_container->contiguous)
     286          33 :                 _container->is_canonical = false;
     287         414 :         }
     288             : 
     289             :         /// @brief Create a view on raw (possibly non-contiguous) data
     290             :         /// @param cleanup destructor to be called, once this NdView
     291             :         ///                (and all of its parent or subviews) have been deleted
     292             :         NdViewTagged(data_t* raw_data, const IndexVector_t& shape, const IndexVector_t& strides,
     293             :                      std::function<void()> destructor)
     294             :             : NdViewTagged(raw_data, shape, strides,
     295             :                            std::make_shared<Cleanup>(std::move(destructor)))
     296         401 :         {
     297         401 :         }
     298             : 
     299             :         /// @brief Create an empty view
     300             :         NdViewTagged() : _container{std::make_shared<Container>()}
     301           1 :         {
     302           1 :             _container->size = 0;
     303           1 :             _container->is_canonical = true;
     304           1 :             _container->contiguous = true;
     305           1 :             _container->cleanup = std::make_shared<Cleanup>([]() {});
     306           1 :         }
     307             : 
     308             :         /// @return true iff the raw data is layed out contiguously
     309         290 :         bool is_contiguous() const { return _container->contiguous; }
     310             :         /// @return true iff the raw data follows the canonical layout
     311             :         /// Canonical layout is defined as follows:
     312             :         /// strides[0] = 1;
     313             :         /// strides[i] = strides[i - 1] * shape[i - 1];
     314             :         /// It is tempting to call this column major layout, but there is one
     315             :         /// caveat. In elsa, the first index refers to the column (i.e. x-coordinate)
     316             :         /// and the second refers to the row (i.e. y-coordinate).
     317         122 :         bool is_canonical() const { return _container->is_canonical; }
     318             : 
     319             :         /// @return true iff the view does not contain any data
     320             :         bool is_empty() const { return _container->size == 0; }
     321             : 
     322             :         /// @brief iterate over the raw data in any order, if the data
     323             :         /// is contiguous (iterators are pointer)
     324             :         /// Useful for reductions or transformations where the order of the
     325             :         /// data is not relevant (strides are ignored).
     326             :         IteratorRange<pointer_type> contiguous_range()
     327         155 :         {
     328         155 :             if (!_container->contiguous)
     329           1 :                 throw NdViewUnavailableIteration();
     330         154 :             return IteratorRange<pointer_type>(_container->contiguous_start,
     331         154 :                                                _container->contiguous_start + _container->size);
     332         154 :         }
     333             :         IteratorRange<const_pointer_type> contiguous_range() const
     334         169 :         {
     335         169 :             if (!_container->contiguous)
     336           0 :                 throw NdViewUnavailableIteration();
     337         169 :             return IteratorRange<const_pointer_type>(
     338         169 :                 _container->contiguous_start, _container->contiguous_start + _container->size);
     339         169 :         }
     340             : 
     341             :         /// @brief iterate over the raw data in canonical order, if the data
     342             :         /// is layed out in canonical layout (iterators are pointer).
     343             :         /// @see is_canonical()
     344             :         IteratorRange<pointer_type> canonical_range()
     345          16 :         {
     346          16 :             if (!_container->is_canonical)
     347           1 :                 throw NdViewUnavailableIteration();
     348          15 :             return IteratorRange<pointer_type>(_container->contiguous_start,
     349          15 :                                                _container->contiguous_start + _container->size);
     350          15 :         }
     351             :         IteratorRange<const_pointer_type> canonical_range() const
     352          48 :         {
     353          48 :             if (!_container->is_canonical)
     354           0 :                 throw NdViewUnavailableIteration();
     355          48 :             return IteratorRange<const_pointer_type>(
     356          48 :                 _container->contiguous_start, _container->contiguous_start + _container->size);
     357          48 :         }
     358             : 
     359             :         /// @brief iterate over the data in canonical order, regardless of its real layout
     360             :         StridedRange<pointer_type> range()
     361         353 :         {
     362         353 :             return StridedRange<pointer_type>(_container->pointer, _container->size,
     363         353 :                                               _container->dims.data().get(),
     364         353 :                                               _container->dims.size());
     365         353 :         }
     366             :         StridedRange<const_pointer_type> range() const
     367         132 :         {
     368         132 :             return StridedRange<const_pointer_type>(_container->pointer, _container->size,
     369         132 :                                                     _container->dims.data().get(),
     370         132 :                                                     _container->dims.size());
     371         132 :         }
     372             : 
     373             :     private:
     374             :         template <class...>
     375             :         bool UnpackAndCheckBounds(size_t index) const
     376          81 :         {
     377          81 :             static_cast<void>(index);
     378          81 :             return true;
     379          81 :         }
     380             : 
     381             :         template <class Index, class... Indices>
     382             :         bool UnpackAndCheckBounds(size_t current, Index index, Indices... indices) const
     383         245 :         {
     384         245 :             if (index < 0 || index >= _container->shape(current))
     385           1 :                 return false;
     386         244 :             return UnpackAndCheckBounds<Indices...>(current + 1, indices...);
     387         244 :         }
     388             : 
     389             :         template <class...>
     390             :         ssize_t UnpackAndComputeIndex(size_t index) const
     391          81 :         {
     392          81 :             static_cast<void>(index);
     393          81 :             return 0;
     394          81 :         }
     395             : 
     396             :         template <class Index, class... Indices>
     397             :         ssize_t UnpackAndComputeIndex(size_t current, Index index, Indices... indices) const
     398         243 :         {
     399         243 :             return (_container->strides(current) * index)
     400         243 :                    + UnpackAndComputeIndex<Indices...>(current + 1, indices...);
     401         243 :         }
     402             : 
     403             :         /// Performs an element-wise binary operation. The result is returned in
     404             :         /// a view, which owns a newly allocated buffer for the output. The strides
     405             :         /// of the output may not match the strides of either input.
     406             :         template <mr::StorageType other_tag, typename Functor>
     407             :         NdViewTagged<decltype(std::declval<Functor>()(std::declval<data_t>(),
     408             :                                                       std::declval<data_t>())),
     409             :                      mr::sysStorageType>
     410             :             binop(const NdViewTagged<data_t, other_tag>& other, Functor functor) const
     411          87 :         {
     412          87 :             static_assert(mr::are_storages_compatible<tag, other_tag>::value,
     413          87 :                           "Binary operations are only possible when both arguments live on "
     414          87 :                           "compatible devices");
     415          87 :             using output_type =
     416          87 :                 decltype(std::declval<Functor>()(std::declval<data_t>(), std::declval<data_t>()));
     417          87 :             static_assert(
     418          87 :                 std::is_trivially_copyable_v<output_type>,
     419          87 :                 "Binary operations are only implemented for trivially copyable result types");
     420             : 
     421          87 :             const auto& shape = this->shape();
     422          87 :             const auto& strides = this->strides();
     423          87 :             const auto& other_shape = other.shape();
     424          87 :             const auto& other_strides = other.strides();
     425             : 
     426          87 :             size_t element_count = this->size();
     427          87 :             bool strides_compatible =
     428          87 :                 detail::are_strides_compatible(shape, strides, other_shape, other_strides);
     429             : 
     430          87 :             IndexVector_t out_strides;
     431             : 
     432          87 :             bool same_layout = strides_compatible && this->is_contiguous() && other.is_contiguous();
     433          87 :             if (!same_layout) {
     434             :                 /* layouts do not match or are non-contiguous, replace out strides with column major
     435             :                  * layout */
     436          30 :                 size_t stride = 1;
     437          30 :                 size_t dim_count = shape.size();
     438          30 :                 out_strides = IndexVector_t(dim_count);
     439         120 :                 for (size_t i = 0; i < dim_count; i++) {
     440          90 :                     out_strides(i) = stride;
     441          90 :                     stride *= shape(i);
     442          90 :                 }
     443          57 :             } else {
     444          57 :                 out_strides = strides;
     445          57 :             }
     446             : 
     447             :             /* No initialization is performed on the buffer before assigning its contents with
     448             :              * thrust::transform, so this only works for trivially copyable types. To extend it for
     449             :              * other types, one could use thrust::for_each with a functor that relies on placement
     450             :              * new to assign the computed values. */
     451          87 :             auto out = detail::create_uninitialized_owned_view<output_type, mr::sysStorageType>(
     452          87 :                 shape, out_strides, element_count);
     453             : 
     454          87 :             auto out_range = out.contiguous_range();
     455          87 :             if (same_layout) {
     456          57 :                 auto left_range = this->contiguous_range();
     457          57 :                 auto right_range = other.contiguous_range();
     458          57 :                 thrust::transform(left_range.begin(), left_range.end(), right_range.begin(),
     459          57 :                                   out_range.begin(), functor);
     460          57 :             } else {
     461          30 :                 this->with_canonical_crange([&](auto left_range) mutable {
     462          30 :                     other.with_canonical_crange([&](auto right_range) mutable {
     463          30 :                         thrust::transform(left_range.begin(), left_range.end(), right_range.begin(),
     464          30 :                                           out_range.begin(), functor);
     465          30 :                     });
     466          30 :                 });
     467          30 :             }
     468          87 :             return out;
     469          87 :         }
     470             : 
     471             :         /// Performs an element-wise unary operation. The result is returned in a view, which owns a
     472             :         /// newly allocated buffer for the output. The strides of the output may not match the
     473             :         /// strides of the input.
     474             :         template <typename Functor>
     475             :         NdViewTagged<decltype(std::declval<Functor>()(std::declval<data_t>())), mr::sysStorageType>
     476             :             unop(Functor functor) const
     477          72 :         {
     478          72 :             using output_type = decltype(std::declval<Functor>()(std::declval<data_t>()));
     479          72 :             static_assert(
     480          72 :                 std::is_trivially_copyable_v<output_type>,
     481          72 :                 "Binary operations are only implemented for trivially copyable result types");
     482             : 
     483          72 :             auto& shape = this->shape();
     484          72 :             IndexVector_t out_strides;
     485          72 :             size_t element_count = this->size();
     486             : 
     487          72 :             if (!is_contiguous()) {
     488             :                 /* layout is non-contiguous, replace out strides with column major layout */
     489          22 :                 size_t stride = 1;
     490          22 :                 size_t dim_count = shape.size();
     491          22 :                 out_strides = IndexVector_t(dim_count);
     492          88 :                 for (size_t i = 0; i < dim_count; i++) {
     493          66 :                     out_strides(i) = stride;
     494          66 :                     stride *= shape(i);
     495          66 :                 }
     496          50 :             } else {
     497          50 :                 out_strides = this->strides();
     498          50 :             }
     499             : 
     500          72 :             auto out = detail::create_uninitialized_owned_view<output_type, mr::sysStorageType>(
     501          72 :                 shape, out_strides, element_count);
     502             : 
     503          72 :             auto compute_functor = [=](auto src_range, auto dst_range) {
     504          72 :                 thrust::transform(src_range.begin(), src_range.end(), dst_range.begin(), functor);
     505          72 :             };
     506             : 
     507          72 :             if (is_contiguous()) {
     508          50 :                 auto src_range = contiguous_range();
     509          50 :                 auto dst_range = out.contiguous_range();
     510          50 :                 compute_functor(src_range, dst_range);
     511          50 :             } else {
     512          22 :                 auto src_range = range();
     513          22 :                 auto dst_range = out.range();
     514          22 :                 compute_functor(src_range, dst_range);
     515          22 :             }
     516          72 :             return out;
     517          72 :         }
     518             : 
     519             :     public:
     520             :         template <class... Indices>
     521             :         const data_t& operator()(Indices... index) const
     522             :         {
     523             :             return const_cast<self_type*>(this)->operator()(index...);
     524             :         }
     525             : 
     526             :         /// @brief extract a single element; Indices for all dimensions must be supplied
     527             :         template <class... Indices>
     528             :         data_t& operator()(Indices... index)
     529          84 :         {
     530          84 :             if (sizeof...(Indices) != _container->shape.size())
     531           2 :                 throw NdViewDimError();
     532          82 :             if (_container->size == 0)
     533           0 :                 throw NdViewEmptyView();
     534             : 
     535          82 :             if constexpr (sizeof...(Indices) == 0)
     536           0 :                 return _container->pointer[0];
     537          82 :             else {
     538          82 :                 if (!UnpackAndCheckBounds<Indices...>(0, index...))
     539           1 :                     throw NdViewIndexOutOfBounds();
     540          81 :                 return _container->pointer[UnpackAndComputeIndex<Indices...>(0, index...)];
     541          81 :             }
     542          82 :         }
     543             : 
     544             :         /// @brief returned NdView has its dimensionality reduce by one by
     545             :         ///        selecting a point along one dimension to bind to a set value
     546             :         /// @param dim dimension to fix
     547             :         /// @param where the index of the slice along the dimension
     548             :         self_type fix(size_t dim, size_t where)
     549           9 :         {
     550           9 :             if (dim >= static_cast<size_t>(_container->shape.size()))
     551           1 :                 throw NdViewDimOutOfBounds();
     552           8 :             if (static_cast<dim_data_type>(where) >= _container->shape(dim))
     553           1 :                 throw NdViewIndexOutOfBounds();
     554             : 
     555             :             /* allocate the new strides */
     556           7 :             IndexVector_t shape(_container->shape.size() - 1);
     557           7 :             IndexVector_t strides(_container->strides.size() - 1);
     558          18 :             for (index_t i = 0; i < _container->shape.size() - 1; ++i) {
     559          11 :                 shape(i) = _container->shape(i + (i >= dim ? 1 : 0));
     560          11 :                 strides(i) = _container->strides(i + (i >= dim ? 1 : 0));
     561          11 :             }
     562             : 
     563             :             /* adjust the pointer to the offset */
     564           7 :             data_t* raw =
     565           7 :                 thrust::raw_pointer_cast(_container->pointer) + _container->strides(dim) * where;
     566           7 :             return self_type(raw, shape, strides, _container->cleanup);
     567           7 :         }
     568             : 
     569             :         /// @brief returned NdView has same dimensionality but a shape of less or
     570             :         ///        equal to the original along the corresponding dimension
     571             :         /// @param dim         dimension along which to take a sub-range
     572             :         /// @param where_begin lowest index along dimension dim (inclusive)
     573             :         /// @param where_end   highest index along dimension dim (exclusive);
     574             :         ///                    where_begin <= where_end must hold!
     575             :         self_type slice(size_t dim, size_t where_begin, size_t where_end)
     576          10 :         {
     577          10 :             if (dim >= _container->shape.size())
     578           1 :                 throw NdViewDimOutOfBounds();
     579           9 :             if (where_begin >= static_cast<size_t>(_container->shape(dim))
     580           9 :                 || where_end >= static_cast<size_t>(_container->shape(dim)))
     581           1 :                 throw NdViewIndexOutOfBounds();
     582           8 :             if (where_begin == where_end)
     583           0 :                 throw NdViewUnsupportedShape();
     584             : 
     585           8 :             data_t* raw = thrust::raw_pointer_cast(_container->pointer);
     586           8 :             size_t dims = _container->shape.size();
     587           8 :             IndexVector_t shape(dims);
     588           8 :             IndexVector_t strides(dims);
     589          32 :             for (size_t i = 0; i < dims; i++) {
     590          24 :                 shape(i) = _container->shape(i);
     591          24 :                 strides(i) = _container->strides(i);
     592          24 :             }
     593             : 
     594             :             /* limit the size along dimension dim */
     595           8 :             shape(dim) = where_end - where_begin;
     596             :             /* adjust start to skip all */
     597           8 :             raw += where_begin * strides(dim);
     598             : 
     599           8 :             return self_type(raw, shape, strides, _container->cleanup);
     600           8 :         }
     601             : 
     602             :         /// @return the shape of this NdView
     603         417 :         const IndexVector_t& shape() const { return _container->shape; }
     604             : 
     605             :         /// @return the strides of this NdView
     606         368 :         const IndexVector_t& strides() const { return _container->strides; }
     607             : 
     608             :         /// @return the shape and strides of this NdView; Stored in memory of type
     609             :         /// mr::sysStorageType, i.e. they are device accessible if compiled with CUDA
     610             :         const ContiguousStorage<dim_data>& layout_data() const { return _container->dims; }
     611             : 
     612             :         /// @return the cleanup sentinel; once its last reference is dropped, the destructor is
     613             :         /// called
     614           0 :         std::shared_ptr<Cleanup> getCleanup() { return _container->cleanup; }
     615             : 
     616             :         /// @return the number of elements in this view
     617         971 :         size_t size() const { return _container->size; }
     618             : 
     619             :         /// @brief Calls the functor with the lowest overhead iterator range that guarantees
     620             :         /// canonical iteration order. I.e. if the data is naturarlly layed out in canonical
     621             :         /// order, the iterators will be pointers.
     622             :         /// @param f functor to call with an iterator range object with methods .begin() and .end()
     623             :         template <typename Functor>
     624             :         auto with_canonical_range(Functor&& f)
     625             :             -> decltype(std::declval<Functor>()(this->canonical_range()))
     626          25 :         {
     627          25 :             detail::with_canonical_range(*this, std::forward<Functor>(f));
     628          25 :         }
     629             : 
     630             :         /// @brief Const version of with_canonical_range()
     631             :         /// @see with_canonical_range()
     632             :         template <typename Functor>
     633             :         auto with_canonical_crange(Functor&& f) const
     634             :             -> decltype(std::declval<Functor>()(this->canonical_range()))
     635          86 :         {
     636          86 :             detail::with_canonical_range(*this, std::forward<Functor>(f));
     637          86 :         }
     638             : 
     639             :         /// @brief Calls the functor with the lowest overhead iterator range available.
     640             :         /// I.e. if the data is naturarlly layed out in contiguously, the iterators will be
     641             :         /// pointers.
     642             :         /// @param f functor to call with an iterator range object with methods .begin() and .end()
     643             :         template <typename Functor>
     644             :         auto with_unordered_range(Functor f)
     645             :             -> decltype(std::declval<Functor>(this->contiguous_range()))
     646             :         {
     647             :             static_assert(std::is_same_v<decltype(std::declval<Functor>(this->contiguous_range())),
     648             :                                          decltype(std::declval<Functor>(this->range()))>,
     649             :                           "Functor return value differs, depending on input iterator range "
     650             :                           "type!");
     651             :             if (is_contiguous()) {
     652             :                 return f(contiguous_range());
     653             :             } else {
     654             :                 return f(range());
     655             :             }
     656             :         }
     657             : 
     658             : #define NDVIEW_BINOP(op)                                                                    \
     659             :     template <mr::StorageType other_tag>                                                    \
     660             :     auto operator op(const NdViewTagged<data_t, other_tag>& other)                          \
     661             :         const->decltype(binop(other, thrust::placeholders::_1 op thrust::placeholders::_2)) \
     662          87 :     {                                                                                       \
     663          87 :         return binop(other, thrust::placeholders::_1 op thrust::placeholders::_2);          \
     664          87 :     }                                                                                       \
     665             :                                                                                             \
     666             :     template <typename data2_t = data_t>                                                    \
     667             :     friend auto operator op(const NdViewTagged<data_t, tag>& ndview, data2_t other)         \
     668             :         ->typename std::enable_if<                                                          \
     669             :             std::is_same_v<data_t, data2_t>,                                                \
     670             :             NdViewTagged<decltype(std::declval<data_t>() op std::declval<data2_t>()),       \
     671             :                          mr::sysStorageType>>::type                                         \
     672          72 :     {                                                                                       \
     673          72 :         return ndview.unop(thrust::placeholders::_1 op other);                              \
     674          72 :     }                                                                                       \
     675             :                                                                                             \
     676             :     template <typename data2_t = data_t>                                                    \
     677             :     friend auto operator op(data2_t other, const NdViewTagged<data_t, tag>& ndview)         \
     678             :         ->typename std::enable_if<                                                          \
     679             :             std::is_same_v<data_t, data2_t>,                                                \
     680             :             NdViewTagged<decltype(std::declval<data2_t>() op std::declval<data_t>()),       \
     681             :                          mr::sysStorageType>>::type                                         \
     682             :     {                                                                                       \
     683             :         return ndview.unop(other op thrust::placeholders::_1);                              \
     684             :     }
     685             : 
     686             :         /* Elementwise binary operations. The result of the operation
     687             :          * is eagerly evaluated and returned in a new NdView that owns
     688             :          * the result.
     689             :          * Note assignment operators are not implemented, as their desired
     690             :          * semantics are unclear. One might expect that x *= 2 doubles
     691             :          * all elements of x in the underlying data that x is a view of.
     692             :          * However, then x = x * 2 would behave differently form x *= 2,
     693             :          * as x * 2 allocates a new buffer.
     694             :          **/
     695             :         NDVIEW_BINOP(+);
     696             :         NDVIEW_BINOP(-);
     697             :         NDVIEW_BINOP(*);
     698             :         NDVIEW_BINOP(/);
     699             :         NDVIEW_BINOP(%);
     700             :         NDVIEW_BINOP(>);
     701             :         NDVIEW_BINOP(<);
     702             :         NDVIEW_BINOP(>=);
     703             :         NDVIEW_BINOP(<=);
     704             :         NDVIEW_BINOP(==);
     705             :         NDVIEW_BINOP(!=);
     706             :         NDVIEW_BINOP(&);
     707             :         NDVIEW_BINOP(|);
     708             :         NDVIEW_BINOP(^);
     709             :         NDVIEW_BINOP(&&);
     710             :         NDVIEW_BINOP(||);
     711             : 
     712             : #undef NDVIEW_BINOP
     713             : 
     714             : #define NDVIEW_UNOP(op)                                             \
     715             :     auto operator op() const->decltype(op thrust::placeholders::_1) \
     716             :     {                                                               \
     717             :         return unop(op thrust::placeholders::_1);                   \
     718             :     }
     719             : 
     720             :         NDVIEW_UNOP(-);
     721             :         NDVIEW_UNOP(!);
     722             : 
     723             :         /// Creates a left hand side, to be used for filtered assignments.
     724             :         /// The index parameter must be an NdView of equal dimensions to this.
     725             :         /// The returned object can be assigned to, replacing all entries in this
     726             :         /// view, whose corresponding index element is true. Indices whose filter-
     727             :         /// element is false remain unchanged.
     728             :         /// If the index tensor does not have the correct dimensions, no exception
     729             :         /// is thrown until the actual assignment occurs.
     730             :         ///
     731             :         /// Example:
     732             :         /// ```
     733             :         /// template<typename T>
     734             :         /// void zero_value_range(NdViewTagged<float, T> &x, float lb, float ub) {
     735             :         ///     x[x >= lb && x < ub] = 0.0f;
     736             :         /// }
     737             :         /// ```
     738             :         /// This example function sets all elements in the value range lb <= x < ub
     739             :         /// to zero.
     740             :         /// @see BoolIndexedView
     741             :         template <mr::StorageType other_tag>
     742             :         BoolIndexedView<data_t, tag, other_tag>
     743             :             operator[](const NdViewTagged<bool, other_tag>& index)
     744          27 :         {
     745          27 :             return BoolIndexedView(*this, index);
     746          27 :         }
     747             :     };
     748             : 
     749             :     namespace detail
     750             :     {
     751             :         template <typename data_t>
     752             :         struct fill_const_if_functor
     753             :             : public thrust::unary_function<thrust::tuple<data_t&, bool>, void> {
     754             :             data_t fill_value;
     755             : 
     756           9 :             fill_const_if_functor(const data_t& fill_value) : fill_value{fill_value} {}
     757             : 
     758             :             __host__ __device__ void operator()(thrust::tuple<data_t&, bool> args)
     759         243 :             {
     760         243 :                 if (thrust::get<1>(args)) {
     761          81 :                     thrust::get<0>(args) = fill_value;
     762          81 :                 }
     763         243 :             }
     764             :         };
     765             : 
     766             :         template <typename data_t>
     767             :         struct fill_if_functor : public thrust::unary_function<thrust::tuple<data_t&, bool>, void> {
     768             :             __host__ __device__ void operator()(thrust::tuple<data_t&, bool, data_t> args)
     769         486 :             {
     770         486 :                 if (thrust::get<1>(args)) {
     771         162 :                     thrust::get<0>(args) = thrust::get<2>(args);
     772         162 :                 }
     773         486 :             }
     774             :         };
     775             : 
     776             :         template <mr::StorageType... tags>
     777             :         struct are_tags_host_compatible {
     778             :             static constexpr bool value = false;
     779             :         };
     780             : 
     781             :         template <mr::StorageType... tags>
     782             :         struct are_tags_host_compatible<mr::StorageType::universal, tags...> {
     783             :             static constexpr bool value = are_tags_host_compatible<tags...>::value;
     784             :         };
     785             : 
     786             :         template <mr::StorageType... tags>
     787             :         struct are_tags_host_compatible<mr::StorageType::host, tags...> {
     788             :             static constexpr bool value = are_tags_host_compatible<tags...>::value;
     789             :         };
     790             : 
     791             :         template <>
     792             :         struct are_tags_host_compatible<> {
     793             :             static constexpr bool value = true;
     794             :         };
     795             : 
     796             :         template <mr::StorageType... tags>
     797             :         struct select_policy {
     798             :         };
     799             : 
     800             :         template <>
     801             :         struct select_policy<> {
     802             :             static constexpr decltype(thrust::device) policy = thrust::device;
     803             :         };
     804             : 
     805             :         template <mr::StorageType... tags>
     806             :         struct select_policy<mr::StorageType::universal, tags...> {
     807             :             static constexpr decltype(select_policy<tags...>::policy) policy =
     808             :                 select_policy<tags...>::policy;
     809             :         };
     810             : 
     811             :         template <mr::StorageType... tags>
     812             :         struct select_policy<mr::StorageType::host, tags...> {
     813             :             static_assert(are_tags_host_compatible<tags...>::value,
     814             :                           "Cannot select policy, as storage tags are incompatible");
     815             :             static constexpr decltype(thrust::host) policy = thrust::host;
     816             :         };
     817             :     } // namespace detail
     818             : 
     819             :     /// @brief NdView using the system storage type
     820             :     /// @tparam data_t type of the data that the NdView points to
     821             :     /// @see elsa::mr::sysStorageType
     822             :     template <class data_t>
     823             :     using NdView = NdViewTagged<data_t, mr::sysStorageType>;
     824             : 
     825             :     /// @brief Class representing a view with a boolean index tensor applied. Objects of this
     826             :     ///        class are intended to be assigned to, which overwrites the element whose filter is
     827             :     ///        true.
     828             :     template <typename data_t, mr::StorageType indexed_tag, mr::StorageType index_tag>
     829             :     class BoolIndexedView
     830             :     {
     831             :         NdViewTagged<data_t, indexed_tag> _indexed;
     832             :         NdViewTagged<bool, index_tag> _index;
     833             : 
     834             :     public:
     835             :         BoolIndexedView(const NdViewTagged<data_t, indexed_tag>& indexed,
     836             :                         NdViewTagged<bool, index_tag> index)
     837             :             : _indexed{indexed}, _index{index}
     838          27 :         {
     839          27 :         }
     840             : 
     841             :         /// @brief Assign the elements of rhs to the indexed tensor wherever the corresponding
     842             :         ///        element of the index tensor is true.
     843             :         /// @throw NdViewDimError if rhs, the left hand side NdView and the indexing tensor do
     844             :         ///        not all have the same dimensions.
     845             :         template <mr::StorageType rhs_tag>
     846             :         BoolIndexedView& operator=(const NdViewTagged<data_t, rhs_tag>& rhs)
     847          18 :         {
     848          18 :             static_assert(mr::are_storages_compatible<indexed_tag, rhs_tag>::value,
     849          18 :                           "Indexed view storage type and assignment right side storage type are "
     850          18 :                           "incompatible");
     851          18 :             static_assert(mr::are_storages_compatible<index_tag, rhs_tag>::value,
     852          18 :                           "Boolean index storage type and assignment right side storage type are "
     853          18 :                           "incompatible");
     854             : 
     855          18 :             const auto& indexed_shape = _indexed.shape();
     856          18 :             const auto& indexed_strides = _indexed.strides();
     857          18 :             const auto& index_shape = _index.shape();
     858          18 :             const auto& index_strides = _index.strides();
     859          18 :             const auto& rhs_shape = rhs.shape();
     860          18 :             const auto& rhs_strides = rhs.strides();
     861          18 :             bool strides_compatible = detail::are_strides_compatible(indexed_shape, indexed_strides,
     862          18 :                                                                      index_shape, index_strides)
     863          18 :                                       && detail::are_strides_compatible(
     864           8 :                                           indexed_shape, indexed_strides, rhs_shape, rhs_strides);
     865             : 
     866          18 :             auto copy_functor = [&](auto& indexed_range, auto& index_range, auto& rhs_range) {
     867          18 :                 auto begin = thrust::make_zip_iterator(thrust::make_tuple(
     868          18 :                     indexed_range.begin(), index_range.begin(), rhs_range.begin()));
     869          18 :                 auto end = thrust::make_zip_iterator(
     870          18 :                     thrust::make_tuple(indexed_range.end(), index_range.end(), rhs_range.end()));
     871             : 
     872          18 :                 thrust::for_each(detail::select_policy<indexed_tag, index_tag, rhs_tag>::policy,
     873          18 :                                  begin, end, detail::fill_if_functor<data_t>());
     874          18 :             };
     875             : 
     876          18 :             if (strides_compatible && _indexed.is_contiguous() && _index.is_contiguous()
     877          18 :                 && rhs.is_contiguous()) {
     878           5 :                 auto indexed_range = _indexed.contiguous_range();
     879           5 :                 auto index_range = _index.contiguous_range();
     880           5 :                 auto rhs_range = rhs.contiguous_range();
     881           5 :                 copy_functor(indexed_range, index_range, rhs_range);
     882          13 :             } else {
     883          13 :                 _indexed.with_canonical_range([&](auto indexed_range) mutable {
     884          13 :                     _index.with_canonical_crange([&](auto index_range) mutable {
     885          13 :                         rhs.with_canonical_crange([&](auto rhs_range) mutable {
     886          13 :                             copy_functor(indexed_range, index_range, rhs_range);
     887          13 :                         });
     888          13 :                     });
     889          13 :                 });
     890          13 :             }
     891          18 :             return *this;
     892          18 :         }
     893             : 
     894             :         /// @brief Assign rhs to the indexed tensor wherever the corresponding element of the index
     895             :         ///        tensor is true
     896             :         /// @throw NdViewDimError if the left hand side NdView and the indexing tensor do
     897             :         ///        not have the same dimensions.
     898             :         BoolIndexedView& operator=(data_t rhs)
     899           9 :         {
     900           9 :             const auto& indexed_shape = _indexed.shape();
     901           9 :             const auto& indexed_strides = _indexed.strides();
     902           9 :             const auto& index_shape = _index.shape();
     903           9 :             const auto& index_strides = _index.strides();
     904           9 :             bool strides_compatible = detail::are_strides_compatible(indexed_shape, indexed_strides,
     905           9 :                                                                      index_shape, index_strides);
     906             : 
     907           9 :             auto copy_functor = [&](auto indexed_begin, auto indexed_end, auto index_begin,
     908           9 :                                     auto index_end) {
     909           9 :                 auto begin =
     910           9 :                     thrust::make_zip_iterator(thrust::make_tuple(indexed_begin, index_begin));
     911           9 :                 auto end = thrust::make_zip_iterator(thrust::make_tuple(indexed_end, index_end));
     912             : 
     913           9 :                 thrust::for_each(detail::select_policy<indexed_tag, index_tag>::policy, begin, end,
     914           9 :                                  detail::fill_const_if_functor(rhs));
     915           9 :             };
     916             : 
     917           9 :             if (strides_compatible && _indexed.is_contiguous() && _index.is_contiguous()) {
     918           3 :                 auto indexed_range = _indexed.contiguous_range();
     919           3 :                 auto index_range = _index.contiguous_range();
     920           3 :                 copy_functor(indexed_range.begin(), indexed_range.end(), index_range.begin(),
     921           3 :                              index_range.end());
     922           6 :             } else {
     923           6 :                 _indexed.with_canonical_range([&](auto indexed_range) {
     924           6 :                     _index.with_canonical_range([&](auto index_range) {
     925           6 :                         copy_functor(indexed_range.begin(), indexed_range.end(),
     926           6 :                                      index_range.begin(), index_range.end());
     927           6 :                     });
     928           6 :                 });
     929           6 :             }
     930           9 :             return *this;
     931           9 :         }
     932             :     };
     933             : } // namespace elsa

Generated by: LCOV version 1.14