Line data Source code
1 : #pragma once
2 :
3 : #include "FFTPolicy.h"
4 :
5 : #ifdef ELSA_CUDA_TOOLKIT_PRESENT
6 : #include <cuda_runtime.h>
7 : #include <cufftXt.h>
8 :
9 : #include <thrust/universal_ptr.h>
10 : #include <thrust/functional.h>
11 : #include <thrust/transform.h>
12 : #include <thrust/iterator/constant_iterator.h>
13 : #include <thrust/iterator/counting_iterator.h>
14 : #include <thrust/gather.h>
15 : #include <thrust/execution_policy.h>
16 :
17 : #include <list>
18 : #include <unordered_map>
19 : #include <optional>
20 : #endif
21 :
22 : #include "Complex.h"
23 : #include "elsaDefines.h"
24 : #include "Error.h"
25 : #include "DataDescriptor.h"
26 : #include "VolumeDescriptor.h"
27 : #include "ContiguousStorage.h"
28 : #include "DataContainer.h"
29 :
30 : #if WITH_FFTW
31 : #define EIGEN_FFTW_DEFAULT
32 : #endif
33 : #include <unsupported/Eigen/FFT>
34 :
35 : namespace elsa
36 : {
37 : namespace detail
38 : {
39 : enum class FFTType : uint8_t {
40 : C2C,
41 : R2C,
42 : C2R,
43 : INVALID,
44 : };
45 :
46 : template <FFTType tvalue>
47 : struct FFTType_t {
48 : static constexpr FFTType value = tvalue;
49 : };
50 :
51 : #ifdef ELSA_CUDA_TOOLKIT_PRESENT
52 : template <FFTType fft_type, typename inner>
53 : struct FFTCuFFTType {
54 : /* not valid */
55 : static constexpr cufftType value = CUFFT_C2C;
56 : };
57 :
58 : template <>
59 : struct FFTCuFFTType<FFTType::C2C, float> {
60 : static constexpr cufftType value = CUFFT_C2C;
61 : };
62 :
63 : template <>
64 : struct FFTCuFFTType<FFTType::C2C, double> {
65 : static constexpr cufftType value = CUFFT_Z2Z;
66 : };
67 :
68 : template <>
69 : struct FFTCuFFTType<FFTType::R2C, float> {
70 : static constexpr cufftType value = CUFFT_R2C;
71 : };
72 :
73 : template <>
74 : struct FFTCuFFTType<FFTType::R2C, double> {
75 : static constexpr cufftType value = CUFFT_D2Z;
76 : };
77 :
78 : template <>
79 : struct FFTCuFFTType<FFTType::C2R, float> {
80 : static constexpr cufftType value = CUFFT_C2R;
81 : };
82 :
83 : template <>
84 : struct FFTCuFFTType<FFTType::C2R, double> {
85 : static constexpr cufftType value = CUFFT_Z2D;
86 : };
87 : #endif
88 :
89 : template <typename in, typename out>
90 : struct FFTInfo {
91 : using in_t = in;
92 : using out_t = out;
93 : using real_t =
94 : std::conditional_t<std::is_same_v<value_type_of_t<in>, value_type_of_t<out>>,
95 : value_type_of_t<in>, void>;
96 : using complex_t = std::conditional_t<is_complex_v<in>, in, out>;
97 : using eigen_in_t =
98 : std::conditional_t<is_complex_v<in>, std::complex<value_type_of_t<in>>, in>;
99 : using eigen_out_t =
100 : std::conditional_t<is_complex_v<out>, std::complex<value_type_of_t<out>>, out>;
101 : static constexpr FFTType kind =
102 : std::conditional_t < is_complex_v<in> && is_complex_v<out>,
103 : FFTType_t<FFTType::C2C>, std::conditional_t < is_complex_v<in>,
104 : FFTType_t<FFTType::C2R>,
105 : std::conditional_t < is_complex_v<out>,
106 : FFTType_t<FFTType::R2C>,
107 : FFTType_t < FFTType::INVALID >>>> ::value;
108 : static constexpr bool valid = !std::is_same_v<real_t, void> && kind != FFTType::INVALID;
109 : #ifdef ELSA_CUDA_TOOLKIT_PRESENT
110 : /* cuFFT can only work with single precision and double precision floats; Since pointers
111 : * have to be cast to the corresponding cufft type pointers, its important to make sure
112 : * the conversions are valid. (Eigen's FFT is templated, so I'd hope that it complains
113 : * when it is unhappy with its input types)
114 : */
115 : static constexpr bool cufft_valid =
116 : valid && (std::is_same_v<real_t, float> || std::is_same_v<real_t, double>);
117 : static constexpr cufftType cufft_type = FFTCuFFTType<kind, real_t>::value;
118 : #endif
119 : };
120 :
121 : /// @return [out_shape, normalization_constant]
122 : template <FFTType type>
123 : std::pair<IndexVector_t, index_t>
124 : computeOutputShapeAndNormalization(const IndexVector_t& src_shape)
125 21 : {
126 21 : static_assert(type == FFTType::C2C || type == FFTType::C2R || type == FFTType::R2C,
127 21 : "Invalid fft type!");
128 21 : if constexpr (type == FFTType::C2C) {
129 21 : return std::make_pair(src_shape, src_shape.prod());
130 21 : } else if constexpr (type == FFTType::C2R) {
131 9 : index_t last_dim = src_shape.size() - 1;
132 9 : IndexVector_t dst_shape(src_shape.size());
133 9 : dst_shape << src_shape.head(last_dim), (src_shape(last_dim) - 1) * 2;
134 9 : return std::make_pair(dst_shape, dst_shape.prod());
135 9 : } else if constexpr (type == FFTType::R2C) {
136 9 : index_t last_dim = src_shape.size() - 1;
137 9 : IndexVector_t dst_shape(src_shape.size());
138 9 : dst_shape << src_shape.head(last_dim), src_shape(last_dim) / 2 + 1;
139 9 : return std::make_pair(dst_shape, src_shape.prod());
140 9 : }
141 21 : }
142 :
143 : #ifdef ELSA_CUDA_TOOLKIT_PRESENT
144 :
145 : /* Create a cufft handle and plan an FFT for the given type and dimensions.
146 : * This handle is to be freed via cufftDestroy(...)
147 : * Returns a result indicating whether creating the plan was successful. If not,
148 : * the plan is not initialized and does not need to be freed.
149 : */
150 : cufftResult createPlan(cufftHandle* plan, cufftType type, const IndexVector_t& shape);
151 :
152 : /* Normalize the result after an fft with the factor size or sqrt(size), depending on
153 : the flag applySqrt. */
154 : template <typename data_t>
155 : void fftNormalize(data_t* ptr, index_t size, index_t normalization_constant, bool applySqrt)
156 : {
157 : data_t normalizing_factor = static_cast<data_t>(
158 : applySqrt ? std::sqrt(normalization_constant) : normalization_constant);
159 : thrust::transform(thrust::device, ptr, ptr + size,
160 : thrust::make_constant_iterator(normalizing_factor), ptr,
161 : thrust::divides<data_t>());
162 : }
163 :
164 : /* Assuming row-major layout! */
165 : struct TransposeIndexFunctor3D : public thrust::unary_function<size_t, size_t> {
166 : size_t rows, columns, depth;
167 :
168 : __host__ __device__ TransposeIndexFunctor3D(size_t rows, size_t columns, size_t depth)
169 : : rows(rows), columns(columns), depth(depth)
170 : {
171 : }
172 :
173 : __host__ __device__ size_t operator()(size_t linear_index)
174 : {
175 : size_t src_row = linear_index / (columns * depth);
176 : size_t remainder = linear_index % (columns * depth);
177 : size_t src_column = remainder / depth;
178 : size_t src_depth = remainder % depth;
179 :
180 : return src_depth * (rows * columns) + src_column * rows + src_row;
181 : }
182 : };
183 :
184 : /* src and dst must not alias! */
185 : template <typename data_t>
186 : bool transposeDevice(const data_t* src, data_t* dst, const IndexVector_t& src_shape)
187 : {
188 :
189 : switch (src_shape.size()) {
190 : case 1:
191 : /* no need to transpose, but must copy from src to dst */
192 : thrust::copy(thrust::device, src, src + src_shape.prod(), dst);
193 : return true;
194 : case 2: {
195 : TransposeIndexFunctor3D functor(src_shape(0), src_shape(1), 1);
196 : thrust::counting_iterator<size_t> indices(0);
197 : auto transpose_index_iterator =
198 : thrust::make_transform_iterator(indices, functor);
199 : thrust::gather(thrust::device, transpose_index_iterator,
200 : transpose_index_iterator + src_shape.prod(), src, dst);
201 : return true;
202 : }
203 : case 3: {
204 : TransposeIndexFunctor3D functor(src_shape(0), src_shape(1), src_shape(2));
205 : thrust::counting_iterator<size_t> indices(0);
206 : auto transpose_index_iterator =
207 : thrust::make_transform_iterator(indices, functor);
208 : thrust::gather(thrust::device, transpose_index_iterator,
209 : transpose_index_iterator + src_shape.prod(), src, dst);
210 : return true;
211 : }
212 : default:
213 : return false;
214 : }
215 : }
216 :
217 : /* Cache is not thread safe, also plans should not be used across multiple threads at the
218 : * same time! Hence, we use a thread local instance. Potential optimizations, should the
219 : * caches GPU memory consumption ever be a problem: Aside from disabling the caching
220 : * mechanism, there are two potential optimizations.
221 : * - Currently, when no new plan can be allocated, the cache is flushed. This could be
222 : * extended also flush the caches of all other threads.
223 : * - cuFFT allows us to manage the work area memory. Before planning,
224 : * cufftSetAutoAllocation(false) could be called, then the work area can be explicitely set.
225 : * This way, all elements of the cache could share a work area, fitting the requirements of
226 : * the largest cached plan. This would increase the management overhead, but reduce memory
227 : * consumption. Currently, the cache has to few elements for this to be worth it.
228 : */
229 : class CuFFTPlanCache
230 : {
231 : private:
232 : using CacheElement = std::tuple<cufftHandle, IndexVector_t, cufftType>;
233 : using CacheList = std::list<CacheElement>;
234 : CacheList _cache;
235 : /* this should be very low, to conserve GPU memory!
236 : Initialized via ELSA_CUFFT_CACHE_SIZE */
237 : size_t _limit;
238 :
239 : void flush();
240 : void evict();
241 :
242 : public:
243 : CuFFTPlanCache();
244 :
245 : CuFFTPlanCache(const CuFFTPlanCache& other) = delete;
246 : CuFFTPlanCache& operator=(const CuFFTPlanCache& other) = delete;
247 : /* This performs linear search, because that should be less overhead than a
248 : map lookup for very small cache sizes*/
249 : std::optional<cufftHandle> get(cufftType type, const IndexVector_t& shape);
250 : };
251 :
252 : extern thread_local CuFFTPlanCache cufftCache;
253 :
254 : /* ATTENTION! Expects data to be layed out in ROW MAJOR order!
255 : *
256 : * | in_data_t | out_data_t |
257 : * -----+-----------------+-----------------+
258 : * C2C | complex<float> | complex<float> |
259 : * Z2Z | complex<double> | complex<double> |
260 : * C2R | complex<float> | float |
261 : * Z2D | complex<double> | double |
262 : * R2C | float | complex<float> |
263 : * D2Z | double | complex<float> |
264 : *
265 : */
266 : template <bool is_forward, typename in_data_t, typename out_data_t>
267 : bool doFftDevice(in_data_t* in_data, out_data_t* out_data, const IndexVector_t& shape,
268 : const IndexVector_t& out_shape, FFTNorm norm,
269 : index_t normalization_constant)
270 : {
271 : /* According to this example:
272 : * https://github.com/NVIDIA/CUDALibrarySamples/blob/master/cuFFT/1d_c2c/1d_c2c_example.cpp
273 : * it is fine to reinterpret_cast std::complex to cufftComplex.
274 : * The same applies to thrust::complex, which is what elsa::complex maps to.
275 : */
276 :
277 : using info = FFTInfo<in_data_t, out_data_t>;
278 : if constexpr (!info::cufft_valid) {
279 : static_cast<void>(in_data);
280 : static_cast<void>(out_data);
281 : static_cast<void>(shape);
282 : static_cast<void>(out_shape);
283 : static_cast<void>(norm);
284 : static_cast<void>(normalization_constant);
285 : return false;
286 : } else {
287 : constexpr cufftType type = info::cufft_type;
288 :
289 : cufftHandle plan;
290 : #if ELSA_CUFFT_CACHE_SIZE != 0
291 : std::optional<cufftHandle> planOpt = cufftCache.get(type, shape);
292 : if (planOpt) {
293 : plan = *planOpt;
294 : } else {
295 : return false;
296 : }
297 : #else
298 : if (createPlan(&plan, type, shape) != CUFFT_SUCCESS) {
299 : return false;
300 : }
301 : #endif
302 :
303 : bool success;
304 : if constexpr (type == CUFFT_C2C) {
305 : int direction = is_forward ? CUFFT_FORWARD : CUFFT_INVERSE;
306 : /* cuFFT can handle in-place transforms */
307 : success = cufftExecC2C(plan, reinterpret_cast<cufftComplex*>(in_data),
308 : reinterpret_cast<cufftComplex*>(out_data), direction)
309 : == CUFFT_SUCCESS;
310 : } else if constexpr (type == CUFFT_Z2Z) {
311 : int direction = is_forward ? CUFFT_FORWARD : CUFFT_INVERSE;
312 : success =
313 : cufftExecZ2Z(plan, reinterpret_cast<cufftDoubleComplex*>(in_data),
314 : reinterpret_cast<cufftDoubleComplex*>(out_data), direction)
315 : == CUFFT_SUCCESS;
316 : } else if constexpr (type == CUFFT_R2C) {
317 : static_assert(is_forward);
318 : success = cufftExecR2C(plan, reinterpret_cast<cufftReal*>(in_data),
319 : reinterpret_cast<cufftComplex*>(out_data))
320 : == CUFFT_SUCCESS;
321 : } else if constexpr (type == CUFFT_C2R) {
322 : static_assert(!is_forward);
323 : success = cufftExecC2R(plan, reinterpret_cast<cufftComplex*>(in_data),
324 : reinterpret_cast<cufftReal*>(out_data))
325 : == CUFFT_SUCCESS;
326 : } else if constexpr (type == CUFFT_D2Z) {
327 : static_assert(is_forward);
328 : success = cufftExecD2Z(plan, reinterpret_cast<cufftDoubleReal*>(in_data),
329 : reinterpret_cast<cufftDoubleComplex*>(out_data))
330 : == CUFFT_SUCCESS;
331 : } else if constexpr (type == CUFFT_Z2D) {
332 : static_assert(!is_forward);
333 : success = cufftExecZ2D(plan, reinterpret_cast<cufftDoubleComplex*>(in_data),
334 : reinterpret_cast<cufftDoubleReal*>(out_data))
335 : == CUFFT_SUCCESS;
336 : }
337 : cudaDeviceSynchronize();
338 :
339 : if (likely(success)) {
340 : /* cuFFT performs unnormalized FFTs, therefore we are left to do scaling
341 : according to FFTNorm */
342 : if (norm == FFTNorm::FORWARD && is_forward || norm == FFTNorm::ORTHO
343 : || norm == FFTNorm::BACKWARD && !is_forward) {
344 : fftNormalize(out_data, out_shape.prod(), normalization_constant,
345 : norm == FFTNorm::ORTHO);
346 : }
347 : }
348 :
349 : #if ELSA_CUFFT_CACHE_SIZE == 0
350 : cufftDestroy(plan);
351 : #endif
352 :
353 : return success;
354 : }
355 : }
356 :
357 : /* Perform am fft on the GPU using cuFFT if possible.
358 : * Template parameter is_forward determines whether an fft or ifft is performed.
359 : * true -> fft
360 : * false -> ifft
361 : *
362 : * A return value of false indicates that the operation was not successful and the
363 : * input data is unchanged.
364 : */
365 : template <class data_t, bool is_forward>
366 : bool fftDevice(thrust::universal_ptr<data_t> this_data, const IndexVector_t& src_shape,
367 : FFTNorm norm)
368 : {
369 : /* Rationale: cuFFT expects the data in row major layout, but DataContainer uses column
370 : major layout (and the CPU dft implementation expects it, as well)
371 : => We use the following identity: DFT(A^T)^T = DFT(A) (the proof is left as an
372 : exercise to the reader ;)) Hence, the data does not need to be transposed, as long as
373 : the shape is reversed, because to cuFFT our input is transposed. */
374 : IndexVector_t src_shape_transposed = src_shape.reverse();
375 : auto [_, normalization_constant] =
376 : computeOutputShapeAndNormalization<FFTType::C2C>(src_shape_transposed);
377 :
378 : return doFftDevice<is_forward>(this_data.get(), this_data.get(), src_shape_transposed,
379 : src_shape_transposed, norm, normalization_constant);
380 : }
381 :
382 : /* shape is the actual size of the volume being transformed, not the size of the output
383 : * (which is less due to the symmetry that R2C exploits) */
384 : template <class data_t>
385 : bool rfftDevice(thrust::universal_ptr<const data_t> in_data, const IndexVector_t& shape,
386 : const IndexVector_t& out_shape,
387 : thrust::universal_ptr<elsa::complex<data_t>> out_data, FFTNorm norm,
388 : index_t normalization_constant)
389 : {
390 : using info = FFTInfo<data_t, elsa::complex<data_t>>;
391 : if constexpr (!info::cufft_valid) {
392 : /* only single and double precision floats are supported */
393 : static_cast<void>(in_data);
394 : static_cast<void>(shape);
395 : static_cast<void>(out_shape);
396 : static_cast<void>(out_data);
397 : static_cast<void>(norm);
398 : static_cast<void>(normalization_constant);
399 : return false;
400 : }
401 :
402 : elsa::complex<data_t>* wc_buf;
403 : if (cudaMalloc(&wc_buf, sizeof(elsa::complex<data_t>) * out_shape.prod())
404 : != cudaSuccess) {
405 : return false;
406 : }
407 :
408 : /* making use of the fact that, for R2C, the input is at most as big as the output*/
409 : transposeDevice(in_data.get(), reinterpret_cast<data_t*>(out_data.get()), shape);
410 : bool success = doFftDevice<true>(reinterpret_cast<data_t*>(out_data.get()), wc_buf,
411 : shape, out_shape, norm, normalization_constant);
412 :
413 : if (success) {
414 : transposeDevice(wc_buf, out_data.get(), out_shape.reverse());
415 : }
416 : cudaFree(wc_buf);
417 : return success;
418 : }
419 :
420 : template <class data_t>
421 : bool irfftDevice(thrust::universal_ptr<const elsa::complex<data_t>> in_data,
422 : const IndexVector_t& src_shape, const IndexVector_t& out_shape,
423 : thrust::universal_ptr<data_t> out_data, FFTNorm norm,
424 : index_t normalization_constant)
425 : {
426 : using info = FFTInfo<elsa::complex<data_t>, data_t>;
427 : if constexpr (!info::cufft_valid) {
428 : /* only single and double precision are supported */
429 : static_cast<void>(in_data);
430 : static_cast<void>(src_shape);
431 : static_cast<void>(out_shape);
432 : static_cast<void>(out_data);
433 : static_cast<void>(norm);
434 : static_cast<void>(normalization_constant);
435 : return false;
436 : }
437 :
438 : elsa::complex<data_t>* wc_buf1;
439 : if (cudaMalloc(&wc_buf1, sizeof(elsa::complex<data_t>) * src_shape.prod())
440 : != cudaSuccess) {
441 : return false;
442 : }
443 :
444 : data_t* wc_buf2;
445 : if (cudaMalloc(&wc_buf2, sizeof(data_t) * out_shape.prod()) != cudaSuccess) {
446 : cudaFree(wc_buf1);
447 : return false;
448 : }
449 :
450 : transposeDevice(in_data.get(), wc_buf1, src_shape);
451 :
452 : bool success = doFftDevice<false>(wc_buf1, wc_buf2, out_shape, out_shape, norm,
453 : normalization_constant);
454 :
455 : cudaFree(wc_buf1);
456 : if (success) {
457 : transposeDevice(wc_buf2, out_data.get(), out_shape.reverse());
458 : }
459 : cudaFree(wc_buf2);
460 : return success;
461 : }
462 : #endif
463 :
464 : template <bool is_forward, typename in_data_t, typename out_data_t>
465 : void fftHostSingleDim(const in_data_t* in_data, const IndexVector_t& in_shape,
466 : out_data_t* out_data, const IndexVector_t& out_shape, index_t dim_idx,
467 : FFTNorm norm)
468 1640 : {
469 1640 : using info = FFTInfo<in_data_t, out_data_t>;
470 1640 : using InputVector_t = const Eigen::Matrix<in_data_t, Eigen::Dynamic, 1>;
471 1640 : using OutputVector_t = Eigen::Matrix<out_data_t, Eigen::Dynamic, 1>;
472 :
473 : // jumps in the data for the current dimension's data
474 : // dim_size[0] * dim_size[1] * ...
475 : // 1 for dim_idx == 0.
476 1640 : const index_t in_stride = in_shape.head(dim_idx).prod();
477 1640 : const index_t out_stride = out_shape.head(dim_idx).prod();
478 :
479 : // number of coefficients for the current dimension
480 1640 : const index_t in_dim_size = in_shape(dim_idx);
481 1640 : const index_t out_dim_size = out_shape(dim_idx);
482 1640 : index_t true_dim_size;
483 1640 : if constexpr (info::kind == FFTType::R2C) {
484 1631 : true_dim_size = in_dim_size;
485 1631 : } else {
486 1631 : true_dim_size = out_dim_size;
487 1631 : }
488 :
489 : // number of coefficients for the other dimensions
490 : // this is the number of 1d-ffts we'll do
491 : // e.g. shape=[2, 3, 4] and we do dim_idx=2 (=shape 4)
492 : // -> in_shape.prod() == 24 / 4 = 6 == 2*3
493 1640 : const index_t other_dims_size = in_shape.prod() / in_dim_size;
494 1640 : const index_t out_other_dims_size = out_shape.prod() / out_dim_size;
495 1640 : if (unlikely(other_dims_size != out_other_dims_size
496 1640 : || (info::kind == FFTType::C2C && in_dim_size != out_dim_size)
497 1640 : || (info::kind == FFTType::C2R && in_dim_size != out_dim_size / 2 + 1)
498 1640 : || (info::kind == FFTType::R2C && in_dim_size / 2 + 1 != out_dim_size))) {
499 : // if input & output do not have the same shape (in a C2R or R2C transform),
500 : // only the transform axis may be different
501 0 : throw Error{"FFT has incompatible input & output dimensions"};
502 0 : }
503 1640 : #ifndef EIGEN_FFTW_DEFAULT
504 : // when using eigen+fftw, this corrupts the memory, so don't parallelize.
505 : // error messages may include:
506 : // * double free or corruption (fasttop)
507 : // * malloc_consolidate(): unaligned fastbin chunk detected
508 1640 : #pragma omp parallel for
509 1640 : #endif
510 : // do all the 1d-ffts along the current dimensions axis
511 1640 : for (index_t i = 0; i < other_dims_size; ++i) {
512 :
513 0 : index_t in_ray_start = i;
514 : // each time i is a multiple of stride,
515 : // jump forward the current+previous dimensions' shape product
516 : // (draw an indexed 3d cube to visualize this)
517 0 : in_ray_start +=
518 0 : (in_stride * (in_dim_size - 1)) * ((i - (i % in_stride)) / in_stride);
519 :
520 0 : index_t out_ray_start = i;
521 0 : out_ray_start +=
522 0 : (out_stride * (out_dim_size - 1)) * ((i - (i % out_stride)) / out_stride);
523 :
524 : // this is one "ray" through the volume
525 0 : Eigen::Map<InputVector_t, Eigen::AlignmentType::Unaligned, Eigen::InnerStride<>>
526 0 : input_map(in_data + in_ray_start, in_dim_size, Eigen::InnerStride<>(in_stride));
527 :
528 0 : Eigen::Map<OutputVector_t, Eigen::AlignmentType::Unaligned, Eigen::InnerStride<>>
529 0 : output_map(out_data + out_ray_start, out_dim_size,
530 0 : Eigen::InnerStride<>(out_stride));
531 0 : using inner_t = typename info::real_t;
532 :
533 0 : Eigen::FFT<inner_t> fft_op;
534 :
535 : // disable any scaling in eigen - normally it does 1/n for ifft
536 0 : fft_op.SetFlag(Eigen::FFT<inner_t>::Flag::Unscaled);
537 : // use half spectrum as input/output for complex-to-real/real-to-complex
538 : // transforms
539 0 : fft_op.SetFlag(Eigen::FFT<inner_t>::Flag::HalfSpectrum);
540 :
541 0 : Eigen::Matrix<typename info::eigen_in_t, Eigen::Dynamic, 1> fft_in{in_dim_size};
542 0 : Eigen::Matrix<typename info::eigen_out_t, Eigen::Dynamic, 1> fft_out{out_dim_size};
543 :
544 : // eigen internally copies the fwd input matrix anyway if
545 : // it doesn't have stride == 1
546 0 : fft_in = input_map.block(0, 0, in_dim_size, 1)
547 0 : .template cast<const typename info::eigen_in_t>();
548 :
549 21127 : if (unlikely(in_dim_size == 1)) {
550 : // eigen kiss-fft crashes for size=1...
551 : // TODO: resolve for R2C/C2R
552 21127 : if constexpr (info::kind == FFTType::C2C) {
553 0 : fft_out = fft_in;
554 0 : } else if constexpr (info::kind == FFTType::C2R) {
555 0 : fft_out(0) = fft_in(0).real();
556 0 : } else {
557 0 : fft_out(0) = fft_in(0);
558 0 : }
559 >1844*10^16 : } else {
560 : // arguments for in and out _must not_ be the same matrix!
561 : // they will corrupt wildly otherwise.
562 >1844*10^16 : if constexpr (is_forward) {
563 >1844*10^16 : fft_op.fwd(fft_out, fft_in);
564 >1844*10^16 : if (norm == FFTNorm::FORWARD) {
565 31 : fft_out /= true_dim_size;
566 15791 : } else if (norm == FFTNorm::ORTHO) {
567 31 : fft_out /= std::sqrt(true_dim_size);
568 31 : }
569 >1844*10^16 : } else {
570 >1844*10^16 : fft_op.inv(fft_out, fft_in);
571 >1844*10^16 : if (norm == FFTNorm::BACKWARD) {
572 19771 : fft_out /= true_dim_size;
573 >1844*10^16 : } else if (norm == FFTNorm::ORTHO) {
574 31 : fft_out /= std::sqrt(true_dim_size);
575 31 : }
576 >1844*10^16 : }
577 >1844*10^16 : }
578 :
579 : // we can't directly use the map as fft output,
580 : // since Eigen internally just uses the pointer to
581 : // the map's first element, and doesn't respect stride at all..
582 0 : output_map.block(0, 0, out_dim_size, 1) = fft_out.template cast<out_data_t>();
583 0 : }
584 1640 : }
585 :
586 : // Algorithm sketch for different FFTs:
587 : //
588 : // if(R2C)
589 : // perform R2C FFTs along dimension N // input -> output
590 : // start_dim = 0
591 : // end_dim = N-1
592 : // else if(C2R)
593 : // start_dim = 0
594 : // end_dim = N-1
595 : // else
596 : // start_dim = 0
597 : // end_dim = N
598 : //
599 : // perform FFTs for dims start_dim..end_dim // C2C: input -> output -> ... output
600 : // // C2R: input -> tmp -> ... tmp
601 : // // R2C: output -> output -> ... output
602 : //
603 : // if(C2R)
604 : // perform C2R FFTs along dimension N // tmp -> output
605 : //
606 : // ATTENTION: in_data == out_data is allowed, assuming the input & output buffer is large
607 : // enough to hold the input and the output (not simultaneously). In- and output
608 : // size differs for C2R & R2C transforms
609 : template <class in_data_t, class out_data_t, bool is_forward>
610 : void fftHost(const in_data_t* in_data, const IndexVector_t& in_shape, out_data_t* out_data,
611 : const IndexVector_t& out_shape, index_t dims, FFTNorm norm)
612 823 : {
613 823 : using info = FFTInfo<in_data_t, out_data_t>;
614 :
615 823 : if constexpr (info::valid) {
616 : // TODO: fftw variant
617 :
618 : // R2C is implicitely forward, C2R is implicitely backward
619 0 : static_assert((info::kind != FFTType::R2C || is_forward)
620 0 : && (info::kind != FFTType::C2R || !is_forward));
621 :
622 : // For C2C, half_spectrum_shape == in_shape == out_shape
623 0 : IndexVector_t half_spectrum_shape;
624 823 : if constexpr (info::kind == FFTType::C2R) {
625 811 : half_spectrum_shape = in_shape;
626 811 : } else {
627 811 : half_spectrum_shape = out_shape;
628 811 : }
629 :
630 : // exclusive
631 0 : index_t end_dim;
632 0 : const typename info::complex_t* current_input;
633 0 : typename info::complex_t* working_area;
634 0 : std::unique_ptr<typename info::complex_t[]> tmp;
635 823 : if constexpr (info::kind == FFTType::R2C) {
636 814 : end_dim = dims - 1;
637 814 : fftHostSingleDim<is_forward>(in_data, in_shape, out_data, out_shape, end_dim,
638 814 : norm);
639 814 : current_input = out_data;
640 814 : working_area = out_data;
641 814 : } else if constexpr (info::kind == FFTType::C2R) {
642 802 : end_dim = dims - 1;
643 802 : current_input = in_data;
644 802 : tmp = std::make_unique<typename info::complex_t[]>(in_shape.prod());
645 802 : working_area = tmp.get();
646 802 : } else {
647 802 : end_dim = dims;
648 802 : current_input = in_data;
649 802 : working_area = out_data;
650 802 : }
651 :
652 : // generalization of an 1D-FFT
653 : // walk over each dimension and 1d-fft one 'line' of data
654 2442 : for (index_t dim_idx = 0; dim_idx < end_dim; ++dim_idx) {
655 1619 : fftHostSingleDim<is_forward>(current_input, half_spectrum_shape, working_area,
656 1619 : half_spectrum_shape, dim_idx, norm);
657 1619 : current_input = working_area;
658 1619 : }
659 :
660 823 : if constexpr (info::kind == FFTType::C2R) {
661 12 : fftHostSingleDim<is_forward>(current_input, in_shape, out_data, out_shape,
662 12 : end_dim, norm);
663 12 : }
664 823 : } else {
665 0 : static_cast<void>(in_data);
666 0 : static_cast<void>(out_data);
667 0 : static_cast<void>(in_shape);
668 0 : static_cast<void>(out_shape);
669 0 : static_cast<void>(dims);
670 0 : static_cast<void>(norm);
671 : // TODO: more concrete error with specific types
672 0 : throw Error{"unsupported FFT types"};
673 0 : }
674 823 : }
675 : } // namespace detail
676 :
677 : template <class data_t>
678 : void fft(ContiguousStorage<data_t>& x, const DataDescriptor& desc, FFTNorm norm,
679 : FFTPolicy policy)
680 352 : {
681 352 : const auto& src_shape = desc.getNumberOfCoefficientsPerDimension();
682 352 : const auto& src_dims = desc.getNumberOfDimensions();
683 :
684 : #ifdef ELSA_CUDA_TOOLKIT_PRESENT
685 : if (policy != FFTPolicy::HOST) {
686 : if (detail::fftDevice<data_t, true>(x.data(), src_shape, norm)) {
687 : return;
688 : }
689 : }
690 : #endif
691 352 : if (policy != FFTPolicy::DEVICE) {
692 352 : detail::fftHost<data_t, data_t, true>(x.data().get(), src_shape, x.data().get(),
693 352 : src_shape, src_dims, norm);
694 352 : } else {
695 0 : throw Error{"Cannot sucessfully execute FFT with policy {}", policy};
696 0 : }
697 352 : }
698 :
699 : template <class data_t>
700 : void ifft(ContiguousStorage<data_t>& x, const DataDescriptor& desc, FFTNorm norm,
701 : FFTPolicy policy)
702 450 : {
703 450 : const auto& src_shape = desc.getNumberOfCoefficientsPerDimension();
704 450 : const auto& src_dims = desc.getNumberOfDimensions();
705 :
706 : #ifdef ELSA_CUDA_TOOLKIT_PRESENT
707 : if (policy != FFTPolicy::HOST) {
708 : if (detail::fftDevice<data_t, false>(x.data(), src_shape, norm)) {
709 : return;
710 : }
711 : }
712 : #endif
713 450 : if (policy != FFTPolicy::DEVICE) {
714 450 : detail::fftHost<data_t, data_t, false>(x.data().get(), src_shape, x.data().get(),
715 450 : src_shape, src_dims, norm);
716 450 : } else {
717 0 : throw Error{"Cannot sucessfully execute IFFT with policy {}", policy};
718 0 : }
719 450 : }
720 :
721 : template <typename data_t>
722 : DataContainer<complex<data_t>> rfft(const DataContainer<data_t>& dc, FFTNorm norm,
723 : FFTPolicy policy)
724 9 : {
725 9 : const auto& desc = dc.getDataDescriptor();
726 9 : const auto& src_shape = desc.getNumberOfCoefficientsPerDimension();
727 :
728 9 : auto [dst_shape, normalization_constant] =
729 9 : detail::computeOutputShapeAndNormalization<detail::FFTType::R2C>(src_shape);
730 9 : DataContainer<complex<data_t>> out_dc(*std::make_unique<VolumeDescriptor>(dst_shape));
731 :
732 : #ifdef ELSA_CUDA_TOOLKIT_PRESENT
733 : if (policy != FFTPolicy::HOST) {
734 : if (detail::rfftDevice<data_t>(dc.storage().data(), src_shape, dst_shape,
735 : out_dc.storage().data(), norm, normalization_constant)) {
736 : return out_dc;
737 : }
738 : }
739 : #endif
740 9 : if (policy != FFTPolicy::DEVICE) {
741 9 : detail::fftHost<data_t, complex<data_t>, true>(dc.storage().data().get(), src_shape,
742 9 : out_dc.storage().data().get(), dst_shape,
743 9 : src_shape.size(), norm);
744 9 : return out_dc;
745 9 : } else {
746 0 : throw Error{"Cannot sucessfully execute RFFT with policy {}", policy};
747 0 : }
748 9 : }
749 :
750 : template <typename data_t>
751 : DataContainer<data_t> irfft(const DataContainer<complex<data_t>>& dc, FFTNorm norm,
752 : FFTPolicy policy)
753 12 : {
754 12 : const auto& desc = dc.getDataDescriptor();
755 12 : const auto& src_shape = desc.getNumberOfCoefficientsPerDimension();
756 :
757 12 : auto [dst_shape, normalization_constant] =
758 12 : detail::computeOutputShapeAndNormalization<detail::FFTType::C2R>(src_shape);
759 12 : DataContainer<data_t> out_dc(*std::make_unique<VolumeDescriptor>(dst_shape));
760 :
761 : #ifdef ELSA_CUDA_TOOLKIT_PRESENT
762 : if (policy != FFTPolicy::HOST) {
763 : if (detail::irfftDevice<data_t>(dc.storage().data(), src_shape, dst_shape,
764 : out_dc.storage().data(), norm,
765 : normalization_constant)) {
766 : return out_dc;
767 : }
768 : }
769 : #endif
770 12 : if (policy != FFTPolicy::DEVICE) {
771 12 : detail::fftHost<complex<data_t>, data_t, false>(dc.storage().data().get(), src_shape,
772 12 : out_dc.storage().data().get(),
773 12 : dst_shape, src_shape.size(), norm);
774 12 : } else {
775 0 : throw Error{"Cannot sucessfully execute IRFFT with policy {}", policy};
776 0 : }
777 12 : return out_dc;
778 12 : }
779 :
780 : } // namespace elsa
|