          Line data    Source code
       1             : #pragma once
       2             : 
       3             : #include "FFTPolicy.h"
       4             : 
       5             : #ifdef ELSA_CUDA_TOOLKIT_PRESENT
       6             : #include <cuda_runtime.h>
       7             : #include <cufftXt.h>
       8             : 
       9             : #include <thrust/universal_ptr.h>
      10             : #include <thrust/functional.h>
      11             : #include <thrust/transform.h>
      12             : #include <thrust/iterator/constant_iterator.h>
      13             : #include <thrust/iterator/counting_iterator.h>
      14             : #include <thrust/gather.h>
      15             : #include <thrust/execution_policy.h>
      16             : 
      17             : #include <list>
      18             : #include <unordered_map>
      19             : #include <optional>
      20             : #endif
      21             : 
      22             : #include "Complex.h"
      23             : #include "elsaDefines.h"
      24             : #include "Error.h"
      25             : #include "DataDescriptor.h"
      26             : #include "VolumeDescriptor.h"
      27             : #include "ContiguousStorage.h"
      28             : #include "DataContainer.h"
      29             : 
      30             : #if WITH_FFTW
      31             : #define EIGEN_FFTW_DEFAULT
      32             : #endif
      33             : #include <unsupported/Eigen/FFT>
      34             : 
      35             : namespace elsa
      36             : {
      37             :     namespace detail
      38             :     {
      39             :         enum class FFTType : uint8_t {
      40             :             C2C,
      41             :             R2C,
      42             :             C2R,
      43             :             INVALID,
      44             :         };
      45             : 
      46             :         template <FFTType tvalue>
      47             :         struct FFTType_t {
      48             :             static constexpr FFTType value = tvalue;
      49             :         };
      50             : 
      51             : #ifdef ELSA_CUDA_TOOLKIT_PRESENT
      52             :         template <FFTType fft_type, typename inner>
      53             :         struct FFTCuFFTType {
      54             :             /* not valid */
      55             :             static constexpr cufftType value = CUFFT_C2C;
      56             :         };
      57             : 
      58             :         template <>
      59             :         struct FFTCuFFTType<FFTType::C2C, float> {
      60             :             static constexpr cufftType value = CUFFT_C2C;
      61             :         };
      62             : 
      63             :         template <>
      64             :         struct FFTCuFFTType<FFTType::C2C, double> {
      65             :             static constexpr cufftType value = CUFFT_Z2Z;
      66             :         };
      67             : 
      68             :         template <>
      69             :         struct FFTCuFFTType<FFTType::R2C, float> {
      70             :             static constexpr cufftType value = CUFFT_R2C;
      71             :         };
      72             : 
      73             :         template <>
      74             :         struct FFTCuFFTType<FFTType::R2C, double> {
      75             :             static constexpr cufftType value = CUFFT_D2Z;
      76             :         };
      77             : 
      78             :         template <>
      79             :         struct FFTCuFFTType<FFTType::C2R, float> {
      80             :             static constexpr cufftType value = CUFFT_C2R;
      81             :         };
      82             : 
      83             :         template <>
      84             :         struct FFTCuFFTType<FFTType::C2R, double> {
      85             :             static constexpr cufftType value = CUFFT_Z2D;
      86             :         };
      87             : #endif
      88             : 
      89             :         template <typename in, typename out>
      90             :         struct FFTInfo {
      91             :             using in_t = in;
      92             :             using out_t = out;
      93             :             using real_t =
      94             :                 std::conditional_t<std::is_same_v<value_type_of_t<in>, value_type_of_t<out>>,
      95             :                                    value_type_of_t<in>, void>;
      96             :             using complex_t = std::conditional_t<is_complex_v<in>, in, out>;
      97             :             using eigen_in_t =
      98             :                 std::conditional_t<is_complex_v<in>, std::complex<value_type_of_t<in>>, in>;
      99             :             using eigen_out_t =
     100             :                 std::conditional_t<is_complex_v<out>, std::complex<value_type_of_t<out>>, out>;
     101             :             static constexpr FFTType kind =
     102             :                                          std::conditional_t < is_complex_v<in> && is_complex_v<out>,
     103             :                                      FFTType_t<FFTType::C2C>, std::conditional_t < is_complex_v<in>,
     104             :                                      FFTType_t<FFTType::C2R>,
     105             :                                      std::conditional_t < is_complex_v<out>,
     106             :                                      FFTType_t<FFTType::R2C>,
     107             :                                      FFTType_t < FFTType::INVALID >>>> ::value;
     108             :             static constexpr bool valid = !std::is_same_v<real_t, void> && kind != FFTType::INVALID;
     109             : #ifdef ELSA_CUDA_TOOLKIT_PRESENT
     110             :             /* cuFFT can only work with single precision and double precision floats; Since pointers
     111             :              * have to be cast to the corresponding cufft type pointers, its important to make sure
     112             :              * the conversions are valid. (Eigen's FFT is templated, so I'd hope that it complains
     113             :              * when it is unhappy with its input types)
     114             :              */
     115             :             static constexpr bool cufft_valid =
     116             :                 valid && (std::is_same_v<real_t, float> || std::is_same_v<real_t, double>);
     117             :             static constexpr cufftType cufft_type = FFTCuFFTType<kind, real_t>::value;
     118             : #endif
     119             :         };
     120             : 
     121             :         /// @return [out_shape, normalization_constant]
     122             :         template <FFTType type>
     123             :         std::pair<IndexVector_t, index_t>
     124             :             computeOutputShapeAndNormalization(const IndexVector_t& src_shape)
     125          21 :         {
     126          21 :             static_assert(type == FFTType::C2C || type == FFTType::C2R || type == FFTType::R2C,
     127          21 :                           "Invalid fft type!");
     128          21 :             if constexpr (type == FFTType::C2C) {
     129          21 :                 return std::make_pair(src_shape,;
     130          21 :             } else if constexpr (type == FFTType::C2R) {
     131           9 :                 index_t last_dim = src_shape.size() - 1;
     132           9 :                 IndexVector_t dst_shape(src_shape.size());
     133           9 :                 dst_shape << src_shape.head(last_dim), (src_shape(last_dim) - 1) * 2;
     134           9 :                 return std::make_pair(dst_shape,;
     135           9 :             } else if constexpr (type == FFTType::R2C) {
     136           9 :                 index_t last_dim = src_shape.size() - 1;
     137           9 :                 IndexVector_t dst_shape(src_shape.size());
     138           9 :                 dst_shape << src_shape.head(last_dim), src_shape(last_dim) / 2 + 1;
     139           9 :                 return std::make_pair(dst_shape,;
     140           9 :             }
     141          21 :         }
     142             : 
     143             : #ifdef ELSA_CUDA_TOOLKIT_PRESENT
     144             : 
     145             :         /* Create a cufft handle and plan an FFT for the given type and dimensions.
     146             :          * This handle is to be freed via cufftDestroy(...)
     147             :          * Returns a result indicating whether creating the plan was successful. If not,
     148             :          * the plan is not initialized and does not need to be freed.
     149             :          */
     150             :         cufftResult createPlan(cufftHandle* plan, cufftType type, const IndexVector_t& shape);
     151             : 
     152             :         /* Normalize the result after an fft with the factor size or sqrt(size), depending on
     153             :            the flag applySqrt. */
     154             :         template <typename data_t>
     155             :         void fftNormalize(data_t* ptr, index_t size, index_t normalization_constant, bool applySqrt)
     156             :         {
     157             :             data_t normalizing_factor = static_cast<data_t>(
     158             :                 applySqrt ? std::sqrt(normalization_constant) : normalization_constant);
     159             :             thrust::transform(thrust::device, ptr, ptr + size,
     160             :                               thrust::make_constant_iterator(normalizing_factor), ptr,
     161             :                               thrust::divides<data_t>());
     162             :         }
     163             : 
     164             :         /* Assuming row-major layout! */
     165             :         struct TransposeIndexFunctor3D : public thrust::unary_function<size_t, size_t> {
     166             :             size_t rows, columns, depth;
     167             : 
     168             :             __host__ __device__ TransposeIndexFunctor3D(size_t rows, size_t columns, size_t depth)
     169             :                 : rows(rows), columns(columns), depth(depth)
     170             :             {
     171             :             }
     172             : 
     173             :             __host__ __device__ size_t operator()(size_t linear_index)
     174             :             {
     175             :                 size_t src_row = linear_index / (columns * depth);
     176             :                 size_t remainder = linear_index % (columns * depth);
     177             :                 size_t src_column = remainder / depth;
     178             :                 size_t src_depth = remainder % depth;
     179             : 
     180             :                 return src_depth * (rows * columns) + src_column * rows + src_row;
     181             :             }
     182             :         };
     183             : 
     184             :         /* src and dst must not alias! */
     185             :         template <typename data_t>
     186             :         bool transposeDevice(const data_t* src, data_t* dst, const IndexVector_t& src_shape)
     187             :         {
     188             : 
     189             :             switch (src_shape.size()) {
     190             :                 case 1:
     191             :                     /* no need to transpose, but must copy from src to dst */
     192             :                     thrust::copy(thrust::device, src, src +, dst);
     193             :                     return true;
     194             :                 case 2: {
     195             :                     TransposeIndexFunctor3D functor(src_shape(0), src_shape(1), 1);
     196             :                     thrust::counting_iterator<size_t> indices(0);
     197             :                     auto transpose_index_iterator =
     198             :                         thrust::make_transform_iterator(indices, functor);
     199             :                     thrust::gather(thrust::device, transpose_index_iterator,
     200             :                                    transpose_index_iterator +, src, dst);
     201             :                     return true;
     202             :                 }
     203             :                 case 3: {
     204             :                     TransposeIndexFunctor3D functor(src_shape(0), src_shape(1), src_shape(2));
     205             :                     thrust::counting_iterator<size_t> indices(0);
     206             :                     auto transpose_index_iterator =
     207             :                         thrust::make_transform_iterator(indices, functor);
     208             :                     thrust::gather(thrust::device, transpose_index_iterator,
     209             :                                    transpose_index_iterator +, src, dst);
     210             :                     return true;
     211             :                 }
     212             :                 default:
     213             :                     return false;
     214             :             }
     215             :         }
     216             : 
     217             :         /* Cache is not thread safe, also plans should not be used across multiple threads at the
     218             :          * same time! Hence, we use a thread local instance. Potential optimizations, should the
     219             :          * caches GPU memory consumption ever be a problem: Aside from disabling the caching
     220             :          * mechanism, there are two potential optimizations.
     221             :          *  - Currently, when no new plan can be allocated, the cache is flushed. This could be
     222             :          * extended also flush the caches of all other threads.
     223             :          *  - cuFFT allows us to manage the work area memory. Before planning,
     224             :          * cufftSetAutoAllocation(false) could be called, then the work area can be explicitely set.
     225             :          * This way, all elements of the cache could share a work area, fitting the requirements of
     226             :          * the largest cached plan. This would increase the management overhead, but reduce memory
     227             :          * consumption. Currently, the cache has to few elements for this to be worth it.
     228             :          */
     229             :         class CuFFTPlanCache
     230             :         {
     231             :         private:
     232             :             using CacheElement = std::tuple<cufftHandle, IndexVector_t, cufftType>;
     233             :             using CacheList = std::list<CacheElement>;
     234             :             CacheList _cache;
     235             :             /* this should be very low, to conserve GPU memory!
     236             :                Initialized via ELSA_CUFFT_CACHE_SIZE */
     237             :             size_t _limit;
     238             : 
     239             :             void flush();
     240             :             void evict();
     241             : 
     242             :         public:
     243             :             CuFFTPlanCache();
     244             : 
     245             :             CuFFTPlanCache(const CuFFTPlanCache& other) = delete;
     246             :             CuFFTPlanCache& operator=(const CuFFTPlanCache& other) = delete;
     247             :             /* This performs linear search, because that should be less overhead than a
     248             :                map lookup for very small cache sizes*/
     249             :             std::optional<cufftHandle> get(cufftType type, const IndexVector_t& shape);
     250             :         };
     251             : 
     252             :         extern thread_local CuFFTPlanCache cufftCache;
     253             : 
     254             :         /* ATTENTION! Expects data to be layed out in ROW MAJOR order!
     255             :          *
     256             :          *             | in_data_t       | out_data_t      |
     257             :          *        -----+-----------------+-----------------+
     258             :          *         C2C | complex<float>  | complex<float>  |
     259             :          *         Z2Z | complex<double> | complex<double> |
     260             :          *         C2R | complex<float>  | float           |
     261             :          *         Z2D | complex<double> | double          |
     262             :          *         R2C | float           | complex<float>  |
     263             :          *         D2Z | double          | complex<float>  |
     264             :          *
     265             :          */
     266             :         template <bool is_forward, typename in_data_t, typename out_data_t>
     267             :         bool doFftDevice(in_data_t* in_data, out_data_t* out_data, const IndexVector_t& shape,
     268             :                          const IndexVector_t& out_shape, FFTNorm norm,
     269             :                          index_t normalization_constant)
     270             :         {
     271             :             /* According to this example:
     272             :              *
     273             :              * it is fine to reinterpret_cast std::complex to cufftComplex.
     274             :              * The same applies to thrust::complex, which is what elsa::complex maps to.
     275             :              */
     276             : 
     277             :             using info = FFTInfo<in_data_t, out_data_t>;
     278             :             if constexpr (!info::cufft_valid) {
     279             :                 static_cast<void>(in_data);
     280             :                 static_cast<void>(out_data);
     281             :                 static_cast<void>(shape);
     282             :                 static_cast<void>(out_shape);
     283             :                 static_cast<void>(norm);
     284             :                 static_cast<void>(normalization_constant);
     285             :                 return false;
     286             :             } else {
     287             :                 constexpr cufftType type = info::cufft_type;
     288             : 
     289             :                 cufftHandle plan;
     290             : #if ELSA_CUFFT_CACHE_SIZE != 0
     291             :                 std::optional<cufftHandle> planOpt = cufftCache.get(type, shape);
     292             :                 if (planOpt) {
     293             :                     plan = *planOpt;
     294             :                 } else {
     295             :                     return false;
     296             :                 }
     297             : #else
     298             :                 if (createPlan(&plan, type, shape) != CUFFT_SUCCESS) {
     299             :                     return false;
     300             :                 }
     301             : #endif
     302             : 
     303             :                 bool success;
     304             :                 if constexpr (type == CUFFT_C2C) {
     305             :                     int direction = is_forward ? CUFFT_FORWARD : CUFFT_INVERSE;
     306             :                     /* cuFFT can handle in-place transforms */
     307             :                     success = cufftExecC2C(plan, reinterpret_cast<cufftComplex*>(in_data),
     308             :                                            reinterpret_cast<cufftComplex*>(out_data), direction)
     309             :                               == CUFFT_SUCCESS;
     310             :                 } else if constexpr (type == CUFFT_Z2Z) {
     311             :                     int direction = is_forward ? CUFFT_FORWARD : CUFFT_INVERSE;
     312             :                     success =
     313             :                         cufftExecZ2Z(plan, reinterpret_cast<cufftDoubleComplex*>(in_data),
     314             :                                      reinterpret_cast<cufftDoubleComplex*>(out_data), direction)
     315             :                         == CUFFT_SUCCESS;
     316             :                 } else if constexpr (type == CUFFT_R2C) {
     317             :                     static_assert(is_forward);
     318             :                     success = cufftExecR2C(plan, reinterpret_cast<cufftReal*>(in_data),
     319             :                                            reinterpret_cast<cufftComplex*>(out_data))
     320             :                               == CUFFT_SUCCESS;
     321             :                 } else if constexpr (type == CUFFT_C2R) {
     322             :                     static_assert(!is_forward);
     323             :                     success = cufftExecC2R(plan, reinterpret_cast<cufftComplex*>(in_data),
     324             :                                            reinterpret_cast<cufftReal*>(out_data))
     325             :                               == CUFFT_SUCCESS;
     326             :                 } else if constexpr (type == CUFFT_D2Z) {
     327             :                     static_assert(is_forward);
     328             :                     success = cufftExecD2Z(plan, reinterpret_cast<cufftDoubleReal*>(in_data),
     329             :                                            reinterpret_cast<cufftDoubleComplex*>(out_data))
     330             :                               == CUFFT_SUCCESS;
     331             :                 } else if constexpr (type == CUFFT_Z2D) {
     332             :                     static_assert(!is_forward);
     333             :                     success = cufftExecZ2D(plan, reinterpret_cast<cufftDoubleComplex*>(in_data),
     334             :                                            reinterpret_cast<cufftDoubleReal*>(out_data))
     335             :                               == CUFFT_SUCCESS;
     336             :                 }
     337             :                 cudaDeviceSynchronize();
     338             : 
     339             :                 if (likely(success)) {
     340             :                     /* cuFFT performs unnormalized FFTs, therefore we are left to do scaling
     341             :                        according to FFTNorm */
     342             :                     if (norm == FFTNorm::FORWARD && is_forward || norm == FFTNorm::ORTHO
     343             :                         || norm == FFTNorm::BACKWARD && !is_forward) {
     344             :                         fftNormalize(out_data,, normalization_constant,
     345             :                                      norm == FFTNorm::ORTHO);
     346             :                     }
     347             :                 }
     348             : 
     349             : #if ELSA_CUFFT_CACHE_SIZE == 0
     350             :                 cufftDestroy(plan);
     351             : #endif
     352             : 
     353             :                 return success;
     354             :             }
     355             :         }
     356             : 
     357             :         /* Perform am fft on the GPU using cuFFT if possible.
     358             :          * Template parameter is_forward determines whether an fft or ifft is performed.
     359             :          * true -> fft
     360             :          * false -> ifft
     361             :          *
     362             :          * A return value of false indicates that the operation was not successful and the
     363             :          * input data is unchanged.
     364             :          */
     365             :         template <class data_t, bool is_forward>
     366             :         bool fftDevice(thrust::universal_ptr<data_t> this_data, const IndexVector_t& src_shape,
     367             :                        FFTNorm norm)
     368             :         {
     369             :             /* Rationale: cuFFT expects the data in row major layout, but DataContainer uses column
     370             :                major layout (and the CPU dft implementation expects it, as well)
     371             :                => We use the following identity: DFT(A^T)^T = DFT(A) (the proof is left as an
     372             :                exercise to the reader ;)) Hence, the data does not need to be transposed, as long as
     373             :                the shape is reversed, because to cuFFT our input is transposed. */
     374             :             IndexVector_t src_shape_transposed = src_shape.reverse();
     375             :             auto [_, normalization_constant] =
     376             :                 computeOutputShapeAndNormalization<FFTType::C2C>(src_shape_transposed);
     377             : 
     378             :             return doFftDevice<is_forward>(this_data.get(), this_data.get(), src_shape_transposed,
     379             :                                            src_shape_transposed, norm, normalization_constant);
     380             :         }
     381             : 
     382             :         /* shape is the actual size of the volume being transformed, not the size of the output
     383             :          * (which is less due to the symmetry that R2C exploits) */
     384             :         template <class data_t>
     385             :         bool rfftDevice(thrust::universal_ptr<const data_t> in_data, const IndexVector_t& shape,
     386             :                         const IndexVector_t& out_shape,
     387             :                         thrust::universal_ptr<elsa::complex<data_t>> out_data, FFTNorm norm,
     388             :                         index_t normalization_constant)
     389             :         {
     390             :             using info = FFTInfo<data_t, elsa::complex<data_t>>;
     391             :             if constexpr (!info::cufft_valid) {
     392             :                 /* only single and double precision floats are supported */
     393             :                 static_cast<void>(in_data);
     394             :                 static_cast<void>(shape);
     395             :                 static_cast<void>(out_shape);
     396             :                 static_cast<void>(out_data);
     397             :                 static_cast<void>(norm);
     398             :                 static_cast<void>(normalization_constant);
     399             :                 return false;
     400             :             }
     401             : 
     402             :             elsa::complex<data_t>* wc_buf;
     403             :             if (cudaMalloc(&wc_buf, sizeof(elsa::complex<data_t>) *
     404             :                 != cudaSuccess) {
     405             :                 return false;
     406             :             }
     407             : 
     408             :             /* making use of the fact that, for R2C, the input is at most as big as the output*/
     409             :             transposeDevice(in_data.get(), reinterpret_cast<data_t*>(out_data.get()), shape);
     410             :             bool success = doFftDevice<true>(reinterpret_cast<data_t*>(out_data.get()), wc_buf,
     411             :                                              shape, out_shape, norm, normalization_constant);
     412             : 
     413             :             if (success) {
     414             :                 transposeDevice(wc_buf, out_data.get(), out_shape.reverse());
     415             :             }
     416             :             cudaFree(wc_buf);
     417             :             return success;
     418             :         }
     419             : 
     420             :         template <class data_t>
     421             :         bool irfftDevice(thrust::universal_ptr<const elsa::complex<data_t>> in_data,
     422             :                          const IndexVector_t& src_shape, const IndexVector_t& out_shape,
     423             :                          thrust::universal_ptr<data_t> out_data, FFTNorm norm,
     424             :                          index_t normalization_constant)
     425             :         {
     426             :             using info = FFTInfo<elsa::complex<data_t>, data_t>;
     427             :             if constexpr (!info::cufft_valid) {
     428             :                 /* only single and double precision are supported */
     429             :                 static_cast<void>(in_data);
     430             :                 static_cast<void>(src_shape);
     431             :                 static_cast<void>(out_shape);
     432             :                 static_cast<void>(out_data);
     433             :                 static_cast<void>(norm);
     434             :                 static_cast<void>(normalization_constant);
     435             :                 return false;
     436             :             }
     437             : 
     438             :             elsa::complex<data_t>* wc_buf1;
     439             :             if (cudaMalloc(&wc_buf1, sizeof(elsa::complex<data_t>) *
     440             :                 != cudaSuccess) {
     441             :                 return false;
     442             :             }
     443             : 
     444             :             data_t* wc_buf2;
     445             :             if (cudaMalloc(&wc_buf2, sizeof(data_t) * != cudaSuccess) {
     446             :                 cudaFree(wc_buf1);
     447             :                 return false;
     448             :             }
     449             : 
     450             :             transposeDevice(in_data.get(), wc_buf1, src_shape);
     451             : 
     452             :             bool success = doFftDevice<false>(wc_buf1, wc_buf2, out_shape, out_shape, norm,
     453             :                                               normalization_constant);
     454             : 
     455             :             cudaFree(wc_buf1);
     456             :             if (success) {
     457             :                 transposeDevice(wc_buf2, out_data.get(), out_shape.reverse());
     458             :             }
     459             :             cudaFree(wc_buf2);
     460             :             return success;
     461             :         }
     462             : #endif
     463             : 
     464             :         template <bool is_forward, typename in_data_t, typename out_data_t>
     465             :         void fftHostSingleDim(const in_data_t* in_data, const IndexVector_t& in_shape,
     466             :                               out_data_t* out_data, const IndexVector_t& out_shape, index_t dim_idx,
     467             :                               FFTNorm norm)
     468        1640 :         {
     469        1640 :             using info = FFTInfo<in_data_t, out_data_t>;
     470        1640 :             using InputVector_t = const Eigen::Matrix<in_data_t, Eigen::Dynamic, 1>;
     471        1640 :             using OutputVector_t = Eigen::Matrix<out_data_t, Eigen::Dynamic, 1>;
     472             : 
     473             :             // jumps in the data for the current dimension's data
     474             :             // dim_size[0] * dim_size[1] * ...
     475             :             // 1 for dim_idx == 0.
     476        1640 :             const index_t in_stride = in_shape.head(dim_idx).prod();
     477        1640 :             const index_t out_stride = out_shape.head(dim_idx).prod();
     478             : 
     479             :             // number of coefficients for the current dimension
     480        1640 :             const index_t in_dim_size = in_shape(dim_idx);
     481        1640 :             const index_t out_dim_size = out_shape(dim_idx);
     482        1640 :             index_t true_dim_size;
     483        1640 :             if constexpr (info::kind == FFTType::R2C) {
     484        1631 :                 true_dim_size = in_dim_size;
     485        1631 :             } else {
     486        1631 :                 true_dim_size = out_dim_size;
     487        1631 :             }
     488             : 
     489             :             // number of coefficients for the other dimensions
     490             :             // this is the number of 1d-ffts we'll do
     491             :             // e.g. shape=[2, 3, 4] and we do dim_idx=2 (=shape 4)
     492             :             //   -> == 24 / 4 = 6 == 2*3
     493        1640 :             const index_t other_dims_size = / in_dim_size;
     494        1640 :             const index_t out_other_dims_size = / out_dim_size;
     495        1640 :             if (unlikely(other_dims_size != out_other_dims_size
     496        1640 :                          || (info::kind == FFTType::C2C && in_dim_size != out_dim_size)
     497        1640 :                          || (info::kind == FFTType::C2R && in_dim_size != out_dim_size / 2 + 1)
     498        1640 :                          || (info::kind == FFTType::R2C && in_dim_size / 2 + 1 != out_dim_size))) {
     499             :                 // if input & output do not have the same shape (in a C2R or R2C transform),
     500             :                 // only the transform axis may be different
     501           0 :                 throw Error{"FFT has incompatible input & output dimensions"};
     502           0 :             }
     503        1640 : #ifndef EIGEN_FFTW_DEFAULT
     504             : // when using eigen+fftw, this corrupts the memory, so don't parallelize.
     505             : // error messages may include:
     506             : // * double free or corruption (fasttop)
     507             : // * malloc_consolidate(): unaligned fastbin chunk detected
     508        1640 : #pragma omp parallel for
     509        1640 : #endif
     510             :             // do all the 1d-ffts along the current dimensions axis
     511        1640 :             for (index_t i = 0; i < other_dims_size; ++i) {
     512             : 
     513           0 :                 index_t in_ray_start = i;
     514             :                 // each time i is a multiple of stride,
     515             :                 // jump forward the current+previous dimensions' shape product
     516             :                 // (draw an indexed 3d cube to visualize this)
     517           0 :                 in_ray_start +=
     518           0 :                     (in_stride * (in_dim_size - 1)) * ((i - (i % in_stride)) / in_stride);
     519             : 
     520           0 :                 index_t out_ray_start = i;
     521           0 :                 out_ray_start +=
     522           0 :                     (out_stride * (out_dim_size - 1)) * ((i - (i % out_stride)) / out_stride);
     523             : 
     524             :                 // this is one "ray" through the volume
     525           0 :                 Eigen::Map<InputVector_t, Eigen::AlignmentType::Unaligned, Eigen::InnerStride<>>
     526           0 :                     input_map(in_data + in_ray_start, in_dim_size, Eigen::InnerStride<>(in_stride));
     527             : 
     528           0 :                 Eigen::Map<OutputVector_t, Eigen::AlignmentType::Unaligned, Eigen::InnerStride<>>
     529           0 :                     output_map(out_data + out_ray_start, out_dim_size,
     530           0 :                                Eigen::InnerStride<>(out_stride));
     531           0 :                 using inner_t = typename info::real_t;
     532             : 
     533           0 :                 Eigen::FFT<inner_t> fft_op;
     534             : 
     535             :                 // disable any scaling in eigen - normally it does 1/n for ifft
     536           0 :                 fft_op.SetFlag(Eigen::FFT<inner_t>::Flag::Unscaled);
     537             :                 // use half spectrum as input/output for complex-to-real/real-to-complex
     538             :                 // transforms
     539           0 :                 fft_op.SetFlag(Eigen::FFT<inner_t>::Flag::HalfSpectrum);
     540             : 
     541           0 :                 Eigen::Matrix<typename info::eigen_in_t, Eigen::Dynamic, 1> fft_in{in_dim_size};
     542           0 :                 Eigen::Matrix<typename info::eigen_out_t, Eigen::Dynamic, 1> fft_out{out_dim_size};
     543             : 
     544             :                 // eigen internally copies the fwd input matrix anyway if
     545             :                 // it doesn't have stride == 1
     546           0 :                 fft_in = input_map.block(0, 0, in_dim_size, 1)
     547           0 :                              .template cast<const typename info::eigen_in_t>();
     548             : 
     549       21222 :                 if (unlikely(in_dim_size == 1)) {
     550             :                     // eigen kiss-fft crashes for size=1...
     551             :                     // TODO: resolve for R2C/C2R
     552       21222 :                     if constexpr (info::kind == FFTType::C2C) {
     553           0 :                         fft_out = fft_in;
     554           0 :                     } else if constexpr (info::kind == FFTType::C2R) {
     555           0 :                         fft_out(0) = fft_in(0).real();
     556           0 :                     } else {
     557           0 :                         fft_out(0) = fft_in(0);
     558           0 :                     }
     559 >1844*10^16 :                 } else {
     560             :                     // arguments for in and out _must not_ be the same matrix!
     561             :                     // they will corrupt wildly otherwise.
     562 >1844*10^16 :                     if constexpr (is_forward) {
     563 >1844*10^16 :                         fft_op.fwd(fft_out, fft_in);
     564 >1844*10^16 :                         if (norm == FFTNorm::FORWARD) {
     565          31 :                             fft_out /= true_dim_size;
     566       15953 :                         } else if (norm == FFTNorm::ORTHO) {
     567          31 :                             fft_out /= std::sqrt(true_dim_size);
     568          31 :                         }
     569 >1844*10^16 :                     } else {
     570 >1844*10^16 :                         fft_op.inv(fft_out, fft_in);
     571 >1844*10^16 :                         if (norm == FFTNorm::BACKWARD) {
     572       19874 :                             fft_out /= true_dim_size;
     573 >1844*10^16 :                         } else if (norm == FFTNorm::ORTHO) {
     574          33 :                             fft_out /= std::sqrt(true_dim_size);
     575          33 :                         }
     576 >1844*10^16 :                     }
     577 >1844*10^16 :                 }
     578             : 
     579             :                 // we can't directly use the map as fft output,
     580             :                 // since Eigen internally just uses the pointer to
     581             :                 // the map's first element, and doesn't respect stride at all..
     582           0 :                 output_map.block(0, 0, out_dim_size, 1) = fft_out.template cast<out_data_t>();
     583           0 :             }
     584        1640 :         }
     585             : 
     586             :         // Algorithm sketch for different FFTs:
     587             :         //
     588             :         // if(R2C)
     589             :         //     perform R2C FFTs along dimension N // input -> output
     590             :         //     start_dim = 0
     591             :         //     end_dim = N-1
     592             :         // else if(C2R)
     593             :         //     start_dim = 0
     594             :         //     end_dim = N-1
     595             :         // else
     596             :         //     start_dim = 0
     597             :         //     end_dim = N
     598             :         //
     599             :         // perform FFTs for dims start_dim..end_dim // C2C: input -> output -> ... output
     600             :         //                                          // C2R: input -> tmp -> ... tmp
     601             :         //                                          // R2C: output -> output -> ... output
     602             :         //
     603             :         // if(C2R)
     604             :         //     perform C2R FFTs along dimension N // tmp -> output
     605             :         //
     606             :         // ATTENTION: in_data == out_data is allowed, assuming the input & output buffer is large
     607             :         //            enough to hold the input and the output (not simultaneously). In- and output
     608             :         //            size differs for C2R & R2C transforms
     609             :         template <class in_data_t, class out_data_t, bool is_forward>
     610             :         void fftHost(const in_data_t* in_data, const IndexVector_t& in_shape, out_data_t* out_data,
     611             :                      const IndexVector_t& out_shape, index_t dims, FFTNorm norm)
     612         823 :         {
     613         823 :             using info = FFTInfo<in_data_t, out_data_t>;
     614             : 
     615         823 :             if constexpr (info::valid) {
     616             :                 // TODO: fftw variant
     617             : 
     618             :                 // R2C is implicitely forward, C2R is implicitely backward
     619           0 :                 static_assert((info::kind != FFTType::R2C || is_forward)
     620           0 :                               && (info::kind != FFTType::C2R || !is_forward));
     621             : 
     622             :                 // For C2C, half_spectrum_shape == in_shape == out_shape
     623           0 :                 IndexVector_t half_spectrum_shape;
     624         823 :                 if constexpr (info::kind == FFTType::C2R) {
     625         811 :                     half_spectrum_shape = in_shape;
     626         811 :                 } else {
     627         811 :                     half_spectrum_shape = out_shape;
     628         811 :                 }
     629             : 
     630             :                 // exclusive
     631           0 :                 index_t end_dim;
     632           0 :                 const typename info::complex_t* current_input;
     633           0 :                 typename info::complex_t* working_area;
     634           0 :                 std::unique_ptr<typename info::complex_t[]> tmp;
     635         823 :                 if constexpr (info::kind == FFTType::R2C) {
     636         814 :                     end_dim = dims - 1;
     637         814 :                     fftHostSingleDim<is_forward>(in_data, in_shape, out_data, out_shape, end_dim,
     638         814 :                                                  norm);
     639         814 :                     current_input = out_data;
     640         814 :                     working_area = out_data;
     641         814 :                 } else if constexpr (info::kind == FFTType::C2R) {
     642         802 :                     end_dim = dims - 1;
     643         802 :                     current_input = in_data;
     644         802 :                     tmp = std::make_unique<typename info::complex_t[]>(;
     645         802 :                     working_area = tmp.get();
     646         802 :                 } else {
     647         802 :                     end_dim = dims;
     648         802 :                     current_input = in_data;
     649         802 :                     working_area = out_data;
     650         802 :                 }
     651             : 
     652             :                 // generalization of an 1D-FFT
     653             :                 // walk over each dimension and 1d-fft one 'line' of data
     654        2442 :                 for (index_t dim_idx = 0; dim_idx < end_dim; ++dim_idx) {
     655        1619 :                     fftHostSingleDim<is_forward>(current_input, half_spectrum_shape, working_area,
     656        1619 :                                                  half_spectrum_shape, dim_idx, norm);
     657        1619 :                     current_input = working_area;
     658        1619 :                 }
     659             : 
     660         823 :                 if constexpr (info::kind == FFTType::C2R) {
     661          12 :                     fftHostSingleDim<is_forward>(current_input, in_shape, out_data, out_shape,
     662          12 :                                                  end_dim, norm);
     663          12 :                 }
     664         823 :             } else {
     665           0 :                 static_cast<void>(in_data);
     666           0 :                 static_cast<void>(out_data);
     667           0 :                 static_cast<void>(in_shape);
     668           0 :                 static_cast<void>(out_shape);
     669           0 :                 static_cast<void>(dims);
     670           0 :                 static_cast<void>(norm);
     671             :                 // TODO: more concrete error with specific types
     672           0 :                 throw Error{"unsupported FFT types"};
     673           0 :             }
     674         823 :         }
     675             :     } // namespace detail
     676             : 
     677             :     template <class data_t>
     678             :     void fft(ContiguousStorage<data_t>& x, const DataDescriptor& desc, FFTNorm norm,
     679             :              FFTPolicy policy)
     680         352 :     {
     681         352 :         const auto& src_shape = desc.getNumberOfCoefficientsPerDimension();
     682         352 :         const auto& src_dims = desc.getNumberOfDimensions();
     683             : 
     684             : #ifdef ELSA_CUDA_TOOLKIT_PRESENT
     685             :         if (policy != FFTPolicy::HOST) {
     686             :             if (detail::fftDevice<data_t, true>(, src_shape, norm)) {
     687             :                 return;
     688             :             }
     689             :         }
     690             : #endif
     691         352 :         if (policy != FFTPolicy::DEVICE) {
     692         352 :             detail::fftHost<data_t, data_t, true>(, src_shape,,
     693         352 :                                                   src_shape, src_dims, norm);
     694         352 :         } else {
     695           0 :             throw Error{"Cannot sucessfully execute FFT with policy {}", policy};
     696           0 :         }
     697         352 :     }
     698             : 
     699             :     template <class data_t>
     700             :     void ifft(ContiguousStorage<data_t>& x, const DataDescriptor& desc, FFTNorm norm,
     701             :               FFTPolicy policy)
     702         450 :     {
     703         450 :         const auto& src_shape = desc.getNumberOfCoefficientsPerDimension();
     704         450 :         const auto& src_dims = desc.getNumberOfDimensions();
     705             : 
     706             : #ifdef ELSA_CUDA_TOOLKIT_PRESENT
     707             :         if (policy != FFTPolicy::HOST) {
     708             :             if (detail::fftDevice<data_t, false>(, src_shape, norm)) {
     709             :                 return;
     710             :             }
     711             :         }
     712             : #endif
     713         450 :         if (policy != FFTPolicy::DEVICE) {
     714         450 :             detail::fftHost<data_t, data_t, false>(, src_shape,,
     715         450 :                                                    src_shape, src_dims, norm);
     716         450 :         } else {
     717           0 :             throw Error{"Cannot sucessfully execute IFFT with policy {}", policy};
     718           0 :         }
     719         450 :     }
     720             : 
     721             :     template <typename data_t>
     722             :     DataContainer<complex<data_t>> rfft(const DataContainer<data_t>& dc, FFTNorm norm,
     723             :                                         FFTPolicy policy)
     724           9 :     {
     725           9 :         const auto& desc = dc.getDataDescriptor();
     726           9 :         const auto& src_shape = desc.getNumberOfCoefficientsPerDimension();
     727             : 
     728           9 :         auto [dst_shape, normalization_constant] =
     729           9 :             detail::computeOutputShapeAndNormalization<detail::FFTType::R2C>(src_shape);
     730           9 :         DataContainer<complex<data_t>> out_dc(*std::make_unique<VolumeDescriptor>(dst_shape));
     731             : 
     732             : #ifdef ELSA_CUDA_TOOLKIT_PRESENT
     733             :         if (policy != FFTPolicy::HOST) {
     734             :             if (detail::rfftDevice<data_t>(, src_shape, dst_shape,
     735             :                                  , norm, normalization_constant)) {
     736             :                 return out_dc;
     737             :             }
     738             :         }
     739             : #endif
     740           9 :         if (policy != FFTPolicy::DEVICE) {
     741           9 :             detail::fftHost<data_t, complex<data_t>, true>(, src_shape,
     742           9 :                                                  , dst_shape,
     743           9 :                                                            src_shape.size(), norm);
     744           9 :             return out_dc;
     745           9 :         } else {
     746           0 :             throw Error{"Cannot sucessfully execute RFFT with policy {}", policy};
     747           0 :         }
     748           9 :     }
     749             : 
     750             :     template <typename data_t>
     751             :     DataContainer<data_t> irfft(const DataContainer<complex<data_t>>& dc, FFTNorm norm,
     752             :                                 FFTPolicy policy)
     753          12 :     {
     754          12 :         const auto& desc = dc.getDataDescriptor();
     755          12 :         const auto& src_shape = desc.getNumberOfCoefficientsPerDimension();
     756             : 
     757          12 :         auto [dst_shape, normalization_constant] =
     758          12 :             detail::computeOutputShapeAndNormalization<detail::FFTType::C2R>(src_shape);
     759          12 :         DataContainer<data_t> out_dc(*std::make_unique<VolumeDescriptor>(dst_shape));
     760             : 
     761             : #ifdef ELSA_CUDA_TOOLKIT_PRESENT
     762             :         if (policy != FFTPolicy::HOST) {
     763             :             if (detail::irfftDevice<data_t>(, src_shape, dst_shape,
     764             :                                   , norm,
     765             :                                             normalization_constant)) {
     766             :                 return out_dc;
     767             :             }
     768             :         }
     769             : #endif
     770          12 :         if (policy != FFTPolicy::DEVICE) {
     771          12 :             detail::fftHost<complex<data_t>, data_t, false>(, src_shape,
     772          12 :                                                   ,
     773          12 :                                                             dst_shape, src_shape.size(), norm);
     774          12 :         } else {
     775           0 :             throw Error{"Cannot sucessfully execute IRFFT with policy {}", policy};
     776           0 :         }
     777          12 :         return out_dc;
     778          12 :     }
     779             : 
     780             : } // namespace elsa

