rustfft/avx/
avx_mixed_radix.rs

1use std::any::TypeId;
2use std::sync::Arc;
3
4use num_complex::Complex;
5use num_integer::div_ceil;
6
7use crate::array_utils;
8use crate::{Direction, Fft, FftDirection, FftNum, Length};
9
10use super::{AvxNum, CommonSimdData};
11
12use super::avx_vector;
13use super::avx_vector::{AvxArray, AvxArrayMut, AvxVector, AvxVector128, AvxVector256, Rotation90};
14
15macro_rules! boilerplate_mixedradix {
16    () => {
17        /// Preallocates necessary arrays and precomputes necessary data to efficiently compute the FFT
18        /// Returns Ok() if this machine has the required instruction sets, Err() if some instruction sets are missing
19        #[inline]
20        pub fn new(inner_fft: Arc<dyn Fft<T>>) -> Result<Self, ()> {
21            // Internal sanity check: Make sure that A == T.
22            // This struct has two generic parameters A and T, but they must always be the same, and are only kept separate to help work around the lack of specialization.
23            // It would be cool if we could do this as a static_assert instead
24            let id_a = TypeId::of::<A>();
25            let id_t = TypeId::of::<T>();
26            assert_eq!(id_a, id_t);
27
28            let has_avx = is_x86_feature_detected!("avx");
29            let has_fma = is_x86_feature_detected!("fma");
30            if has_avx && has_fma {
31                // Safety: new_with_avx requires the "avx" feature set. Since we know it's present, we're safe
32                Ok(unsafe { Self::new_with_avx(inner_fft) })
33            } else {
34                Err(())
35            }
36        }
37
38        #[target_feature(enable = "avx", enable = "fma")]
39        unsafe fn perform_fft_inplace(
40            &self,
41            buffer: &mut [Complex<T>],
42            scratch: &mut [Complex<T>],
43        ) {
44            // Perform the column FFTs
45            // Safety: self.perform_column_butterflies() requres the "avx" and "fma" instruction sets, and we return Err() in our constructor if the instructions aren't available
46            unsafe {
47                // Specialization workaround: See the comments in FftPlannerAvx::new() for why these calls to array_utils::workaround_transmute are necessary
48                let transmuted_buffer: &mut [Complex<A>] =
49                    array_utils::workaround_transmute_mut(buffer);
50
51                self.perform_column_butterflies(transmuted_buffer)
52            }
53
54            // process the row FFTs
55            let (scratch, inner_scratch) = scratch.split_at_mut(self.len());
56            self.common_data.inner_fft.process_outofplace_with_scratch(
57                buffer,
58                scratch,
59                inner_scratch,
60            );
61
62            // Transpose
63            // Safety: self.transpose() requres the "avx" instruction set, and we return Err() in our constructor if the instructions aren't available
64            unsafe {
65                // Specialization workaround: See the comments in FftPlannerAvx::new() for why these calls to array_utils::workaround_transmute are necessary
66                let transmuted_scratch: &mut [Complex<A>] =
67                    array_utils::workaround_transmute_mut(scratch);
68                let transmuted_buffer: &mut [Complex<A>] =
69                    array_utils::workaround_transmute_mut(buffer);
70
71                self.transpose(transmuted_scratch, transmuted_buffer)
72            }
73        }
74
75        #[target_feature(enable = "avx", enable = "fma")]
76        unsafe fn perform_fft_immut(
77            &self,
78            input: &[Complex<T>],
79            output: &mut [Complex<T>],
80            scratch: &mut [Complex<T>],
81        ) {
82            // Perform the column FFTs
83            let (scratch, inner_scratch) = scratch.split_at_mut(input.len());
84            {
85                // Specialization workaround: See the comments in FftPlannerAvx::new() for why these calls to array_utils::workaround_transmute are necessary
86                let transmuted_input: &[Complex<A>] = array_utils::workaround_transmute(input);
87                let transmuted_output: &mut [Complex<A>] =
88                    array_utils::workaround_transmute_mut(scratch);
89
90                self.perform_column_butterflies_immut(transmuted_input, transmuted_output);
91            }
92
93            // process the row FFTs. If extra scratch was provided, pass it in. Otherwise, use the output.
94            self.common_data
95                .inner_fft
96                .process_with_scratch(scratch, inner_scratch);
97
98            // Transpose
99            {
100                // Specialization workaround: See the comments in FftPlannerAvx::new() for why these calls to array_utils::workaround_transmute are necessary
101                let transmuted_input: &mut [Complex<A>] =
102                    array_utils::workaround_transmute_mut(scratch);
103                let transmuted_output: &mut [Complex<A>] =
104                    array_utils::workaround_transmute_mut(output);
105
106                self.transpose(transmuted_input, transmuted_output)
107            }
108        }
109
110        #[target_feature(enable = "avx", enable = "fma")]
111        unsafe fn perform_fft_out_of_place(
112            &self,
113            input: &mut [Complex<T>],
114            output: &mut [Complex<T>],
115            scratch: &mut [Complex<T>],
116        ) {
117            // Perform the column FFTs
118            {
119                // Specialization workaround: See the comments in FftPlannerAvx::new() for why these calls to array_utils::workaround_transmute are necessary
120                let transmuted_input: &mut [Complex<A>] =
121                    array_utils::workaround_transmute_mut(input);
122                self.perform_column_butterflies(transmuted_input);
123            }
124
125            // process the row FFTs. If extra scratch was provided, pass it in. Otherwise, use the output.
126            let inner_scratch = if scratch.len() > 0 {
127                scratch
128            } else {
129                &mut output[..]
130            };
131            self.common_data
132                .inner_fft
133                .process_with_scratch(input, inner_scratch);
134
135            // Transpose
136            {
137                // Specialization workaround: See the comments in FftPlannerAvx::new() for why these calls to array_utils::workaround_transmute are necessary
138                let transmuted_input: &mut [Complex<A>] =
139                    array_utils::workaround_transmute_mut(input);
140                let transmuted_output: &mut [Complex<A>] =
141                    array_utils::workaround_transmute_mut(output);
142
143                self.transpose(transmuted_input, transmuted_output)
144            }
145        }
146    };
147}
148
149macro_rules! mixedradix_gen_data {
150    ($row_count: expr, $inner_fft:expr) => {{
151        // Important constants
152        const ROW_COUNT : usize = $row_count;
153        const TWIDDLES_PER_COLUMN : usize = ROW_COUNT - 1;
154
155        // derive some info from our inner FFT
156        let direction = $inner_fft.fft_direction();
157        let len_per_row = $inner_fft.len();
158        let len = len_per_row * ROW_COUNT;
159
160        // We're going to process each row of the FFT one AVX register at a time. We need to know how many AVX registers each row can fit,
161        // and if the last register in each row going to have partial data (ie a remainder)
162        let quotient = len_per_row / A::VectorType::COMPLEX_PER_VECTOR;
163        let remainder = len_per_row % A::VectorType::COMPLEX_PER_VECTOR;
164
165        // Compute our twiddle factors, and arrange them so that we can access them one column of AVX vectors at a time
166        let num_twiddle_columns = quotient + div_ceil(remainder, A::VectorType::COMPLEX_PER_VECTOR);
167        let mut twiddles = Vec::with_capacity(num_twiddle_columns * TWIDDLES_PER_COLUMN);
168        for x in 0..num_twiddle_columns {
169            for y in 1..ROW_COUNT {
170                twiddles.push(AvxVector::make_mixedradix_twiddle_chunk(x * A::VectorType::COMPLEX_PER_VECTOR, y, len, direction));
171            }
172        }
173
174        let inner_outofplace_scratch = $inner_fft.get_outofplace_scratch_len();
175        let inner_inplace_scratch = $inner_fft.get_inplace_scratch_len();
176        let immut_scratch_len = len + $inner_fft.get_inplace_scratch_len();
177
178        CommonSimdData {
179            twiddles: twiddles.into_boxed_slice(),
180            inplace_scratch_len: len + inner_outofplace_scratch,
181            outofplace_scratch_len: if inner_inplace_scratch > len { inner_inplace_scratch } else { 0 },
182            immut_scratch_len,
183            inner_fft: $inner_fft,
184            len,
185            direction,
186        }
187    }}
188}
189
190macro_rules! mixedradix_column_butterflies {
191    ($row_count: expr, $butterfly_fn: expr, $butterfly_fn_lo: expr) => {
192        #[target_feature(enable = "avx", enable = "fma")]
193        unsafe fn perform_column_butterflies_immut(
194            &self,
195            input: impl AvxArray<A>,
196            mut buffer: impl AvxArrayMut<A>,
197        ) {
198            // How many rows this FFT has, ie 2 for 2xn, 4 for 4xn, etc
199            const ROW_COUNT: usize = $row_count;
200            const TWIDDLES_PER_COLUMN: usize = ROW_COUNT - 1;
201
202            let len_per_row = self.len() / ROW_COUNT;
203            let chunk_count = len_per_row / A::VectorType::COMPLEX_PER_VECTOR;
204
205            // process the column FFTs
206            for (c, twiddle_chunk) in self
207                .common_data
208                .twiddles
209                .chunks_exact(TWIDDLES_PER_COLUMN)
210                .take(chunk_count)
211                .enumerate()
212            {
213                let index_base = c * A::VectorType::COMPLEX_PER_VECTOR;
214
215                // Load columns from the input into registers
216                let mut columns = [AvxVector::zero(); ROW_COUNT];
217                for i in 0..ROW_COUNT {
218                    columns[i] = input.load_complex(index_base + len_per_row * i);
219                }
220
221                // apply our butterfly function down the columns
222                let output = $butterfly_fn(columns, self);
223
224                // always write the first row directly back without twiddles
225                buffer.store_complex(output[0], index_base);
226
227                // for every other row, apply twiddle factors and then write back to memory
228                for i in 1..ROW_COUNT {
229                    let twiddle = twiddle_chunk[i - 1];
230                    let output = AvxVector::mul_complex(twiddle, output[i]);
231                    buffer.store_complex(output, index_base + len_per_row * i);
232                }
233            }
234
235            // finally, we might have a remainder chunk
236            // Normally, we can fit COMPLEX_PER_VECTOR complex numbers into an AVX register, but we only have `partial_remainder` columns left, so we need special logic to handle these final columns
237            let partial_remainder = len_per_row % A::VectorType::COMPLEX_PER_VECTOR;
238            if partial_remainder > 0 {
239                let partial_remainder_base = chunk_count * A::VectorType::COMPLEX_PER_VECTOR;
240                let partial_remainder_twiddle_base =
241                    self.common_data.twiddles.len() - TWIDDLES_PER_COLUMN;
242                let final_twiddle_chunk =
243                    &self.common_data.twiddles[partial_remainder_twiddle_base..];
244
245                if partial_remainder > 2 {
246                    // Load 3 columns into full AVX vectors to process our remainder
247                    let mut columns = [AvxVector::zero(); ROW_COUNT];
248                    for i in 0..ROW_COUNT {
249                        columns[i] =
250                            input.load_partial3_complex(partial_remainder_base + len_per_row * i);
251                    }
252
253                    // apply our butterfly function down the columns
254                    let mid = $butterfly_fn(columns, self);
255
256                    // always write the first row without twiddles
257                    buffer.store_partial3_complex(mid[0], partial_remainder_base);
258
259                    // for the remaining rows, apply twiddle factors and then write back to memory
260                    for i in 1..ROW_COUNT {
261                        let twiddle = final_twiddle_chunk[i - 1];
262                        let output = AvxVector::mul_complex(twiddle, mid[i]);
263                        buffer.store_partial3_complex(
264                            output,
265                            partial_remainder_base + len_per_row * i,
266                        );
267                    }
268                } else {
269                    // Load 1 or 2 columns into half vectors to process our remainder. Thankfully, the compiler is smart enough to eliminate this branch on f64, since the partial remainder can only possibly be 1
270                    let mut columns = [AvxVector::zero(); ROW_COUNT];
271                    if partial_remainder == 1 {
272                        for i in 0..ROW_COUNT {
273                            columns[i] = input
274                                .load_partial1_complex(partial_remainder_base + len_per_row * i);
275                        }
276                    } else {
277                        for i in 0..ROW_COUNT {
278                            columns[i] = input
279                                .load_partial2_complex(partial_remainder_base + len_per_row * i);
280                        }
281                    }
282
283                    // apply our butterfly function down the columns
284                    let mut mid = $butterfly_fn_lo(columns, self);
285
286                    // apply twiddle factors
287                    for i in 1..ROW_COUNT {
288                        mid[i] = AvxVector::mul_complex(final_twiddle_chunk[i - 1].lo(), mid[i]);
289                    }
290
291                    // store output
292                    if partial_remainder == 1 {
293                        for i in 0..ROW_COUNT {
294                            buffer.store_partial1_complex(
295                                mid[i],
296                                partial_remainder_base + len_per_row * i,
297                            );
298                        }
299                    } else {
300                        for i in 0..ROW_COUNT {
301                            buffer.store_partial2_complex(
302                                mid[i],
303                                partial_remainder_base + len_per_row * i,
304                            );
305                        }
306                    }
307                }
308            }
309        }
310        #[target_feature(enable = "avx", enable = "fma")]
311        unsafe fn perform_column_butterflies(&self, mut buffer: impl AvxArrayMut<A>) {
312            // How many rows this FFT has, ie 2 for 2xn, 4 for 4xn, etc
313            const ROW_COUNT: usize = $row_count;
314            const TWIDDLES_PER_COLUMN: usize = ROW_COUNT - 1;
315
316            let len_per_row = self.len() / ROW_COUNT;
317            let chunk_count = len_per_row / A::VectorType::COMPLEX_PER_VECTOR;
318
319            // process the column FFTs
320            for (c, twiddle_chunk) in self
321                .common_data
322                .twiddles
323                .chunks_exact(TWIDDLES_PER_COLUMN)
324                .take(chunk_count)
325                .enumerate()
326            {
327                let index_base = c * A::VectorType::COMPLEX_PER_VECTOR;
328
329                // Load columns from the buffer into registers
330                let mut columns = [AvxVector::zero(); ROW_COUNT];
331                for i in 0..ROW_COUNT {
332                    columns[i] = buffer.load_complex(index_base + len_per_row * i);
333                }
334
335                // apply our butterfly function down the columns
336                let output = $butterfly_fn(columns, self);
337
338                // always write the first row directly back without twiddles
339                buffer.store_complex(output[0], index_base);
340
341                // for every other row, apply twiddle factors and then write back to memory
342                for i in 1..ROW_COUNT {
343                    let twiddle = twiddle_chunk[i - 1];
344                    let output = AvxVector::mul_complex(twiddle, output[i]);
345                    buffer.store_complex(output, index_base + len_per_row * i);
346                }
347            }
348
349            // finally, we might have a remainder chunk
350            // Normally, we can fit COMPLEX_PER_VECTOR complex numbers into an AVX register, but we only have `partial_remainder` columns left, so we need special logic to handle these final columns
351            let partial_remainder = len_per_row % A::VectorType::COMPLEX_PER_VECTOR;
352            if partial_remainder > 0 {
353                let partial_remainder_base = chunk_count * A::VectorType::COMPLEX_PER_VECTOR;
354                let partial_remainder_twiddle_base =
355                    self.common_data.twiddles.len() - TWIDDLES_PER_COLUMN;
356                let final_twiddle_chunk =
357                    &self.common_data.twiddles[partial_remainder_twiddle_base..];
358
359                if partial_remainder > 2 {
360                    // Load 3 columns into full AVX vectors to process our remainder
361                    let mut columns = [AvxVector::zero(); ROW_COUNT];
362                    for i in 0..ROW_COUNT {
363                        columns[i] =
364                            buffer.load_partial3_complex(partial_remainder_base + len_per_row * i);
365                    }
366
367                    // apply our butterfly function down the columns
368                    let mid = $butterfly_fn(columns, self);
369
370                    // always write the first row without twiddles
371                    buffer.store_partial3_complex(mid[0], partial_remainder_base);
372
373                    // for the remaining rows, apply twiddle factors and then write back to memory
374                    for i in 1..ROW_COUNT {
375                        let twiddle = final_twiddle_chunk[i - 1];
376                        let output = AvxVector::mul_complex(twiddle, mid[i]);
377                        buffer.store_partial3_complex(
378                            output,
379                            partial_remainder_base + len_per_row * i,
380                        );
381                    }
382                } else {
383                    // Load 1 or 2 columns into half vectors to process our remainder. Thankfully, the compiler is smart enough to eliminate this branch on f64, since the partial remainder can only possibly be 1
384                    let mut columns = [AvxVector::zero(); ROW_COUNT];
385                    if partial_remainder == 1 {
386                        for i in 0..ROW_COUNT {
387                            columns[i] = buffer
388                                .load_partial1_complex(partial_remainder_base + len_per_row * i);
389                        }
390                    } else {
391                        for i in 0..ROW_COUNT {
392                            columns[i] = buffer
393                                .load_partial2_complex(partial_remainder_base + len_per_row * i);
394                        }
395                    }
396
397                    // apply our butterfly function down the columns
398                    let mut mid = $butterfly_fn_lo(columns, self);
399
400                    // apply twiddle factors
401                    for i in 1..ROW_COUNT {
402                        mid[i] = AvxVector::mul_complex(final_twiddle_chunk[i - 1].lo(), mid[i]);
403                    }
404
405                    // store output
406                    if partial_remainder == 1 {
407                        for i in 0..ROW_COUNT {
408                            buffer.store_partial1_complex(
409                                mid[i],
410                                partial_remainder_base + len_per_row * i,
411                            );
412                        }
413                    } else {
414                        for i in 0..ROW_COUNT {
415                            buffer.store_partial2_complex(
416                                mid[i],
417                                partial_remainder_base + len_per_row * i,
418                            );
419                        }
420                    }
421                }
422            }
423        }
424    };
425}
426
427macro_rules! mixedradix_transpose{
428    ($row_count: expr, $transpose_fn: path, $transpose_fn_lo: path, $($unroll_workaround_index:expr);*, $($remainder3_unroll_workaround_index:expr);*) => (
429
430    // Transpose the input (treated as a nxc array) into the output (as a cxn array)
431    #[target_feature(enable = "avx")]
432    unsafe fn transpose(&self, input: &[Complex<A>], mut output: &mut [Complex<A>]) {
433        const ROW_COUNT : usize = $row_count;
434
435        let len_per_row = self.len() / ROW_COUNT;
436        let chunk_count = len_per_row / A::VectorType::COMPLEX_PER_VECTOR;
437
438        // transpose the scratch as a nx2 array into the buffer as an 2xn array
439        for c in 0..chunk_count {
440            let input_index_base = c*A::VectorType::COMPLEX_PER_VECTOR;
441            let output_index_base = input_index_base * ROW_COUNT;
442
443            // Load rows from the input into registers
444            let mut rows : [A::VectorType; ROW_COUNT] = [AvxVector::zero(); ROW_COUNT];
445            for i in 0..ROW_COUNT {
446                rows[i] = input.load_complex(input_index_base + len_per_row*i);
447            }
448
449            // transpose the rows to the columns
450            let transposed = $transpose_fn(rows);
451
452            // store the transposed rows contiguously
453            // IE, unlike the way we loaded the data, which was to load it strided across each of our rows
454            // we will not output it strided, but instead writing it out as a contiguous block
455
456            // we are using a macro hack to manually unroll the loop, to work around this rustc bug:
457            // https://github.com/rust-lang/rust/issues/71025
458
459            // if we don't manually unroll the loop, the compiler will insert unnecessary writes+reads to the stack which tank performance
460            // once the compiler bug is fixed, this can be replaced by a "for i in 0..ROW_COUNT" loop
461            $(
462                output.store_complex(transposed[$unroll_workaround_index], output_index_base + A::VectorType::COMPLEX_PER_VECTOR * $unroll_workaround_index);
463            )*
464        }
465
466        // transpose the remainder
467        let input_index_base = chunk_count * A::VectorType::COMPLEX_PER_VECTOR;
468        let output_index_base = input_index_base * ROW_COUNT;
469
470        let partial_remainder = len_per_row % A::VectorType::COMPLEX_PER_VECTOR;
471        if partial_remainder == 1 {
472            // If the partial remainder is 1, there's no transposing to do - just gather from across the rows and store contiguously
473            for i in 0..ROW_COUNT {
474                let input_cell = input.get_unchecked(input_index_base + len_per_row*i);
475                let output_cell = output.get_unchecked_mut(output_index_base + i);
476                *output_cell = *input_cell;
477            }
478        } else if partial_remainder == 2 {
479            // If the partial remainder is 2, use the provided transpose_lo function to do a transpose on half-vectors
480            let mut rows = [AvxVector::zero(); ROW_COUNT];
481            for i in 0..ROW_COUNT {
482                rows[i] = input.load_partial2_complex(input_index_base + len_per_row*i);
483            }
484
485            let transposed = $transpose_fn_lo(rows);
486
487            // use the same macro hack as above to unroll the loop
488            $(
489                output.store_partial2_complex(transposed[$unroll_workaround_index], output_index_base + <A::VectorType as AvxVector256>::HalfVector::COMPLEX_PER_VECTOR * $unroll_workaround_index);
490            )*
491        }
492        else if partial_remainder == 3 {
493            // If the partial remainder is 3, we have to load full vectors, use the full transpose, and then write out a variable number of outputs
494            let mut rows = [AvxVector::zero(); ROW_COUNT];
495            for i in 0..ROW_COUNT {
496                rows[i] = input.load_partial3_complex(input_index_base + len_per_row*i);
497            }
498
499            // transpose the rows to the columns
500            let transposed = $transpose_fn(rows);
501
502            // We're going to write constant number of full vectors, and then some constant-sized partial vector
503            // Sadly, because of rust limitations, we can't make full_vector_count a const, so we have to cross our fingers that the compiler optimizes it to a constant
504            let element_count = 3*ROW_COUNT;
505            let full_vector_count = element_count / A::VectorType::COMPLEX_PER_VECTOR;
506            let final_remainder_count = element_count % A::VectorType::COMPLEX_PER_VECTOR;
507
508            // write out our full vectors
509            // we are using a macro hack to manually unroll the loop, to work around this rustc bug:
510            // https://github.com/rust-lang/rust/issues/71025
511
512            // if we don't manually unroll the loop, the compiler will insert unnecessary writes+reads to the stack which tank performance
513            // once the compiler bug is fixed, this can be replaced by a "for i in 0..full_vector_count" loop
514            $(
515                output.store_complex(transposed[$remainder3_unroll_workaround_index], output_index_base + A::VectorType::COMPLEX_PER_VECTOR * $remainder3_unroll_workaround_index);
516            )*
517
518            // write out our partial vector. again, this is a compile-time constant, even if we can't represent that within rust yet
519            match final_remainder_count {
520                0 => {},
521                1 => output.store_partial1_complex(transposed[full_vector_count].lo(), output_index_base + full_vector_count * A::VectorType::COMPLEX_PER_VECTOR),
522                2 => output.store_partial2_complex(transposed[full_vector_count].lo(), output_index_base + full_vector_count * A::VectorType::COMPLEX_PER_VECTOR),
523                3 => output.store_partial3_complex(transposed[full_vector_count], output_index_base + full_vector_count * A::VectorType::COMPLEX_PER_VECTOR),
524                _ => unreachable!(),
525            }
526        }
527    }
528)}
529
530pub struct MixedRadix2xnAvx<A: AvxNum, T> {
531    common_data: CommonSimdData<T, A::VectorType>,
532    _phantom: std::marker::PhantomData<T>,
533}
534boilerplate_avx_fft_commondata!(MixedRadix2xnAvx);
535
536impl<A: AvxNum, T: FftNum> MixedRadix2xnAvx<A, T> {
537    #[target_feature(enable = "avx")]
538    unsafe fn new_with_avx(inner_fft: Arc<dyn Fft<T>>) -> Self {
539        Self {
540            common_data: mixedradix_gen_data!(2, inner_fft),
541            _phantom: std::marker::PhantomData,
542        }
543    }
544    mixedradix_column_butterflies!(
545        2,
546        |columns, _: _| AvxVector::column_butterfly2(columns),
547        |columns, _: _| AvxVector::column_butterfly2(columns)
548    );
549    mixedradix_transpose!(2,
550        AvxVector::transpose2_packed,
551        AvxVector::transpose2_packed,
552        0;1, 0
553    );
554    boilerplate_mixedradix!();
555}
556
557pub struct MixedRadix3xnAvx<A: AvxNum, T> {
558    twiddles_butterfly3: A::VectorType,
559    common_data: CommonSimdData<T, A::VectorType>,
560    _phantom: std::marker::PhantomData<T>,
561}
562boilerplate_avx_fft_commondata!(MixedRadix3xnAvx);
563
564impl<A: AvxNum, T: FftNum> MixedRadix3xnAvx<A, T> {
565    #[target_feature(enable = "avx")]
566    unsafe fn new_with_avx(inner_fft: Arc<dyn Fft<T>>) -> Self {
567        Self {
568            twiddles_butterfly3: AvxVector::broadcast_twiddle(1, 3, inner_fft.fft_direction()),
569            common_data: mixedradix_gen_data!(3, inner_fft),
570            _phantom: std::marker::PhantomData,
571        }
572    }
573    mixedradix_column_butterflies!(
574        3,
575        |columns, this: &Self| AvxVector::column_butterfly3(columns, this.twiddles_butterfly3),
576        |columns, this: &Self| AvxVector::column_butterfly3(columns, this.twiddles_butterfly3.lo())
577    );
578    mixedradix_transpose!(3,
579        AvxVector::transpose3_packed,
580        AvxVector::transpose3_packed,
581        0;1;2, 0;1
582    );
583    boilerplate_mixedradix!();
584}
585
586pub struct MixedRadix4xnAvx<A: AvxNum, T> {
587    twiddles_butterfly4: Rotation90<A::VectorType>,
588    common_data: CommonSimdData<T, A::VectorType>,
589    _phantom: std::marker::PhantomData<T>,
590}
591boilerplate_avx_fft_commondata!(MixedRadix4xnAvx);
592
593impl<A: AvxNum, T: FftNum> MixedRadix4xnAvx<A, T> {
594    #[target_feature(enable = "avx")]
595    unsafe fn new_with_avx(inner_fft: Arc<dyn Fft<T>>) -> Self {
596        Self {
597            twiddles_butterfly4: AvxVector::make_rotation90(inner_fft.fft_direction()),
598            common_data: mixedradix_gen_data!(4, inner_fft),
599            _phantom: std::marker::PhantomData,
600        }
601    }
602    mixedradix_column_butterflies!(
603        4,
604        |columns, this: &Self| AvxVector::column_butterfly4(columns, this.twiddles_butterfly4),
605        |columns, this: &Self| AvxVector::column_butterfly4(columns, this.twiddles_butterfly4.lo())
606    );
607    mixedradix_transpose!(4,
608        AvxVector::transpose4_packed,
609        AvxVector::transpose4_packed,
610        0;1;2;3, 0;1;2
611    );
612    boilerplate_mixedradix!();
613}
614
615pub struct MixedRadix5xnAvx<A: AvxNum, T> {
616    twiddles_butterfly5: [A::VectorType; 2],
617    common_data: CommonSimdData<T, A::VectorType>,
618    _phantom: std::marker::PhantomData<T>,
619}
620boilerplate_avx_fft_commondata!(MixedRadix5xnAvx);
621
622impl<A: AvxNum, T: FftNum> MixedRadix5xnAvx<A, T> {
623    #[target_feature(enable = "avx")]
624    unsafe fn new_with_avx(inner_fft: Arc<dyn Fft<T>>) -> Self {
625        Self {
626            twiddles_butterfly5: [
627                AvxVector::broadcast_twiddle(1, 5, inner_fft.fft_direction()),
628                AvxVector::broadcast_twiddle(2, 5, inner_fft.fft_direction()),
629            ],
630            common_data: mixedradix_gen_data!(5, inner_fft),
631            _phantom: std::marker::PhantomData,
632        }
633    }
634    mixedradix_column_butterflies!(
635        5,
636        |columns, this: &Self| AvxVector::column_butterfly5(columns, this.twiddles_butterfly5),
637        |columns, this: &Self| AvxVector::column_butterfly5(
638            columns,
639            [
640                this.twiddles_butterfly5[0].lo(),
641                this.twiddles_butterfly5[1].lo()
642            ]
643        )
644    );
645    mixedradix_transpose!(5,
646        AvxVector::transpose5_packed,
647        AvxVector::transpose5_packed,
648        0;1;2;3;4, 0;1;2
649    );
650    boilerplate_mixedradix!();
651}
652
653pub struct MixedRadix6xnAvx<A: AvxNum, T> {
654    twiddles_butterfly3: A::VectorType,
655    common_data: CommonSimdData<T, A::VectorType>,
656    _phantom: std::marker::PhantomData<T>,
657}
658boilerplate_avx_fft_commondata!(MixedRadix6xnAvx);
659
660impl<A: AvxNum, T: FftNum> MixedRadix6xnAvx<A, T> {
661    #[target_feature(enable = "avx")]
662    unsafe fn new_with_avx(inner_fft: Arc<dyn Fft<T>>) -> Self {
663        Self {
664            twiddles_butterfly3: AvxVector::broadcast_twiddle(1, 3, inner_fft.fft_direction()),
665            common_data: mixedradix_gen_data!(6, inner_fft),
666            _phantom: std::marker::PhantomData,
667        }
668    }
669    mixedradix_column_butterflies!(
670        6,
671        |columns, this: &Self| AvxVector256::column_butterfly6(columns, this.twiddles_butterfly3),
672        |columns, this: &Self| AvxVector128::column_butterfly6(columns, this.twiddles_butterfly3)
673    );
674    mixedradix_transpose!(6,
675        AvxVector::transpose6_packed,
676        AvxVector::transpose6_packed,
677        0;1;2;3;4;5, 0;1;2;3
678    );
679    boilerplate_mixedradix!();
680}
681
682pub struct MixedRadix7xnAvx<A: AvxNum, T> {
683    twiddles_butterfly7: [A::VectorType; 3],
684    common_data: CommonSimdData<T, A::VectorType>,
685    _phantom: std::marker::PhantomData<T>,
686}
687boilerplate_avx_fft_commondata!(MixedRadix7xnAvx);
688
689impl<A: AvxNum, T: FftNum> MixedRadix7xnAvx<A, T> {
690    #[target_feature(enable = "avx")]
691    unsafe fn new_with_avx(inner_fft: Arc<dyn Fft<T>>) -> Self {
692        Self {
693            twiddles_butterfly7: [
694                AvxVector::broadcast_twiddle(1, 7, inner_fft.fft_direction()),
695                AvxVector::broadcast_twiddle(2, 7, inner_fft.fft_direction()),
696                AvxVector::broadcast_twiddle(3, 7, inner_fft.fft_direction()),
697            ],
698            common_data: mixedradix_gen_data!(7, inner_fft),
699            _phantom: std::marker::PhantomData,
700        }
701    }
702    mixedradix_column_butterflies!(
703        7,
704        |columns, this: &Self| AvxVector::column_butterfly7(columns, this.twiddles_butterfly7),
705        |columns, this: &Self| AvxVector::column_butterfly7(
706            columns,
707            [
708                this.twiddles_butterfly7[0].lo(),
709                this.twiddles_butterfly7[1].lo(),
710                this.twiddles_butterfly7[2].lo()
711            ]
712        )
713    );
714    mixedradix_transpose!(7,
715        AvxVector::transpose7_packed,
716        AvxVector::transpose7_packed,
717        0;1;2;3;4;5;6, 0;1;2;3;4
718    );
719    boilerplate_mixedradix!();
720}
721
722pub struct MixedRadix8xnAvx<A: AvxNum, T> {
723    twiddles_butterfly4: Rotation90<A::VectorType>,
724    common_data: CommonSimdData<T, A::VectorType>,
725    _phantom: std::marker::PhantomData<T>,
726}
727boilerplate_avx_fft_commondata!(MixedRadix8xnAvx);
728
729impl<A: AvxNum, T: FftNum> MixedRadix8xnAvx<A, T> {
730    #[target_feature(enable = "avx")]
731    unsafe fn new_with_avx(inner_fft: Arc<dyn Fft<T>>) -> Self {
732        Self {
733            twiddles_butterfly4: AvxVector::make_rotation90(inner_fft.fft_direction()),
734            common_data: mixedradix_gen_data!(8, inner_fft),
735            _phantom: std::marker::PhantomData,
736        }
737    }
738
739    mixedradix_column_butterflies!(
740        8,
741        |columns, this: &Self| AvxVector::column_butterfly8(columns, this.twiddles_butterfly4),
742        |columns, this: &Self| AvxVector::column_butterfly8(columns, this.twiddles_butterfly4.lo())
743    );
744    mixedradix_transpose!(8,
745        AvxVector::transpose8_packed,
746        AvxVector::transpose8_packed,
747        0;1;2;3;4;5;6;7, 0;1;2;3;4;5
748    );
749    boilerplate_mixedradix!();
750}
751
752pub struct MixedRadix9xnAvx<A: AvxNum, T> {
753    twiddles_butterfly9: [A::VectorType; 3],
754    twiddles_butterfly9_lo: [A::VectorType; 2],
755    twiddles_butterfly3: A::VectorType,
756    common_data: CommonSimdData<T, A::VectorType>,
757    _phantom: std::marker::PhantomData<T>,
758}
759boilerplate_avx_fft_commondata!(MixedRadix9xnAvx);
760
761impl<A: AvxNum, T: FftNum> MixedRadix9xnAvx<A, T> {
762    #[target_feature(enable = "avx")]
763    unsafe fn new_with_avx(inner_fft: Arc<dyn Fft<T>>) -> Self {
764        let inverse = inner_fft.fft_direction();
765
766        let twiddle1 = AvxVector::broadcast_twiddle(1, 9, inner_fft.fft_direction());
767        let twiddle2 = AvxVector::broadcast_twiddle(2, 9, inner_fft.fft_direction());
768        let twiddle4 = AvxVector::broadcast_twiddle(4, 9, inner_fft.fft_direction());
769
770        Self {
771            twiddles_butterfly9: [
772                AvxVector::broadcast_twiddle(1, 9, inverse),
773                AvxVector::broadcast_twiddle(2, 9, inverse),
774                AvxVector::broadcast_twiddle(4, 9, inverse),
775            ],
776            twiddles_butterfly9_lo: [
777                AvxVector256::merge(twiddle1, twiddle2),
778                AvxVector256::merge(twiddle2, twiddle4),
779            ],
780            twiddles_butterfly3: AvxVector::broadcast_twiddle(1, 3, inner_fft.fft_direction()),
781            common_data: mixedradix_gen_data!(9, inner_fft),
782            _phantom: std::marker::PhantomData,
783        }
784    }
785
786    mixedradix_column_butterflies!(
787        9,
788        |columns, this: &Self| AvxVector256::column_butterfly9(
789            columns,
790            this.twiddles_butterfly9,
791            this.twiddles_butterfly3
792        ),
793        |columns, this: &Self| AvxVector128::column_butterfly9(
794            columns,
795            this.twiddles_butterfly9_lo,
796            this.twiddles_butterfly3
797        )
798    );
799    mixedradix_transpose!(9,
800        AvxVector::transpose9_packed,
801        AvxVector::transpose9_packed,
802        0;1;2;3;4;5;6;7;8, 0;1;2;3;4;5
803    );
804    boilerplate_mixedradix!();
805}
806
807pub struct MixedRadix11xnAvx<A: AvxNum, T> {
808    twiddles_butterfly11: [A::VectorType; 5],
809    common_data: CommonSimdData<T, A::VectorType>,
810    _phantom: std::marker::PhantomData<T>,
811}
812boilerplate_avx_fft_commondata!(MixedRadix11xnAvx);
813
814impl<A: AvxNum, T: FftNum> MixedRadix11xnAvx<A, T> {
815    #[target_feature(enable = "avx")]
816    unsafe fn new_with_avx(inner_fft: Arc<dyn Fft<T>>) -> Self {
817        Self {
818            twiddles_butterfly11: [
819                AvxVector::broadcast_twiddle(1, 11, inner_fft.fft_direction()),
820                AvxVector::broadcast_twiddle(2, 11, inner_fft.fft_direction()),
821                AvxVector::broadcast_twiddle(3, 11, inner_fft.fft_direction()),
822                AvxVector::broadcast_twiddle(4, 11, inner_fft.fft_direction()),
823                AvxVector::broadcast_twiddle(5, 11, inner_fft.fft_direction()),
824            ],
825            common_data: mixedradix_gen_data!(11, inner_fft),
826            _phantom: std::marker::PhantomData,
827        }
828    }
829    mixedradix_column_butterflies!(
830        11,
831        |columns, this: &Self| AvxVector::column_butterfly11(columns, this.twiddles_butterfly11),
832        |columns, this: &Self| AvxVector::column_butterfly11(
833            columns,
834            [
835                this.twiddles_butterfly11[0].lo(),
836                this.twiddles_butterfly11[1].lo(),
837                this.twiddles_butterfly11[2].lo(),
838                this.twiddles_butterfly11[3].lo(),
839                this.twiddles_butterfly11[4].lo()
840            ]
841        )
842    );
843    mixedradix_transpose!(11,
844        AvxVector::transpose11_packed,
845        AvxVector::transpose11_packed,
846        0;1;2;3;4;5;6;7;8;9;10, 0;1;2;3;4;5;6;7
847    );
848    boilerplate_mixedradix!();
849}
850
851pub struct MixedRadix12xnAvx<A: AvxNum, T> {
852    twiddles_butterfly4: Rotation90<A::VectorType>,
853    twiddles_butterfly3: A::VectorType,
854    common_data: CommonSimdData<T, A::VectorType>,
855    _phantom: std::marker::PhantomData<T>,
856}
857boilerplate_avx_fft_commondata!(MixedRadix12xnAvx);
858
859impl<A: AvxNum, T: FftNum> MixedRadix12xnAvx<A, T> {
860    #[target_feature(enable = "avx")]
861    unsafe fn new_with_avx(inner_fft: Arc<dyn Fft<T>>) -> Self {
862        let inverse = inner_fft.fft_direction();
863        Self {
864            twiddles_butterfly4: AvxVector::make_rotation90(inverse),
865            twiddles_butterfly3: AvxVector::broadcast_twiddle(1, 3, inverse),
866            common_data: mixedradix_gen_data!(12, inner_fft),
867            _phantom: std::marker::PhantomData,
868        }
869    }
870
871    mixedradix_column_butterflies!(
872        12,
873        |columns, this: &Self| AvxVector256::column_butterfly12(
874            columns,
875            this.twiddles_butterfly3,
876            this.twiddles_butterfly4
877        ),
878        |columns, this: &Self| AvxVector128::column_butterfly12(
879            columns,
880            this.twiddles_butterfly3,
881            this.twiddles_butterfly4
882        )
883    );
884    mixedradix_transpose!(12,
885        AvxVector::transpose12_packed,
886        AvxVector::transpose12_packed,
887        0;1;2;3;4;5;6;7;8;9;10;11, 0;1;2;3;4;5;6;7;8
888    );
889    boilerplate_mixedradix!();
890}
891
892pub struct MixedRadix16xnAvx<A: AvxNum, T> {
893    twiddles_butterfly4: Rotation90<A::VectorType>,
894    twiddles_butterfly16: [A::VectorType; 2],
895    common_data: CommonSimdData<T, A::VectorType>,
896    _phantom: std::marker::PhantomData<T>,
897}
898boilerplate_avx_fft_commondata!(MixedRadix16xnAvx);
899
900impl<A: AvxNum, T: FftNum> MixedRadix16xnAvx<A, T> {
901    #[target_feature(enable = "avx")]
902    unsafe fn new_with_avx(inner_fft: Arc<dyn Fft<T>>) -> Self {
903        let inverse = inner_fft.fft_direction();
904        Self {
905            twiddles_butterfly4: AvxVector::make_rotation90(inner_fft.fft_direction()),
906            twiddles_butterfly16: [
907                AvxVector::broadcast_twiddle(1, 16, inverse),
908                AvxVector::broadcast_twiddle(3, 16, inverse),
909            ],
910            common_data: mixedradix_gen_data!(16, inner_fft),
911            _phantom: std::marker::PhantomData,
912        }
913    }
914
915    #[target_feature(enable = "avx", enable = "fma")]
916    unsafe fn perform_column_butterflies(&self, mut buffer: impl AvxArrayMut<A>) {
917        // How many rows this FFT has, ie 2 for 2xn, 4 for 4xn, etc
918        const ROW_COUNT: usize = 16;
919        const TWIDDLES_PER_COLUMN: usize = ROW_COUNT - 1;
920
921        let len_per_row = self.len() / ROW_COUNT;
922        let chunk_count = len_per_row / A::VectorType::COMPLEX_PER_VECTOR;
923
924        // process the column FFTs
925        for (c, twiddle_chunk) in self
926            .common_data
927            .twiddles
928            .chunks_exact(TWIDDLES_PER_COLUMN)
929            .take(chunk_count)
930            .enumerate()
931        {
932            let index_base = c * A::VectorType::COMPLEX_PER_VECTOR;
933
934            column_butterfly16_loadfn!(
935                |index| buffer.load_complex(index_base + len_per_row * index),
936                |mut data, index| {
937                    if index > 0 {
938                        data = AvxVector::mul_complex(data, twiddle_chunk[index - 1]);
939                    }
940                    buffer.store_complex(data, index_base + len_per_row * index)
941                },
942                self.twiddles_butterfly16,
943                self.twiddles_butterfly4
944            );
945        }
946
947        // finally, we might have a single partial chunk.
948        // Normally, we can fit 4 complex numbers into an AVX register, but we only have `partial_remainder` columns left, so we need special logic to handle these final columns
949        let partial_remainder = len_per_row % A::VectorType::COMPLEX_PER_VECTOR;
950        if partial_remainder > 0 {
951            let partial_remainder_base = chunk_count * A::VectorType::COMPLEX_PER_VECTOR;
952            let partial_remainder_twiddle_base =
953                self.common_data.twiddles.len() - TWIDDLES_PER_COLUMN;
954            let final_twiddle_chunk = &self.common_data.twiddles[partial_remainder_twiddle_base..];
955
956            match partial_remainder {
957                1 => {
958                    column_butterfly16_loadfn!(
959                        |index| buffer
960                            .load_partial1_complex(partial_remainder_base + len_per_row * index),
961                        |mut data, index| {
962                            if index > 0 {
963                                let twiddle: A::VectorType = final_twiddle_chunk[index - 1];
964                                data = AvxVector::mul_complex(data, twiddle.lo());
965                            }
966                            buffer.store_partial1_complex(
967                                data,
968                                partial_remainder_base + len_per_row * index,
969                            )
970                        },
971                        [
972                            self.twiddles_butterfly16[0].lo(),
973                            self.twiddles_butterfly16[1].lo()
974                        ],
975                        self.twiddles_butterfly4.lo()
976                    );
977                }
978                2 => {
979                    column_butterfly16_loadfn!(
980                        |index| buffer
981                            .load_partial2_complex(partial_remainder_base + len_per_row * index),
982                        |mut data, index| {
983                            if index > 0 {
984                                let twiddle: A::VectorType = final_twiddle_chunk[index - 1];
985                                data = AvxVector::mul_complex(data, twiddle.lo());
986                            }
987                            buffer.store_partial2_complex(
988                                data,
989                                partial_remainder_base + len_per_row * index,
990                            )
991                        },
992                        [
993                            self.twiddles_butterfly16[0].lo(),
994                            self.twiddles_butterfly16[1].lo()
995                        ],
996                        self.twiddles_butterfly4.lo()
997                    );
998                }
999                3 => {
1000                    column_butterfly16_loadfn!(
1001                        |index| buffer
1002                            .load_partial3_complex(partial_remainder_base + len_per_row * index),
1003                        |mut data, index| {
1004                            if index > 0 {
1005                                data = AvxVector::mul_complex(data, final_twiddle_chunk[index - 1]);
1006                            }
1007                            buffer.store_partial3_complex(
1008                                data,
1009                                partial_remainder_base + len_per_row * index,
1010                            )
1011                        },
1012                        self.twiddles_butterfly16,
1013                        self.twiddles_butterfly4
1014                    );
1015                }
1016                _ => unreachable!(),
1017            }
1018        }
1019    }
1020    #[target_feature(enable = "avx", enable = "fma")]
1021    unsafe fn perform_column_butterflies_immut(
1022        &self,
1023        input: impl AvxArray<A>,
1024        mut buffer: impl AvxArrayMut<A>,
1025    ) {
1026        // How many rows this FFT has, ie 2 for 2xn, 4 for 4xn, etc
1027        const ROW_COUNT: usize = 16;
1028        const TWIDDLES_PER_COLUMN: usize = ROW_COUNT - 1;
1029
1030        let len_per_row = self.len() / ROW_COUNT;
1031        let chunk_count = len_per_row / A::VectorType::COMPLEX_PER_VECTOR;
1032
1033        // process the column FFTs
1034        for (c, twiddle_chunk) in self
1035            .common_data
1036            .twiddles
1037            .chunks_exact(TWIDDLES_PER_COLUMN)
1038            .take(chunk_count)
1039            .enumerate()
1040        {
1041            let index_base = c * A::VectorType::COMPLEX_PER_VECTOR;
1042
1043            column_butterfly16_loadfn!(
1044                |index| input.load_complex(index_base + len_per_row * index),
1045                |mut data, index| {
1046                    if index > 0 {
1047                        data = AvxVector::mul_complex(data, twiddle_chunk[index - 1]);
1048                    }
1049                    buffer.store_complex(data, index_base + len_per_row * index)
1050                },
1051                self.twiddles_butterfly16,
1052                self.twiddles_butterfly4
1053            );
1054        }
1055
1056        // finally, we might have a single partial chunk.
1057        // Normally, we can fit 4 complex numbers into an AVX register, but we only have `partial_remainder` columns left, so we need special logic to handle these final columns
1058        let partial_remainder = len_per_row % A::VectorType::COMPLEX_PER_VECTOR;
1059        if partial_remainder > 0 {
1060            let partial_remainder_base = chunk_count * A::VectorType::COMPLEX_PER_VECTOR;
1061            let partial_remainder_twiddle_base =
1062                self.common_data.twiddles.len() - TWIDDLES_PER_COLUMN;
1063            let final_twiddle_chunk = &self.common_data.twiddles[partial_remainder_twiddle_base..];
1064
1065            match partial_remainder {
1066                1 => {
1067                    for c in 0..self.len() / len_per_row {
1068                        let cs = c * len_per_row + len_per_row - partial_remainder;
1069                        buffer.store_partial1_complex(input.load_partial1_complex(cs), cs);
1070                    }
1071                    column_butterfly16_loadfn!(
1072                        |index| buffer
1073                            .load_partial1_complex(partial_remainder_base + len_per_row * index),
1074                        |mut data, index| {
1075                            if index > 0 {
1076                                let twiddle: A::VectorType = final_twiddle_chunk[index - 1];
1077                                data = AvxVector::mul_complex(data, twiddle.lo());
1078                            }
1079                            buffer.store_partial1_complex(
1080                                data,
1081                                partial_remainder_base + len_per_row * index,
1082                            )
1083                        },
1084                        [
1085                            self.twiddles_butterfly16[0].lo(),
1086                            self.twiddles_butterfly16[1].lo()
1087                        ],
1088                        self.twiddles_butterfly4.lo()
1089                    );
1090                }
1091                2 => {
1092                    for c in 0..self.len() / len_per_row {
1093                        let cs = c * len_per_row + len_per_row - partial_remainder;
1094                        buffer.store_partial2_complex(input.load_partial2_complex(cs), cs);
1095                    }
1096                    column_butterfly16_loadfn!(
1097                        |index| buffer
1098                            .load_partial2_complex(partial_remainder_base + len_per_row * index),
1099                        |mut data, index| {
1100                            if index > 0 {
1101                                let twiddle: A::VectorType = final_twiddle_chunk[index - 1];
1102                                data = AvxVector::mul_complex(data, twiddle.lo());
1103                            }
1104                            buffer.store_partial2_complex(
1105                                data,
1106                                partial_remainder_base + len_per_row * index,
1107                            )
1108                        },
1109                        [
1110                            self.twiddles_butterfly16[0].lo(),
1111                            self.twiddles_butterfly16[1].lo()
1112                        ],
1113                        self.twiddles_butterfly4.lo()
1114                    );
1115                }
1116                3 => {
1117                    for c in 0..self.len() / len_per_row {
1118                        let cs = c * len_per_row + len_per_row - partial_remainder;
1119                        buffer.store_partial3_complex(input.load_partial3_complex(cs), cs);
1120                    }
1121                    column_butterfly16_loadfn!(
1122                        |index| buffer
1123                            .load_partial3_complex(partial_remainder_base + len_per_row * index),
1124                        |mut data, index| {
1125                            if index > 0 {
1126                                data = AvxVector::mul_complex(data, final_twiddle_chunk[index - 1]);
1127                            }
1128                            buffer.store_partial3_complex(
1129                                data,
1130                                partial_remainder_base + len_per_row * index,
1131                            )
1132                        },
1133                        self.twiddles_butterfly16,
1134                        self.twiddles_butterfly4
1135                    );
1136                }
1137                _ => unreachable!(),
1138            }
1139        }
1140    }
1141    mixedradix_transpose!(16,
1142        AvxVector::transpose16_packed,
1143        AvxVector::transpose16_packed,
1144        0;1;2;3;4;5;6;7;8;9;10;11;12;13;14;15, 0;1;2;3;4;5;6;7;8;9;10;11
1145    );
1146    boilerplate_mixedradix!();
1147}
1148
1149#[cfg(test)]
1150mod unit_tests {
1151    use super::*;
1152    use crate::algorithm::*;
1153    use crate::test_utils::check_fft_algorithm;
1154    use std::sync::Arc;
1155
1156    macro_rules! test_avx_mixed_radix {
1157        ($f32_test_name:ident, $f64_test_name:ident, $struct_name:ident, $inner_count:expr) => (
1158            #[test]
1159            fn $f32_test_name() {
1160                for inner_fft_len in 1..32 {
1161                    let len = inner_fft_len * $inner_count;
1162
1163                    let inner_fft_forward = Arc::new(Dft::new(inner_fft_len, FftDirection::Forward)) as Arc<dyn Fft<f32>>;
1164                    let fft_forward = $struct_name::<f32, f32>::new(inner_fft_forward).expect("Can't run test because this machine doesn't have the required instruction sets");
1165                    check_fft_algorithm(&fft_forward, len, FftDirection::Forward);
1166
1167                    let inner_fft_inverse = Arc::new(Dft::new(inner_fft_len, FftDirection::Inverse)) as Arc<dyn Fft<f32>>;
1168                    let fft_inverse = $struct_name::<f32, f32>::new(inner_fft_inverse).expect("Can't run test because this machine doesn't have the required instruction sets");
1169                    check_fft_algorithm(&fft_inverse, len, FftDirection::Inverse);
1170                }
1171            }
1172            #[test]
1173            fn $f64_test_name() {
1174                for inner_fft_len in 1..32 {
1175                    let len = inner_fft_len * $inner_count;
1176
1177                    let inner_fft_forward = Arc::new(Dft::new(inner_fft_len, FftDirection::Forward)) as Arc<dyn Fft<f64>>;
1178                    let fft_forward = $struct_name::<f64, f64>::new(inner_fft_forward).expect("Can't run test because this machine doesn't have the required instruction sets");
1179                    check_fft_algorithm(&fft_forward, len, FftDirection::Forward);
1180
1181                    let inner_fft_inverse = Arc::new(Dft::new(inner_fft_len, FftDirection::Inverse)) as Arc<dyn Fft<f64>>;
1182                    let fft_inverse = $struct_name::<f64, f64>::new(inner_fft_inverse).expect("Can't run test because this machine doesn't have the required instruction sets");
1183                    check_fft_algorithm(&fft_inverse, len, FftDirection::Inverse);
1184                }
1185            }
1186        )
1187    }
1188
1189    test_avx_mixed_radix!(
1190        test_mixedradix_2xn_avx_f32,
1191        test_mixedradix_2xn_avx_f64,
1192        MixedRadix2xnAvx,
1193        2
1194    );
1195    test_avx_mixed_radix!(
1196        test_mixedradix_3xn_avx_f32,
1197        test_mixedradix_3xn_avx_f64,
1198        MixedRadix3xnAvx,
1199        3
1200    );
1201    test_avx_mixed_radix!(
1202        test_mixedradix_4xn_avx_f32,
1203        test_mixedradix_4xn_avx_f64,
1204        MixedRadix4xnAvx,
1205        4
1206    );
1207    test_avx_mixed_radix!(
1208        test_mixedradix_5xn_avx_f32,
1209        test_mixedradix_5xn_avx_f64,
1210        MixedRadix5xnAvx,
1211        5
1212    );
1213    test_avx_mixed_radix!(
1214        test_mixedradix_6xn_avx_f32,
1215        test_mixedradix_6xn_avx_f64,
1216        MixedRadix6xnAvx,
1217        6
1218    );
1219    test_avx_mixed_radix!(
1220        test_mixedradix_7xn_avx_f32,
1221        test_mixedradix_7xn_avx_f64,
1222        MixedRadix7xnAvx,
1223        7
1224    );
1225    test_avx_mixed_radix!(
1226        test_mixedradix_8xn_avx_f32,
1227        test_mixedradix_8xn_avx_f64,
1228        MixedRadix8xnAvx,
1229        8
1230    );
1231    test_avx_mixed_radix!(
1232        test_mixedradix_9xn_avx_f32,
1233        test_mixedradix_9xn_avx_f64,
1234        MixedRadix9xnAvx,
1235        9
1236    );
1237    test_avx_mixed_radix!(
1238        test_mixedradix_11xn_avx_f32,
1239        test_mixedradix_11xn_avx_f64,
1240        MixedRadix11xnAvx,
1241        11
1242    );
1243    test_avx_mixed_radix!(
1244        test_mixedradix_12xn_avx_f32,
1245        test_mixedradix_12xn_avx_f64,
1246        MixedRadix12xnAvx,
1247        12
1248    );
1249    test_avx_mixed_radix!(
1250        test_mixedradix_16xn_avx_f32,
1251        test_mixedradix_16xn_avx_f64,
1252        MixedRadix16xnAvx,
1253        16
1254    );
1255}