Line data Source code
1 : #pragma once
2 :
3 : #include <thrust/complex.h>
4 : #include "ContiguousStorage.h"
5 : #include "Utilities/StridedIterator.h"
6 : #include "elsaDefines.h"
7 :
8 : DISABLE_WARNING_PUSH
9 : DISABLE_WARNING_SIGN_CONVERSION
10 : #include <thrust/universal_vector.h>
11 : #include <thrust/device_ptr.h>
12 : #include <thrust/functional.h>
13 : #include <thrust/transform.h>
14 : #include <thrust/for_each.h>
15 : #include <thrust/iterator/constant_iterator.h>
16 : #include <thrust/iterator/zip_iterator.h>
17 : DISABLE_WARNING_POP
18 :
19 : #include <memory>
20 : #include <type_traits>
21 : #include <cinttypes>
22 : #include <functional>
23 : #include <exception>
24 :
25 : namespace elsa
26 : {
27 : class NdViewDimError : public std::runtime_error
28 : {
29 : public:
30 2 : NdViewDimError() : std::runtime_error("NdView dimensions do not match") {}
31 : };
32 : class NdViewUnsupportedShape : public std::runtime_error
33 : {
34 : public:
35 : NdViewUnsupportedShape()
36 : : std::runtime_error("NdView one or more dimensions are not positive")
37 1 : {
38 1 : }
39 : };
40 : class NdViewDimOutOfBounds : public std::runtime_error
41 : {
42 : public:
43 2 : NdViewDimOutOfBounds() : std::runtime_error("NdView dimension out-of-bounds") {}
44 : };
45 : class NdViewIndexOutOfBounds : public std::runtime_error
46 : {
47 : public:
48 : NdViewIndexOutOfBounds() : std::runtime_error("NdView index into dimension out-of-bounds")
49 3 : {
50 3 : }
51 : };
52 : class NdViewUnavailableIteration : public std::runtime_error
53 : {
54 : public:
55 : NdViewUnavailableIteration()
56 : : std::runtime_error("NdView iteration range not available for current data")
57 2 : {
58 2 : }
59 : };
60 : class NdViewEmptyView : public std::runtime_error
61 : {
62 : public:
63 0 : NdViewEmptyView() : std::runtime_error("NdView is empty") {}
64 : };
65 :
66 : template <typename data_t, mr::StorageType indexed_tag, mr::StorageType index_tag>
67 : class BoolIndexedView;
68 :
69 : template <class data_t, mr::StorageType tag>
70 : class NdViewTagged;
71 :
72 : namespace detail
73 : {
74 : template <typename View, typename Functor>
75 : auto with_canonical_range(View& view, Functor&& f)
76 : -> decltype(std::declval<Functor>()(std::declval<View>().canonical_range()))
77 111 : {
78 111 : static_assert(std::is_same_v<decltype(std::declval<Functor>()(view.canonical_range())),
79 111 : decltype(std::declval<Functor>()(view.range()))>,
80 111 : "Functor return value differs, depending on input iterator range "
81 111 : "type!");
82 111 : if (view.is_canonical()) {
83 57 : return f(view.canonical_range());
84 57 : } else {
85 54 : return f(view.range());
86 54 : }
87 111 : }
88 :
89 : /* throws NdViewDimError if the dimension coefficions do not match */
90 : bool are_strides_compatible(const IndexVector_t& shape1, const IndexVector_t& strides1,
91 : const IndexVector_t& shape2, const IndexVector_t& strides2);
92 :
93 : template <typename size_type>
94 : std::pair<IndexVector_t, IndexVector_t>
95 : unpack_dim_data(const ContiguousStorage<DimData<size_type>>& shape)
96 : {
97 : size_t dim_count = shape.size();
98 : IndexVector_t out_shape(dim_count);
99 : IndexVector_t out_strides(dim_count);
100 : for (size_t i = 0; i < dim_count; i++) {
101 : auto dim_data = shape[i];
102 : out_shape(i) = dim_data.shape;
103 : out_strides(i) = dim_data.stride;
104 : }
105 : return std::make_pair(out_shape, out_strides);
106 : }
107 :
108 : template <typename data_t, mr::StorageType tag>
109 : NdViewTagged<data_t, tag> create_uninitialized_owned_view(const IndexVector_t& shape,
110 : const IndexVector_t& strides,
111 : size_t allocation_size)
112 159 : {
113 159 : auto memres = mr::defaultResource();
114 159 : if (std::numeric_limits<size_t>::max() / sizeof(data_t) < allocation_size) {
115 0 : throw std::runtime_error("Overflowing multiplication");
116 0 : }
117 159 : size_t out_data_size = sizeof(data_t) * allocation_size;
118 159 : void* out_data = memres->allocate(out_data_size, alignof(data_t));
119 159 : return NdViewTagged<data_t, mr::sysStorageType>(
120 159 : reinterpret_cast<data_t*>(out_data), shape, strides,
121 159 : [=]() { memres->deallocate(out_data, out_data_size, alignof(data_t)); });
122 159 : }
123 :
124 : template <typename data_t, typename Functor>
125 : struct binop_to_unop
126 : : public thrust::unary_function<data_t,
127 : decltype(std::declval<Functor>()(
128 : std::declval<data_t>(), std::declval<data_t>()))> {
129 : data_t other;
130 : Functor f;
131 :
132 : binop_to_unop(data_t other, Functor f) : other{other}, f{f} {}
133 :
134 : __host__ __device__ auto operator()(data_t e) const -> decltype(f(e, other))
135 : {
136 : return f(e, other);
137 : }
138 : };
139 :
140 : } // namespace detail
141 :
142 : /// @brief Represents a non-owning view. It can handle arbitrary strides (including negative).
143 : /// Supports the creation of subviews. Supports iteration in canonical order with a
144 : /// thrust::device compatible iterator, provided the storage type is device accessible.
145 : /// Upon deletion, if no other NdView has a reference to the data, signals to the owner
146 : /// of the data that it may be deleted via a destructor that is passed as
147 : /// constructor parameter.
148 : /// Additionally, NdViewTagged provides elementwise unary and binary operations and
149 : /// filtered assignments.
150 : /// @tparam data_t type of the data that the NdView points to
151 : /// @tparam tag storage type/location of the data; Unlike DataContainer, NdView supports
152 : /// non-universal memory, if compiled with CUDA
153 : /// @see is_canonical()
154 : template <class data_t, mr::StorageType tag>
155 : class NdViewTagged
156 : {
157 : public:
158 : using pointer_type = std::conditional_t<
159 : tag == mr::StorageType::host, data_t*,
160 : std::conditional_t<tag == mr::StorageType::device, thrust::device_ptr<data_t>,
161 : thrust::universal_ptr<data_t>>>;
162 : using const_pointer_type = std::conditional_t<
163 : tag == mr::StorageType::host, const data_t*,
164 : std::conditional_t<tag == mr::StorageType::device, thrust::device_ptr<const data_t>,
165 : thrust::universal_ptr<const data_t>>>;
166 :
167 : using Scalar = data_t;
168 : using self_type = NdViewTagged<data_t, tag>;
169 : using value_type = data_t;
170 : using iterator = pointer_type;
171 : using const_iterator = const_pointer_type;
172 : using reference = data_t&;
173 : using const_reference = std::add_const_t<reference>;
174 : using dim_data_type = decltype(pointer_type() - pointer_type());
175 : using dim_data = DimData<dim_data_type>;
176 :
177 : struct Cleanup {
178 : std::function<void()> cleanup;
179 402 : Cleanup(std::function<void()>&& cleanup) : cleanup{cleanup} {}
180 402 : ~Cleanup() { cleanup(); }
181 : };
182 :
183 : template <class ItType>
184 : struct IteratorRange {
185 : private:
186 : ItType _begin = ItType();
187 : ItType _end = ItType();
188 :
189 : public:
190 386 : IteratorRange(const ItType& b, const ItType& e) : _begin(b), _end(e) {}
191 385 : ItType begin() const { return _begin; }
192 173 : ItType end() const { return _end; }
193 : };
194 :
195 : private:
196 : struct Container {
197 : std::shared_ptr<Cleanup> cleanup;
198 : pointer_type pointer = pointer_type();
199 : IndexVector_t shape;
200 : IndexVector_t strides;
201 : ContiguousStorage<dim_data> dims;
202 : pointer_type contiguous_start = pointer_type();
203 : size_t size{0};
204 : bool is_canonical = false;
205 : bool contiguous = false;
206 : };
207 : std::shared_ptr<Container> _container;
208 :
209 : public:
210 : /// @brief Create a view on raw (possibly non-contiguous) data
211 : /// @param cleanup shared encapsulated destructor; to be called, once this NdView
212 : /// (and all of its parent or subviews) have been deleted
213 : NdViewTagged(data_t* raw_data, const IndexVector_t& shape, const IndexVector_t& strides,
214 : const std::shared_ptr<Cleanup>& cleanup)
215 : : _container{std::make_shared<Container>()}
216 416 : {
217 416 : size_t dims = shape.size();
218 416 : if (static_cast<ssize_t>(dims) != strides.size())
219 0 : throw NdViewDimError();
220 :
221 : /* check that all shapes are positive and if the overall container describes an
222 : * empty volume (initialize size as 1 as zero-dimensional objects have this size) */
223 416 : _container->size = 1;
224 1233 : for (auto& x : shape) {
225 1233 : if (x < 0)
226 1 : throw NdViewUnsupportedShape();
227 1232 : if (x == 0)
228 0 : _container->size = 0;
229 1232 : }
230 416 : if (_container->size == 0)
231 0 : dims = 0;
232 415 : _container->cleanup = cleanup;
233 415 : _container->pointer = pointer_type(raw_data);
234 415 : _container->shape = shape;
235 415 : _container->strides = strides;
236 415 : _container->dims = ContiguousStorage<dim_data>(static_cast<size_t>(dims));
237 415 : _container->contiguous_start = _container->pointer;
238 415 : _container->is_canonical = true;
239 415 : _container->contiguous = true;
240 :
241 : /* check if the container is empty or if it is zero-dimensional,
242 : * in which case the processing is done here */
243 415 : if (_container->size == 0)
244 0 : return;
245 415 : if (dims == 0)
246 1 : return;
247 :
248 : /* iterate over all dimensions and:
249 : * - check if its a canonical layout
250 : * - adjust the start-pointer to the actual first address
251 : * to allow for contiugous iteration (in case of negative strides)
252 : * - create the copy of the strides in order to determine if the range is contiguous */
253 414 : std::vector<dim_data> shapeAndStrides(dims);
254 1646 : for (size_t i = 0; i < dims; i++) {
255 1232 : _container->dims[i] = {shape(i), strides(i)};
256 1232 : shapeAndStrides[i] = {shape(i), std::abs(strides(i))};
257 :
258 : /* check if the stride matches the canonical layout */
259 1232 : if (static_cast<ssize_t>(_container->size) != _container->dims[i].stride)
260 339 : _container->is_canonical = false;
261 :
262 : /* check if the stride is negative and adjust the pointer accordingly */
263 1232 : if (_container->dims[i].stride < 0) {
264 : /* move the adjusted pointer back by the negative stides by its dimension */
265 0 : _container->contiguous_start +=
266 0 : _container->dims[i].stride * _container->dims[i].shape;
267 0 : }
268 1232 : _container->size *= shape(i);
269 1232 : }
270 :
271 : /* sort the shapes and strides by the strides and check if every next stride
272 : * is equal to the previous stride times shape (in which case its contiguous) */
273 414 : std::sort(shapeAndStrides.begin(), shapeAndStrides.end(),
274 1383 : [](auto& a, auto& b) { return a.stride < b.stride; });
275 :
276 : /* dims > 0 is ensured above */
277 414 : _container->contiguous = (shapeAndStrides[0].stride == 1);
278 1184 : for (size_t i = 1; i < dims && _container->contiguous; i++) {
279 770 : _container->contiguous =
280 770 : shapeAndStrides[i].stride
281 770 : == shapeAndStrides[i - 1].stride * shapeAndStrides[i - 1].shape;
282 770 : }
283 :
284 : /* is_canonical will already be implicitly false, but this makes it clearer */
285 414 : if (!_container->contiguous)
286 33 : _container->is_canonical = false;
287 414 : }
288 :
289 : /// @brief Create a view on raw (possibly non-contiguous) data
290 : /// @param cleanup destructor to be called, once this NdView
291 : /// (and all of its parent or subviews) have been deleted
292 : NdViewTagged(data_t* raw_data, const IndexVector_t& shape, const IndexVector_t& strides,
293 : std::function<void()> destructor)
294 : : NdViewTagged(raw_data, shape, strides,
295 : std::make_shared<Cleanup>(std::move(destructor)))
296 401 : {
297 401 : }
298 :
299 : /// @brief Create an empty view
300 : NdViewTagged() : _container{std::make_shared<Container>()}
301 1 : {
302 1 : _container->size = 0;
303 1 : _container->is_canonical = true;
304 1 : _container->contiguous = true;
305 1 : _container->cleanup = std::make_shared<Cleanup>([]() {});
306 1 : }
307 :
308 : /// @return true iff the raw data is layed out contiguously
309 290 : bool is_contiguous() const { return _container->contiguous; }
310 : /// @return true iff the raw data follows the canonical layout
311 : /// Canonical layout is defined as follows:
312 : /// strides[0] = 1;
313 : /// strides[i] = strides[i - 1] * shape[i - 1];
314 : /// It is tempting to call this column major layout, but there is one
315 : /// caveat. In elsa, the first index refers to the column (i.e. x-coordinate)
316 : /// and the second refers to the row (i.e. y-coordinate).
317 122 : bool is_canonical() const { return _container->is_canonical; }
318 :
319 : /// @return true iff the view does not contain any data
320 : bool is_empty() const { return _container->size == 0; }
321 :
322 : /// @brief iterate over the raw data in any order, if the data
323 : /// is contiguous (iterators are pointer)
324 : /// Useful for reductions or transformations where the order of the
325 : /// data is not relevant (strides are ignored).
326 : IteratorRange<pointer_type> contiguous_range()
327 155 : {
328 155 : if (!_container->contiguous)
329 1 : throw NdViewUnavailableIteration();
330 154 : return IteratorRange<pointer_type>(_container->contiguous_start,
331 154 : _container->contiguous_start + _container->size);
332 154 : }
333 : IteratorRange<const_pointer_type> contiguous_range() const
334 169 : {
335 169 : if (!_container->contiguous)
336 0 : throw NdViewUnavailableIteration();
337 169 : return IteratorRange<const_pointer_type>(
338 169 : _container->contiguous_start, _container->contiguous_start + _container->size);
339 169 : }
340 :
341 : /// @brief iterate over the raw data in canonical order, if the data
342 : /// is layed out in canonical layout (iterators are pointer).
343 : /// @see is_canonical()
344 : IteratorRange<pointer_type> canonical_range()
345 16 : {
346 16 : if (!_container->is_canonical)
347 1 : throw NdViewUnavailableIteration();
348 15 : return IteratorRange<pointer_type>(_container->contiguous_start,
349 15 : _container->contiguous_start + _container->size);
350 15 : }
351 : IteratorRange<const_pointer_type> canonical_range() const
352 48 : {
353 48 : if (!_container->is_canonical)
354 0 : throw NdViewUnavailableIteration();
355 48 : return IteratorRange<const_pointer_type>(
356 48 : _container->contiguous_start, _container->contiguous_start + _container->size);
357 48 : }
358 :
359 : /// @brief iterate over the data in canonical order, regardless of its real layout
360 : StridedRange<pointer_type> range()
361 353 : {
362 353 : return StridedRange<pointer_type>(_container->pointer, _container->size,
363 353 : _container->dims.data().get(),
364 353 : _container->dims.size());
365 353 : }
366 : StridedRange<const_pointer_type> range() const
367 132 : {
368 132 : return StridedRange<const_pointer_type>(_container->pointer, _container->size,
369 132 : _container->dims.data().get(),
370 132 : _container->dims.size());
371 132 : }
372 :
373 : private:
374 : template <class...>
375 : bool UnpackAndCheckBounds(size_t index) const
376 81 : {
377 81 : static_cast<void>(index);
378 81 : return true;
379 81 : }
380 :
381 : template <class Index, class... Indices>
382 : bool UnpackAndCheckBounds(size_t current, Index index, Indices... indices) const
383 245 : {
384 245 : if (index < 0 || index >= _container->shape(current))
385 1 : return false;
386 244 : return UnpackAndCheckBounds<Indices...>(current + 1, indices...);
387 244 : }
388 :
389 : template <class...>
390 : ssize_t UnpackAndComputeIndex(size_t index) const
391 81 : {
392 81 : static_cast<void>(index);
393 81 : return 0;
394 81 : }
395 :
396 : template <class Index, class... Indices>
397 : ssize_t UnpackAndComputeIndex(size_t current, Index index, Indices... indices) const
398 243 : {
399 243 : return (_container->strides(current) * index)
400 243 : + UnpackAndComputeIndex<Indices...>(current + 1, indices...);
401 243 : }
402 :
403 : /// Performs an element-wise binary operation. The result is returned in
404 : /// a view, which owns a newly allocated buffer for the output. The strides
405 : /// of the output may not match the strides of either input.
406 : template <mr::StorageType other_tag, typename Functor>
407 : NdViewTagged<decltype(std::declval<Functor>()(std::declval<data_t>(),
408 : std::declval<data_t>())),
409 : mr::sysStorageType>
410 : binop(const NdViewTagged<data_t, other_tag>& other, Functor functor) const
411 87 : {
412 87 : static_assert(mr::are_storages_compatible<tag, other_tag>::value,
413 87 : "Binary operations are only possible when both arguments live on "
414 87 : "compatible devices");
415 87 : using output_type =
416 87 : decltype(std::declval<Functor>()(std::declval<data_t>(), std::declval<data_t>()));
417 87 : static_assert(
418 87 : std::is_trivially_copyable_v<output_type>,
419 87 : "Binary operations are only implemented for trivially copyable result types");
420 :
421 87 : const auto& shape = this->shape();
422 87 : const auto& strides = this->strides();
423 87 : const auto& other_shape = other.shape();
424 87 : const auto& other_strides = other.strides();
425 :
426 87 : size_t element_count = this->size();
427 87 : bool strides_compatible =
428 87 : detail::are_strides_compatible(shape, strides, other_shape, other_strides);
429 :
430 87 : IndexVector_t out_strides;
431 :
432 87 : bool same_layout = strides_compatible && this->is_contiguous() && other.is_contiguous();
433 87 : if (!same_layout) {
434 : /* layouts do not match or are non-contiguous, replace out strides with column major
435 : * layout */
436 30 : size_t stride = 1;
437 30 : size_t dim_count = shape.size();
438 30 : out_strides = IndexVector_t(dim_count);
439 120 : for (size_t i = 0; i < dim_count; i++) {
440 90 : out_strides(i) = stride;
441 90 : stride *= shape(i);
442 90 : }
443 57 : } else {
444 57 : out_strides = strides;
445 57 : }
446 :
447 : /* No initialization is performed on the buffer before assigning its contents with
448 : * thrust::transform, so this only works for trivially copyable types. To extend it for
449 : * other types, one could use thrust::for_each with a functor that relies on placement
450 : * new to assign the computed values. */
451 87 : auto out = detail::create_uninitialized_owned_view<output_type, mr::sysStorageType>(
452 87 : shape, out_strides, element_count);
453 :
454 87 : auto out_range = out.contiguous_range();
455 87 : if (same_layout) {
456 57 : auto left_range = this->contiguous_range();
457 57 : auto right_range = other.contiguous_range();
458 57 : thrust::transform(left_range.begin(), left_range.end(), right_range.begin(),
459 57 : out_range.begin(), functor);
460 57 : } else {
461 30 : this->with_canonical_crange([&](auto left_range) mutable {
462 30 : other.with_canonical_crange([&](auto right_range) mutable {
463 30 : thrust::transform(left_range.begin(), left_range.end(), right_range.begin(),
464 30 : out_range.begin(), functor);
465 30 : });
466 30 : });
467 30 : }
468 87 : return out;
469 87 : }
470 :
471 : /// Performs an element-wise unary operation. The result is returned in a view, which owns a
472 : /// newly allocated buffer for the output. The strides of the output may not match the
473 : /// strides of the input.
474 : template <typename Functor>
475 : NdViewTagged<decltype(std::declval<Functor>()(std::declval<data_t>())), mr::sysStorageType>
476 : unop(Functor functor) const
477 72 : {
478 72 : using output_type = decltype(std::declval<Functor>()(std::declval<data_t>()));
479 72 : static_assert(
480 72 : std::is_trivially_copyable_v<output_type>,
481 72 : "Binary operations are only implemented for trivially copyable result types");
482 :
483 72 : auto& shape = this->shape();
484 72 : IndexVector_t out_strides;
485 72 : size_t element_count = this->size();
486 :
487 72 : if (!is_contiguous()) {
488 : /* layout is non-contiguous, replace out strides with column major layout */
489 22 : size_t stride = 1;
490 22 : size_t dim_count = shape.size();
491 22 : out_strides = IndexVector_t(dim_count);
492 88 : for (size_t i = 0; i < dim_count; i++) {
493 66 : out_strides(i) = stride;
494 66 : stride *= shape(i);
495 66 : }
496 50 : } else {
497 50 : out_strides = this->strides();
498 50 : }
499 :
500 72 : auto out = detail::create_uninitialized_owned_view<output_type, mr::sysStorageType>(
501 72 : shape, out_strides, element_count);
502 :
503 72 : auto compute_functor = [=](auto src_range, auto dst_range) {
504 72 : thrust::transform(src_range.begin(), src_range.end(), dst_range.begin(), functor);
505 72 : };
506 :
507 72 : if (is_contiguous()) {
508 50 : auto src_range = contiguous_range();
509 50 : auto dst_range = out.contiguous_range();
510 50 : compute_functor(src_range, dst_range);
511 50 : } else {
512 22 : auto src_range = range();
513 22 : auto dst_range = out.range();
514 22 : compute_functor(src_range, dst_range);
515 22 : }
516 72 : return out;
517 72 : }
518 :
519 : public:
520 : template <class... Indices>
521 : const data_t& operator()(Indices... index) const
522 : {
523 : return const_cast<self_type*>(this)->operator()(index...);
524 : }
525 :
526 : /// @brief extract a single element; Indices for all dimensions must be supplied
527 : template <class... Indices>
528 : data_t& operator()(Indices... index)
529 84 : {
530 84 : if (sizeof...(Indices) != _container->shape.size())
531 2 : throw NdViewDimError();
532 82 : if (_container->size == 0)
533 0 : throw NdViewEmptyView();
534 :
535 82 : if constexpr (sizeof...(Indices) == 0)
536 0 : return _container->pointer[0];
537 82 : else {
538 82 : if (!UnpackAndCheckBounds<Indices...>(0, index...))
539 1 : throw NdViewIndexOutOfBounds();
540 81 : return _container->pointer[UnpackAndComputeIndex<Indices...>(0, index...)];
541 81 : }
542 82 : }
543 :
544 : /// @brief returned NdView has its dimensionality reduce by one by
545 : /// selecting a point along one dimension to bind to a set value
546 : /// @param dim dimension to fix
547 : /// @param where the index of the slice along the dimension
548 : self_type fix(size_t dim, size_t where)
549 9 : {
550 9 : if (dim >= static_cast<size_t>(_container->shape.size()))
551 1 : throw NdViewDimOutOfBounds();
552 8 : if (static_cast<dim_data_type>(where) >= _container->shape(dim))
553 1 : throw NdViewIndexOutOfBounds();
554 :
555 : /* allocate the new strides */
556 7 : IndexVector_t shape(_container->shape.size() - 1);
557 7 : IndexVector_t strides(_container->strides.size() - 1);
558 18 : for (index_t i = 0; i < _container->shape.size() - 1; ++i) {
559 11 : shape(i) = _container->shape(i + (i >= dim ? 1 : 0));
560 11 : strides(i) = _container->strides(i + (i >= dim ? 1 : 0));
561 11 : }
562 :
563 : /* adjust the pointer to the offset */
564 7 : data_t* raw =
565 7 : thrust::raw_pointer_cast(_container->pointer) + _container->strides(dim) * where;
566 7 : return self_type(raw, shape, strides, _container->cleanup);
567 7 : }
568 :
569 : /// @brief returned NdView has same dimensionality but a shape of less or
570 : /// equal to the original along the corresponding dimension
571 : /// @param dim dimension along which to take a sub-range
572 : /// @param where_begin lowest index along dimension dim (inclusive)
573 : /// @param where_end highest index along dimension dim (exclusive);
574 : /// where_begin <= where_end must hold!
575 : self_type slice(size_t dim, size_t where_begin, size_t where_end)
576 10 : {
577 10 : if (dim >= _container->shape.size())
578 1 : throw NdViewDimOutOfBounds();
579 9 : if (where_begin >= static_cast<size_t>(_container->shape(dim))
580 9 : || where_end >= static_cast<size_t>(_container->shape(dim)))
581 1 : throw NdViewIndexOutOfBounds();
582 8 : if (where_begin == where_end)
583 0 : throw NdViewUnsupportedShape();
584 :
585 8 : data_t* raw = thrust::raw_pointer_cast(_container->pointer);
586 8 : size_t dims = _container->shape.size();
587 8 : IndexVector_t shape(dims);
588 8 : IndexVector_t strides(dims);
589 32 : for (size_t i = 0; i < dims; i++) {
590 24 : shape(i) = _container->shape(i);
591 24 : strides(i) = _container->strides(i);
592 24 : }
593 :
594 : /* limit the size along dimension dim */
595 8 : shape(dim) = where_end - where_begin;
596 : /* adjust start to skip all */
597 8 : raw += where_begin * strides(dim);
598 :
599 8 : return self_type(raw, shape, strides, _container->cleanup);
600 8 : }
601 :
602 : /// @return the shape of this NdView
603 417 : const IndexVector_t& shape() const { return _container->shape; }
604 :
605 : /// @return the strides of this NdView
606 368 : const IndexVector_t& strides() const { return _container->strides; }
607 :
608 : /// @return the shape and strides of this NdView; Stored in memory of type
609 : /// mr::sysStorageType, i.e. they are device accessible if compiled with CUDA
610 : const ContiguousStorage<dim_data>& layout_data() const { return _container->dims; }
611 :
612 : /// @return the cleanup sentinel; once its last reference is dropped, the destructor is
613 : /// called
614 0 : std::shared_ptr<Cleanup> getCleanup() { return _container->cleanup; }
615 :
616 : /// @return the number of elements in this view
617 971 : size_t size() const { return _container->size; }
618 :
619 : /// @brief Calls the functor with the lowest overhead iterator range that guarantees
620 : /// canonical iteration order. I.e. if the data is naturarlly layed out in canonical
621 : /// order, the iterators will be pointers.
622 : /// @param f functor to call with an iterator range object with methods .begin() and .end()
623 : template <typename Functor>
624 : auto with_canonical_range(Functor&& f)
625 : -> decltype(std::declval<Functor>()(this->canonical_range()))
626 25 : {
627 25 : detail::with_canonical_range(*this, std::forward<Functor>(f));
628 25 : }
629 :
630 : /// @brief Const version of with_canonical_range()
631 : /// @see with_canonical_range()
632 : template <typename Functor>
633 : auto with_canonical_crange(Functor&& f) const
634 : -> decltype(std::declval<Functor>()(this->canonical_range()))
635 86 : {
636 86 : detail::with_canonical_range(*this, std::forward<Functor>(f));
637 86 : }
638 :
639 : /// @brief Calls the functor with the lowest overhead iterator range available.
640 : /// I.e. if the data is naturarlly layed out in contiguously, the iterators will be
641 : /// pointers.
642 : /// @param f functor to call with an iterator range object with methods .begin() and .end()
643 : template <typename Functor>
644 : auto with_unordered_range(Functor f)
645 : -> decltype(std::declval<Functor>(this->contiguous_range()))
646 : {
647 : static_assert(std::is_same_v<decltype(std::declval<Functor>(this->contiguous_range())),
648 : decltype(std::declval<Functor>(this->range()))>,
649 : "Functor return value differs, depending on input iterator range "
650 : "type!");
651 : if (is_contiguous()) {
652 : return f(contiguous_range());
653 : } else {
654 : return f(range());
655 : }
656 : }
657 :
658 : #define NDVIEW_BINOP(op) \
659 : template <mr::StorageType other_tag> \
660 : auto operator op(const NdViewTagged<data_t, other_tag>& other) \
661 : const->decltype(binop(other, thrust::placeholders::_1 op thrust::placeholders::_2)) \
662 87 : { \
663 87 : return binop(other, thrust::placeholders::_1 op thrust::placeholders::_2); \
664 87 : } \
665 : \
666 : template <typename data2_t = data_t> \
667 : friend auto operator op(const NdViewTagged<data_t, tag>& ndview, data2_t other) \
668 : ->typename std::enable_if< \
669 : std::is_same_v<data_t, data2_t>, \
670 : NdViewTagged<decltype(std::declval<data_t>() op std::declval<data2_t>()), \
671 : mr::sysStorageType>>::type \
672 72 : { \
673 72 : return ndview.unop(thrust::placeholders::_1 op other); \
674 72 : } \
675 : \
676 : template <typename data2_t = data_t> \
677 : friend auto operator op(data2_t other, const NdViewTagged<data_t, tag>& ndview) \
678 : ->typename std::enable_if< \
679 : std::is_same_v<data_t, data2_t>, \
680 : NdViewTagged<decltype(std::declval<data2_t>() op std::declval<data_t>()), \
681 : mr::sysStorageType>>::type \
682 : { \
683 : return ndview.unop(other op thrust::placeholders::_1); \
684 : }
685 :
686 : /* Elementwise binary operations. The result of the operation
687 : * is eagerly evaluated and returned in a new NdView that owns
688 : * the result.
689 : * Note assignment operators are not implemented, as their desired
690 : * semantics are unclear. One might expect that x *= 2 doubles
691 : * all elements of x in the underlying data that x is a view of.
692 : * However, then x = x * 2 would behave differently form x *= 2,
693 : * as x * 2 allocates a new buffer.
694 : **/
695 : NDVIEW_BINOP(+);
696 : NDVIEW_BINOP(-);
697 : NDVIEW_BINOP(*);
698 : NDVIEW_BINOP(/);
699 : NDVIEW_BINOP(%);
700 : NDVIEW_BINOP(>);
701 : NDVIEW_BINOP(<);
702 : NDVIEW_BINOP(>=);
703 : NDVIEW_BINOP(<=);
704 : NDVIEW_BINOP(==);
705 : NDVIEW_BINOP(!=);
706 : NDVIEW_BINOP(&);
707 : NDVIEW_BINOP(|);
708 : NDVIEW_BINOP(^);
709 : NDVIEW_BINOP(&&);
710 : NDVIEW_BINOP(||);
711 :
712 : #undef NDVIEW_BINOP
713 :
714 : #define NDVIEW_UNOP(op) \
715 : auto operator op() const->decltype(op thrust::placeholders::_1) \
716 : { \
717 : return unop(op thrust::placeholders::_1); \
718 : }
719 :
720 : NDVIEW_UNOP(-);
721 : NDVIEW_UNOP(!);
722 :
723 : /// Creates a left hand side, to be used for filtered assignments.
724 : /// The index parameter must be an NdView of equal dimensions to this.
725 : /// The returned object can be assigned to, replacing all entries in this
726 : /// view, whose corresponding index element is true. Indices whose filter-
727 : /// element is false remain unchanged.
728 : /// If the index tensor does not have the correct dimensions, no exception
729 : /// is thrown until the actual assignment occurs.
730 : ///
731 : /// Example:
732 : /// ```
733 : /// template<typename T>
734 : /// void zero_value_range(NdViewTagged<float, T> &x, float lb, float ub) {
735 : /// x[x >= lb && x < ub] = 0.0f;
736 : /// }
737 : /// ```
738 : /// This example function sets all elements in the value range lb <= x < ub
739 : /// to zero.
740 : /// @see BoolIndexedView
741 : template <mr::StorageType other_tag>
742 : BoolIndexedView<data_t, tag, other_tag>
743 : operator[](const NdViewTagged<bool, other_tag>& index)
744 27 : {
745 27 : return BoolIndexedView(*this, index);
746 27 : }
747 : };
748 :
749 : namespace detail
750 : {
751 : template <typename data_t>
752 : struct fill_const_if_functor
753 : : public thrust::unary_function<thrust::tuple<data_t&, bool>, void> {
754 : data_t fill_value;
755 :
756 9 : fill_const_if_functor(const data_t& fill_value) : fill_value{fill_value} {}
757 :
758 : __host__ __device__ void operator()(thrust::tuple<data_t&, bool> args)
759 243 : {
760 243 : if (thrust::get<1>(args)) {
761 81 : thrust::get<0>(args) = fill_value;
762 81 : }
763 243 : }
764 : };
765 :
766 : template <typename data_t>
767 : struct fill_if_functor : public thrust::unary_function<thrust::tuple<data_t&, bool>, void> {
768 : __host__ __device__ void operator()(thrust::tuple<data_t&, bool, data_t> args)
769 486 : {
770 486 : if (thrust::get<1>(args)) {
771 162 : thrust::get<0>(args) = thrust::get<2>(args);
772 162 : }
773 486 : }
774 : };
775 :
776 : template <mr::StorageType... tags>
777 : struct are_tags_host_compatible {
778 : static constexpr bool value = false;
779 : };
780 :
781 : template <mr::StorageType... tags>
782 : struct are_tags_host_compatible<mr::StorageType::universal, tags...> {
783 : static constexpr bool value = are_tags_host_compatible<tags...>::value;
784 : };
785 :
786 : template <mr::StorageType... tags>
787 : struct are_tags_host_compatible<mr::StorageType::host, tags...> {
788 : static constexpr bool value = are_tags_host_compatible<tags...>::value;
789 : };
790 :
791 : template <>
792 : struct are_tags_host_compatible<> {
793 : static constexpr bool value = true;
794 : };
795 :
796 : template <mr::StorageType... tags>
797 : struct select_policy {
798 : };
799 :
800 : template <>
801 : struct select_policy<> {
802 : static constexpr decltype(thrust::device) policy = thrust::device;
803 : };
804 :
805 : template <mr::StorageType... tags>
806 : struct select_policy<mr::StorageType::universal, tags...> {
807 : static constexpr decltype(select_policy<tags...>::policy) policy =
808 : select_policy<tags...>::policy;
809 : };
810 :
811 : template <mr::StorageType... tags>
812 : struct select_policy<mr::StorageType::host, tags...> {
813 : static_assert(are_tags_host_compatible<tags...>::value,
814 : "Cannot select policy, as storage tags are incompatible");
815 : static constexpr decltype(thrust::host) policy = thrust::host;
816 : };
817 : } // namespace detail
818 :
819 : /// @brief NdView using the system storage type
820 : /// @tparam data_t type of the data that the NdView points to
821 : /// @see elsa::mr::sysStorageType
822 : template <class data_t>
823 : using NdView = NdViewTagged<data_t, mr::sysStorageType>;
824 :
825 : /// @brief Class representing a view with a boolean index tensor applied. Objects of this
826 : /// class are intended to be assigned to, which overwrites the element whose filter is
827 : /// true.
828 : template <typename data_t, mr::StorageType indexed_tag, mr::StorageType index_tag>
829 : class BoolIndexedView
830 : {
831 : NdViewTagged<data_t, indexed_tag> _indexed;
832 : NdViewTagged<bool, index_tag> _index;
833 :
834 : public:
835 : BoolIndexedView(const NdViewTagged<data_t, indexed_tag>& indexed,
836 : NdViewTagged<bool, index_tag> index)
837 : : _indexed{indexed}, _index{index}
838 27 : {
839 27 : }
840 :
841 : /// @brief Assign the elements of rhs to the indexed tensor wherever the corresponding
842 : /// element of the index tensor is true.
843 : /// @throw NdViewDimError if rhs, the left hand side NdView and the indexing tensor do
844 : /// not all have the same dimensions.
845 : template <mr::StorageType rhs_tag>
846 : BoolIndexedView& operator=(const NdViewTagged<data_t, rhs_tag>& rhs)
847 18 : {
848 18 : static_assert(mr::are_storages_compatible<indexed_tag, rhs_tag>::value,
849 18 : "Indexed view storage type and assignment right side storage type are "
850 18 : "incompatible");
851 18 : static_assert(mr::are_storages_compatible<index_tag, rhs_tag>::value,
852 18 : "Boolean index storage type and assignment right side storage type are "
853 18 : "incompatible");
854 :
855 18 : const auto& indexed_shape = _indexed.shape();
856 18 : const auto& indexed_strides = _indexed.strides();
857 18 : const auto& index_shape = _index.shape();
858 18 : const auto& index_strides = _index.strides();
859 18 : const auto& rhs_shape = rhs.shape();
860 18 : const auto& rhs_strides = rhs.strides();
861 18 : bool strides_compatible = detail::are_strides_compatible(indexed_shape, indexed_strides,
862 18 : index_shape, index_strides)
863 18 : && detail::are_strides_compatible(
864 8 : indexed_shape, indexed_strides, rhs_shape, rhs_strides);
865 :
866 18 : auto copy_functor = [&](auto& indexed_range, auto& index_range, auto& rhs_range) {
867 18 : auto begin = thrust::make_zip_iterator(thrust::make_tuple(
868 18 : indexed_range.begin(), index_range.begin(), rhs_range.begin()));
869 18 : auto end = thrust::make_zip_iterator(
870 18 : thrust::make_tuple(indexed_range.end(), index_range.end(), rhs_range.end()));
871 :
872 18 : thrust::for_each(detail::select_policy<indexed_tag, index_tag, rhs_tag>::policy,
873 18 : begin, end, detail::fill_if_functor<data_t>());
874 18 : };
875 :
876 18 : if (strides_compatible && _indexed.is_contiguous() && _index.is_contiguous()
877 18 : && rhs.is_contiguous()) {
878 5 : auto indexed_range = _indexed.contiguous_range();
879 5 : auto index_range = _index.contiguous_range();
880 5 : auto rhs_range = rhs.contiguous_range();
881 5 : copy_functor(indexed_range, index_range, rhs_range);
882 13 : } else {
883 13 : _indexed.with_canonical_range([&](auto indexed_range) mutable {
884 13 : _index.with_canonical_crange([&](auto index_range) mutable {
885 13 : rhs.with_canonical_crange([&](auto rhs_range) mutable {
886 13 : copy_functor(indexed_range, index_range, rhs_range);
887 13 : });
888 13 : });
889 13 : });
890 13 : }
891 18 : return *this;
892 18 : }
893 :
894 : /// @brief Assign rhs to the indexed tensor wherever the corresponding element of the index
895 : /// tensor is true
896 : /// @throw NdViewDimError if the left hand side NdView and the indexing tensor do
897 : /// not have the same dimensions.
898 : BoolIndexedView& operator=(data_t rhs)
899 9 : {
900 9 : const auto& indexed_shape = _indexed.shape();
901 9 : const auto& indexed_strides = _indexed.strides();
902 9 : const auto& index_shape = _index.shape();
903 9 : const auto& index_strides = _index.strides();
904 9 : bool strides_compatible = detail::are_strides_compatible(indexed_shape, indexed_strides,
905 9 : index_shape, index_strides);
906 :
907 9 : auto copy_functor = [&](auto indexed_begin, auto indexed_end, auto index_begin,
908 9 : auto index_end) {
909 9 : auto begin =
910 9 : thrust::make_zip_iterator(thrust::make_tuple(indexed_begin, index_begin));
911 9 : auto end = thrust::make_zip_iterator(thrust::make_tuple(indexed_end, index_end));
912 :
913 9 : thrust::for_each(detail::select_policy<indexed_tag, index_tag>::policy, begin, end,
914 9 : detail::fill_const_if_functor(rhs));
915 9 : };
916 :
917 9 : if (strides_compatible && _indexed.is_contiguous() && _index.is_contiguous()) {
918 3 : auto indexed_range = _indexed.contiguous_range();
919 3 : auto index_range = _index.contiguous_range();
920 3 : copy_functor(indexed_range.begin(), indexed_range.end(), index_range.begin(),
921 3 : index_range.end());
922 6 : } else {
923 6 : _indexed.with_canonical_range([&](auto indexed_range) {
924 6 : _index.with_canonical_range([&](auto index_range) {
925 6 : copy_functor(indexed_range.begin(), indexed_range.end(),
926 6 : index_range.begin(), index_range.end());
927 6 : });
928 6 : });
929 6 : }
930 9 : return *this;
931 9 : }
932 : };
933 : } // namespace elsa
|