Line data Source code
1 : #pragma once
2 :
3 : #include "elsaDefines.h"
4 : #include "DataDescriptor.h"
5 : #include "DataHandler.h"
6 : #include "DataHandlerCPU.h"
7 : #include "DataHandlerMapCPU.h"
8 : #include "DataContainerIterator.h"
9 : #include "Error.h"
10 : #include "Expression.h"
11 : #include "FormatConfig.h"
12 : #include "TypeCasts.hpp"
13 :
14 : #ifdef ELSA_CUDA_VECTOR
15 : #include "DataHandlerGPU.h"
16 : #include "DataHandlerMapGPU.h"
17 : #endif
18 :
19 : #include <memory>
20 : #include <type_traits>
21 :
22 : namespace elsa
23 : {
24 : /**
25 : * @brief class representing and storing a linearized n-dimensional signal
26 : *
27 : * This class provides a container for a signal that is stored in memory. This signal can
28 : * be n-dimensional, and will be stored in memory in a linearized fashion. The information
29 : * on how this linearization is performed is provided by an associated DataDescriptor.
30 : *
31 : * @tparam data_t data type that is stored in the DataContainer, defaulting to real_t.
32 : *
33 : * @author
34 : * - Matthias Wieczorek - initial code
35 : * - Tobias Lasser - rewrite, modularization, modernization
36 : * - David Frank - added DataHandler concept, iterators
37 : * - Nikola Dinev - add block support
38 : * - Jens Petit - expression templates
39 : * - Jonas Jelten - various enhancements, fft, complex handling, pretty formatting
40 : */
41 : template <typename data_t>
42 : class DataContainer
43 : {
44 : public:
45 : /// Scalar alias
46 : using Scalar = data_t;
47 :
48 : /// delete default constructor (without metadata there can be no valid container)
49 : DataContainer() = delete;
50 :
51 : /**
52 : * @brief Constructor for empty DataContainer, no initialisation is performed,
53 : * but the underlying space is allocated.
54 : *
55 : * @param[in] dataDescriptor containing the associated metadata
56 : * @param[in] handlerType the data handler (default: CPU)
57 : */
58 : explicit DataContainer(const DataDescriptor& dataDescriptor,
59 : DataHandlerType handlerType = defaultHandlerType);
60 :
61 : /**
62 : * @brief Constructor for DataContainer, initializing it with a DataVector
63 : *
64 : * @param[in] dataDescriptor containing the associated metadata
65 : * @param[in] data vector containing the initialization data
66 : * @param[in] handlerType the data handler (default: CPU)
67 : */
68 : DataContainer(const DataDescriptor& dataDescriptor,
69 : const Eigen::Matrix<data_t, Eigen::Dynamic, 1>& data,
70 : DataHandlerType handlerType = defaultHandlerType);
71 :
72 : /**
73 : * @brief Copy constructor for DataContainer
74 : *
75 : * @param[in] other DataContainer to copy
76 : */
77 : DataContainer(const DataContainer<data_t>& other);
78 :
79 : /**
80 : * @brief copy assignment for DataContainer
81 : *
82 : * @param[in] other DataContainer to copy
83 : *
84 : * Note that a copy assignment with a DataContainer on a different device (CPU vs GPU) will
85 : * result in an "infectious" copy which means that afterwards the current container will use
86 : * the same device as "other".
87 : */
88 : DataContainer<data_t>& operator=(const DataContainer<data_t>& other);
89 :
90 : /**
91 : * @brief Move constructor for DataContainer
92 : *
93 : * @param[in] other DataContainer to move from
94 : *
95 : * The moved-from objects remains in a valid state. However, as preconditions are not
96 : * fulfilled for any member functions, the object should not be used. After move- or copy-
97 : * assignment, this is possible again.
98 : */
99 : DataContainer(DataContainer<data_t>&& other) noexcept;
100 :
101 : /**
102 : * @brief Move assignment for DataContainer
103 : *
104 : * @param[in] other DataContainer to move from
105 : *
106 : * The moved-from objects remains in a valid state. However, as preconditions are not
107 : * fulfilled for any member functions, the object should not be used. After move- or copy-
108 : * assignment, this is possible again.
109 : *
110 : * Note that a copy assignment with a DataContainer on a different device (CPU vs GPU) will
111 : * result in an "infectious" copy which means that afterwards the current container will use
112 : * the same device as "other".
113 : */
114 : DataContainer<data_t>& operator=(DataContainer<data_t>&& other);
115 :
116 : /**
117 : * @brief Expression evaluation assignment for DataContainer
118 : *
119 : * @param[in] source expression to evaluate
120 : *
121 : * This evaluates an expression template term into the underlying data member of
122 : * the DataHandler in use.
123 : */
124 : template <typename Source, typename = std::enable_if_t<isExpression<Source>>>
125 : DataContainer<data_t>& operator=(Source const& source)
126 67836 : {
127 67836 : if (auto handler = downcast_safe<DataHandlerCPU<data_t>>(_dataHandler.get())) {
128 67796 : handler->accessData() = source.template eval<false>();
129 67796 : } else if (auto handler =
130 40 : downcast_safe<DataHandlerMapCPU<data_t>>(_dataHandler.get())) {
131 40 : handler->accessData() = source.template eval<false>();
132 : #ifdef ELSA_CUDA_VECTOR
133 : } else if (auto handler = downcast_safe<DataHandlerGPU<data_t>>(_dataHandler.get())) {
134 : handler->accessData().eval(source.template eval<true>());
135 : } else if (auto handler =
136 : downcast_safe<DataHandlerMapGPU<data_t>>(_dataHandler.get())) {
137 : handler->accessData().eval(source.template eval<true>());
138 : #endif
139 40 : } else {
140 0 : throw LogicError("Unknown handler type");
141 0 : }
142 :
143 67836 : return *this;
144 67836 : }
145 :
146 : /**
147 : * @brief Expression constructor
148 : *
149 : * @param[in] source expression to evaluate
150 : *
151 : * It creates a new DataContainer out of an expression. For this the meta information which
152 : * is saved in the expression is used.
153 : */
154 : template <typename Source, typename = std::enable_if_t<isExpression<Source>>>
155 : DataContainer<data_t>(Source const& source)
156 : : DataContainer<data_t>(source.getDataMetaInfo().first, source.getDataMetaInfo().second)
157 10972 : {
158 10972 : this->operator=(source);
159 10972 : }
160 :
161 : /// return the current DataDescriptor
162 : const DataDescriptor& getDataDescriptor() const;
163 :
164 : /// return the size of the stored data (i.e. the number of elements in the linearized
165 : /// signal)
166 : index_t getSize() const;
167 :
168 : /// return the index-th element of linearized signal (not bounds-checked!)
169 : data_t& operator[](index_t index);
170 :
171 : /// return the index-th element of the linearized signal as read-only (not bounds-checked!)
172 : const data_t& operator[](index_t index) const;
173 :
174 : /// return an element by n-dimensional coordinate (not bounds-checked!)
175 : data_t& operator()(const IndexVector_t& coordinate);
176 :
177 : /// return an element by n-dimensional coordinate as read-only (not bounds-checked!)
178 : const data_t& operator()(const IndexVector_t& coordinate) const;
179 :
180 : data_t at(const IndexVector_t& coordinate) const;
181 :
182 : /// return an element by its coordinates (not bounds-checked!)
183 : template <typename idx0_t, typename... idx_t,
184 : typename = std::enable_if_t<
185 : std::is_integral_v<idx0_t> && (... && std::is_integral_v<idx_t>)>>
186 : data_t& operator()(idx0_t idx0, idx_t... indices)
187 695738 : {
188 695738 : IndexVector_t coordinate(sizeof...(indices) + 1);
189 695738 : ((coordinate << idx0), ..., indices);
190 695738 : return operator()(coordinate);
191 695738 : }
192 :
193 : /// return an element by its coordinates as read-only (not bounds-checked!)
194 : template <typename idx0_t, typename... idx_t,
195 : typename = std::enable_if_t<
196 : std::is_integral_v<idx0_t> && (... && std::is_integral_v<idx_t>)>>
197 : const data_t& operator()(idx0_t idx0, idx_t... indices) const
198 : {
199 : IndexVector_t coordinate(sizeof...(indices) + 1);
200 : ((coordinate << idx0), ..., indices);
201 : return operator()(coordinate);
202 : }
203 :
204 : /// return the dot product of this signal with the one from container other
205 : data_t dot(const DataContainer<data_t>& other) const;
206 :
207 : /// return the dot product of this signal with the one from an expression
208 : template <typename Source, typename = std::enable_if_t<isExpression<Source>>>
209 : data_t dot(const Source& source) const
210 458 : {
211 458 : if (auto handler = downcast_safe<DataHandlerCPU<data_t>>(_dataHandler.get())) {
212 458 : return (*this * source).template eval<false>().sum();
213 458 : } else if (auto handler =
214 0 : downcast_safe<DataHandlerMapCPU<data_t>>(_dataHandler.get())) {
215 0 : return (*this * source).template eval<false>().sum();
216 : #ifdef ELSA_CUDA_VECTOR
217 : } else if (auto handler = downcast_safe<DataHandlerGPU<data_t>>(_dataHandler.get())) {
218 : DataContainer temp = (*this * source);
219 : return temp.sum();
220 : } else if (auto handler =
221 : downcast_safe<DataHandlerMapGPU<data_t>>(_dataHandler.get())) {
222 : DataContainer temp = (*this * source);
223 : return temp.sum();
224 : #endif
225 0 : } else {
226 0 : throw LogicError("Unknown handler type");
227 0 : }
228 458 : }
229 :
230 : /// return the squared l2 norm of this signal (dot product with itself)
231 : GetFloatingPointType_t<data_t> squaredL2Norm() const;
232 :
233 : /// return the l2 norm of this signal (square root of dot product with itself)
234 : GetFloatingPointType_t<data_t> l2Norm() const;
235 :
236 : /// return the l0 pseudo-norm of this signal (number of non-zero values)
237 : index_t l0PseudoNorm() const;
238 :
239 : /// return the l1 norm of this signal (sum of absolute values)
240 : GetFloatingPointType_t<data_t> l1Norm() const;
241 :
242 : /// return the linf norm of this signal (maximum of absolute values)
243 : GetFloatingPointType_t<data_t> lInfNorm() const;
244 :
245 : /// return the sum of all elements of this signal
246 : data_t sum() const;
247 :
248 : /// return the min of all elements of this signal
249 : data_t minElement() const;
250 :
251 : /// return the max of all elements of this signal
252 : data_t maxElement() const;
253 :
254 : /// convert to the fourier transformed signal
255 : void fft(FFTNorm norm) const;
256 :
257 : /// convert to the inverse fourier transformed signal
258 : void ifft(FFTNorm norm) const;
259 :
260 : /// if the datacontainer is already complex, return itself.
261 : template <typename _data_t = data_t>
262 : typename std::enable_if_t<isComplex<_data_t>, DataContainer<_data_t>> asComplex() const
263 124 : {
264 124 : return *this;
265 124 : }
266 :
267 : /// if the datacontainer is not complex,
268 : /// return a copy and fill in 0 as imaginary values
269 : template <typename _data_t = data_t>
270 : typename std::enable_if_t<not isComplex<_data_t>, DataContainer<complex<_data_t>>>
271 : asComplex() const
272 248 : {
273 248 : DataContainer<complex<data_t>> ret{
274 248 : *this->_dataDescriptor,
275 248 : this->_dataHandlerType,
276 248 : };
277 :
278 : // extend with complex zero value
279 250168 : for (index_t idx = 0; idx < this->getSize(); ++idx) {
280 249920 : ret[idx] = complex<data_t>{(*this)[idx], 0};
281 249920 : }
282 :
283 248 : return ret;
284 248 : }
285 :
286 : /// compute in-place element-wise addition of another container
287 : DataContainer<data_t>& operator+=(const DataContainer<data_t>& dc);
288 :
289 : /// compute in-place element-wise addition with another expression
290 : template <typename Source, typename = std::enable_if_t<isExpression<Source>>>
291 : DataContainer<data_t>& operator+=(Source const& source)
292 599 : {
293 599 : *this = *this + source;
294 599 : return *this;
295 599 : }
296 :
297 : /// compute in-place element-wise subtraction of another container
298 : DataContainer<data_t>& operator-=(const DataContainer<data_t>& dc);
299 :
300 : /// compute in-place element-wise subtraction with another expression
301 : template <typename Source, typename = std::enable_if_t<isExpression<Source>>>
302 : DataContainer<data_t>& operator-=(Source const& source)
303 452 : {
304 452 : *this = *this - source;
305 452 : return *this;
306 452 : }
307 :
308 : /// compute in-place element-wise multiplication with another container
309 : DataContainer<data_t>& operator*=(const DataContainer<data_t>& dc);
310 :
311 : /// compute in-place element-wise multiplication with another expression
312 : template <typename Source, typename = std::enable_if_t<isExpression<Source>>>
313 : DataContainer<data_t>& operator*=(Source const& source)
314 2 : {
315 2 : *this = *this * source;
316 2 : return *this;
317 2 : }
318 :
319 : /// compute in-place element-wise division by another container
320 : DataContainer<data_t>& operator/=(const DataContainer<data_t>& dc);
321 :
322 : /// compute in-place element-wise division with another expression
323 : template <typename Source, typename = std::enable_if_t<isExpression<Source>>>
324 : DataContainer<data_t>& operator/=(Source const& source)
325 2 : {
326 2 : *this = *this / source;
327 2 : return *this;
328 2 : }
329 :
330 : /// compute in-place addition of a scalar
331 : DataContainer<data_t>& operator+=(data_t scalar);
332 :
333 : /// compute in-place subtraction of a scalar
334 : DataContainer<data_t>& operator-=(data_t scalar);
335 :
336 : /// compute in-place multiplication with a scalar
337 : DataContainer<data_t>& operator*=(data_t scalar);
338 :
339 : /// compute in-place division by a scalar
340 : DataContainer<data_t>& operator/=(data_t scalar);
341 :
342 : /// assign a scalar to the DataContainer
343 : DataContainer<data_t>& operator=(data_t scalar);
344 :
345 : /// comparison with another DataContainer
346 : bool operator==(const DataContainer<data_t>& other) const;
347 :
348 : /// comparison with another DataContainer
349 : bool operator!=(const DataContainer<data_t>& other) const;
350 :
351 : /// returns a reference to the i-th block, wrapped in a DataContainer
352 : DataContainer<data_t> getBlock(index_t i);
353 :
354 : /// returns a const reference to the i-th block, wrapped in a DataContainer
355 : const DataContainer<data_t> getBlock(index_t i) const;
356 :
357 : /// return a view of this DataContainer with a different descriptor
358 : DataContainer<data_t> viewAs(const DataDescriptor& dataDescriptor);
359 :
360 : /// return a const view of this DataContainer with a different descriptor
361 : const DataContainer<data_t> viewAs(const DataDescriptor& dataDescriptor) const;
362 :
363 : /// @brief Slice the container in the last dimension
364 : ///
365 : /// Access a portion of the container via a slice. The slicing always is in the last
366 : /// dimension. So for a 3D volume, the slice would be an sliced in the z direction and would
367 : /// be a part of the x-y plane.
368 : ///
369 : /// A slice is always the same dimension as the original DataContainer, but with a thickness
370 : /// of 1 in the last dimension (i.e. the coefficient of the last dimension is 1)
371 : const DataContainer<data_t> slice(index_t i) const;
372 :
373 : /// @brief Slice the container in the last dimension, non-const overload
374 : ///
375 : /// @overload
376 : /// @see slice(index_t) const
377 : DataContainer<data_t> slice(index_t i);
378 :
379 : /// iterator for DataContainer (random access and continuous)
380 : using iterator = DataContainerIterator<DataContainer<data_t>>;
381 :
382 : /// const iterator for DataContainer (random access and continuous)
383 : using const_iterator = ConstDataContainerIterator<DataContainer<data_t>>;
384 :
385 : /// alias for reverse iterator
386 : using reverse_iterator = std::reverse_iterator<iterator>;
387 : /// alias for const reverse iterator
388 : using const_reverse_iterator = std::reverse_iterator<const_iterator>;
389 :
390 : /// returns iterator to the first element of the container
391 : iterator begin();
392 :
393 : /// returns const iterator to the first element of the container (cannot mutate data)
394 : const_iterator begin() const;
395 :
396 : /// returns const iterator to the first element of the container (cannot mutate data)
397 : const_iterator cbegin() const;
398 :
399 : /// returns iterator to one past the last element of the container
400 : iterator end();
401 :
402 : /// returns const iterator to one past the last element of the container (cannot mutate
403 : /// data)
404 : const_iterator end() const;
405 :
406 : /// returns const iterator to one past the last element of the container (cannot mutate
407 : /// data)
408 : const_iterator cend() const;
409 :
410 : /// returns reversed iterator to the last element of the container
411 : reverse_iterator rbegin();
412 :
413 : /// returns const reversed iterator to the last element of the container (cannot mutate
414 : /// data)
415 : const_reverse_iterator rbegin() const;
416 :
417 : /// returns const reversed iterator to the last element of the container (cannot mutate
418 : /// data)
419 : const_reverse_iterator crbegin() const;
420 :
421 : /// returns reversed iterator to one past the first element of container
422 : reverse_iterator rend();
423 :
424 : /// returns const reversed iterator to one past the first element of container (cannot
425 : /// mutate data)
426 : const_reverse_iterator rend() const;
427 :
428 : /// returns const reversed iterator to one past the first element of container (cannot
429 : /// mutate data)
430 : const_reverse_iterator crend() const;
431 :
432 : /// value_type of the DataContainer elements for iterators
433 : using value_type = data_t;
434 : /// pointer type of DataContainer elements for iterators
435 : using pointer = data_t*;
436 : /// const pointer type of DataContainer elements for iterators
437 : using const_pointer = const data_t*;
438 : /// reference type of DataContainer elements for iterators
439 : using reference = data_t&;
440 : /// const reference type of DataContainer elements for iterators
441 : using const_reference = const data_t&;
442 : /// difference type for iterators
443 : using difference_type = std::ptrdiff_t;
444 :
445 : /// returns the type of the DataHandler in use
446 : DataHandlerType getDataHandlerType() const;
447 :
448 : /// friend constexpr function to implement expression templates
449 : template <bool GPU, class Operand, std::enable_if_t<isDataContainer<Operand>, int>>
450 : friend constexpr auto evaluateOrReturn(Operand const& operand);
451 :
452 : /// write a pretty-formatted string representation to stream
453 : void format(std::ostream& os, format_config cfg = {}) const;
454 :
455 : /**
456 : * @brief Factory function which returns GPU based DataContainers
457 : *
458 : * @return the GPU based DataContainer
459 : *
460 : * Note that if this function is called on a container which is already GPU based, it
461 : * will throw an exception.
462 : */
463 : DataContainer loadToGPU();
464 :
465 : /**
466 : * @brief Factory function which returns CPU based DataContainers
467 : *
468 : * @return the CPU based DataContainer
469 : *
470 : * Note that if this function is called on a container which is already CPU based, it will
471 : * throw an exception.
472 : */
473 : DataContainer loadToCPU();
474 :
475 : private:
476 : /// the current DataDescriptor
477 : std::unique_ptr<DataDescriptor> _dataDescriptor;
478 :
479 : /// the current DataHandler
480 : std::unique_ptr<DataHandler<data_t>> _dataHandler;
481 :
482 : /// the current DataHandlerType
483 : DataHandlerType _dataHandlerType;
484 :
485 : /// factory method to create DataHandlers based on handlerType with perfect forwarding of
486 : /// constructor arguments
487 : template <typename... Args>
488 : std::unique_ptr<DataHandler<data_t>> createDataHandler(DataHandlerType handlerType,
489 : Args&&... args);
490 :
491 : /// private constructor accepting a DataDescriptor and a DataHandler
492 : explicit DataContainer(const DataDescriptor& dataDescriptor,
493 : std::unique_ptr<DataHandler<data_t>> dataHandler,
494 : DataHandlerType dataType = defaultHandlerType);
495 :
496 : /**
497 : * @brief Helper function to indicate if a regular assignment or a clone should be performed
498 : *
499 : * @param[in] handlerType the member variable of the other container in
500 : * copy-/move-assignment
501 : *
502 : * @return true if a regular assignment of the pointed to DataHandlers should be done
503 : *
504 : * An assignment operation with a DataContainer which does not use the same device (CPU /
505 : * GPU) has to be handled differently. This helper function indicates if a regular
506 : * assignment should be performed or not.
507 : */
508 : bool canAssign(DataHandlerType handlerType);
509 : };
510 :
511 : /// pretty output formatting.
512 : /// for configurable output, use `DataContainerFormatter` directly.
513 : template <typename T>
514 : std::ostream& operator<<(std::ostream& os, const elsa::DataContainer<T>& dc)
515 0 : {
516 0 : dc.format(os);
517 0 : return os;
518 0 : }
519 :
520 : /// clip the container values outside of the interval, to the interval edges
521 : template <typename data_t>
522 : DataContainer<data_t> clip(DataContainer<data_t> dc, data_t min, data_t max);
523 :
524 : /// Concatenate two DataContainers to one (requires copying of both)
525 : template <typename data_t>
526 : DataContainer<data_t> concatenate(const DataContainer<data_t>& dc1,
527 : const DataContainer<data_t>& dc2);
528 :
529 : /// Perform the FFT shift operation to the provided signal. Refer to
530 : /// https://numpy.org/doc/stable/reference/generated/numpy.fft.fftshift.html for further
531 : /// details.
532 : template <typename data_t>
533 : DataContainer<data_t> fftShift2D(DataContainer<data_t> dc);
534 :
535 : /// Perform the IFFT shift operation to the provided signal. Refer to
536 : /// https://numpy.org/doc/stable/reference/generated/numpy.fft.ifftshift.html for further
537 : /// details.
538 : template <typename data_t>
539 : DataContainer<data_t> ifftShift2D(DataContainer<data_t> dc);
540 :
541 : /// User-defined template argument deduction guide for the expression based constructor
542 : template <typename Source>
543 : DataContainer(Source const& source) -> DataContainer<typename Source::data_t>;
544 :
545 : /// Collects callable lambdas for later dispatch
546 : template <typename... Ts>
547 : struct Callables : Ts... {
548 : using Ts::operator()...;
549 : };
550 :
551 : /// Class template deduction guide
552 : template <typename... Ts>
553 : Callables(Ts...) -> Callables<Ts...>;
554 :
555 : /// Multiplying two operands (including scalars)
556 : template <typename LHS, typename RHS, typename = std::enable_if_t<isBinaryOpOk<LHS, RHS>>>
557 : auto operator*(LHS const& lhs, RHS const& rhs)
558 61229 : {
559 61229 : auto multiplicationGPU = [](auto const& left, auto const& right, bool /**/) {
560 61229 : return left * right;
561 61229 : };
562 :
563 61229 : if constexpr (isDcOrExpr<LHS> && isDcOrExpr<RHS>) {
564 33979 : auto multiplication = [](auto const& left, auto const& right) {
565 33979 : return (left.array() * right.array()).matrix();
566 33979 : };
567 33983 : return Expression{Callables{multiplication, multiplicationGPU}, lhs, rhs};
568 33983 : } else if constexpr (isArithmetic<LHS>) {
569 27158 : auto multiplication = [](auto const& left, auto const& right) {
570 27158 : return (left * right.array()).matrix();
571 27158 : };
572 27158 : return Expression{Callables{multiplication, multiplicationGPU}, lhs, rhs};
573 27158 : } else if constexpr (isArithmetic<RHS>) {
574 88 : auto multiplication = [](auto const& left, auto const& right) {
575 88 : return (left.array() * right).matrix();
576 88 : };
577 88 : return Expression{Callables{multiplication, multiplicationGPU}, lhs, rhs};
578 88 : } else {
579 61229 : auto multiplication = [](auto const& left, auto const& right) { return left * right; };
580 61229 : return Expression{Callables{multiplication, multiplicationGPU}, lhs, rhs};
581 61229 : }
582 61229 : }
583 :
584 : /// Adding two operands (including scalars)
585 : template <typename LHS, typename RHS, typename = std::enable_if_t<isBinaryOpOk<LHS, RHS>>>
586 : auto operator+(LHS const& lhs, RHS const& rhs)
587 12780 : {
588 12780 : auto additionGPU = [](auto const& left, auto const& right, bool /**/) {
589 12780 : return left + right;
590 12780 : };
591 :
592 12780 : if constexpr (isDcOrExpr<LHS> && isDcOrExpr<RHS>) {
593 12770 : auto addition = [](auto const& left, auto const& right) { return left + right; };
594 12770 : return Expression{Callables{addition, additionGPU}, lhs, rhs};
595 12770 : } else if constexpr (isArithmetic<LHS>) {
596 5 : auto addition = [](auto const& left, auto const& right) {
597 5 : return (left + right.array()).matrix();
598 5 : };
599 5 : return Expression{Callables{addition, additionGPU}, lhs, rhs};
600 5 : } else if constexpr (isArithmetic<RHS>) {
601 5 : auto addition = [](auto const& left, auto const& right) {
602 5 : return (left.array() + right).matrix();
603 5 : };
604 5 : return Expression{Callables{addition, additionGPU}, lhs, rhs};
605 5 : } else {
606 12780 : auto addition = [](auto const& left, auto const& right) { return left + right; };
607 12780 : return Expression{Callables{addition, additionGPU}, lhs, rhs};
608 12780 : }
609 12780 : }
610 :
611 : /// Subtracting two operands (including scalars)
612 : template <typename LHS, typename RHS, typename = std::enable_if_t<isBinaryOpOk<LHS, RHS>>>
613 : auto operator-(LHS const& lhs, RHS const& rhs)
614 28257 : {
615 28257 : auto subtractionGPU = [](auto const& left, auto const& right, bool /**/) {
616 28257 : return left - right;
617 28257 : };
618 :
619 28257 : if constexpr (isDcOrExpr<LHS> && isDcOrExpr<RHS>) {
620 28245 : auto subtraction = [](auto const& left, auto const& right) { return left - right; };
621 28247 : return Expression{Callables{subtraction, subtractionGPU}, lhs, rhs};
622 28247 : } else if constexpr (isArithmetic<LHS>) {
623 5 : auto subtraction = [](auto const& left, auto const& right) {
624 5 : return (left - right.array()).matrix();
625 5 : };
626 5 : return Expression{Callables{subtraction, subtractionGPU}, lhs, rhs};
627 5 : } else if constexpr (isArithmetic<RHS>) {
628 5 : auto subtraction = [](auto const& left, auto const& right) {
629 5 : return (left.array() - right).matrix();
630 5 : };
631 5 : return Expression{Callables{subtraction, subtractionGPU}, lhs, rhs};
632 5 : } else {
633 28257 : auto subtraction = [](auto const& left, auto const& right) { return left - right; };
634 28257 : return Expression{Callables{subtraction, subtractionGPU}, lhs, rhs};
635 28257 : }
636 28257 : }
637 :
638 : /// Dividing two operands (including scalars)
639 : template <typename LHS, typename RHS, typename = std::enable_if_t<isBinaryOpOk<LHS, RHS>>>
640 : auto operator/(LHS const& lhs, RHS const& rhs)
641 14560 : {
642 14560 : auto divisionGPU = [](auto const& left, auto const& right, bool /**/) {
643 14560 : return left / right;
644 14560 : };
645 :
646 14560 : if constexpr (isDcOrExpr<LHS> && isDcOrExpr<RHS>) {
647 14533 : auto division = [](auto const& left, auto const& right) {
648 27 : return (left.array() / right.array()).matrix();
649 27 : };
650 27 : return Expression{Callables{division, divisionGPU}, lhs, rhs};
651 14533 : } else if constexpr (isArithmetic<LHS>) {
652 14513 : auto division = [](auto const& left, auto const& right) {
653 20 : return (left / right.array()).matrix();
654 20 : };
655 20 : return Expression{Callables{division, divisionGPU}, lhs, rhs};
656 14513 : } else if constexpr (isArithmetic<RHS>) {
657 14513 : auto division = [](auto const& left, auto const& right) {
658 14513 : return (left.array() / right).matrix();
659 14513 : };
660 14513 : return Expression{Callables{division, divisionGPU}, lhs, rhs};
661 14513 : } else {
662 14560 : auto division = [](auto const& left, auto const& right) { return left / right; };
663 14560 : return Expression{Callables{division, divisionGPU}, lhs, rhs};
664 14560 : }
665 14560 : }
666 :
667 : /// Element-wise maximum value operation between two operands
668 : template <typename LHS, typename RHS, typename = std::enable_if_t<isBinaryOpOk<LHS, RHS>>>
669 : auto cwiseMax(LHS const& lhs, RHS const& rhs)
670 20 : {
671 20 : constexpr bool isLHSComplex = isComplex<GetOperandDataType_t<LHS>>;
672 20 : constexpr bool isRHSComplex = isComplex<GetOperandDataType_t<RHS>>;
673 :
674 : #ifdef ELSA_CUDA_VECTOR
675 : auto cwiseMaxGPU = [](auto const& lhs, auto const& rhs, bool) {
676 : return quickvec::cwiseMax(lhs, rhs);
677 : };
678 : #endif
679 20 : auto cwiseMax = [] {
680 20 : if constexpr (isLHSComplex && isRHSComplex) {
681 15 : return [](auto const& left, auto const& right) {
682 5 : return (left.array().abs().max(right.array().abs())).matrix();
683 5 : };
684 15 : } else if constexpr (isLHSComplex) {
685 10 : return [](auto const& left, auto const& right) {
686 5 : return (left.array().abs().max(right.array())).matrix();
687 5 : };
688 10 : } else if constexpr (isRHSComplex) {
689 5 : return [](auto const& left, auto const& right) {
690 5 : return (left.array().max(right.array().abs())).matrix();
691 5 : };
692 5 : } else {
693 5 : return [](auto const& left, auto const& right) {
694 5 : return (left.array().max(right.array())).matrix();
695 5 : };
696 5 : }
697 20 : }();
698 :
699 : #ifdef ELSA_CUDA_VECTOR
700 : return Expression{Callables{cwiseMax, cwiseMaxGPU}, lhs, rhs};
701 : #else
702 20 : return Expression{cwiseMax, lhs, rhs};
703 20 : #endif
704 20 : }
705 :
706 : /// Element-wise absolute value operation
707 : template <typename Operand, typename = std::enable_if_t<isDcOrExpr<Operand>>>
708 : auto cwiseAbs(Operand const& operand)
709 13 : {
710 13 : auto abs = [](auto const& operand) { return (operand.array().abs()).matrix(); };
711 : #ifdef ELSA_CUDA_VECTOR
712 : auto absGPU = [](auto const& operand, bool) { return quickvec::cwiseAbs(operand); };
713 : return Expression{Callables{abs, absGPU}, operand};
714 : #else
715 13 : return Expression{abs, operand};
716 13 : #endif
717 13 : }
718 :
719 : /// Element-wise square operation
720 : template <typename Operand, typename = std::enable_if_t<isDcOrExpr<Operand>>>
721 : auto square(Operand const& operand)
722 13 : {
723 13 : auto square = [](auto const& operand) { return (operand.array().square()).matrix(); };
724 : #ifdef ELSA_CUDA_VECTOR
725 : auto squareGPU = [](auto const& operand, bool /**/) { return quickvec::square(operand); };
726 : return Expression{Callables{square, squareGPU}, operand};
727 : #else
728 13 : return Expression{square, operand};
729 13 : #endif
730 13 : }
731 :
732 : /// Element-wise square-root operation
733 : template <typename Operand, typename = std::enable_if_t<isDcOrExpr<Operand>>>
734 : auto sqrt(Operand const& operand)
735 9 : {
736 9 : auto sqrt = [](auto const& operand) { return (operand.array().sqrt()).matrix(); };
737 : #ifdef ELSA_CUDA_VECTOR
738 : auto sqrtGPU = [](auto const& operand, bool /**/) { return quickvec::sqrt(operand); };
739 : return Expression{Callables{sqrt, sqrtGPU}, operand};
740 : #else
741 9 : return Expression{sqrt, operand};
742 9 : #endif
743 9 : }
744 :
745 : /// Element-wise exponenation operation with euler base
746 : template <typename Operand, typename = std::enable_if_t<isDcOrExpr<Operand>>>
747 : auto exp(Operand const& operand)
748 6 : {
749 6 : auto exp = [](auto const& operand) { return (operand.array().exp()).matrix(); };
750 : #ifdef ELSA_CUDA_VECTOR
751 : auto expGPU = [](auto const& operand, bool /**/) { return quickvec::exp(operand); };
752 : return Expression{Callables{exp, expGPU}, operand};
753 : #else
754 6 : return Expression{exp, operand};
755 6 : #endif
756 6 : }
757 :
758 : /// Element-wise natural logarithm
759 : template <typename Operand, typename = std::enable_if_t<isDcOrExpr<Operand>>>
760 : auto log(Operand const& operand)
761 7 : {
762 7 : auto log = [](auto const& operand) { return (operand.array().log()).matrix(); };
763 : #ifdef ELSA_CUDA_VECTOR
764 : auto logGPU = [](auto const& operand, bool /**/) { return quickvec::log(operand); };
765 : return Expression{Callables{log, logGPU}, operand};
766 : #else
767 7 : return Expression{log, operand};
768 7 : #endif
769 7 : }
770 :
771 : /// Element-wise real parts of the Operand
772 : template <typename Operand, typename = std::enable_if_t<isDcOrExpr<Operand>>>
773 : auto real(Operand const& operand)
774 7 : {
775 7 : auto real = [](auto const& operand) { return (operand.array().real()).matrix(); };
776 : #ifdef ELSA_CUDA_VECTOR
777 : auto realGPU = [](auto const& operand, bool) { return quickvec::real(operand); };
778 : return Expression{Callables{real, realGPU}, operand};
779 : #else
780 7 : return Expression{real, operand};
781 7 : #endif
782 7 : }
783 :
784 : /// Element-wise imaginary parts of the Operand
785 : template <typename Operand, typename = std::enable_if_t<isDcOrExpr<Operand>>>
786 : auto imag(Operand const& operand)
787 5 : {
788 5 : auto imag = [](auto const& operand) { return (operand.array().imag()).matrix(); };
789 : #ifdef ELSA_CUDA_VECTOR
790 : auto imagGPU = [](auto const& operand, bool) { return quickvec::imag(operand); };
791 : return Expression{Callables{imag, imagGPU}, operand};
792 : #else
793 5 : return Expression{imag, operand};
794 5 : #endif
795 5 : }
796 : } // namespace elsa
|