Line data Source code
1 : #pragma once 2 : 3 : #include "Functions.hpp" 4 : 5 : #include <thrust/transform.h> 6 : 7 : namespace elsa 8 : { 9 : namespace detail 10 : { 11 : template <class T> 12 : struct MulVectorScalar { 13 5790 : MulVectorScalar(T scalar) : scalar_(scalar) {} 14 : 15 : template <class data_t> 16 : __host__ __device__ auto operator()(const data_t& x) -> std::common_type_t<T, data_t> 17 1295857 : { 18 1295857 : using U = std::common_type_t<T, data_t>; 19 1295857 : return static_cast<U>(x) * static_cast<U>(scalar_); 20 1295857 : } 21 : 22 : T scalar_; 23 : }; 24 : 25 : template <class T> 26 : struct MulScalarVector { 27 : MulScalarVector(T scalar) : scalar_(scalar) {} 28 : 29 : template <class data_t> 30 : __host__ __device__ auto operator()(const data_t& x) -> std::common_type_t<T, data_t> 31 : { 32 : using U = std::common_type_t<T, data_t>; 33 : return static_cast<U>(scalar_) * static_cast<U>(x); 34 : } 35 : 36 : T scalar_; 37 : }; 38 : } // namespace detail 39 : 40 : /// @brief Compute the component wise multiplication of two vectors 41 : /// @ingroup transforms 42 : template <class InputIter1, class InputIter2, class OutIter> 43 : void mul(InputIter1 xfirst, InputIter1 xlast, InputIter2 yfirst, OutIter out) 44 10246 : { 45 10246 : thrust::transform(xfirst, xlast, yfirst, out, elsa::multiplies{}); 46 10246 : } 47 : 48 : /// @brief Compute the component wise multiplies of a vectors and a scalar 49 : /// @ingroup transforms 50 : template <class data_t, class InputIter, class OutIter> 51 : void mulScalar(InputIter first, InputIter last, const data_t& scalar, OutIter out) 52 5790 : { 53 : // TODO: Find out why a lambda doesn't work here! 54 5790 : thrust::transform(first, last, out, detail::MulVectorScalar(scalar)); 55 5790 : } 56 : 57 : /// @brief Compute the component wise multiplication of a scalar and a vector 58 : /// @ingroup transforms 59 : template <class data_t, class InputIter, class OutIter> 60 : void mulScalar(const data_t& scalar, InputIter first, InputIter last, OutIter out) 61 : { 62 : // TODO: Find out why a lambda doesn't work here! 63 : thrust::transform(first, last, out, detail::MulScalarVector(scalar)); 64 : } 65 : } // namespace elsa