1use std::any::TypeId;
2use std::sync::Arc;
3
4use num_complex::Complex;
5use num_integer::div_ceil;
6
7use crate::array_utils;
8use crate::{Direction, Fft, FftDirection, FftNum, Length};
9
10use super::{AvxNum, CommonSimdData};
11
12use super::avx_vector;
13use super::avx_vector::{AvxArray, AvxArrayMut, AvxVector, AvxVector128, AvxVector256, Rotation90};
14
15macro_rules! boilerplate_mixedradix {
16 () => {
17 #[inline]
20 pub fn new(inner_fft: Arc<dyn Fft<T>>) -> Result<Self, ()> {
21 let id_a = TypeId::of::<A>();
25 let id_t = TypeId::of::<T>();
26 assert_eq!(id_a, id_t);
27
28 let has_avx = is_x86_feature_detected!("avx");
29 let has_fma = is_x86_feature_detected!("fma");
30 if has_avx && has_fma {
31 Ok(unsafe { Self::new_with_avx(inner_fft) })
33 } else {
34 Err(())
35 }
36 }
37
38 #[target_feature(enable = "avx", enable = "fma")]
39 unsafe fn perform_fft_inplace(
40 &self,
41 buffer: &mut [Complex<T>],
42 scratch: &mut [Complex<T>],
43 ) {
44 unsafe {
47 let transmuted_buffer: &mut [Complex<A>] =
49 array_utils::workaround_transmute_mut(buffer);
50
51 self.perform_column_butterflies(transmuted_buffer)
52 }
53
54 let (scratch, inner_scratch) = scratch.split_at_mut(self.len());
56 self.common_data.inner_fft.process_outofplace_with_scratch(
57 buffer,
58 scratch,
59 inner_scratch,
60 );
61
62 unsafe {
65 let transmuted_scratch: &mut [Complex<A>] =
67 array_utils::workaround_transmute_mut(scratch);
68 let transmuted_buffer: &mut [Complex<A>] =
69 array_utils::workaround_transmute_mut(buffer);
70
71 self.transpose(transmuted_scratch, transmuted_buffer)
72 }
73 }
74
75 #[target_feature(enable = "avx", enable = "fma")]
76 unsafe fn perform_fft_immut(
77 &self,
78 input: &[Complex<T>],
79 output: &mut [Complex<T>],
80 scratch: &mut [Complex<T>],
81 ) {
82 let (scratch, inner_scratch) = scratch.split_at_mut(input.len());
84 {
85 let transmuted_input: &[Complex<A>] = array_utils::workaround_transmute(input);
87 let transmuted_output: &mut [Complex<A>] =
88 array_utils::workaround_transmute_mut(scratch);
89
90 self.perform_column_butterflies_immut(transmuted_input, transmuted_output);
91 }
92
93 self.common_data
95 .inner_fft
96 .process_with_scratch(scratch, inner_scratch);
97
98 {
100 let transmuted_input: &mut [Complex<A>] =
102 array_utils::workaround_transmute_mut(scratch);
103 let transmuted_output: &mut [Complex<A>] =
104 array_utils::workaround_transmute_mut(output);
105
106 self.transpose(transmuted_input, transmuted_output)
107 }
108 }
109
110 #[target_feature(enable = "avx", enable = "fma")]
111 unsafe fn perform_fft_out_of_place(
112 &self,
113 input: &mut [Complex<T>],
114 output: &mut [Complex<T>],
115 scratch: &mut [Complex<T>],
116 ) {
117 {
119 let transmuted_input: &mut [Complex<A>] =
121 array_utils::workaround_transmute_mut(input);
122 self.perform_column_butterflies(transmuted_input);
123 }
124
125 let inner_scratch = if scratch.len() > 0 {
127 scratch
128 } else {
129 &mut output[..]
130 };
131 self.common_data
132 .inner_fft
133 .process_with_scratch(input, inner_scratch);
134
135 {
137 let transmuted_input: &mut [Complex<A>] =
139 array_utils::workaround_transmute_mut(input);
140 let transmuted_output: &mut [Complex<A>] =
141 array_utils::workaround_transmute_mut(output);
142
143 self.transpose(transmuted_input, transmuted_output)
144 }
145 }
146 };
147}
148
149macro_rules! mixedradix_gen_data {
150 ($row_count: expr, $inner_fft:expr) => {{
151 const ROW_COUNT : usize = $row_count;
153 const TWIDDLES_PER_COLUMN : usize = ROW_COUNT - 1;
154
155 let direction = $inner_fft.fft_direction();
157 let len_per_row = $inner_fft.len();
158 let len = len_per_row * ROW_COUNT;
159
160 let quotient = len_per_row / A::VectorType::COMPLEX_PER_VECTOR;
163 let remainder = len_per_row % A::VectorType::COMPLEX_PER_VECTOR;
164
165 let num_twiddle_columns = quotient + div_ceil(remainder, A::VectorType::COMPLEX_PER_VECTOR);
167 let mut twiddles = Vec::with_capacity(num_twiddle_columns * TWIDDLES_PER_COLUMN);
168 for x in 0..num_twiddle_columns {
169 for y in 1..ROW_COUNT {
170 twiddles.push(AvxVector::make_mixedradix_twiddle_chunk(x * A::VectorType::COMPLEX_PER_VECTOR, y, len, direction));
171 }
172 }
173
174 let inner_outofplace_scratch = $inner_fft.get_outofplace_scratch_len();
175 let inner_inplace_scratch = $inner_fft.get_inplace_scratch_len();
176 let immut_scratch_len = len + $inner_fft.get_inplace_scratch_len();
177
178 CommonSimdData {
179 twiddles: twiddles.into_boxed_slice(),
180 inplace_scratch_len: len + inner_outofplace_scratch,
181 outofplace_scratch_len: if inner_inplace_scratch > len { inner_inplace_scratch } else { 0 },
182 immut_scratch_len,
183 inner_fft: $inner_fft,
184 len,
185 direction,
186 }
187 }}
188}
189
190macro_rules! mixedradix_column_butterflies {
191 ($row_count: expr, $butterfly_fn: expr, $butterfly_fn_lo: expr) => {
192 #[target_feature(enable = "avx", enable = "fma")]
193 unsafe fn perform_column_butterflies_immut(
194 &self,
195 input: impl AvxArray<A>,
196 mut buffer: impl AvxArrayMut<A>,
197 ) {
198 const ROW_COUNT: usize = $row_count;
200 const TWIDDLES_PER_COLUMN: usize = ROW_COUNT - 1;
201
202 let len_per_row = self.len() / ROW_COUNT;
203 let chunk_count = len_per_row / A::VectorType::COMPLEX_PER_VECTOR;
204
205 for (c, twiddle_chunk) in self
207 .common_data
208 .twiddles
209 .chunks_exact(TWIDDLES_PER_COLUMN)
210 .take(chunk_count)
211 .enumerate()
212 {
213 let index_base = c * A::VectorType::COMPLEX_PER_VECTOR;
214
215 let mut columns = [AvxVector::zero(); ROW_COUNT];
217 for i in 0..ROW_COUNT {
218 columns[i] = input.load_complex(index_base + len_per_row * i);
219 }
220
221 let output = $butterfly_fn(columns, self);
223
224 buffer.store_complex(output[0], index_base);
226
227 for i in 1..ROW_COUNT {
229 let twiddle = twiddle_chunk[i - 1];
230 let output = AvxVector::mul_complex(twiddle, output[i]);
231 buffer.store_complex(output, index_base + len_per_row * i);
232 }
233 }
234
235 let partial_remainder = len_per_row % A::VectorType::COMPLEX_PER_VECTOR;
238 if partial_remainder > 0 {
239 let partial_remainder_base = chunk_count * A::VectorType::COMPLEX_PER_VECTOR;
240 let partial_remainder_twiddle_base =
241 self.common_data.twiddles.len() - TWIDDLES_PER_COLUMN;
242 let final_twiddle_chunk =
243 &self.common_data.twiddles[partial_remainder_twiddle_base..];
244
245 if partial_remainder > 2 {
246 let mut columns = [AvxVector::zero(); ROW_COUNT];
248 for i in 0..ROW_COUNT {
249 columns[i] =
250 input.load_partial3_complex(partial_remainder_base + len_per_row * i);
251 }
252
253 let mid = $butterfly_fn(columns, self);
255
256 buffer.store_partial3_complex(mid[0], partial_remainder_base);
258
259 for i in 1..ROW_COUNT {
261 let twiddle = final_twiddle_chunk[i - 1];
262 let output = AvxVector::mul_complex(twiddle, mid[i]);
263 buffer.store_partial3_complex(
264 output,
265 partial_remainder_base + len_per_row * i,
266 );
267 }
268 } else {
269 let mut columns = [AvxVector::zero(); ROW_COUNT];
271 if partial_remainder == 1 {
272 for i in 0..ROW_COUNT {
273 columns[i] = input
274 .load_partial1_complex(partial_remainder_base + len_per_row * i);
275 }
276 } else {
277 for i in 0..ROW_COUNT {
278 columns[i] = input
279 .load_partial2_complex(partial_remainder_base + len_per_row * i);
280 }
281 }
282
283 let mut mid = $butterfly_fn_lo(columns, self);
285
286 for i in 1..ROW_COUNT {
288 mid[i] = AvxVector::mul_complex(final_twiddle_chunk[i - 1].lo(), mid[i]);
289 }
290
291 if partial_remainder == 1 {
293 for i in 0..ROW_COUNT {
294 buffer.store_partial1_complex(
295 mid[i],
296 partial_remainder_base + len_per_row * i,
297 );
298 }
299 } else {
300 for i in 0..ROW_COUNT {
301 buffer.store_partial2_complex(
302 mid[i],
303 partial_remainder_base + len_per_row * i,
304 );
305 }
306 }
307 }
308 }
309 }
310 #[target_feature(enable = "avx", enable = "fma")]
311 unsafe fn perform_column_butterflies(&self, mut buffer: impl AvxArrayMut<A>) {
312 const ROW_COUNT: usize = $row_count;
314 const TWIDDLES_PER_COLUMN: usize = ROW_COUNT - 1;
315
316 let len_per_row = self.len() / ROW_COUNT;
317 let chunk_count = len_per_row / A::VectorType::COMPLEX_PER_VECTOR;
318
319 for (c, twiddle_chunk) in self
321 .common_data
322 .twiddles
323 .chunks_exact(TWIDDLES_PER_COLUMN)
324 .take(chunk_count)
325 .enumerate()
326 {
327 let index_base = c * A::VectorType::COMPLEX_PER_VECTOR;
328
329 let mut columns = [AvxVector::zero(); ROW_COUNT];
331 for i in 0..ROW_COUNT {
332 columns[i] = buffer.load_complex(index_base + len_per_row * i);
333 }
334
335 let output = $butterfly_fn(columns, self);
337
338 buffer.store_complex(output[0], index_base);
340
341 for i in 1..ROW_COUNT {
343 let twiddle = twiddle_chunk[i - 1];
344 let output = AvxVector::mul_complex(twiddle, output[i]);
345 buffer.store_complex(output, index_base + len_per_row * i);
346 }
347 }
348
349 let partial_remainder = len_per_row % A::VectorType::COMPLEX_PER_VECTOR;
352 if partial_remainder > 0 {
353 let partial_remainder_base = chunk_count * A::VectorType::COMPLEX_PER_VECTOR;
354 let partial_remainder_twiddle_base =
355 self.common_data.twiddles.len() - TWIDDLES_PER_COLUMN;
356 let final_twiddle_chunk =
357 &self.common_data.twiddles[partial_remainder_twiddle_base..];
358
359 if partial_remainder > 2 {
360 let mut columns = [AvxVector::zero(); ROW_COUNT];
362 for i in 0..ROW_COUNT {
363 columns[i] =
364 buffer.load_partial3_complex(partial_remainder_base + len_per_row * i);
365 }
366
367 let mid = $butterfly_fn(columns, self);
369
370 buffer.store_partial3_complex(mid[0], partial_remainder_base);
372
373 for i in 1..ROW_COUNT {
375 let twiddle = final_twiddle_chunk[i - 1];
376 let output = AvxVector::mul_complex(twiddle, mid[i]);
377 buffer.store_partial3_complex(
378 output,
379 partial_remainder_base + len_per_row * i,
380 );
381 }
382 } else {
383 let mut columns = [AvxVector::zero(); ROW_COUNT];
385 if partial_remainder == 1 {
386 for i in 0..ROW_COUNT {
387 columns[i] = buffer
388 .load_partial1_complex(partial_remainder_base + len_per_row * i);
389 }
390 } else {
391 for i in 0..ROW_COUNT {
392 columns[i] = buffer
393 .load_partial2_complex(partial_remainder_base + len_per_row * i);
394 }
395 }
396
397 let mut mid = $butterfly_fn_lo(columns, self);
399
400 for i in 1..ROW_COUNT {
402 mid[i] = AvxVector::mul_complex(final_twiddle_chunk[i - 1].lo(), mid[i]);
403 }
404
405 if partial_remainder == 1 {
407 for i in 0..ROW_COUNT {
408 buffer.store_partial1_complex(
409 mid[i],
410 partial_remainder_base + len_per_row * i,
411 );
412 }
413 } else {
414 for i in 0..ROW_COUNT {
415 buffer.store_partial2_complex(
416 mid[i],
417 partial_remainder_base + len_per_row * i,
418 );
419 }
420 }
421 }
422 }
423 }
424 };
425}
426
427macro_rules! mixedradix_transpose{
428 ($row_count: expr, $transpose_fn: path, $transpose_fn_lo: path, $($unroll_workaround_index:expr);*, $($remainder3_unroll_workaround_index:expr);*) => (
429
430 #[target_feature(enable = "avx")]
432 unsafe fn transpose(&self, input: &[Complex<A>], mut output: &mut [Complex<A>]) {
433 const ROW_COUNT : usize = $row_count;
434
435 let len_per_row = self.len() / ROW_COUNT;
436 let chunk_count = len_per_row / A::VectorType::COMPLEX_PER_VECTOR;
437
438 for c in 0..chunk_count {
440 let input_index_base = c*A::VectorType::COMPLEX_PER_VECTOR;
441 let output_index_base = input_index_base * ROW_COUNT;
442
443 let mut rows : [A::VectorType; ROW_COUNT] = [AvxVector::zero(); ROW_COUNT];
445 for i in 0..ROW_COUNT {
446 rows[i] = input.load_complex(input_index_base + len_per_row*i);
447 }
448
449 let transposed = $transpose_fn(rows);
451
452 $(
462 output.store_complex(transposed[$unroll_workaround_index], output_index_base + A::VectorType::COMPLEX_PER_VECTOR * $unroll_workaround_index);
463 )*
464 }
465
466 let input_index_base = chunk_count * A::VectorType::COMPLEX_PER_VECTOR;
468 let output_index_base = input_index_base * ROW_COUNT;
469
470 let partial_remainder = len_per_row % A::VectorType::COMPLEX_PER_VECTOR;
471 if partial_remainder == 1 {
472 for i in 0..ROW_COUNT {
474 let input_cell = input.get_unchecked(input_index_base + len_per_row*i);
475 let output_cell = output.get_unchecked_mut(output_index_base + i);
476 *output_cell = *input_cell;
477 }
478 } else if partial_remainder == 2 {
479 let mut rows = [AvxVector::zero(); ROW_COUNT];
481 for i in 0..ROW_COUNT {
482 rows[i] = input.load_partial2_complex(input_index_base + len_per_row*i);
483 }
484
485 let transposed = $transpose_fn_lo(rows);
486
487 $(
489 output.store_partial2_complex(transposed[$unroll_workaround_index], output_index_base + <A::VectorType as AvxVector256>::HalfVector::COMPLEX_PER_VECTOR * $unroll_workaround_index);
490 )*
491 }
492 else if partial_remainder == 3 {
493 let mut rows = [AvxVector::zero(); ROW_COUNT];
495 for i in 0..ROW_COUNT {
496 rows[i] = input.load_partial3_complex(input_index_base + len_per_row*i);
497 }
498
499 let transposed = $transpose_fn(rows);
501
502 let element_count = 3*ROW_COUNT;
505 let full_vector_count = element_count / A::VectorType::COMPLEX_PER_VECTOR;
506 let final_remainder_count = element_count % A::VectorType::COMPLEX_PER_VECTOR;
507
508 $(
515 output.store_complex(transposed[$remainder3_unroll_workaround_index], output_index_base + A::VectorType::COMPLEX_PER_VECTOR * $remainder3_unroll_workaround_index);
516 )*
517
518 match final_remainder_count {
520 0 => {},
521 1 => output.store_partial1_complex(transposed[full_vector_count].lo(), output_index_base + full_vector_count * A::VectorType::COMPLEX_PER_VECTOR),
522 2 => output.store_partial2_complex(transposed[full_vector_count].lo(), output_index_base + full_vector_count * A::VectorType::COMPLEX_PER_VECTOR),
523 3 => output.store_partial3_complex(transposed[full_vector_count], output_index_base + full_vector_count * A::VectorType::COMPLEX_PER_VECTOR),
524 _ => unreachable!(),
525 }
526 }
527 }
528)}
529
530pub struct MixedRadix2xnAvx<A: AvxNum, T> {
531 common_data: CommonSimdData<T, A::VectorType>,
532 _phantom: std::marker::PhantomData<T>,
533}
534boilerplate_avx_fft_commondata!(MixedRadix2xnAvx);
535
536impl<A: AvxNum, T: FftNum> MixedRadix2xnAvx<A, T> {
537 #[target_feature(enable = "avx")]
538 unsafe fn new_with_avx(inner_fft: Arc<dyn Fft<T>>) -> Self {
539 Self {
540 common_data: mixedradix_gen_data!(2, inner_fft),
541 _phantom: std::marker::PhantomData,
542 }
543 }
544 mixedradix_column_butterflies!(
545 2,
546 |columns, _: _| AvxVector::column_butterfly2(columns),
547 |columns, _: _| AvxVector::column_butterfly2(columns)
548 );
549 mixedradix_transpose!(2,
550 AvxVector::transpose2_packed,
551 AvxVector::transpose2_packed,
552 0;1, 0
553 );
554 boilerplate_mixedradix!();
555}
556
557pub struct MixedRadix3xnAvx<A: AvxNum, T> {
558 twiddles_butterfly3: A::VectorType,
559 common_data: CommonSimdData<T, A::VectorType>,
560 _phantom: std::marker::PhantomData<T>,
561}
562boilerplate_avx_fft_commondata!(MixedRadix3xnAvx);
563
564impl<A: AvxNum, T: FftNum> MixedRadix3xnAvx<A, T> {
565 #[target_feature(enable = "avx")]
566 unsafe fn new_with_avx(inner_fft: Arc<dyn Fft<T>>) -> Self {
567 Self {
568 twiddles_butterfly3: AvxVector::broadcast_twiddle(1, 3, inner_fft.fft_direction()),
569 common_data: mixedradix_gen_data!(3, inner_fft),
570 _phantom: std::marker::PhantomData,
571 }
572 }
573 mixedradix_column_butterflies!(
574 3,
575 |columns, this: &Self| AvxVector::column_butterfly3(columns, this.twiddles_butterfly3),
576 |columns, this: &Self| AvxVector::column_butterfly3(columns, this.twiddles_butterfly3.lo())
577 );
578 mixedradix_transpose!(3,
579 AvxVector::transpose3_packed,
580 AvxVector::transpose3_packed,
581 0;1;2, 0;1
582 );
583 boilerplate_mixedradix!();
584}
585
586pub struct MixedRadix4xnAvx<A: AvxNum, T> {
587 twiddles_butterfly4: Rotation90<A::VectorType>,
588 common_data: CommonSimdData<T, A::VectorType>,
589 _phantom: std::marker::PhantomData<T>,
590}
591boilerplate_avx_fft_commondata!(MixedRadix4xnAvx);
592
593impl<A: AvxNum, T: FftNum> MixedRadix4xnAvx<A, T> {
594 #[target_feature(enable = "avx")]
595 unsafe fn new_with_avx(inner_fft: Arc<dyn Fft<T>>) -> Self {
596 Self {
597 twiddles_butterfly4: AvxVector::make_rotation90(inner_fft.fft_direction()),
598 common_data: mixedradix_gen_data!(4, inner_fft),
599 _phantom: std::marker::PhantomData,
600 }
601 }
602 mixedradix_column_butterflies!(
603 4,
604 |columns, this: &Self| AvxVector::column_butterfly4(columns, this.twiddles_butterfly4),
605 |columns, this: &Self| AvxVector::column_butterfly4(columns, this.twiddles_butterfly4.lo())
606 );
607 mixedradix_transpose!(4,
608 AvxVector::transpose4_packed,
609 AvxVector::transpose4_packed,
610 0;1;2;3, 0;1;2
611 );
612 boilerplate_mixedradix!();
613}
614
615pub struct MixedRadix5xnAvx<A: AvxNum, T> {
616 twiddles_butterfly5: [A::VectorType; 2],
617 common_data: CommonSimdData<T, A::VectorType>,
618 _phantom: std::marker::PhantomData<T>,
619}
620boilerplate_avx_fft_commondata!(MixedRadix5xnAvx);
621
622impl<A: AvxNum, T: FftNum> MixedRadix5xnAvx<A, T> {
623 #[target_feature(enable = "avx")]
624 unsafe fn new_with_avx(inner_fft: Arc<dyn Fft<T>>) -> Self {
625 Self {
626 twiddles_butterfly5: [
627 AvxVector::broadcast_twiddle(1, 5, inner_fft.fft_direction()),
628 AvxVector::broadcast_twiddle(2, 5, inner_fft.fft_direction()),
629 ],
630 common_data: mixedradix_gen_data!(5, inner_fft),
631 _phantom: std::marker::PhantomData,
632 }
633 }
634 mixedradix_column_butterflies!(
635 5,
636 |columns, this: &Self| AvxVector::column_butterfly5(columns, this.twiddles_butterfly5),
637 |columns, this: &Self| AvxVector::column_butterfly5(
638 columns,
639 [
640 this.twiddles_butterfly5[0].lo(),
641 this.twiddles_butterfly5[1].lo()
642 ]
643 )
644 );
645 mixedradix_transpose!(5,
646 AvxVector::transpose5_packed,
647 AvxVector::transpose5_packed,
648 0;1;2;3;4, 0;1;2
649 );
650 boilerplate_mixedradix!();
651}
652
653pub struct MixedRadix6xnAvx<A: AvxNum, T> {
654 twiddles_butterfly3: A::VectorType,
655 common_data: CommonSimdData<T, A::VectorType>,
656 _phantom: std::marker::PhantomData<T>,
657}
658boilerplate_avx_fft_commondata!(MixedRadix6xnAvx);
659
660impl<A: AvxNum, T: FftNum> MixedRadix6xnAvx<A, T> {
661 #[target_feature(enable = "avx")]
662 unsafe fn new_with_avx(inner_fft: Arc<dyn Fft<T>>) -> Self {
663 Self {
664 twiddles_butterfly3: AvxVector::broadcast_twiddle(1, 3, inner_fft.fft_direction()),
665 common_data: mixedradix_gen_data!(6, inner_fft),
666 _phantom: std::marker::PhantomData,
667 }
668 }
669 mixedradix_column_butterflies!(
670 6,
671 |columns, this: &Self| AvxVector256::column_butterfly6(columns, this.twiddles_butterfly3),
672 |columns, this: &Self| AvxVector128::column_butterfly6(columns, this.twiddles_butterfly3)
673 );
674 mixedradix_transpose!(6,
675 AvxVector::transpose6_packed,
676 AvxVector::transpose6_packed,
677 0;1;2;3;4;5, 0;1;2;3
678 );
679 boilerplate_mixedradix!();
680}
681
682pub struct MixedRadix7xnAvx<A: AvxNum, T> {
683 twiddles_butterfly7: [A::VectorType; 3],
684 common_data: CommonSimdData<T, A::VectorType>,
685 _phantom: std::marker::PhantomData<T>,
686}
687boilerplate_avx_fft_commondata!(MixedRadix7xnAvx);
688
689impl<A: AvxNum, T: FftNum> MixedRadix7xnAvx<A, T> {
690 #[target_feature(enable = "avx")]
691 unsafe fn new_with_avx(inner_fft: Arc<dyn Fft<T>>) -> Self {
692 Self {
693 twiddles_butterfly7: [
694 AvxVector::broadcast_twiddle(1, 7, inner_fft.fft_direction()),
695 AvxVector::broadcast_twiddle(2, 7, inner_fft.fft_direction()),
696 AvxVector::broadcast_twiddle(3, 7, inner_fft.fft_direction()),
697 ],
698 common_data: mixedradix_gen_data!(7, inner_fft),
699 _phantom: std::marker::PhantomData,
700 }
701 }
702 mixedradix_column_butterflies!(
703 7,
704 |columns, this: &Self| AvxVector::column_butterfly7(columns, this.twiddles_butterfly7),
705 |columns, this: &Self| AvxVector::column_butterfly7(
706 columns,
707 [
708 this.twiddles_butterfly7[0].lo(),
709 this.twiddles_butterfly7[1].lo(),
710 this.twiddles_butterfly7[2].lo()
711 ]
712 )
713 );
714 mixedradix_transpose!(7,
715 AvxVector::transpose7_packed,
716 AvxVector::transpose7_packed,
717 0;1;2;3;4;5;6, 0;1;2;3;4
718 );
719 boilerplate_mixedradix!();
720}
721
722pub struct MixedRadix8xnAvx<A: AvxNum, T> {
723 twiddles_butterfly4: Rotation90<A::VectorType>,
724 common_data: CommonSimdData<T, A::VectorType>,
725 _phantom: std::marker::PhantomData<T>,
726}
727boilerplate_avx_fft_commondata!(MixedRadix8xnAvx);
728
729impl<A: AvxNum, T: FftNum> MixedRadix8xnAvx<A, T> {
730 #[target_feature(enable = "avx")]
731 unsafe fn new_with_avx(inner_fft: Arc<dyn Fft<T>>) -> Self {
732 Self {
733 twiddles_butterfly4: AvxVector::make_rotation90(inner_fft.fft_direction()),
734 common_data: mixedradix_gen_data!(8, inner_fft),
735 _phantom: std::marker::PhantomData,
736 }
737 }
738
739 mixedradix_column_butterflies!(
740 8,
741 |columns, this: &Self| AvxVector::column_butterfly8(columns, this.twiddles_butterfly4),
742 |columns, this: &Self| AvxVector::column_butterfly8(columns, this.twiddles_butterfly4.lo())
743 );
744 mixedradix_transpose!(8,
745 AvxVector::transpose8_packed,
746 AvxVector::transpose8_packed,
747 0;1;2;3;4;5;6;7, 0;1;2;3;4;5
748 );
749 boilerplate_mixedradix!();
750}
751
752pub struct MixedRadix9xnAvx<A: AvxNum, T> {
753 twiddles_butterfly9: [A::VectorType; 3],
754 twiddles_butterfly9_lo: [A::VectorType; 2],
755 twiddles_butterfly3: A::VectorType,
756 common_data: CommonSimdData<T, A::VectorType>,
757 _phantom: std::marker::PhantomData<T>,
758}
759boilerplate_avx_fft_commondata!(MixedRadix9xnAvx);
760
761impl<A: AvxNum, T: FftNum> MixedRadix9xnAvx<A, T> {
762 #[target_feature(enable = "avx")]
763 unsafe fn new_with_avx(inner_fft: Arc<dyn Fft<T>>) -> Self {
764 let inverse = inner_fft.fft_direction();
765
766 let twiddle1 = AvxVector::broadcast_twiddle(1, 9, inner_fft.fft_direction());
767 let twiddle2 = AvxVector::broadcast_twiddle(2, 9, inner_fft.fft_direction());
768 let twiddle4 = AvxVector::broadcast_twiddle(4, 9, inner_fft.fft_direction());
769
770 Self {
771 twiddles_butterfly9: [
772 AvxVector::broadcast_twiddle(1, 9, inverse),
773 AvxVector::broadcast_twiddle(2, 9, inverse),
774 AvxVector::broadcast_twiddle(4, 9, inverse),
775 ],
776 twiddles_butterfly9_lo: [
777 AvxVector256::merge(twiddle1, twiddle2),
778 AvxVector256::merge(twiddle2, twiddle4),
779 ],
780 twiddles_butterfly3: AvxVector::broadcast_twiddle(1, 3, inner_fft.fft_direction()),
781 common_data: mixedradix_gen_data!(9, inner_fft),
782 _phantom: std::marker::PhantomData,
783 }
784 }
785
786 mixedradix_column_butterflies!(
787 9,
788 |columns, this: &Self| AvxVector256::column_butterfly9(
789 columns,
790 this.twiddles_butterfly9,
791 this.twiddles_butterfly3
792 ),
793 |columns, this: &Self| AvxVector128::column_butterfly9(
794 columns,
795 this.twiddles_butterfly9_lo,
796 this.twiddles_butterfly3
797 )
798 );
799 mixedradix_transpose!(9,
800 AvxVector::transpose9_packed,
801 AvxVector::transpose9_packed,
802 0;1;2;3;4;5;6;7;8, 0;1;2;3;4;5
803 );
804 boilerplate_mixedradix!();
805}
806
807pub struct MixedRadix11xnAvx<A: AvxNum, T> {
808 twiddles_butterfly11: [A::VectorType; 5],
809 common_data: CommonSimdData<T, A::VectorType>,
810 _phantom: std::marker::PhantomData<T>,
811}
812boilerplate_avx_fft_commondata!(MixedRadix11xnAvx);
813
814impl<A: AvxNum, T: FftNum> MixedRadix11xnAvx<A, T> {
815 #[target_feature(enable = "avx")]
816 unsafe fn new_with_avx(inner_fft: Arc<dyn Fft<T>>) -> Self {
817 Self {
818 twiddles_butterfly11: [
819 AvxVector::broadcast_twiddle(1, 11, inner_fft.fft_direction()),
820 AvxVector::broadcast_twiddle(2, 11, inner_fft.fft_direction()),
821 AvxVector::broadcast_twiddle(3, 11, inner_fft.fft_direction()),
822 AvxVector::broadcast_twiddle(4, 11, inner_fft.fft_direction()),
823 AvxVector::broadcast_twiddle(5, 11, inner_fft.fft_direction()),
824 ],
825 common_data: mixedradix_gen_data!(11, inner_fft),
826 _phantom: std::marker::PhantomData,
827 }
828 }
829 mixedradix_column_butterflies!(
830 11,
831 |columns, this: &Self| AvxVector::column_butterfly11(columns, this.twiddles_butterfly11),
832 |columns, this: &Self| AvxVector::column_butterfly11(
833 columns,
834 [
835 this.twiddles_butterfly11[0].lo(),
836 this.twiddles_butterfly11[1].lo(),
837 this.twiddles_butterfly11[2].lo(),
838 this.twiddles_butterfly11[3].lo(),
839 this.twiddles_butterfly11[4].lo()
840 ]
841 )
842 );
843 mixedradix_transpose!(11,
844 AvxVector::transpose11_packed,
845 AvxVector::transpose11_packed,
846 0;1;2;3;4;5;6;7;8;9;10, 0;1;2;3;4;5;6;7
847 );
848 boilerplate_mixedradix!();
849}
850
851pub struct MixedRadix12xnAvx<A: AvxNum, T> {
852 twiddles_butterfly4: Rotation90<A::VectorType>,
853 twiddles_butterfly3: A::VectorType,
854 common_data: CommonSimdData<T, A::VectorType>,
855 _phantom: std::marker::PhantomData<T>,
856}
857boilerplate_avx_fft_commondata!(MixedRadix12xnAvx);
858
859impl<A: AvxNum, T: FftNum> MixedRadix12xnAvx<A, T> {
860 #[target_feature(enable = "avx")]
861 unsafe fn new_with_avx(inner_fft: Arc<dyn Fft<T>>) -> Self {
862 let inverse = inner_fft.fft_direction();
863 Self {
864 twiddles_butterfly4: AvxVector::make_rotation90(inverse),
865 twiddles_butterfly3: AvxVector::broadcast_twiddle(1, 3, inverse),
866 common_data: mixedradix_gen_data!(12, inner_fft),
867 _phantom: std::marker::PhantomData,
868 }
869 }
870
871 mixedradix_column_butterflies!(
872 12,
873 |columns, this: &Self| AvxVector256::column_butterfly12(
874 columns,
875 this.twiddles_butterfly3,
876 this.twiddles_butterfly4
877 ),
878 |columns, this: &Self| AvxVector128::column_butterfly12(
879 columns,
880 this.twiddles_butterfly3,
881 this.twiddles_butterfly4
882 )
883 );
884 mixedradix_transpose!(12,
885 AvxVector::transpose12_packed,
886 AvxVector::transpose12_packed,
887 0;1;2;3;4;5;6;7;8;9;10;11, 0;1;2;3;4;5;6;7;8
888 );
889 boilerplate_mixedradix!();
890}
891
892pub struct MixedRadix16xnAvx<A: AvxNum, T> {
893 twiddles_butterfly4: Rotation90<A::VectorType>,
894 twiddles_butterfly16: [A::VectorType; 2],
895 common_data: CommonSimdData<T, A::VectorType>,
896 _phantom: std::marker::PhantomData<T>,
897}
898boilerplate_avx_fft_commondata!(MixedRadix16xnAvx);
899
900impl<A: AvxNum, T: FftNum> MixedRadix16xnAvx<A, T> {
901 #[target_feature(enable = "avx")]
902 unsafe fn new_with_avx(inner_fft: Arc<dyn Fft<T>>) -> Self {
903 let inverse = inner_fft.fft_direction();
904 Self {
905 twiddles_butterfly4: AvxVector::make_rotation90(inner_fft.fft_direction()),
906 twiddles_butterfly16: [
907 AvxVector::broadcast_twiddle(1, 16, inverse),
908 AvxVector::broadcast_twiddle(3, 16, inverse),
909 ],
910 common_data: mixedradix_gen_data!(16, inner_fft),
911 _phantom: std::marker::PhantomData,
912 }
913 }
914
915 #[target_feature(enable = "avx", enable = "fma")]
916 unsafe fn perform_column_butterflies(&self, mut buffer: impl AvxArrayMut<A>) {
917 const ROW_COUNT: usize = 16;
919 const TWIDDLES_PER_COLUMN: usize = ROW_COUNT - 1;
920
921 let len_per_row = self.len() / ROW_COUNT;
922 let chunk_count = len_per_row / A::VectorType::COMPLEX_PER_VECTOR;
923
924 for (c, twiddle_chunk) in self
926 .common_data
927 .twiddles
928 .chunks_exact(TWIDDLES_PER_COLUMN)
929 .take(chunk_count)
930 .enumerate()
931 {
932 let index_base = c * A::VectorType::COMPLEX_PER_VECTOR;
933
934 column_butterfly16_loadfn!(
935 |index| buffer.load_complex(index_base + len_per_row * index),
936 |mut data, index| {
937 if index > 0 {
938 data = AvxVector::mul_complex(data, twiddle_chunk[index - 1]);
939 }
940 buffer.store_complex(data, index_base + len_per_row * index)
941 },
942 self.twiddles_butterfly16,
943 self.twiddles_butterfly4
944 );
945 }
946
947 let partial_remainder = len_per_row % A::VectorType::COMPLEX_PER_VECTOR;
950 if partial_remainder > 0 {
951 let partial_remainder_base = chunk_count * A::VectorType::COMPLEX_PER_VECTOR;
952 let partial_remainder_twiddle_base =
953 self.common_data.twiddles.len() - TWIDDLES_PER_COLUMN;
954 let final_twiddle_chunk = &self.common_data.twiddles[partial_remainder_twiddle_base..];
955
956 match partial_remainder {
957 1 => {
958 column_butterfly16_loadfn!(
959 |index| buffer
960 .load_partial1_complex(partial_remainder_base + len_per_row * index),
961 |mut data, index| {
962 if index > 0 {
963 let twiddle: A::VectorType = final_twiddle_chunk[index - 1];
964 data = AvxVector::mul_complex(data, twiddle.lo());
965 }
966 buffer.store_partial1_complex(
967 data,
968 partial_remainder_base + len_per_row * index,
969 )
970 },
971 [
972 self.twiddles_butterfly16[0].lo(),
973 self.twiddles_butterfly16[1].lo()
974 ],
975 self.twiddles_butterfly4.lo()
976 );
977 }
978 2 => {
979 column_butterfly16_loadfn!(
980 |index| buffer
981 .load_partial2_complex(partial_remainder_base + len_per_row * index),
982 |mut data, index| {
983 if index > 0 {
984 let twiddle: A::VectorType = final_twiddle_chunk[index - 1];
985 data = AvxVector::mul_complex(data, twiddle.lo());
986 }
987 buffer.store_partial2_complex(
988 data,
989 partial_remainder_base + len_per_row * index,
990 )
991 },
992 [
993 self.twiddles_butterfly16[0].lo(),
994 self.twiddles_butterfly16[1].lo()
995 ],
996 self.twiddles_butterfly4.lo()
997 );
998 }
999 3 => {
1000 column_butterfly16_loadfn!(
1001 |index| buffer
1002 .load_partial3_complex(partial_remainder_base + len_per_row * index),
1003 |mut data, index| {
1004 if index > 0 {
1005 data = AvxVector::mul_complex(data, final_twiddle_chunk[index - 1]);
1006 }
1007 buffer.store_partial3_complex(
1008 data,
1009 partial_remainder_base + len_per_row * index,
1010 )
1011 },
1012 self.twiddles_butterfly16,
1013 self.twiddles_butterfly4
1014 );
1015 }
1016 _ => unreachable!(),
1017 }
1018 }
1019 }
1020 #[target_feature(enable = "avx", enable = "fma")]
1021 unsafe fn perform_column_butterflies_immut(
1022 &self,
1023 input: impl AvxArray<A>,
1024 mut buffer: impl AvxArrayMut<A>,
1025 ) {
1026 const ROW_COUNT: usize = 16;
1028 const TWIDDLES_PER_COLUMN: usize = ROW_COUNT - 1;
1029
1030 let len_per_row = self.len() / ROW_COUNT;
1031 let chunk_count = len_per_row / A::VectorType::COMPLEX_PER_VECTOR;
1032
1033 for (c, twiddle_chunk) in self
1035 .common_data
1036 .twiddles
1037 .chunks_exact(TWIDDLES_PER_COLUMN)
1038 .take(chunk_count)
1039 .enumerate()
1040 {
1041 let index_base = c * A::VectorType::COMPLEX_PER_VECTOR;
1042
1043 column_butterfly16_loadfn!(
1044 |index| input.load_complex(index_base + len_per_row * index),
1045 |mut data, index| {
1046 if index > 0 {
1047 data = AvxVector::mul_complex(data, twiddle_chunk[index - 1]);
1048 }
1049 buffer.store_complex(data, index_base + len_per_row * index)
1050 },
1051 self.twiddles_butterfly16,
1052 self.twiddles_butterfly4
1053 );
1054 }
1055
1056 let partial_remainder = len_per_row % A::VectorType::COMPLEX_PER_VECTOR;
1059 if partial_remainder > 0 {
1060 let partial_remainder_base = chunk_count * A::VectorType::COMPLEX_PER_VECTOR;
1061 let partial_remainder_twiddle_base =
1062 self.common_data.twiddles.len() - TWIDDLES_PER_COLUMN;
1063 let final_twiddle_chunk = &self.common_data.twiddles[partial_remainder_twiddle_base..];
1064
1065 match partial_remainder {
1066 1 => {
1067 for c in 0..self.len() / len_per_row {
1068 let cs = c * len_per_row + len_per_row - partial_remainder;
1069 buffer.store_partial1_complex(input.load_partial1_complex(cs), cs);
1070 }
1071 column_butterfly16_loadfn!(
1072 |index| buffer
1073 .load_partial1_complex(partial_remainder_base + len_per_row * index),
1074 |mut data, index| {
1075 if index > 0 {
1076 let twiddle: A::VectorType = final_twiddle_chunk[index - 1];
1077 data = AvxVector::mul_complex(data, twiddle.lo());
1078 }
1079 buffer.store_partial1_complex(
1080 data,
1081 partial_remainder_base + len_per_row * index,
1082 )
1083 },
1084 [
1085 self.twiddles_butterfly16[0].lo(),
1086 self.twiddles_butterfly16[1].lo()
1087 ],
1088 self.twiddles_butterfly4.lo()
1089 );
1090 }
1091 2 => {
1092 for c in 0..self.len() / len_per_row {
1093 let cs = c * len_per_row + len_per_row - partial_remainder;
1094 buffer.store_partial2_complex(input.load_partial2_complex(cs), cs);
1095 }
1096 column_butterfly16_loadfn!(
1097 |index| buffer
1098 .load_partial2_complex(partial_remainder_base + len_per_row * index),
1099 |mut data, index| {
1100 if index > 0 {
1101 let twiddle: A::VectorType = final_twiddle_chunk[index - 1];
1102 data = AvxVector::mul_complex(data, twiddle.lo());
1103 }
1104 buffer.store_partial2_complex(
1105 data,
1106 partial_remainder_base + len_per_row * index,
1107 )
1108 },
1109 [
1110 self.twiddles_butterfly16[0].lo(),
1111 self.twiddles_butterfly16[1].lo()
1112 ],
1113 self.twiddles_butterfly4.lo()
1114 );
1115 }
1116 3 => {
1117 for c in 0..self.len() / len_per_row {
1118 let cs = c * len_per_row + len_per_row - partial_remainder;
1119 buffer.store_partial3_complex(input.load_partial3_complex(cs), cs);
1120 }
1121 column_butterfly16_loadfn!(
1122 |index| buffer
1123 .load_partial3_complex(partial_remainder_base + len_per_row * index),
1124 |mut data, index| {
1125 if index > 0 {
1126 data = AvxVector::mul_complex(data, final_twiddle_chunk[index - 1]);
1127 }
1128 buffer.store_partial3_complex(
1129 data,
1130 partial_remainder_base + len_per_row * index,
1131 )
1132 },
1133 self.twiddles_butterfly16,
1134 self.twiddles_butterfly4
1135 );
1136 }
1137 _ => unreachable!(),
1138 }
1139 }
1140 }
1141 mixedradix_transpose!(16,
1142 AvxVector::transpose16_packed,
1143 AvxVector::transpose16_packed,
1144 0;1;2;3;4;5;6;7;8;9;10;11;12;13;14;15, 0;1;2;3;4;5;6;7;8;9;10;11
1145 );
1146 boilerplate_mixedradix!();
1147}
1148
1149#[cfg(test)]
1150mod unit_tests {
1151 use super::*;
1152 use crate::algorithm::*;
1153 use crate::test_utils::check_fft_algorithm;
1154 use std::sync::Arc;
1155
1156 macro_rules! test_avx_mixed_radix {
1157 ($f32_test_name:ident, $f64_test_name:ident, $struct_name:ident, $inner_count:expr) => (
1158 #[test]
1159 fn $f32_test_name() {
1160 for inner_fft_len in 1..32 {
1161 let len = inner_fft_len * $inner_count;
1162
1163 let inner_fft_forward = Arc::new(Dft::new(inner_fft_len, FftDirection::Forward)) as Arc<dyn Fft<f32>>;
1164 let fft_forward = $struct_name::<f32, f32>::new(inner_fft_forward).expect("Can't run test because this machine doesn't have the required instruction sets");
1165 check_fft_algorithm(&fft_forward, len, FftDirection::Forward);
1166
1167 let inner_fft_inverse = Arc::new(Dft::new(inner_fft_len, FftDirection::Inverse)) as Arc<dyn Fft<f32>>;
1168 let fft_inverse = $struct_name::<f32, f32>::new(inner_fft_inverse).expect("Can't run test because this machine doesn't have the required instruction sets");
1169 check_fft_algorithm(&fft_inverse, len, FftDirection::Inverse);
1170 }
1171 }
1172 #[test]
1173 fn $f64_test_name() {
1174 for inner_fft_len in 1..32 {
1175 let len = inner_fft_len * $inner_count;
1176
1177 let inner_fft_forward = Arc::new(Dft::new(inner_fft_len, FftDirection::Forward)) as Arc<dyn Fft<f64>>;
1178 let fft_forward = $struct_name::<f64, f64>::new(inner_fft_forward).expect("Can't run test because this machine doesn't have the required instruction sets");
1179 check_fft_algorithm(&fft_forward, len, FftDirection::Forward);
1180
1181 let inner_fft_inverse = Arc::new(Dft::new(inner_fft_len, FftDirection::Inverse)) as Arc<dyn Fft<f64>>;
1182 let fft_inverse = $struct_name::<f64, f64>::new(inner_fft_inverse).expect("Can't run test because this machine doesn't have the required instruction sets");
1183 check_fft_algorithm(&fft_inverse, len, FftDirection::Inverse);
1184 }
1185 }
1186 )
1187 }
1188
1189 test_avx_mixed_radix!(
1190 test_mixedradix_2xn_avx_f32,
1191 test_mixedradix_2xn_avx_f64,
1192 MixedRadix2xnAvx,
1193 2
1194 );
1195 test_avx_mixed_radix!(
1196 test_mixedradix_3xn_avx_f32,
1197 test_mixedradix_3xn_avx_f64,
1198 MixedRadix3xnAvx,
1199 3
1200 );
1201 test_avx_mixed_radix!(
1202 test_mixedradix_4xn_avx_f32,
1203 test_mixedradix_4xn_avx_f64,
1204 MixedRadix4xnAvx,
1205 4
1206 );
1207 test_avx_mixed_radix!(
1208 test_mixedradix_5xn_avx_f32,
1209 test_mixedradix_5xn_avx_f64,
1210 MixedRadix5xnAvx,
1211 5
1212 );
1213 test_avx_mixed_radix!(
1214 test_mixedradix_6xn_avx_f32,
1215 test_mixedradix_6xn_avx_f64,
1216 MixedRadix6xnAvx,
1217 6
1218 );
1219 test_avx_mixed_radix!(
1220 test_mixedradix_7xn_avx_f32,
1221 test_mixedradix_7xn_avx_f64,
1222 MixedRadix7xnAvx,
1223 7
1224 );
1225 test_avx_mixed_radix!(
1226 test_mixedradix_8xn_avx_f32,
1227 test_mixedradix_8xn_avx_f64,
1228 MixedRadix8xnAvx,
1229 8
1230 );
1231 test_avx_mixed_radix!(
1232 test_mixedradix_9xn_avx_f32,
1233 test_mixedradix_9xn_avx_f64,
1234 MixedRadix9xnAvx,
1235 9
1236 );
1237 test_avx_mixed_radix!(
1238 test_mixedradix_11xn_avx_f32,
1239 test_mixedradix_11xn_avx_f64,
1240 MixedRadix11xnAvx,
1241 11
1242 );
1243 test_avx_mixed_radix!(
1244 test_mixedradix_12xn_avx_f32,
1245 test_mixedradix_12xn_avx_f64,
1246 MixedRadix12xnAvx,
1247 12
1248 );
1249 test_avx_mixed_radix!(
1250 test_mixedradix_16xn_avx_f32,
1251 test_mixedradix_16xn_avx_f64,
1252 MixedRadix16xnAvx,
1253 16
1254 );
1255}