Line data Source code
1 : #pragma once 2 : 3 : #include "TypeTraits.hpp" 4 : #include "functions/Conj.hpp" 5 : #include "Functions.hpp" 6 : 7 : #include <thrust/complex.h> 8 : #include <thrust/inner_product.h> 9 : #include <thrust/iterator/iterator_traits.h> 10 : 11 : namespace elsa 12 : { 13 : /// @brief Compute the dot product between two vectors 14 : /// 15 : /// Compute the sum of products of each entry in the vectors, i.e. \f$\sum_i x_i * y_i\f$. 16 : /// If any of the two vectors is complex, the dot product is conjugate linear in the first 17 : /// component and linear in the second, i.e. \f$\sum_i \bar{x}_i * y_i\f$, as is done in Eigen 18 : /// and Numpy. 19 : /// 20 : /// The return type is determined from the value types of the two iterators. If any is a complex 21 : /// type, the return type will also be a complex type. 22 : /// 23 : /// @ingroup reductions 24 : template <class InputIter1, class InputIter2, 25 : class data_t = std::common_type_t<thrust::iterator_value_t<InputIter1>, 26 : thrust::iterator_value_t<InputIter2>>> 27 : auto dot(InputIter1 xfirst, InputIter1 xlast, InputIter2 yfirst) 28 : -> std::common_type_t<thrust::iterator_value_t<InputIter1>, 29 : thrust::iterator_value_t<InputIter2>> 30 2028 : { 31 2028 : using xdata_t = thrust::iterator_value_t<InputIter1>; 32 2028 : using ydata_t = thrust::iterator_value_t<InputIter2>; 33 : 34 : // using data_t = std::common_type_t<xdata_t, ydata_t>; 35 : 36 2028 : if constexpr (is_specialization_v< 37 2028 : xdata_t, 38 2028 : thrust::complex> || is_specialization_v<ydata_t, thrust::complex>) { 39 2004 : return thrust::inner_product( 40 2004 : xfirst, xlast, yfirst, data_t(0), elsa::plus{}, 41 34733 : [] __host__ __device__(const xdata_t& x, const ydata_t& y) { 42 34733 : return elsa::fn::conj(x) * y; 43 34733 : }); 44 2004 : } else { 45 2004 : return thrust::inner_product(xfirst, xlast, yfirst, xdata_t(0)); 46 2004 : } 47 2028 : } 48 : } // namespace elsa