LCOV - code coverage report
Current view: top level - elsa/storage/reductions - DotProduct.h (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 15 15 100.0 %
Date: 2024-05-16 04:22:26 Functions: 13 13 100.0 %

          Line data    Source code
       1             : #pragma once
       2             : 
       3             : #include "TypeTraits.hpp"
       4             : #include "functions/Conj.hpp"
       5             : #include "Functions.hpp"
       6             : 
       7             : #include <thrust/complex.h>
       8             : #include <thrust/inner_product.h>
       9             : #include <thrust/iterator/iterator_traits.h>
      10             : 
      11             : namespace elsa
      12             : {
      13             :     /// @brief Compute the dot product between two vectors
      14             :     ///
      15             :     /// Compute the sum of products of each entry in the vectors, i.e. \f$\sum_i x_i * y_i\f$.
      16             :     /// If any of the two vectors is complex, the dot product is conjugate linear in the first
      17             :     /// component and linear in the second, i.e. \f$\sum_i \bar{x}_i * y_i\f$, as is done in Eigen
      18             :     /// and Numpy.
      19             :     ///
      20             :     /// The return type is determined from the value types of the two iterators. If any is a complex
      21             :     /// type, the return type will also be a complex type.
      22             :     ///
      23             :     /// @ingroup reductions
      24             :     template <class InputIter1, class InputIter2,
      25             :               class data_t = std::common_type_t<thrust::iterator_value_t<InputIter1>,
      26             :                                                 thrust::iterator_value_t<InputIter2>>>
      27             :     auto dot(InputIter1 xfirst, InputIter1 xlast, InputIter2 yfirst)
      28             :         -> std::common_type_t<thrust::iterator_value_t<InputIter1>,
      29             :                               thrust::iterator_value_t<InputIter2>>
      30        2028 :     {
      31        2028 :         using xdata_t = thrust::iterator_value_t<InputIter1>;
      32        2028 :         using ydata_t = thrust::iterator_value_t<InputIter2>;
      33             : 
      34             :         // using data_t = std::common_type_t<xdata_t, ydata_t>;
      35             : 
      36        2028 :         if constexpr (is_specialization_v<
      37        2028 :                           xdata_t,
      38        2028 :                           thrust::complex> || is_specialization_v<ydata_t, thrust::complex>) {
      39        2004 :             return thrust::inner_product(
      40        2004 :                 xfirst, xlast, yfirst, data_t(0), elsa::plus{},
      41       40032 :                 [] __host__ __device__(const xdata_t& x, const ydata_t& y) {
      42       40032 :                     return elsa::fn::conj(x) * y;
      43       40032 :                 });
      44        2004 :         } else {
      45        2004 :             return thrust::inner_product(xfirst, xlast, yfirst, xdata_t(0));
      46        2004 :         }
      47        2028 :     }
      48             : } // namespace elsa

Generated by: LCOV version 1.14