rustfft/avx/
avx_vector.rs

1use std::arch::x86_64::*;
2use std::fmt::Debug;
3use std::ops::{Deref, DerefMut};
4
5use num_complex::Complex;
6use num_traits::Zero;
7
8use crate::{array_utils::DoubleBuf, twiddles, FftDirection};
9
10use super::AvxNum;
11
12/// A SIMD vector of complex numbers, stored with the real values and imaginary values interleaved.
13/// Implemented for __m128, __m128d, __m256, __m256d, but these all require the AVX instruction set.
14///
15/// The goal of this trait is to reduce code duplication by letting code be generic over the vector type
16pub trait AvxVector: Copy + Debug + Send + Sync {
17    const SCALAR_PER_VECTOR: usize;
18    const COMPLEX_PER_VECTOR: usize;
19
20    // useful constants
21    unsafe fn zero() -> Self;
22    unsafe fn half_root2() -> Self; // an entire vector filled with 0.5.sqrt()
23
24    // Basic operations that map directly to 1-2 AVX intrinsics
25    unsafe fn add(left: Self, right: Self) -> Self;
26    unsafe fn sub(left: Self, right: Self) -> Self;
27    unsafe fn xor(left: Self, right: Self) -> Self;
28    unsafe fn neg(self) -> Self;
29    unsafe fn mul(left: Self, right: Self) -> Self;
30    unsafe fn fmadd(left: Self, right: Self, add: Self) -> Self;
31    unsafe fn fnmadd(left: Self, right: Self, add: Self) -> Self;
32    unsafe fn fmaddsub(left: Self, right: Self, add: Self) -> Self;
33    unsafe fn fmsubadd(left: Self, right: Self, add: Self) -> Self;
34
35    // More basic operations that end up being implemented in 1-2 intrinsics, but unlike the ones above, these have higher-level meaning than just arithmetic
36    /// Swap each real number with its corresponding imaginary number
37    unsafe fn swap_complex_components(self) -> Self;
38
39    /// first return is the reals duplicated into the imaginaries, second return is the imaginaries duplicated into the reals
40    unsafe fn duplicate_complex_components(self) -> (Self, Self);
41
42    /// Reverse the order of complex numbers in the vector, so that the last is the first and the first is the last
43    unsafe fn reverse_complex_elements(self) -> Self;
44
45    /// Copies the even elements of rows[1] into the corresponding odd elements of rows[0] and returns the result.
46    unsafe fn unpacklo_complex(rows: [Self; 2]) -> Self;
47    /// Copies the odd elements of rows[0] into the corresponding even elements of rows[1] and returns the result.
48    unsafe fn unpackhi_complex(rows: [Self; 2]) -> Self;
49
50    #[inline(always)]
51    unsafe fn unpack_complex(rows: [Self; 2]) -> [Self; 2] {
52        [Self::unpacklo_complex(rows), Self::unpackhi_complex(rows)]
53    }
54
55    /// Fill a vector by computing a twiddle factor and repeating it across the whole vector
56    unsafe fn broadcast_twiddle(index: usize, len: usize, direction: FftDirection) -> Self;
57
58    /// create a Rotator90 instance to rotate complex numbers either 90 or 270 degrees, based on the value of `inverse`
59    unsafe fn make_rotation90(direction: FftDirection) -> Rotation90<Self>;
60
61    /// Generates a chunk of twiddle factors starting at (X,Y) and incrementing X `COMPLEX_PER_VECTOR` times.
62    /// The result will be [twiddle(x*y, len), twiddle((x+1)*y, len), twiddle((x+2)*y, len), ...] for as many complex numbers fit in a vector
63    unsafe fn make_mixedradix_twiddle_chunk(
64        x: usize,
65        y: usize,
66        len: usize,
67        direction: FftDirection,
68    ) -> Self;
69
70    /// Packed transposes. Used by mixed radix. These all take a NxC array, where C is COMPLEX_PER_VECTOR, and transpose it to a CxN array.
71    /// But they also pack the result into as few vectors as possible, with the goal of writing the transposed data out contiguously.
72    unsafe fn transpose2_packed(rows: [Self; 2]) -> [Self; 2];
73    unsafe fn transpose3_packed(rows: [Self; 3]) -> [Self; 3];
74    unsafe fn transpose4_packed(rows: [Self; 4]) -> [Self; 4];
75    unsafe fn transpose5_packed(rows: [Self; 5]) -> [Self; 5];
76    unsafe fn transpose6_packed(rows: [Self; 6]) -> [Self; 6];
77    unsafe fn transpose7_packed(rows: [Self; 7]) -> [Self; 7];
78    unsafe fn transpose8_packed(rows: [Self; 8]) -> [Self; 8];
79    unsafe fn transpose9_packed(rows: [Self; 9]) -> [Self; 9];
80    unsafe fn transpose11_packed(rows: [Self; 11]) -> [Self; 11];
81    unsafe fn transpose12_packed(rows: [Self; 12]) -> [Self; 12];
82    unsafe fn transpose16_packed(rows: [Self; 16]) -> [Self; 16];
83
84    /// Pairwise multiply the complex numbers in `left` with the complex numbers in `right`.
85    #[inline(always)]
86    unsafe fn mul_complex(left: Self, right: Self) -> Self {
87        // Extract the real and imaginary components from left into 2 separate registers
88        let (left_real, left_imag) = Self::duplicate_complex_components(left);
89
90        // create a shuffled version of right where the imaginary values are swapped with the reals
91        let right_shuffled = Self::swap_complex_components(right);
92
93        // multiply our duplicated imaginary left vector by our shuffled right vector. that will give us the right side of the traditional complex multiplication formula
94        let output_right = Self::mul(left_imag, right_shuffled);
95
96        // use a FMA instruction to multiply together left side of the complex multiplication formula, then alternatingly add and subtract the left side from the right
97        Self::fmaddsub(left_real, right, output_right)
98    }
99
100    #[inline(always)]
101    unsafe fn rotate90(self, rotation: Rotation90<Self>) -> Self {
102        // Use the pre-computed vector stored in the Rotation90 instance to negate either the reals or imaginaries
103        let negated = Self::xor(self, rotation.0);
104
105        // Our goal is to swap the reals with the imaginaries, then negate either the reals or the imaginaries, based on whether we're an inverse or not
106        Self::swap_complex_components(negated)
107    }
108
109    #[inline(always)]
110    unsafe fn column_butterfly2(rows: [Self; 2]) -> [Self; 2] {
111        [Self::add(rows[0], rows[1]), Self::sub(rows[0], rows[1])]
112    }
113
114    #[inline(always)]
115    unsafe fn column_butterfly3(rows: [Self; 3], twiddles: Self) -> [Self; 3] {
116        // This algorithm is derived directly from the definition of the Dft of size 3
117        // We'd theoretically have to do 4 complex multiplications, but all of the twiddles we'd be multiplying by are conjugates of each other
118        // By doing some algebra to expand the complex multiplications and factor out the multiplications, we get this
119
120        let [mut mid1, mid2] = Self::column_butterfly2([rows[1], rows[2]]);
121        let output0 = Self::add(rows[0], mid1);
122
123        let (twiddle_real, twiddle_imag) = Self::duplicate_complex_components(twiddles);
124
125        mid1 = Self::fmadd(mid1, twiddle_real, rows[0]);
126
127        let rotation = Self::make_rotation90(FftDirection::Inverse);
128        let mid2_rotated = Self::rotate90(mid2, rotation);
129
130        let output1 = Self::fmadd(mid2_rotated, twiddle_imag, mid1);
131        let output2 = Self::fnmadd(mid2_rotated, twiddle_imag, mid1);
132
133        [output0, output1, output2]
134    }
135
136    #[inline(always)]
137    unsafe fn column_butterfly4(rows: [Self; 4], rotation: Rotation90<Self>) -> [Self; 4] {
138        // Algorithm: 2x2 mixed radix
139
140        // Perform the first set of size-2 FFTs.
141        let [mid0, mid2] = Self::column_butterfly2([rows[0], rows[2]]);
142        let [mid1, mid3] = Self::column_butterfly2([rows[1], rows[3]]);
143
144        // Apply twiddle factors (in this case just a rotation)
145        let mid3_rotated = mid3.rotate90(rotation);
146
147        // Transpose the data and do size-2 FFTs down the columns
148        let [output0, output1] = Self::column_butterfly2([mid0, mid1]);
149        let [output2, output3] = Self::column_butterfly2([mid2, mid3_rotated]);
150
151        // Swap outputs 1 and 2 in the output to do a square transpose
152        [output0, output2, output1, output3]
153    }
154
155    #[inline(always)]
156    unsafe fn column_butterfly5(rows: [Self; 5], twiddles: [Self; 2]) -> [Self; 5] {
157        // This algorithm is derived directly from the definition of the Dft of size 5
158        // We'd theoretically have to do 16 complex multiplications for the Dft, but many of the twiddles we'd be multiplying by are conjugates of each other
159        // By doing some algebra to expand the complex multiplications and factor out the real multiplications, we get this faster formula where we only do the equivalent of 4 multiplications
160
161        // do some prep work before we can start applying twiddle factors
162        let [sum1, diff4] = Self::column_butterfly2([rows[1], rows[4]]);
163        let [sum2, diff3] = Self::column_butterfly2([rows[2], rows[3]]);
164
165        let rotation = Self::make_rotation90(FftDirection::Inverse);
166        let rotated4 = Self::rotate90(diff4, rotation);
167        let rotated3 = Self::rotate90(diff3, rotation);
168
169        // to compute the first output, compute the sum of all elements. sum1 and sum2 already have the sum of 1+4 and 2+3 respectively, so if we add them, we'll get the sum of all 4
170        let sum1234 = Self::add(sum1, sum2);
171        let output0 = Self::add(rows[0], sum1234);
172
173        // apply twiddle factors
174        let (twiddles0_re, twiddles0_im) = Self::duplicate_complex_components(twiddles[0]);
175        let (twiddles1_re, twiddles1_im) = Self::duplicate_complex_components(twiddles[1]);
176
177        let twiddled1_mid = Self::fmadd(twiddles0_re, sum1, rows[0]);
178        let twiddled2_mid = Self::fmadd(twiddles1_re, sum1, rows[0]);
179        let twiddled3_mid = Self::mul(twiddles1_im, rotated4);
180        let twiddled4_mid = Self::mul(twiddles0_im, rotated4);
181
182        let twiddled1 = Self::fmadd(twiddles1_re, sum2, twiddled1_mid);
183        let twiddled2 = Self::fmadd(twiddles0_re, sum2, twiddled2_mid);
184        let twiddled3 = Self::fnmadd(twiddles0_im, rotated3, twiddled3_mid); // fnmadd instead of fmadd because we're actually re-using twiddle0 here. remember that this algorithm is all about factoring out conjugated multiplications -- this negation of the twiddle0 imaginaries is a reflection of one of those conugations
185        let twiddled4 = Self::fmadd(twiddles1_im, rotated3, twiddled4_mid);
186
187        // Post-processing to mix the twiddle factors between the rest of the output
188        let [output1, output4] = Self::column_butterfly2([twiddled1, twiddled4]);
189        let [output2, output3] = Self::column_butterfly2([twiddled2, twiddled3]);
190
191        [output0, output1, output2, output3, output4]
192    }
193
194    #[inline(always)]
195    unsafe fn column_butterfly7(rows: [Self; 7], twiddles: [Self; 3]) -> [Self; 7] {
196        // This algorithm is derived directly from the definition of the Dft of size 7
197        // We'd theoretically have to do 36 complex multiplications for the Dft, but many of the twiddles we'd be multiplying by are conjugates of each other
198        // By doing some algebra to expand the complex multiplications and factor out the real multiplications, we get this faster formula where we only do the equivalent of 9 multiplications
199
200        // do some prep work before we can start applying twiddle factors
201        let [sum1, diff6] = Self::column_butterfly2([rows[1], rows[6]]);
202        let [sum2, diff5] = Self::column_butterfly2([rows[2], rows[5]]);
203        let [sum3, diff4] = Self::column_butterfly2([rows[3], rows[4]]);
204
205        let rotation = Self::make_rotation90(FftDirection::Inverse);
206        let rotated4 = Self::rotate90(diff4, rotation);
207        let rotated5 = Self::rotate90(diff5, rotation);
208        let rotated6 = Self::rotate90(diff6, rotation);
209
210        // to compute the first output, compute the sum of all elements. sum1, sum2, and sum3 already have the sum of 1+6 and 2+5 and 3+4 respectively, so if we add them, we'll get the sum of all 6
211        let output0_left = Self::add(sum1, sum2);
212        let output0_right = Self::add(sum3, rows[0]);
213        let output0 = Self::add(output0_left, output0_right);
214
215        // apply twiddle factors. This is probably pushing the limit of how much we should do with this technique.
216        // We probably shouldn't do a size-11 FFT with this technique, for example, because this block of multiplies would grow quadratically
217        let (twiddles0_re, twiddles0_im) = Self::duplicate_complex_components(twiddles[0]);
218        let (twiddles1_re, twiddles1_im) = Self::duplicate_complex_components(twiddles[1]);
219        let (twiddles2_re, twiddles2_im) = Self::duplicate_complex_components(twiddles[2]);
220
221        // Let's do a plain 7-point Dft
222        // | X0 |   | W0  W0  W0  W0  W0  W0  W0  |   | x0 |
223        // | X1 |   | W0  W1  W2  W3  W4  W5  W6  |   | x1 |
224        // | X2 |   | W0  W2  W4  W6  W8  W10 W12 |   | x2 |
225        // | X3 |   | W0  W3  W6  W9  W12 W15 W18 |   | x3 |
226        // | X4 |   | W0  W4  W8  W12 W16 W20 W24 |   | x4 |
227        // | X5 |   | W0  W5  W10 W15 W20 W25 W30 |   | x5 |
228        // | X6 |   | W0  W6  W12 W18 W24 W30 W36 |   | x6 |
229        // where Wn = exp(-2*pi*n/7) for a forward transform, and exp(+2*pi*n/7) for an inverse.
230
231        // Next, take advantage of the fact that twiddle factor indexes for a size-7 Dft are cyclical mod 7
232        // | X0 |   | W0  W0  W0  W0  W0  W0  W0  |   | x0 |
233        // | X1 |   | W0  W1  W2  W3  W4  W5  W6  |   | x1 |
234        // | X2 |   | W0  W2  W4  W6  W1  W3  W5  |   | x2 |
235        // | X3 |   | W0  W3  W6  W2  W5  W1  W4  |   | x3 |
236        // | X4 |   | W0  W4  W1  W5  W2  W6  W3  |   | x4 |
237        // | X5 |   | W0  W5  W3  W1  W6  W4  W2  |   | x5 |
238        // | X6 |   | W0  W6  W5  W4  W3  W2  W1  |   | x6 |
239
240        // Finally, take advantage of the fact that for a size-7 Dft,
241        // twiddles 4 through 6 are conjugates of twiddes 3 through 0 (Asterisk marks conjugates)
242        // | X0 |   | W0  W0  W0  W0  W0  W0  W0  |   | x0 |
243        // | X1 |   | W0  W1  W2  W3  W3* W2* W1* |   | x1 |
244        // | X2 |   | W0  W2  W3* W1* W1  W3  W2* |   | x2 |
245        // | X3 |   | W0  W3  W1* W2  W2* W1  W3* |   | x3 |
246        // | X4 |   | W0  W3* W1  W2* W2  W1* W3  |   | x4 |
247        // | X5 |   | W0  W2* W3  W1  W1* W3* W2  |   | x5 |
248        // | X6 |   | W0  W1* W2* W3* W3  W2  W1  |   | x6 |
249
250        let twiddled1_mid = Self::fmadd(twiddles0_re, sum1, rows[0]);
251        let twiddled2_mid = Self::fmadd(twiddles1_re, sum1, rows[0]);
252        let twiddled3_mid = Self::fmadd(twiddles2_re, sum1, rows[0]);
253        let twiddled4_mid = Self::mul(twiddles2_im, rotated6);
254        let twiddled5_mid = Self::mul(twiddles1_im, rotated6);
255        let twiddled6_mid = Self::mul(twiddles0_im, rotated6);
256
257        let twiddled1_mid2 = Self::fmadd(twiddles1_re, sum2, twiddled1_mid);
258        let twiddled2_mid2 = Self::fmadd(twiddles2_re, sum2, twiddled2_mid);
259        let twiddled3_mid2 = Self::fmadd(twiddles0_re, sum2, twiddled3_mid);
260        let twiddled4_mid2 = Self::fnmadd(twiddles0_im, rotated5, twiddled4_mid); // fnmadd instead of fmadd because we're actually re-using twiddle0 here. remember that this algorithm is all about factoring out conjugated multiplications -- this negation of the twiddle0 imaginaries is a reflection of one of those conugations
261        let twiddled5_mid2 = Self::fnmadd(twiddles2_im, rotated5, twiddled5_mid);
262        let twiddled6_mid2 = Self::fmadd(twiddles1_im, rotated5, twiddled6_mid);
263
264        let twiddled1 = Self::fmadd(twiddles2_re, sum3, twiddled1_mid2);
265        let twiddled2 = Self::fmadd(twiddles0_re, sum3, twiddled2_mid2);
266        let twiddled3 = Self::fmadd(twiddles1_re, sum3, twiddled3_mid2);
267        let twiddled4 = Self::fmadd(twiddles1_im, rotated4, twiddled4_mid2);
268        let twiddled5 = Self::fnmadd(twiddles0_im, rotated4, twiddled5_mid2);
269        let twiddled6 = Self::fmadd(twiddles2_im, rotated4, twiddled6_mid2);
270
271        // Post-processing to mix the twiddle factors between the rest of the output
272        let [output1, output6] = Self::column_butterfly2([twiddled1, twiddled6]);
273        let [output2, output5] = Self::column_butterfly2([twiddled2, twiddled5]);
274        let [output3, output4] = Self::column_butterfly2([twiddled3, twiddled4]);
275
276        [
277            output0, output1, output2, output3, output4, output5, output6,
278        ]
279    }
280
281    #[inline(always)]
282    unsafe fn column_butterfly8(rows: [Self; 8], rotation: Rotation90<Self>) -> [Self; 8] {
283        // Algorithm: 4x2 mixed radix
284
285        // Size-4 FFTs down the columns
286        let mid0 = Self::column_butterfly4([rows[0], rows[2], rows[4], rows[6]], rotation);
287        let mut mid1 = Self::column_butterfly4([rows[1], rows[3], rows[5], rows[7]], rotation);
288
289        // Apply twiddle factors
290        mid1[1] = apply_butterfly8_twiddle1(mid1[1], rotation);
291        mid1[2] = mid1[2].rotate90(rotation);
292        mid1[3] = apply_butterfly8_twiddle3(mid1[3], rotation);
293
294        // Transpose the data and do size-2 FFTs down the columns
295        let [output0, output1] = Self::column_butterfly2([mid0[0], mid1[0]]);
296        let [output2, output3] = Self::column_butterfly2([mid0[1], mid1[1]]);
297        let [output4, output5] = Self::column_butterfly2([mid0[2], mid1[2]]);
298        let [output6, output7] = Self::column_butterfly2([mid0[3], mid1[3]]);
299
300        [
301            output0, output2, output4, output6, output1, output3, output5, output7,
302        ]
303    }
304
305    #[inline(always)]
306    unsafe fn column_butterfly11(rows: [Self; 11], twiddles: [Self; 5]) -> [Self; 11] {
307        // This algorithm is derived directly from the definition of the Dft of size 11
308        // We'd theoretically have to do 100 complex multiplications for the Dft, but many of the twiddles we'd be multiplying by are conjugates of each other
309        // By doing some algebra to expand the complex multiplications and factor out the real multiplications, we get this faster formula where we only do the equivalent of 9 multiplications
310
311        // do some prep work before we can start applying twiddle factors
312        let [sum1, diff10] = Self::column_butterfly2([rows[1], rows[10]]);
313        let [sum2, diff9] = Self::column_butterfly2([rows[2], rows[9]]);
314        let [sum3, diff8] = Self::column_butterfly2([rows[3], rows[8]]);
315        let [sum4, diff7] = Self::column_butterfly2([rows[4], rows[7]]);
316        let [sum5, diff6] = Self::column_butterfly2([rows[5], rows[6]]);
317
318        let rotation = Self::make_rotation90(FftDirection::Inverse);
319        let rotated10 = Self::rotate90(diff10, rotation);
320        let rotated9 = Self::rotate90(diff9, rotation);
321        let rotated8 = Self::rotate90(diff8, rotation);
322        let rotated7 = Self::rotate90(diff7, rotation);
323        let rotated6 = Self::rotate90(diff6, rotation);
324
325        // to compute the first output, compute the sum of all elements. sum1, sum2, and sum3 already have the sum of 1+6 and 2+5 and 3+4 respectively, so if we add them, we'll get the sum of all 6
326        let sum01 = Self::add(rows[0], sum1);
327        let sum23 = Self::add(sum2, sum3);
328        let sum45 = Self::add(sum4, sum5);
329        let sum0123 = Self::add(sum01, sum23);
330        let output0 = Self::add(sum0123, sum45);
331
332        // apply twiddle factors. This is probably pushing the limit of how much we should do with this technique.
333        // We probably shouldn't do a size-11 FFT with this technique, for example, because this block of multiplies would grow quadratically
334        let (twiddles0_re, twiddles0_im) = Self::duplicate_complex_components(twiddles[0]);
335        let (twiddles1_re, twiddles1_im) = Self::duplicate_complex_components(twiddles[1]);
336        let (twiddles2_re, twiddles2_im) = Self::duplicate_complex_components(twiddles[2]);
337        let (twiddles3_re, twiddles3_im) = Self::duplicate_complex_components(twiddles[3]);
338        let (twiddles4_re, twiddles4_im) = Self::duplicate_complex_components(twiddles[4]);
339
340        let twiddled1 = Self::fmadd(twiddles0_re, sum1, rows[0]);
341        let twiddled2 = Self::fmadd(twiddles1_re, sum1, rows[0]);
342        let twiddled3 = Self::fmadd(twiddles2_re, sum1, rows[0]);
343        let twiddled4 = Self::fmadd(twiddles3_re, sum1, rows[0]);
344        let twiddled5 = Self::fmadd(twiddles4_re, sum1, rows[0]);
345        let twiddled6 = Self::mul(twiddles4_im, rotated10);
346        let twiddled7 = Self::mul(twiddles3_im, rotated10);
347        let twiddled8 = Self::mul(twiddles2_im, rotated10);
348        let twiddled9 = Self::mul(twiddles1_im, rotated10);
349        let twiddled10 = Self::mul(twiddles0_im, rotated10);
350
351        let twiddled1 = Self::fmadd(twiddles1_re, sum2, twiddled1);
352        let twiddled2 = Self::fmadd(twiddles3_re, sum2, twiddled2);
353        let twiddled3 = Self::fmadd(twiddles4_re, sum2, twiddled3);
354        let twiddled4 = Self::fmadd(twiddles2_re, sum2, twiddled4);
355        let twiddled5 = Self::fmadd(twiddles0_re, sum2, twiddled5);
356        let twiddled6 = Self::fnmadd(twiddles0_im, rotated9, twiddled6);
357        let twiddled7 = Self::fnmadd(twiddles2_im, rotated9, twiddled7);
358        let twiddled8 = Self::fnmadd(twiddles4_im, rotated9, twiddled8);
359        let twiddled9 = Self::fmadd(twiddles3_im, rotated9, twiddled9);
360        let twiddled10 = Self::fmadd(twiddles1_im, rotated9, twiddled10);
361
362        let twiddled1 = Self::fmadd(twiddles2_re, sum3, twiddled1);
363        let twiddled2 = Self::fmadd(twiddles4_re, sum3, twiddled2);
364        let twiddled3 = Self::fmadd(twiddles1_re, sum3, twiddled3);
365        let twiddled4 = Self::fmadd(twiddles0_re, sum3, twiddled4);
366        let twiddled5 = Self::fmadd(twiddles3_re, sum3, twiddled5);
367        let twiddled6 = Self::fmadd(twiddles3_im, rotated8, twiddled6);
368        let twiddled7 = Self::fmadd(twiddles0_im, rotated8, twiddled7);
369        let twiddled8 = Self::fnmadd(twiddles1_im, rotated8, twiddled8);
370        let twiddled9 = Self::fnmadd(twiddles4_im, rotated8, twiddled9);
371        let twiddled10 = Self::fmadd(twiddles2_im, rotated8, twiddled10);
372
373        let twiddled1 = Self::fmadd(twiddles3_re, sum4, twiddled1);
374        let twiddled2 = Self::fmadd(twiddles2_re, sum4, twiddled2);
375        let twiddled3 = Self::fmadd(twiddles0_re, sum4, twiddled3);
376        let twiddled4 = Self::fmadd(twiddles4_re, sum4, twiddled4);
377        let twiddled5 = Self::fmadd(twiddles1_re, sum4, twiddled5);
378        let twiddled6 = Self::fnmadd(twiddles1_im, rotated7, twiddled6);
379        let twiddled7 = Self::fmadd(twiddles4_im, rotated7, twiddled7);
380        let twiddled8 = Self::fmadd(twiddles0_im, rotated7, twiddled8);
381        let twiddled9 = Self::fnmadd(twiddles2_im, rotated7, twiddled9);
382        let twiddled10 = Self::fmadd(twiddles3_im, rotated7, twiddled10);
383
384        let twiddled1 = Self::fmadd(twiddles4_re, sum5, twiddled1);
385        let twiddled2 = Self::fmadd(twiddles0_re, sum5, twiddled2);
386        let twiddled3 = Self::fmadd(twiddles3_re, sum5, twiddled3);
387        let twiddled4 = Self::fmadd(twiddles1_re, sum5, twiddled4);
388        let twiddled5 = Self::fmadd(twiddles2_re, sum5, twiddled5);
389        let twiddled6 = Self::fmadd(twiddles2_im, rotated6, twiddled6);
390        let twiddled7 = Self::fnmadd(twiddles1_im, rotated6, twiddled7);
391        let twiddled8 = Self::fmadd(twiddles3_im, rotated6, twiddled8);
392        let twiddled9 = Self::fnmadd(twiddles0_im, rotated6, twiddled9);
393        let twiddled10 = Self::fmadd(twiddles4_im, rotated6, twiddled10);
394
395        // Post-processing to mix the twiddle factors between the rest of the output
396        let [output1, output10] = Self::column_butterfly2([twiddled1, twiddled10]);
397        let [output2, output9] = Self::column_butterfly2([twiddled2, twiddled9]);
398        let [output3, output8] = Self::column_butterfly2([twiddled3, twiddled8]);
399        let [output4, output7] = Self::column_butterfly2([twiddled4, twiddled7]);
400        let [output5, output6] = Self::column_butterfly2([twiddled5, twiddled6]);
401
402        [
403            output0, output1, output2, output3, output4, output5, output6, output7, output8,
404            output9, output10,
405        ]
406    }
407}
408
409/// A 256-bit SIMD vector of complex numbers, stored with the real values and imaginary values interleaved.
410/// Implemented for __m256, __m256d
411///
412/// This trait implements things specific to 256-types, like splitting a 256 vector into 128 vectors
413/// For compiler-placation reasons, all interactions/awareness the scalar type go here
414pub trait AvxVector256: AvxVector {
415    type HalfVector: AvxVector128<FullVector = Self>;
416    type ScalarType: AvxNum<VectorType = Self>;
417
418    unsafe fn lo(self) -> Self::HalfVector;
419    unsafe fn hi(self) -> Self::HalfVector;
420    unsafe fn merge(lo: Self::HalfVector, hi: Self::HalfVector) -> Self;
421
422    /// Fill a vector by repeating the provided complex number as many times as possible
423    unsafe fn broadcast_complex_elements(value: Complex<Self::ScalarType>) -> Self;
424
425    // loads/stores of complex numbers
426    unsafe fn load_complex(ptr: *const Complex<Self::ScalarType>) -> Self;
427    unsafe fn store_complex(ptr: *mut Complex<Self::ScalarType>, data: Self);
428
429    // Gather 4 complex numbers (for f32) or 2 complex numbers (for f64) using 4 i32 indexes (for index32) or 4 i64 indexes (for index64).
430    // For f32, there should be 1 index per complex. For f64, there should be 2 indexes, each duplicated
431    // (So to load the complex<f64> at index 5 and 7, the index vector should contain 5,5,7,7. this api sucks but it's internal so whatever.)
432    unsafe fn gather_complex_avx2_index32(
433        ptr: *const Complex<Self::ScalarType>,
434        indexes: __m128i,
435    ) -> Self;
436    unsafe fn gather_complex_avx2_index64(
437        ptr: *const Complex<Self::ScalarType>,
438        indexes: __m256i,
439    ) -> Self;
440
441    // loads/stores of partial vectors of complex numbers. When loading, empty elements are zeroed
442    // unimplemented!() if Self::COMPLEX_PER_VECTOR is not greater than the partial count
443    unsafe fn load_partial1_complex(ptr: *const Complex<Self::ScalarType>) -> Self::HalfVector;
444    unsafe fn load_partial2_complex(ptr: *const Complex<Self::ScalarType>) -> Self::HalfVector;
445    unsafe fn load_partial3_complex(ptr: *const Complex<Self::ScalarType>) -> Self;
446    unsafe fn store_partial1_complex(ptr: *mut Complex<Self::ScalarType>, data: Self::HalfVector);
447    unsafe fn store_partial2_complex(ptr: *mut Complex<Self::ScalarType>, data: Self::HalfVector);
448    unsafe fn store_partial3_complex(ptr: *mut Complex<Self::ScalarType>, data: Self);
449
450    #[inline(always)]
451    unsafe fn column_butterfly6(rows: [Self; 6], twiddles: Self) -> [Self; 6] {
452        // Algorithm: 3x2 good-thomas
453
454        // Size-3 FFTs down the columns of our reordered array
455        let mid0 = Self::column_butterfly3([rows[0], rows[2], rows[4]], twiddles);
456        let mid1 = Self::column_butterfly3([rows[3], rows[5], rows[1]], twiddles);
457
458        // We normally would put twiddle factors right here, but since this is good-thomas algorithm, we don't need twiddle factors
459
460        // Transpose the data and do size-2 FFTs down the columns
461        let [output0, output1] = Self::column_butterfly2([mid0[0], mid1[0]]);
462        let [output2, output3] = Self::column_butterfly2([mid0[1], mid1[1]]);
463        let [output4, output5] = Self::column_butterfly2([mid0[2], mid1[2]]);
464
465        // Reorder into output
466        [output0, output3, output4, output1, output2, output5]
467    }
468
469    #[inline(always)]
470    unsafe fn column_butterfly9(
471        rows: [Self; 9],
472        twiddles: [Self; 3],
473        butterfly3_twiddles: Self,
474    ) -> [Self; 9] {
475        // Algorithm: 3x3 mixed radix
476
477        // Size-3 FFTs down the columns
478        let mid0 = Self::column_butterfly3([rows[0], rows[3], rows[6]], butterfly3_twiddles);
479        let mut mid1 = Self::column_butterfly3([rows[1], rows[4], rows[7]], butterfly3_twiddles);
480        let mut mid2 = Self::column_butterfly3([rows[2], rows[5], rows[8]], butterfly3_twiddles);
481
482        // Apply twiddle factors. Note that we're re-using twiddles[1]
483        mid1[1] = Self::mul_complex(twiddles[0], mid1[1]);
484        mid1[2] = Self::mul_complex(twiddles[1], mid1[2]);
485        mid2[1] = Self::mul_complex(twiddles[1], mid2[1]);
486        mid2[2] = Self::mul_complex(twiddles[2], mid2[2]);
487
488        let [output0, output1, output2] =
489            Self::column_butterfly3([mid0[0], mid1[0], mid2[0]], butterfly3_twiddles);
490        let [output3, output4, output5] =
491            Self::column_butterfly3([mid0[1], mid1[1], mid2[1]], butterfly3_twiddles);
492        let [output6, output7, output8] =
493            Self::column_butterfly3([mid0[2], mid1[2], mid2[2]], butterfly3_twiddles);
494
495        [
496            output0, output3, output6, output1, output4, output7, output2, output5, output8,
497        ]
498    }
499
500    #[inline(always)]
501    unsafe fn column_butterfly12(
502        rows: [Self; 12],
503        butterfly3_twiddles: Self,
504        rotation: Rotation90<Self>,
505    ) -> [Self; 12] {
506        // Algorithm: 4x3 good-thomas
507
508        // Size-4 FFTs down the columns of our reordered array
509        let mid0 = Self::column_butterfly4([rows[0], rows[3], rows[6], rows[9]], rotation);
510        let mid1 = Self::column_butterfly4([rows[4], rows[7], rows[10], rows[1]], rotation);
511        let mid2 = Self::column_butterfly4([rows[8], rows[11], rows[2], rows[5]], rotation);
512
513        // Since this is good-thomas algorithm, we don't need twiddle factors
514
515        // Transpose the data and do size-2 FFTs down the columns
516        let [output0, output1, output2] =
517            Self::column_butterfly3([mid0[0], mid1[0], mid2[0]], butterfly3_twiddles);
518        let [output3, output4, output5] =
519            Self::column_butterfly3([mid0[1], mid1[1], mid2[1]], butterfly3_twiddles);
520        let [output6, output7, output8] =
521            Self::column_butterfly3([mid0[2], mid1[2], mid2[2]], butterfly3_twiddles);
522        let [output9, output10, output11] =
523            Self::column_butterfly3([mid0[3], mid1[3], mid2[3]], butterfly3_twiddles);
524
525        [
526            output0, output4, output8, output9, output1, output5, output6, output10, output2,
527            output3, output7, output11,
528        ]
529    }
530}
531
532/// A 128-bit SIMD vector of complex numbers, stored with the real values and imaginary values interleaved.
533/// Implemented for __m128, __m128d, but these are all oriented around AVX, so don't call methods on these from a SSE-only context
534///
535/// This trait implements things specific to 128-types, like merging 2 128 vectors into a 256 vector
536pub trait AvxVector128: AvxVector {
537    type FullVector: AvxVector256<HalfVector = Self>;
538
539    unsafe fn merge(lo: Self, hi: Self) -> Self::FullVector;
540    unsafe fn zero_extend(self) -> Self::FullVector;
541
542    unsafe fn lo(input: Self::FullVector) -> Self;
543    unsafe fn hi(input: Self::FullVector) -> Self;
544    unsafe fn split(input: Self::FullVector) -> (Self, Self) {
545        (Self::lo(input), Self::hi(input))
546    }
547    unsafe fn lo_rotation(input: Rotation90<Self::FullVector>) -> Rotation90<Self>;
548
549    /// Fill a vector by repeating the provided complex number as many times as possible
550    unsafe fn broadcast_complex_elements(
551        value: Complex<<<Self as AvxVector128>::FullVector as AvxVector256>::ScalarType>,
552    ) -> Self;
553
554    // Gather 2 complex numbers (for f32) or 1 complex number (for f64) using 2 i32 indexes (for gather32) or 2 i64 indexes (for gather64).
555    // For f32, there should be 1 index per complex. For f64, there should be 2 indexes, each duplicated
556    // (So to load the complex<f64> at index 5, the index vector should contain 5,5. this api sucks but it's internal so whatever.)
557    unsafe fn gather32_complex_avx2(
558        ptr: *const Complex<<Self::FullVector as AvxVector256>::ScalarType>,
559        indexes: __m128i,
560    ) -> Self;
561    unsafe fn gather64_complex_avx2(
562        ptr: *const Complex<<Self::FullVector as AvxVector256>::ScalarType>,
563        indexes: __m128i,
564    ) -> Self;
565
566    #[inline(always)]
567    unsafe fn column_butterfly6(rows: [Self; 6], twiddles: Self::FullVector) -> [Self; 6] {
568        // Algorithm: 3x2 good-thomas
569
570        // if we merge some of our 128 registers into 256 registers, we can do 1 inner butterfly3 instead of 2
571        let rows03 = Self::merge(rows[0], rows[3]);
572        let rows25 = Self::merge(rows[2], rows[5]);
573        let rows41 = Self::merge(rows[4], rows[1]);
574
575        // Size-3 FFTs down the columns of our reordered array
576        let mid = Self::FullVector::column_butterfly3([rows03, rows25, rows41], twiddles);
577
578        // We normally would put twiddle factors right here, but since this is good-thomas algorithm, we don't need twiddle factors
579
580        // we can't use our merged columns anymore. so split them back into half vectors
581        let (mid0_0, mid1_0) = Self::split(mid[0]);
582        let (mid0_1, mid1_1) = Self::split(mid[1]);
583        let (mid0_2, mid1_2) = Self::split(mid[2]);
584
585        // Transpose the data and do size-2 FFTs down the columns
586        let [output0, output1] = Self::column_butterfly2([mid0_0, mid1_0]);
587        let [output2, output3] = Self::column_butterfly2([mid0_1, mid1_1]);
588        let [output4, output5] = Self::column_butterfly2([mid0_2, mid1_2]);
589
590        // Reorder into output
591        [output0, output3, output4, output1, output2, output5]
592    }
593
594    #[inline(always)]
595    unsafe fn column_butterfly9(
596        rows: [Self; 9],
597        twiddles_merged: [Self::FullVector; 2],
598        butterfly3_twiddles: Self::FullVector,
599    ) -> [Self; 9] {
600        // Algorithm: 3x3 mixed radix
601
602        // if we merge some of our 128 registers into 256 registers, we can do 2 inner butterfly3's instead of 3
603        let rows12 = Self::merge(rows[1], rows[2]);
604        let rows45 = Self::merge(rows[4], rows[5]);
605        let rows78 = Self::merge(rows[7], rows[8]);
606
607        let mid0 =
608            Self::column_butterfly3([rows[0], rows[3], rows[6]], Self::lo(butterfly3_twiddles));
609        let mut mid12 =
610            Self::FullVector::column_butterfly3([rows12, rows45, rows78], butterfly3_twiddles);
611
612        // Apply twiddle factors. we're applying them on the merged set of vectors, so we need slightly different twiddle factors
613        mid12[1] = Self::FullVector::mul_complex(twiddles_merged[0], mid12[1]);
614        mid12[2] = Self::FullVector::mul_complex(twiddles_merged[1], mid12[2]);
615
616        // we can't use our merged columns anymore. so split them back into half vectors
617        let (mid1_0, mid2_0) = Self::split(mid12[0]);
618        let (mid1_1, mid2_1) = Self::split(mid12[1]);
619        let (mid1_2, mid2_2) = Self::split(mid12[2]);
620
621        // Re-merge our half vectors into different, transposed full vectors. Thankfully the compiler is smart enough to combine these inserts and extracts into permutes
622        let transposed12 = Self::merge(mid0[1], mid0[2]);
623        let transposed45 = Self::merge(mid1_1, mid1_2);
624        let transposed78 = Self::merge(mid2_1, mid2_2);
625
626        let [output0, output1, output2] =
627            Self::column_butterfly3([mid0[0], mid1_0, mid2_0], Self::lo(butterfly3_twiddles));
628        let [output36, output47, output58] = Self::FullVector::column_butterfly3(
629            [transposed12, transposed45, transposed78],
630            butterfly3_twiddles,
631        );
632
633        // Finally, extract our second set of merged columns
634        let (output3, output6) = Self::split(output36);
635        let (output4, output7) = Self::split(output47);
636        let (output5, output8) = Self::split(output58);
637
638        [
639            output0, output3, output6, output1, output4, output7, output2, output5, output8,
640        ]
641    }
642
643    #[inline(always)]
644    unsafe fn column_butterfly12(
645        rows: [Self; 12],
646        butterfly3_twiddles: Self::FullVector,
647        rotation: Rotation90<Self::FullVector>,
648    ) -> [Self; 12] {
649        // Algorithm: 4x3 good-thomas
650
651        // if we merge some of our 128 registers into 256 registers, we can do 2 inner butterfly4's instead of 3
652        let rows48 = Self::merge(rows[4], rows[8]);
653        let rows711 = Self::merge(rows[7], rows[11]);
654        let rows102 = Self::merge(rows[10], rows[2]);
655        let rows15 = Self::merge(rows[1], rows[5]);
656
657        // Size-4 FFTs down the columns of our reordered array
658        let mid0 = Self::column_butterfly4(
659            [rows[0], rows[3], rows[6], rows[9]],
660            Self::lo_rotation(rotation),
661        );
662        let mid12 =
663            Self::FullVector::column_butterfly4([rows48, rows711, rows102, rows15], rotation);
664
665        // We normally would put twiddle factors right here, but since this is good-thomas algorithm, we don't need twiddle factors
666
667        // we can't use our merged columns anymore. so split them back into half vectors
668        let (mid1_0, mid2_0) = Self::split(mid12[0]);
669        let (mid1_1, mid2_1) = Self::split(mid12[1]);
670        let (mid1_2, mid2_2) = Self::split(mid12[2]);
671        let (mid1_3, mid2_3) = Self::split(mid12[3]);
672
673        // Re-merge our half vectors into different, transposed full vectors. This will let us do 2 inner butterfly 3's instead of 4!
674        // Thankfully the compiler is smart enough to combine these inserts and extracts into permutes
675        let transposed03 = Self::merge(mid0[0], mid0[1]);
676        let transposed14 = Self::merge(mid1_0, mid1_1);
677        let transposed25 = Self::merge(mid2_0, mid2_1);
678
679        let transposed69 = Self::merge(mid0[2], mid0[3]);
680        let transposed710 = Self::merge(mid1_2, mid1_3);
681        let transposed811 = Self::merge(mid2_2, mid2_3);
682
683        // Transpose the data and do size-2 FFTs down the columns
684        let [output03, output14, output25] = Self::FullVector::column_butterfly3(
685            [transposed03, transposed14, transposed25],
686            butterfly3_twiddles,
687        );
688        let [output69, output710, output811] = Self::FullVector::column_butterfly3(
689            [transposed69, transposed710, transposed811],
690            butterfly3_twiddles,
691        );
692
693        // Finally, extract our second set of merged columns
694        let (output0, output3) = Self::split(output03);
695        let (output1, output4) = Self::split(output14);
696        let (output2, output5) = Self::split(output25);
697        let (output6, output9) = Self::split(output69);
698        let (output7, output10) = Self::split(output710);
699        let (output8, output11) = Self::split(output811);
700
701        [
702            output0, output4, output8, output9, output1, output5, output6, output10, output2,
703            output3, output7, output11,
704        ]
705    }
706}
707
708#[inline(always)]
709pub unsafe fn apply_butterfly8_twiddle1<V: AvxVector>(input: V, rotation: Rotation90<V>) -> V {
710    let rotated = input.rotate90(rotation);
711    let combined = V::add(rotated, input);
712    V::mul(V::half_root2(), combined)
713}
714#[inline(always)]
715pub unsafe fn apply_butterfly8_twiddle3<V: AvxVector>(input: V, rotation: Rotation90<V>) -> V {
716    let rotated = input.rotate90(rotation);
717    let combined = V::sub(rotated, input);
718    V::mul(V::half_root2(), combined)
719}
720
721#[repr(transparent)]
722#[derive(Clone, Copy, Debug)]
723pub struct Rotation90<V>(V);
724impl<V: AvxVector256> Rotation90<V> {
725    #[inline(always)]
726    pub unsafe fn lo(self) -> Rotation90<V::HalfVector> {
727        Rotation90(self.0.lo())
728    }
729}
730
731impl AvxVector for __m256 {
732    const SCALAR_PER_VECTOR: usize = 8;
733    const COMPLEX_PER_VECTOR: usize = 4;
734
735    #[inline(always)]
736    unsafe fn zero() -> Self {
737        _mm256_setzero_ps()
738    }
739    #[inline(always)]
740    unsafe fn half_root2() -> Self {
741        // note: we're computing a square root here, but checking the assembly says the compiler is smart enough to turn this into a constant
742        _mm256_broadcast_ss(&0.5f32.sqrt())
743    }
744
745    #[inline(always)]
746    unsafe fn xor(left: Self, right: Self) -> Self {
747        _mm256_xor_ps(left, right)
748    }
749    #[inline(always)]
750    unsafe fn neg(self) -> Self {
751        _mm256_xor_ps(self, _mm256_broadcast_ss(&-0.0))
752    }
753    #[inline(always)]
754    unsafe fn add(left: Self, right: Self) -> Self {
755        _mm256_add_ps(left, right)
756    }
757    #[inline(always)]
758    unsafe fn sub(left: Self, right: Self) -> Self {
759        _mm256_sub_ps(left, right)
760    }
761    #[inline(always)]
762    unsafe fn mul(left: Self, right: Self) -> Self {
763        _mm256_mul_ps(left, right)
764    }
765    #[inline(always)]
766    unsafe fn fmadd(left: Self, right: Self, add: Self) -> Self {
767        _mm256_fmadd_ps(left, right, add)
768    }
769    #[inline(always)]
770    unsafe fn fnmadd(left: Self, right: Self, add: Self) -> Self {
771        _mm256_fnmadd_ps(left, right, add)
772    }
773    #[inline(always)]
774    unsafe fn fmaddsub(left: Self, right: Self, add: Self) -> Self {
775        _mm256_fmaddsub_ps(left, right, add)
776    }
777    #[inline(always)]
778    unsafe fn fmsubadd(left: Self, right: Self, add: Self) -> Self {
779        _mm256_fmsubadd_ps(left, right, add)
780    }
781    #[inline(always)]
782    unsafe fn reverse_complex_elements(self) -> Self {
783        // swap the elements in-lane
784        let permuted = _mm256_permute_ps(self, 0x4E);
785        // swap the lanes
786        _mm256_permute2f128_ps(permuted, permuted, 0x01)
787    }
788    #[inline(always)]
789    unsafe fn unpacklo_complex(rows: [Self; 2]) -> Self {
790        let row0_double = _mm256_castps_pd(rows[0]);
791        let row1_double = _mm256_castps_pd(rows[1]);
792        let unpacked = _mm256_unpacklo_pd(row0_double, row1_double);
793        _mm256_castpd_ps(unpacked)
794    }
795    #[inline(always)]
796    unsafe fn unpackhi_complex(rows: [Self; 2]) -> Self {
797        let row0_double = _mm256_castps_pd(rows[0]);
798        let row1_double = _mm256_castps_pd(rows[1]);
799        let unpacked = _mm256_unpackhi_pd(row0_double, row1_double);
800        _mm256_castpd_ps(unpacked)
801    }
802
803    #[inline(always)]
804    unsafe fn swap_complex_components(self) -> Self {
805        _mm256_permute_ps(self, 0xB1)
806    }
807
808    #[inline(always)]
809    unsafe fn duplicate_complex_components(self) -> (Self, Self) {
810        (_mm256_moveldup_ps(self), _mm256_movehdup_ps(self))
811    }
812
813    #[inline(always)]
814    unsafe fn make_rotation90(direction: FftDirection) -> Rotation90<Self> {
815        let broadcast = match direction {
816            FftDirection::Forward => Complex::new(-0.0, 0.0),
817            FftDirection::Inverse => Complex::new(0.0, -0.0),
818        };
819        Rotation90(Self::broadcast_complex_elements(broadcast))
820    }
821
822    #[inline(always)]
823    unsafe fn make_mixedradix_twiddle_chunk(
824        x: usize,
825        y: usize,
826        len: usize,
827        direction: FftDirection,
828    ) -> Self {
829        let mut twiddle_chunk = [Complex::<f32>::zero(); Self::COMPLEX_PER_VECTOR];
830        for i in 0..Self::COMPLEX_PER_VECTOR {
831            twiddle_chunk[i] = twiddles::compute_twiddle(y * (x + i), len, direction);
832        }
833
834        twiddle_chunk.as_slice().load_complex(0)
835    }
836
837    #[inline(always)]
838    unsafe fn broadcast_twiddle(index: usize, len: usize, direction: FftDirection) -> Self {
839        Self::broadcast_complex_elements(twiddles::compute_twiddle(index, len, direction))
840    }
841
842    #[inline(always)]
843    unsafe fn transpose2_packed(rows: [Self; 2]) -> [Self; 2] {
844        let unpacked = Self::unpack_complex(rows);
845        let output0 = _mm256_permute2f128_ps(unpacked[0], unpacked[1], 0x20);
846        let output1 = _mm256_permute2f128_ps(unpacked[0], unpacked[1], 0x31);
847
848        [output0, output1]
849    }
850    #[inline(always)]
851    unsafe fn transpose3_packed(rows: [Self; 3]) -> [Self; 3] {
852        let unpacked0 = Self::unpacklo_complex([rows[0], rows[1]]);
853        let unpacked2 = Self::unpackhi_complex([rows[1], rows[2]]);
854
855        // output0 and output2 each need to swap some elements. thankfully we can blend those elements into the same intermediate value, and then do a permute 128 from there
856        let blended = _mm256_blend_ps(rows[0], rows[2], 0x33);
857
858        let output1 = _mm256_permute2f128_ps(unpacked0, unpacked2, 0x12);
859
860        let output0 = _mm256_permute2f128_ps(unpacked0, blended, 0x20);
861        let output2 = _mm256_permute2f128_ps(unpacked2, blended, 0x13);
862
863        [output0, output1, output2]
864    }
865    #[inline(always)]
866    unsafe fn transpose4_packed(rows: [Self; 4]) -> [Self; 4] {
867        let permute0 = _mm256_permute2f128_ps(rows[0], rows[2], 0x20);
868        let permute1 = _mm256_permute2f128_ps(rows[1], rows[3], 0x20);
869        let permute2 = _mm256_permute2f128_ps(rows[0], rows[2], 0x31);
870        let permute3 = _mm256_permute2f128_ps(rows[1], rows[3], 0x31);
871
872        let [unpacked0, unpacked1] = Self::unpack_complex([permute0, permute1]);
873        let [unpacked2, unpacked3] = Self::unpack_complex([permute2, permute3]);
874
875        [unpacked0, unpacked1, unpacked2, unpacked3]
876    }
877    #[inline(always)]
878    unsafe fn transpose5_packed(rows: [Self; 5]) -> [Self; 5] {
879        let unpacked0 = Self::unpacklo_complex([rows[0], rows[1]]);
880        let unpacked1 = Self::unpackhi_complex([rows[1], rows[2]]);
881        let unpacked2 = Self::unpacklo_complex([rows[2], rows[3]]);
882        let unpacked3 = Self::unpackhi_complex([rows[3], rows[4]]);
883        let blended04 = _mm256_blend_ps(rows[0], rows[4], 0x33);
884
885        [
886            _mm256_permute2f128_ps(unpacked0, unpacked2, 0x20),
887            _mm256_permute2f128_ps(blended04, unpacked1, 0x20),
888            _mm256_blend_ps(unpacked0, unpacked3, 0x0f),
889            _mm256_permute2f128_ps(unpacked2, blended04, 0x31),
890            _mm256_permute2f128_ps(unpacked1, unpacked3, 0x31),
891        ]
892    }
893    #[inline(always)]
894    unsafe fn transpose6_packed(rows: [Self; 6]) -> [Self; 6] {
895        let [unpacked0, unpacked1] = Self::unpack_complex([rows[0], rows[1]]);
896        let [unpacked2, unpacked3] = Self::unpack_complex([rows[2], rows[3]]);
897        let [unpacked4, unpacked5] = Self::unpack_complex([rows[4], rows[5]]);
898
899        [
900            _mm256_permute2f128_ps(unpacked0, unpacked2, 0x20),
901            _mm256_permute2f128_ps(unpacked1, unpacked4, 0x02),
902            _mm256_permute2f128_ps(unpacked3, unpacked5, 0x20),
903            _mm256_permute2f128_ps(unpacked0, unpacked2, 0x31),
904            _mm256_permute2f128_ps(unpacked1, unpacked4, 0x13),
905            _mm256_permute2f128_ps(unpacked3, unpacked5, 0x31),
906        ]
907    }
908    #[inline(always)]
909    unsafe fn transpose7_packed(rows: [Self; 7]) -> [Self; 7] {
910        let unpacked0 = Self::unpacklo_complex([rows[0], rows[1]]);
911        let unpacked1 = Self::unpackhi_complex([rows[1], rows[2]]);
912        let unpacked2 = Self::unpacklo_complex([rows[2], rows[3]]);
913        let unpacked3 = Self::unpackhi_complex([rows[3], rows[4]]);
914        let unpacked4 = Self::unpacklo_complex([rows[4], rows[5]]);
915        let unpacked5 = Self::unpackhi_complex([rows[5], rows[6]]);
916        let blended06 = _mm256_blend_ps(rows[0], rows[6], 0x33);
917
918        [
919            _mm256_permute2f128_ps(unpacked0, unpacked2, 0x20),
920            _mm256_permute2f128_ps(unpacked4, blended06, 0x20),
921            _mm256_permute2f128_ps(unpacked1, unpacked3, 0x20),
922            _mm256_blend_ps(unpacked0, unpacked5, 0x0f),
923            _mm256_permute2f128_ps(unpacked2, unpacked4, 0x31),
924            _mm256_permute2f128_ps(blended06, unpacked1, 0x31),
925            _mm256_permute2f128_ps(unpacked3, unpacked5, 0x31),
926        ]
927    }
928    #[inline(always)]
929    unsafe fn transpose8_packed(rows: [Self; 8]) -> [Self; 8] {
930        let chunk0 = [rows[0], rows[1], rows[2], rows[3]];
931        let chunk1 = [rows[4], rows[5], rows[6], rows[7]];
932
933        let output0 = Self::transpose4_packed(chunk0);
934        let output1 = Self::transpose4_packed(chunk1);
935
936        [
937            output0[0], output1[0], output0[1], output1[1], output0[2], output1[2], output0[3],
938            output1[3],
939        ]
940    }
941    #[inline(always)]
942    unsafe fn transpose9_packed(rows: [Self; 9]) -> [Self; 9] {
943        let unpacked0 = Self::unpacklo_complex([rows[0], rows[1]]);
944        let unpacked1 = Self::unpackhi_complex([rows[1], rows[2]]);
945        let unpacked2 = Self::unpacklo_complex([rows[2], rows[3]]);
946        let unpacked3 = Self::unpackhi_complex([rows[3], rows[4]]);
947        let unpacked5 = Self::unpacklo_complex([rows[4], rows[5]]);
948        let unpacked6 = Self::unpackhi_complex([rows[5], rows[6]]);
949        let unpacked7 = Self::unpacklo_complex([rows[6], rows[7]]);
950        let unpacked8 = Self::unpackhi_complex([rows[7], rows[8]]);
951        let blended9 = _mm256_blend_ps(rows[0], rows[8], 0x33);
952
953        [
954            _mm256_permute2f128_ps(unpacked0, unpacked2, 0x20),
955            _mm256_permute2f128_ps(unpacked5, unpacked7, 0x20),
956            _mm256_permute2f128_ps(blended9, unpacked1, 0x20),
957            _mm256_permute2f128_ps(unpacked3, unpacked6, 0x20),
958            _mm256_blend_ps(unpacked0, unpacked8, 0x0f),
959            _mm256_permute2f128_ps(unpacked2, unpacked5, 0x31),
960            _mm256_permute2f128_ps(unpacked7, blended9, 0x31),
961            _mm256_permute2f128_ps(unpacked1, unpacked3, 0x31),
962            _mm256_permute2f128_ps(unpacked6, unpacked8, 0x31),
963        ]
964    }
965    #[inline(always)]
966    unsafe fn transpose11_packed(rows: [Self; 11]) -> [Self; 11] {
967        let unpacked0 = Self::unpacklo_complex([rows[0], rows[1]]);
968        let unpacked1 = Self::unpackhi_complex([rows[1], rows[2]]);
969        let unpacked2 = Self::unpacklo_complex([rows[2], rows[3]]);
970        let unpacked3 = Self::unpackhi_complex([rows[3], rows[4]]);
971        let unpacked4 = Self::unpacklo_complex([rows[4], rows[5]]);
972        let unpacked5 = Self::unpackhi_complex([rows[5], rows[6]]);
973        let unpacked6 = Self::unpacklo_complex([rows[6], rows[7]]);
974        let unpacked7 = Self::unpackhi_complex([rows[7], rows[8]]);
975        let unpacked8 = Self::unpacklo_complex([rows[8], rows[9]]);
976        let unpacked9 = Self::unpackhi_complex([rows[9], rows[10]]);
977        let blended10 = _mm256_blend_ps(rows[0], rows[10], 0x33);
978
979        [
980            _mm256_permute2f128_ps(unpacked0, unpacked2, 0x20),
981            _mm256_permute2f128_ps(unpacked4, unpacked6, 0x20),
982            _mm256_permute2f128_ps(unpacked8, blended10, 0x20),
983            _mm256_permute2f128_ps(unpacked1, unpacked3, 0x20),
984            _mm256_permute2f128_ps(unpacked5, unpacked7, 0x20),
985            _mm256_blend_ps(unpacked0, unpacked9, 0x0f),
986            _mm256_permute2f128_ps(unpacked2, unpacked4, 0x31),
987            _mm256_permute2f128_ps(unpacked6, unpacked8, 0x31),
988            _mm256_permute2f128_ps(blended10, unpacked1, 0x31),
989            _mm256_permute2f128_ps(unpacked3, unpacked5, 0x31),
990            _mm256_permute2f128_ps(unpacked7, unpacked9, 0x31),
991        ]
992    }
993    #[inline(always)]
994    unsafe fn transpose12_packed(rows: [Self; 12]) -> [Self; 12] {
995        let chunk0 = [rows[0], rows[1], rows[2], rows[3]];
996        let chunk1 = [rows[4], rows[5], rows[6], rows[7]];
997        let chunk2 = [rows[8], rows[9], rows[10], rows[11]];
998
999        let output0 = Self::transpose4_packed(chunk0);
1000        let output1 = Self::transpose4_packed(chunk1);
1001        let output2 = Self::transpose4_packed(chunk2);
1002
1003        [
1004            output0[0], output1[0], output2[0], output0[1], output1[1], output2[1], output0[2],
1005            output1[2], output2[2], output0[3], output1[3], output2[3],
1006        ]
1007    }
1008    #[inline(always)]
1009    unsafe fn transpose16_packed(rows: [Self; 16]) -> [Self; 16] {
1010        let chunk0 = [
1011            rows[0], rows[1], rows[2], rows[3], rows[4], rows[5], rows[6], rows[7],
1012        ];
1013        let chunk1 = [
1014            rows[8], rows[9], rows[10], rows[11], rows[12], rows[13], rows[14], rows[15],
1015        ];
1016
1017        let output0 = Self::transpose8_packed(chunk0);
1018        let output1 = Self::transpose8_packed(chunk1);
1019
1020        [
1021            output0[0], output0[1], output1[0], output1[1], output0[2], output0[3], output1[2],
1022            output1[3], output0[4], output0[5], output1[4], output1[5], output0[6], output0[7],
1023            output1[6], output1[7],
1024        ]
1025    }
1026}
1027impl AvxVector256 for __m256 {
1028    type ScalarType = f32;
1029    type HalfVector = __m128;
1030
1031    #[inline(always)]
1032    unsafe fn lo(self) -> Self::HalfVector {
1033        _mm256_castps256_ps128(self)
1034    }
1035    #[inline(always)]
1036    unsafe fn hi(self) -> Self::HalfVector {
1037        _mm256_extractf128_ps(self, 1)
1038    }
1039    #[inline(always)]
1040    unsafe fn merge(lo: Self::HalfVector, hi: Self::HalfVector) -> Self {
1041        _mm256_insertf128_ps(_mm256_castps128_ps256(lo), hi, 1)
1042    }
1043
1044    #[inline(always)]
1045    unsafe fn broadcast_complex_elements(value: Complex<Self::ScalarType>) -> Self {
1046        _mm256_set_ps(
1047            value.im, value.re, value.im, value.re, value.im, value.re, value.im, value.re,
1048        )
1049    }
1050
1051    #[inline(always)]
1052    unsafe fn load_complex(ptr: *const Complex<Self::ScalarType>) -> Self {
1053        _mm256_loadu_ps(ptr as *const Self::ScalarType)
1054    }
1055    #[inline(always)]
1056    unsafe fn store_complex(ptr: *mut Complex<Self::ScalarType>, data: Self) {
1057        _mm256_storeu_ps(ptr as *mut Self::ScalarType, data)
1058    }
1059    #[inline(always)]
1060    unsafe fn gather_complex_avx2_index32(
1061        ptr: *const Complex<Self::ScalarType>,
1062        indexes: __m128i,
1063    ) -> Self {
1064        _mm256_castpd_ps(_mm256_i32gather_pd(ptr as *const f64, indexes, 8))
1065    }
1066    #[inline(always)]
1067    unsafe fn gather_complex_avx2_index64(
1068        ptr: *const Complex<Self::ScalarType>,
1069        indexes: __m256i,
1070    ) -> Self {
1071        _mm256_castpd_ps(_mm256_i64gather_pd(ptr as *const f64, indexes, 8))
1072    }
1073
1074    #[inline(always)]
1075    unsafe fn load_partial1_complex(ptr: *const Complex<Self::ScalarType>) -> Self::HalfVector {
1076        let data = _mm_load_sd(ptr as *const f64);
1077        _mm_castpd_ps(data)
1078    }
1079    #[inline(always)]
1080    unsafe fn load_partial2_complex(ptr: *const Complex<Self::ScalarType>) -> Self::HalfVector {
1081        _mm_loadu_ps(ptr as *const f32)
1082    }
1083    #[inline(always)]
1084    unsafe fn load_partial3_complex(ptr: *const Complex<Self::ScalarType>) -> Self {
1085        let lo = Self::load_partial2_complex(ptr);
1086        let hi = Self::load_partial1_complex(ptr.add(2));
1087        Self::merge(lo, hi)
1088    }
1089    #[inline(always)]
1090    unsafe fn store_partial1_complex(ptr: *mut Complex<Self::ScalarType>, data: Self::HalfVector) {
1091        _mm_store_sd(ptr as *mut f64, _mm_castps_pd(data));
1092    }
1093    #[inline(always)]
1094    unsafe fn store_partial2_complex(ptr: *mut Complex<Self::ScalarType>, data: Self::HalfVector) {
1095        _mm_storeu_ps(ptr as *mut f32, data);
1096    }
1097    #[inline(always)]
1098    unsafe fn store_partial3_complex(ptr: *mut Complex<Self::ScalarType>, data: Self) {
1099        Self::store_partial2_complex(ptr, data.lo());
1100        Self::store_partial1_complex(ptr.add(2), data.hi());
1101    }
1102}
1103
1104impl AvxVector for __m128 {
1105    const SCALAR_PER_VECTOR: usize = 4;
1106    const COMPLEX_PER_VECTOR: usize = 2;
1107
1108    #[inline(always)]
1109    unsafe fn zero() -> Self {
1110        _mm_setzero_ps()
1111    }
1112    #[inline(always)]
1113    unsafe fn half_root2() -> Self {
1114        // note: we're computing a square root here, but checking the assembly says the compiler is smart enough to turn this into a constant
1115        _mm_broadcast_ss(&0.5f32.sqrt())
1116    }
1117
1118    #[inline(always)]
1119    unsafe fn xor(left: Self, right: Self) -> Self {
1120        _mm_xor_ps(left, right)
1121    }
1122    #[inline(always)]
1123    unsafe fn neg(self) -> Self {
1124        _mm_xor_ps(self, _mm_broadcast_ss(&-0.0))
1125    }
1126    #[inline(always)]
1127    unsafe fn add(left: Self, right: Self) -> Self {
1128        _mm_add_ps(left, right)
1129    }
1130    #[inline(always)]
1131    unsafe fn sub(left: Self, right: Self) -> Self {
1132        _mm_sub_ps(left, right)
1133    }
1134    #[inline(always)]
1135    unsafe fn mul(left: Self, right: Self) -> Self {
1136        _mm_mul_ps(left, right)
1137    }
1138    #[inline(always)]
1139    unsafe fn fmadd(left: Self, right: Self, add: Self) -> Self {
1140        _mm_fmadd_ps(left, right, add)
1141    }
1142    #[inline(always)]
1143    unsafe fn fnmadd(left: Self, right: Self, add: Self) -> Self {
1144        _mm_fnmadd_ps(left, right, add)
1145    }
1146    #[inline(always)]
1147    unsafe fn fmaddsub(left: Self, right: Self, add: Self) -> Self {
1148        _mm_fmaddsub_ps(left, right, add)
1149    }
1150    #[inline(always)]
1151    unsafe fn fmsubadd(left: Self, right: Self, add: Self) -> Self {
1152        _mm_fmsubadd_ps(left, right, add)
1153    }
1154
1155    #[inline(always)]
1156    unsafe fn reverse_complex_elements(self) -> Self {
1157        // swap the elements in-lane
1158        _mm_permute_ps(self, 0x4E)
1159    }
1160
1161    #[inline(always)]
1162    unsafe fn unpacklo_complex(rows: [Self; 2]) -> Self {
1163        let row0_double = _mm_castps_pd(rows[0]);
1164        let row1_double = _mm_castps_pd(rows[1]);
1165        let unpacked = _mm_unpacklo_pd(row0_double, row1_double);
1166        _mm_castpd_ps(unpacked)
1167    }
1168    #[inline(always)]
1169    unsafe fn unpackhi_complex(rows: [Self; 2]) -> Self {
1170        let row0_double = _mm_castps_pd(rows[0]);
1171        let row1_double = _mm_castps_pd(rows[1]);
1172        let unpacked = _mm_unpackhi_pd(row0_double, row1_double);
1173        _mm_castpd_ps(unpacked)
1174    }
1175
1176    #[inline(always)]
1177    unsafe fn swap_complex_components(self) -> Self {
1178        _mm_permute_ps(self, 0xB1)
1179    }
1180    #[inline(always)]
1181    unsafe fn duplicate_complex_components(self) -> (Self, Self) {
1182        (_mm_moveldup_ps(self), _mm_movehdup_ps(self))
1183    }
1184
1185    #[inline(always)]
1186    unsafe fn make_rotation90(direction: FftDirection) -> Rotation90<Self> {
1187        let broadcast = match direction {
1188            FftDirection::Forward => Complex::new(-0.0, 0.0),
1189            FftDirection::Inverse => Complex::new(0.0, -0.0),
1190        };
1191        Rotation90(Self::broadcast_complex_elements(broadcast))
1192    }
1193    #[inline(always)]
1194    unsafe fn make_mixedradix_twiddle_chunk(
1195        x: usize,
1196        y: usize,
1197        len: usize,
1198        direction: FftDirection,
1199    ) -> Self {
1200        let mut twiddle_chunk = [Complex::<f32>::zero(); Self::COMPLEX_PER_VECTOR];
1201        for i in 0..Self::COMPLEX_PER_VECTOR {
1202            twiddle_chunk[i] = twiddles::compute_twiddle(y * (x + i), len, direction);
1203        }
1204
1205        _mm_loadu_ps(twiddle_chunk.as_ptr() as *const f32)
1206    }
1207    #[inline(always)]
1208    unsafe fn broadcast_twiddle(index: usize, len: usize, direction: FftDirection) -> Self {
1209        Self::broadcast_complex_elements(twiddles::compute_twiddle(index, len, direction))
1210    }
1211
1212    #[inline(always)]
1213    unsafe fn transpose2_packed(rows: [Self; 2]) -> [Self; 2] {
1214        Self::unpack_complex(rows)
1215    }
1216    #[inline(always)]
1217    unsafe fn transpose3_packed(rows: [Self; 3]) -> [Self; 3] {
1218        let unpacked0 = Self::unpacklo_complex([rows[0], rows[1]]);
1219        let blended = _mm_blend_ps(rows[0], rows[2], 0x03);
1220        let unpacked2 = Self::unpackhi_complex([rows[1], rows[2]]);
1221
1222        [unpacked0, blended, unpacked2]
1223    }
1224    #[inline(always)]
1225    unsafe fn transpose4_packed(rows: [Self; 4]) -> [Self; 4] {
1226        let [unpacked0, unpacked1] = Self::unpack_complex([rows[0], rows[1]]);
1227        let [unpacked2, unpacked3] = Self::unpack_complex([rows[2], rows[3]]);
1228
1229        [unpacked0, unpacked2, unpacked1, unpacked3]
1230    }
1231    #[inline(always)]
1232    unsafe fn transpose5_packed(rows: [Self; 5]) -> [Self; 5] {
1233        [
1234            Self::unpacklo_complex([rows[0], rows[1]]),
1235            Self::unpacklo_complex([rows[2], rows[3]]),
1236            _mm_blend_ps(rows[0], rows[4], 0x03),
1237            Self::unpackhi_complex([rows[1], rows[2]]),
1238            Self::unpackhi_complex([rows[3], rows[4]]),
1239        ]
1240    }
1241    #[inline(always)]
1242    unsafe fn transpose6_packed(rows: [Self; 6]) -> [Self; 6] {
1243        let [unpacked0, unpacked1] = Self::unpack_complex([rows[0], rows[1]]);
1244        let [unpacked2, unpacked3] = Self::unpack_complex([rows[2], rows[3]]);
1245        let [unpacked4, unpacked5] = Self::unpack_complex([rows[4], rows[5]]);
1246
1247        [
1248            unpacked0, unpacked2, unpacked4, unpacked1, unpacked3, unpacked5,
1249        ]
1250    }
1251    #[inline(always)]
1252    unsafe fn transpose7_packed(rows: [Self; 7]) -> [Self; 7] {
1253        [
1254            Self::unpacklo_complex([rows[0], rows[1]]),
1255            Self::unpacklo_complex([rows[2], rows[3]]),
1256            Self::unpacklo_complex([rows[4], rows[5]]),
1257            _mm_shuffle_ps(rows[6], rows[0], 0xE4),
1258            Self::unpackhi_complex([rows[1], rows[2]]),
1259            Self::unpackhi_complex([rows[3], rows[4]]),
1260            Self::unpackhi_complex([rows[5], rows[6]]),
1261        ]
1262    }
1263    #[inline(always)]
1264    unsafe fn transpose8_packed(rows: [Self; 8]) -> [Self; 8] {
1265        let chunk0 = [rows[0], rows[1], rows[2], rows[3]];
1266        let chunk1 = [rows[4], rows[5], rows[6], rows[7]];
1267
1268        let output0 = Self::transpose4_packed(chunk0);
1269        let output1 = Self::transpose4_packed(chunk1);
1270
1271        [
1272            output0[0], output0[1], output1[0], output1[1], output0[2], output0[3], output1[2],
1273            output1[3],
1274        ]
1275    }
1276    #[inline(always)]
1277    unsafe fn transpose9_packed(rows: [Self; 9]) -> [Self; 9] {
1278        [
1279            Self::unpacklo_complex([rows[0], rows[1]]),
1280            Self::unpacklo_complex([rows[2], rows[3]]),
1281            Self::unpacklo_complex([rows[4], rows[5]]),
1282            Self::unpacklo_complex([rows[6], rows[7]]),
1283            _mm_shuffle_ps(rows[8], rows[0], 0xE4),
1284            Self::unpackhi_complex([rows[1], rows[2]]),
1285            Self::unpackhi_complex([rows[3], rows[4]]),
1286            Self::unpackhi_complex([rows[5], rows[6]]),
1287            Self::unpackhi_complex([rows[7], rows[8]]),
1288        ]
1289    }
1290    #[inline(always)]
1291    unsafe fn transpose11_packed(rows: [Self; 11]) -> [Self; 11] {
1292        [
1293            Self::unpacklo_complex([rows[0], rows[1]]),
1294            Self::unpacklo_complex([rows[2], rows[3]]),
1295            Self::unpacklo_complex([rows[4], rows[5]]),
1296            Self::unpacklo_complex([rows[6], rows[7]]),
1297            Self::unpacklo_complex([rows[8], rows[9]]),
1298            _mm_shuffle_ps(rows[10], rows[0], 0xE4),
1299            Self::unpackhi_complex([rows[1], rows[2]]),
1300            Self::unpackhi_complex([rows[3], rows[4]]),
1301            Self::unpackhi_complex([rows[5], rows[6]]),
1302            Self::unpackhi_complex([rows[7], rows[8]]),
1303            Self::unpackhi_complex([rows[9], rows[10]]),
1304        ]
1305    }
1306    #[inline(always)]
1307    unsafe fn transpose12_packed(rows: [Self; 12]) -> [Self; 12] {
1308        let chunk0 = [rows[0], rows[1], rows[2], rows[3]];
1309        let chunk1 = [rows[4], rows[5], rows[6], rows[7]];
1310        let chunk2 = [rows[8], rows[9], rows[10], rows[11]];
1311
1312        let output0 = Self::transpose4_packed(chunk0);
1313        let output1 = Self::transpose4_packed(chunk1);
1314        let output2 = Self::transpose4_packed(chunk2);
1315
1316        [
1317            output0[0], output0[1], output1[0], output1[1], output2[0], output2[1], output0[2],
1318            output0[3], output1[2], output1[3], output2[2], output2[3],
1319        ]
1320    }
1321    #[inline(always)]
1322    unsafe fn transpose16_packed(rows: [Self; 16]) -> [Self; 16] {
1323        let chunk0 = [
1324            rows[0], rows[1], rows[2], rows[3], rows[4], rows[5], rows[6], rows[7],
1325        ];
1326        let chunk1 = [
1327            rows[8], rows[9], rows[10], rows[11], rows[12], rows[13], rows[14], rows[15],
1328        ];
1329
1330        let output0 = Self::transpose8_packed(chunk0);
1331        let output1 = Self::transpose8_packed(chunk1);
1332
1333        [
1334            output0[0], output0[1], output0[2], output0[3], output1[0], output1[1], output1[2],
1335            output1[3], output0[4], output0[5], output0[6], output0[7], output1[4], output1[5],
1336            output1[6], output1[7],
1337        ]
1338    }
1339}
1340impl AvxVector128 for __m128 {
1341    type FullVector = __m256;
1342
1343    #[inline(always)]
1344    unsafe fn lo(input: Self::FullVector) -> Self {
1345        _mm256_castps256_ps128(input)
1346    }
1347    #[inline(always)]
1348    unsafe fn hi(input: Self::FullVector) -> Self {
1349        _mm256_extractf128_ps(input, 1)
1350    }
1351    #[inline(always)]
1352    unsafe fn merge(lo: Self, hi: Self) -> Self::FullVector {
1353        _mm256_insertf128_ps(_mm256_castps128_ps256(lo), hi, 1)
1354    }
1355    #[inline(always)]
1356    unsafe fn zero_extend(self) -> Self::FullVector {
1357        _mm256_zextps128_ps256(self)
1358    }
1359    #[inline(always)]
1360    unsafe fn lo_rotation(input: Rotation90<Self::FullVector>) -> Rotation90<Self> {
1361        input.lo()
1362    }
1363    #[inline(always)]
1364    unsafe fn broadcast_complex_elements(value: Complex<f32>) -> Self {
1365        _mm_set_ps(value.im, value.re, value.im, value.re)
1366    }
1367    #[inline(always)]
1368    unsafe fn gather32_complex_avx2(ptr: *const Complex<f32>, indexes: __m128i) -> Self {
1369        _mm_castpd_ps(_mm_i32gather_pd(ptr as *const f64, indexes, 8))
1370    }
1371    #[inline(always)]
1372    unsafe fn gather64_complex_avx2(ptr: *const Complex<f32>, indexes: __m128i) -> Self {
1373        _mm_castpd_ps(_mm_i64gather_pd(ptr as *const f64, indexes, 8))
1374    }
1375}
1376
1377impl AvxVector for __m256d {
1378    const SCALAR_PER_VECTOR: usize = 4;
1379    const COMPLEX_PER_VECTOR: usize = 2;
1380
1381    #[inline(always)]
1382    unsafe fn zero() -> Self {
1383        _mm256_setzero_pd()
1384    }
1385    #[inline(always)]
1386    unsafe fn half_root2() -> Self {
1387        // note: we're computing a square root here, but checking the assembly says the compiler is smart enough to turn this into a constant
1388        _mm256_broadcast_sd(&0.5f64.sqrt())
1389    }
1390
1391    #[inline(always)]
1392    unsafe fn xor(left: Self, right: Self) -> Self {
1393        _mm256_xor_pd(left, right)
1394    }
1395    #[inline(always)]
1396    unsafe fn neg(self) -> Self {
1397        _mm256_xor_pd(self, _mm256_broadcast_sd(&-0.0))
1398    }
1399    #[inline(always)]
1400    unsafe fn add(left: Self, right: Self) -> Self {
1401        _mm256_add_pd(left, right)
1402    }
1403    #[inline(always)]
1404    unsafe fn sub(left: Self, right: Self) -> Self {
1405        _mm256_sub_pd(left, right)
1406    }
1407    #[inline(always)]
1408    unsafe fn mul(left: Self, right: Self) -> Self {
1409        _mm256_mul_pd(left, right)
1410    }
1411    #[inline(always)]
1412    unsafe fn fmadd(left: Self, right: Self, add: Self) -> Self {
1413        _mm256_fmadd_pd(left, right, add)
1414    }
1415    #[inline(always)]
1416    unsafe fn fnmadd(left: Self, right: Self, add: Self) -> Self {
1417        _mm256_fnmadd_pd(left, right, add)
1418    }
1419    #[inline(always)]
1420    unsafe fn fmaddsub(left: Self, right: Self, add: Self) -> Self {
1421        _mm256_fmaddsub_pd(left, right, add)
1422    }
1423    #[inline(always)]
1424    unsafe fn fmsubadd(left: Self, right: Self, add: Self) -> Self {
1425        _mm256_fmsubadd_pd(left, right, add)
1426    }
1427
1428    #[inline(always)]
1429    unsafe fn reverse_complex_elements(self) -> Self {
1430        _mm256_permute2f128_pd(self, self, 0x01)
1431    }
1432    #[inline(always)]
1433    unsafe fn unpacklo_complex(rows: [Self; 2]) -> Self {
1434        _mm256_permute2f128_pd(rows[0], rows[1], 0x20)
1435    }
1436    #[inline(always)]
1437    unsafe fn unpackhi_complex(rows: [Self; 2]) -> Self {
1438        _mm256_permute2f128_pd(rows[0], rows[1], 0x31)
1439    }
1440
1441    #[inline(always)]
1442    unsafe fn swap_complex_components(self) -> Self {
1443        _mm256_permute_pd(self, 0x05)
1444    }
1445    #[inline(always)]
1446    unsafe fn duplicate_complex_components(self) -> (Self, Self) {
1447        (_mm256_movedup_pd(self), _mm256_permute_pd(self, 0x0F))
1448    }
1449
1450    #[inline(always)]
1451    unsafe fn make_rotation90(direction: FftDirection) -> Rotation90<Self> {
1452        let broadcast = match direction {
1453            FftDirection::Forward => Complex::new(-0.0, 0.0),
1454            FftDirection::Inverse => Complex::new(0.0, -0.0),
1455        };
1456        Rotation90(Self::broadcast_complex_elements(broadcast))
1457    }
1458    #[inline(always)]
1459    unsafe fn make_mixedradix_twiddle_chunk(
1460        x: usize,
1461        y: usize,
1462        len: usize,
1463        direction: FftDirection,
1464    ) -> Self {
1465        let mut twiddle_chunk = [Complex::<f64>::zero(); Self::COMPLEX_PER_VECTOR];
1466        for i in 0..Self::COMPLEX_PER_VECTOR {
1467            twiddle_chunk[i] = twiddles::compute_twiddle(y * (x + i), len, direction);
1468        }
1469
1470        twiddle_chunk.as_slice().load_complex(0)
1471    }
1472    #[inline(always)]
1473    unsafe fn broadcast_twiddle(index: usize, len: usize, direction: FftDirection) -> Self {
1474        Self::broadcast_complex_elements(twiddles::compute_twiddle(index, len, direction))
1475    }
1476
1477    #[inline(always)]
1478    unsafe fn transpose2_packed(rows: [Self; 2]) -> [Self; 2] {
1479        Self::unpack_complex(rows)
1480    }
1481    #[inline(always)]
1482    unsafe fn transpose3_packed(rows: [Self; 3]) -> [Self; 3] {
1483        let unpacked0 = Self::unpacklo_complex([rows[0], rows[1]]);
1484        let blended = _mm256_blend_pd(rows[0], rows[2], 0x03);
1485        let unpacked2 = Self::unpackhi_complex([rows[1], rows[2]]);
1486
1487        [unpacked0, blended, unpacked2]
1488    }
1489    #[inline(always)]
1490    unsafe fn transpose4_packed(rows: [Self; 4]) -> [Self; 4] {
1491        let [unpacked0, unpacked1] = Self::unpack_complex([rows[0], rows[1]]);
1492        let [unpacked2, unpacked3] = Self::unpack_complex([rows[2], rows[3]]);
1493
1494        [unpacked0, unpacked2, unpacked1, unpacked3]
1495    }
1496    #[inline(always)]
1497    unsafe fn transpose5_packed(rows: [Self; 5]) -> [Self; 5] {
1498        [
1499            Self::unpacklo_complex([rows[0], rows[1]]),
1500            Self::unpacklo_complex([rows[2], rows[3]]),
1501            _mm256_blend_pd(rows[0], rows[4], 0x03),
1502            Self::unpackhi_complex([rows[1], rows[2]]),
1503            Self::unpackhi_complex([rows[3], rows[4]]),
1504        ]
1505    }
1506    #[inline(always)]
1507    unsafe fn transpose6_packed(rows: [Self; 6]) -> [Self; 6] {
1508        let [unpacked0, unpacked1] = Self::unpack_complex([rows[0], rows[1]]);
1509        let [unpacked2, unpacked3] = Self::unpack_complex([rows[2], rows[3]]);
1510        let [unpacked4, unpacked5] = Self::unpack_complex([rows[4], rows[5]]);
1511
1512        [
1513            unpacked0, unpacked2, unpacked4, unpacked1, unpacked3, unpacked5,
1514        ]
1515    }
1516    #[inline(always)]
1517    unsafe fn transpose7_packed(rows: [Self; 7]) -> [Self; 7] {
1518        [
1519            Self::unpacklo_complex([rows[0], rows[1]]),
1520            Self::unpacklo_complex([rows[2], rows[3]]),
1521            Self::unpacklo_complex([rows[4], rows[5]]),
1522            _mm256_blend_pd(rows[0], rows[6], 0x03),
1523            Self::unpackhi_complex([rows[1], rows[2]]),
1524            Self::unpackhi_complex([rows[3], rows[4]]),
1525            Self::unpackhi_complex([rows[5], rows[6]]),
1526        ]
1527    }
1528    #[inline(always)]
1529    unsafe fn transpose8_packed(rows: [Self; 8]) -> [Self; 8] {
1530        let chunk0 = [rows[0], rows[1], rows[2], rows[3]];
1531        let chunk1 = [rows[4], rows[5], rows[6], rows[7]];
1532
1533        let output0 = Self::transpose4_packed(chunk0);
1534        let output1 = Self::transpose4_packed(chunk1);
1535
1536        [
1537            output0[0], output0[1], output1[0], output1[1], output0[2], output0[3], output1[2],
1538            output1[3],
1539        ]
1540    }
1541    #[inline(always)]
1542    unsafe fn transpose9_packed(rows: [Self; 9]) -> [Self; 9] {
1543        [
1544            _mm256_permute2f128_pd(rows[0], rows[1], 0x20),
1545            _mm256_permute2f128_pd(rows[2], rows[3], 0x20),
1546            _mm256_permute2f128_pd(rows[4], rows[5], 0x20),
1547            _mm256_permute2f128_pd(rows[6], rows[7], 0x20),
1548            _mm256_permute2f128_pd(rows[8], rows[0], 0x30),
1549            _mm256_permute2f128_pd(rows[1], rows[2], 0x31),
1550            _mm256_permute2f128_pd(rows[3], rows[4], 0x31),
1551            _mm256_permute2f128_pd(rows[5], rows[6], 0x31),
1552            _mm256_permute2f128_pd(rows[7], rows[8], 0x31),
1553        ]
1554    }
1555    #[inline(always)]
1556    unsafe fn transpose11_packed(rows: [Self; 11]) -> [Self; 11] {
1557        [
1558            _mm256_permute2f128_pd(rows[0], rows[1], 0x20),
1559            _mm256_permute2f128_pd(rows[2], rows[3], 0x20),
1560            _mm256_permute2f128_pd(rows[4], rows[5], 0x20),
1561            _mm256_permute2f128_pd(rows[6], rows[7], 0x20),
1562            _mm256_permute2f128_pd(rows[8], rows[9], 0x20),
1563            _mm256_permute2f128_pd(rows[10], rows[0], 0x30),
1564            _mm256_permute2f128_pd(rows[1], rows[2], 0x31),
1565            _mm256_permute2f128_pd(rows[3], rows[4], 0x31),
1566            _mm256_permute2f128_pd(rows[5], rows[6], 0x31),
1567            _mm256_permute2f128_pd(rows[7], rows[8], 0x31),
1568            _mm256_permute2f128_pd(rows[9], rows[10], 0x31),
1569        ]
1570    }
1571    #[inline(always)]
1572    unsafe fn transpose12_packed(rows: [Self; 12]) -> [Self; 12] {
1573        let chunk0 = [rows[0], rows[1], rows[2], rows[3]];
1574        let chunk1 = [rows[4], rows[5], rows[6], rows[7]];
1575        let chunk2 = [rows[8], rows[9], rows[10], rows[11]];
1576
1577        let output0 = Self::transpose4_packed(chunk0);
1578        let output1 = Self::transpose4_packed(chunk1);
1579        let output2 = Self::transpose4_packed(chunk2);
1580
1581        [
1582            output0[0], output0[1], output1[0], output1[1], output2[0], output2[1], output0[2],
1583            output0[3], output1[2], output1[3], output2[2], output2[3],
1584        ]
1585    }
1586    #[inline(always)]
1587    unsafe fn transpose16_packed(rows: [Self; 16]) -> [Self; 16] {
1588        let chunk0 = [
1589            rows[0], rows[1], rows[2], rows[3], rows[4], rows[5], rows[6], rows[7],
1590        ];
1591        let chunk1 = [
1592            rows[8], rows[9], rows[10], rows[11], rows[12], rows[13], rows[14], rows[15],
1593        ];
1594
1595        let output0 = Self::transpose8_packed(chunk0);
1596        let output1 = Self::transpose8_packed(chunk1);
1597
1598        [
1599            output0[0], output0[1], output0[2], output0[3], output1[0], output1[1], output1[2],
1600            output1[3], output0[4], output0[5], output0[6], output0[7], output1[4], output1[5],
1601            output1[6], output1[7],
1602        ]
1603    }
1604}
1605impl AvxVector256 for __m256d {
1606    type ScalarType = f64;
1607    type HalfVector = __m128d;
1608
1609    #[inline(always)]
1610    unsafe fn lo(self) -> Self::HalfVector {
1611        _mm256_castpd256_pd128(self)
1612    }
1613    #[inline(always)]
1614    unsafe fn hi(self) -> Self::HalfVector {
1615        _mm256_extractf128_pd(self, 1)
1616    }
1617    #[inline(always)]
1618    unsafe fn merge(lo: Self::HalfVector, hi: Self::HalfVector) -> Self {
1619        _mm256_insertf128_pd(_mm256_castpd128_pd256(lo), hi, 1)
1620    }
1621
1622    #[inline(always)]
1623    unsafe fn broadcast_complex_elements(value: Complex<Self::ScalarType>) -> Self {
1624        _mm256_set_pd(value.im, value.re, value.im, value.re)
1625    }
1626
1627    #[inline(always)]
1628    unsafe fn load_complex(ptr: *const Complex<Self::ScalarType>) -> Self {
1629        _mm256_loadu_pd(ptr as *const Self::ScalarType)
1630    }
1631    #[inline(always)]
1632    unsafe fn store_complex(ptr: *mut Complex<Self::ScalarType>, data: Self) {
1633        _mm256_storeu_pd(ptr as *mut Self::ScalarType, data)
1634    }
1635    #[inline(always)]
1636    unsafe fn gather_complex_avx2_index32(
1637        ptr: *const Complex<Self::ScalarType>,
1638        indexes: __m128i,
1639    ) -> Self {
1640        let offsets = _mm_set_epi32(1, 0, 1, 0);
1641        let shifted = _mm_slli_epi32(indexes, 1);
1642        let modified_indexes = _mm_add_epi32(offsets, shifted);
1643
1644        _mm256_i32gather_pd(ptr as *const f64, modified_indexes, 8)
1645    }
1646    #[inline(always)]
1647    unsafe fn gather_complex_avx2_index64(
1648        ptr: *const Complex<Self::ScalarType>,
1649        indexes: __m256i,
1650    ) -> Self {
1651        let offsets = _mm256_set_epi64x(1, 0, 1, 0);
1652        let shifted = _mm256_slli_epi64(indexes, 1);
1653        let modified_indexes = _mm256_add_epi64(offsets, shifted);
1654
1655        _mm256_i64gather_pd(ptr as *const f64, modified_indexes, 8)
1656    }
1657    #[inline(always)]
1658    unsafe fn load_partial1_complex(ptr: *const Complex<Self::ScalarType>) -> Self::HalfVector {
1659        _mm_loadu_pd(ptr as *const f64)
1660    }
1661    #[inline(always)]
1662    unsafe fn load_partial2_complex(_ptr: *const Complex<Self::ScalarType>) -> Self::HalfVector {
1663        unimplemented!("Impossible to do a partial load of 2 complex f64's")
1664    }
1665    #[inline(always)]
1666    unsafe fn load_partial3_complex(_ptr: *const Complex<Self::ScalarType>) -> Self {
1667        unimplemented!("Impossible to do a partial load of 3 complex f64's")
1668    }
1669    #[inline(always)]
1670    unsafe fn store_partial1_complex(ptr: *mut Complex<Self::ScalarType>, data: Self::HalfVector) {
1671        _mm_storeu_pd(ptr as *mut f64, data);
1672    }
1673    #[inline(always)]
1674    unsafe fn store_partial2_complex(
1675        _ptr: *mut Complex<Self::ScalarType>,
1676        _data: Self::HalfVector,
1677    ) {
1678        unimplemented!("Impossible to do a partial store of 2 complex f64's")
1679    }
1680    #[inline(always)]
1681    unsafe fn store_partial3_complex(_ptr: *mut Complex<Self::ScalarType>, _data: Self) {
1682        unimplemented!("Impossible to do a partial store of 3 complex f64's")
1683    }
1684}
1685
1686impl AvxVector for __m128d {
1687    const SCALAR_PER_VECTOR: usize = 2;
1688    const COMPLEX_PER_VECTOR: usize = 1;
1689
1690    #[inline(always)]
1691    unsafe fn zero() -> Self {
1692        _mm_setzero_pd()
1693    }
1694    #[inline(always)]
1695    unsafe fn half_root2() -> Self {
1696        // note: we're computing a square root here, but checking the assembly says the compiler is smart enough to turn this into a constant
1697        _mm_load1_pd(&0.5f64.sqrt())
1698    }
1699
1700    #[inline(always)]
1701    unsafe fn xor(left: Self, right: Self) -> Self {
1702        _mm_xor_pd(left, right)
1703    }
1704    #[inline(always)]
1705    unsafe fn neg(self) -> Self {
1706        _mm_xor_pd(self, _mm_load1_pd(&-0.0))
1707    }
1708    #[inline(always)]
1709    unsafe fn add(left: Self, right: Self) -> Self {
1710        _mm_add_pd(left, right)
1711    }
1712    #[inline(always)]
1713    unsafe fn sub(left: Self, right: Self) -> Self {
1714        _mm_sub_pd(left, right)
1715    }
1716    #[inline(always)]
1717    unsafe fn mul(left: Self, right: Self) -> Self {
1718        _mm_mul_pd(left, right)
1719    }
1720    #[inline(always)]
1721    unsafe fn fmadd(left: Self, right: Self, add: Self) -> Self {
1722        _mm_fmadd_pd(left, right, add)
1723    }
1724    #[inline(always)]
1725    unsafe fn fnmadd(left: Self, right: Self, add: Self) -> Self {
1726        _mm_fnmadd_pd(left, right, add)
1727    }
1728    #[inline(always)]
1729    unsafe fn fmaddsub(left: Self, right: Self, add: Self) -> Self {
1730        _mm_fmaddsub_pd(left, right, add)
1731    }
1732    #[inline(always)]
1733    unsafe fn fmsubadd(left: Self, right: Self, add: Self) -> Self {
1734        _mm_fmsubadd_pd(left, right, add)
1735    }
1736
1737    #[inline(always)]
1738    unsafe fn reverse_complex_elements(self) -> Self {
1739        // nothing to reverse
1740        self
1741    }
1742    #[inline(always)]
1743    unsafe fn unpacklo_complex(_rows: [Self; 2]) -> Self {
1744        unimplemented!(); // this operation doesn't make sense with one element. TODO: I don't know if it would be more useful to error here or to just return the inputs unchanged. If returning the inputs is useful, do that.
1745    }
1746    #[inline(always)]
1747    unsafe fn unpackhi_complex(_rows: [Self; 2]) -> Self {
1748        unimplemented!(); // this operation doesn't make sense with one element. TODO: I don't know if it would be more useful to error here or to just return the inputs unchanged. If returning the inputs is useful, do that.
1749    }
1750
1751    #[inline(always)]
1752    unsafe fn swap_complex_components(self) -> Self {
1753        _mm_permute_pd(self, 0x01)
1754    }
1755    #[inline(always)]
1756    unsafe fn duplicate_complex_components(self) -> (Self, Self) {
1757        (_mm_movedup_pd(self), _mm_permute_pd(self, 0x03))
1758    }
1759
1760    #[inline(always)]
1761    unsafe fn make_rotation90(direction: FftDirection) -> Rotation90<Self> {
1762        let broadcast = match direction {
1763            FftDirection::Forward => Complex::new(-0.0, 0.0),
1764            FftDirection::Inverse => Complex::new(0.0, -0.0),
1765        };
1766        Rotation90(Self::broadcast_complex_elements(broadcast))
1767    }
1768    #[inline(always)]
1769    unsafe fn make_mixedradix_twiddle_chunk(
1770        x: usize,
1771        y: usize,
1772        len: usize,
1773        direction: FftDirection,
1774    ) -> Self {
1775        let mut twiddle_chunk = [Complex::<f64>::zero(); Self::COMPLEX_PER_VECTOR];
1776        for i in 0..Self::COMPLEX_PER_VECTOR {
1777            twiddle_chunk[i] = twiddles::compute_twiddle(y * (x + i), len, direction);
1778        }
1779
1780        _mm_loadu_pd(twiddle_chunk.as_ptr() as *const f64)
1781    }
1782    #[inline(always)]
1783    unsafe fn broadcast_twiddle(index: usize, len: usize, direction: FftDirection) -> Self {
1784        Self::broadcast_complex_elements(twiddles::compute_twiddle(index, len, direction))
1785    }
1786
1787    #[inline(always)]
1788    unsafe fn transpose2_packed(rows: [Self; 2]) -> [Self; 2] {
1789        rows
1790    }
1791    #[inline(always)]
1792    unsafe fn transpose3_packed(rows: [Self; 3]) -> [Self; 3] {
1793        rows
1794    }
1795    #[inline(always)]
1796    unsafe fn transpose4_packed(rows: [Self; 4]) -> [Self; 4] {
1797        rows
1798    }
1799    #[inline(always)]
1800    unsafe fn transpose5_packed(rows: [Self; 5]) -> [Self; 5] {
1801        rows
1802    }
1803    #[inline(always)]
1804    unsafe fn transpose6_packed(rows: [Self; 6]) -> [Self; 6] {
1805        rows
1806    }
1807    #[inline(always)]
1808    unsafe fn transpose7_packed(rows: [Self; 7]) -> [Self; 7] {
1809        rows
1810    }
1811    #[inline(always)]
1812    unsafe fn transpose8_packed(rows: [Self; 8]) -> [Self; 8] {
1813        rows
1814    }
1815    #[inline(always)]
1816    unsafe fn transpose9_packed(rows: [Self; 9]) -> [Self; 9] {
1817        rows
1818    }
1819    #[inline(always)]
1820    unsafe fn transpose11_packed(rows: [Self; 11]) -> [Self; 11] {
1821        rows
1822    }
1823    #[inline(always)]
1824    unsafe fn transpose12_packed(rows: [Self; 12]) -> [Self; 12] {
1825        rows
1826    }
1827    #[inline(always)]
1828    unsafe fn transpose16_packed(rows: [Self; 16]) -> [Self; 16] {
1829        rows
1830    }
1831}
1832impl AvxVector128 for __m128d {
1833    type FullVector = __m256d;
1834
1835    #[inline(always)]
1836    unsafe fn lo(input: Self::FullVector) -> Self {
1837        _mm256_castpd256_pd128(input)
1838    }
1839    #[inline(always)]
1840    unsafe fn hi(input: Self::FullVector) -> Self {
1841        _mm256_extractf128_pd(input, 1)
1842    }
1843    #[inline(always)]
1844    unsafe fn merge(lo: Self, hi: Self) -> Self::FullVector {
1845        _mm256_insertf128_pd(_mm256_castpd128_pd256(lo), hi, 1)
1846    }
1847    #[inline(always)]
1848    unsafe fn zero_extend(self) -> Self::FullVector {
1849        _mm256_zextpd128_pd256(self)
1850    }
1851    #[inline(always)]
1852    unsafe fn lo_rotation(input: Rotation90<Self::FullVector>) -> Rotation90<Self> {
1853        input.lo()
1854    }
1855    #[inline(always)]
1856    unsafe fn broadcast_complex_elements(value: Complex<f64>) -> Self {
1857        _mm_set_pd(value.im, value.re)
1858    }
1859    #[inline(always)]
1860    unsafe fn gather32_complex_avx2(ptr: *const Complex<f64>, indexes: __m128i) -> Self {
1861        let mut index_storage: [i32; 4] = [0; 4];
1862        _mm_storeu_si128(index_storage.as_mut_ptr() as *mut __m128i, indexes);
1863
1864        _mm_loadu_pd(ptr.offset(index_storage[0] as isize) as *const f64)
1865    }
1866    #[inline(always)]
1867    unsafe fn gather64_complex_avx2(ptr: *const Complex<f64>, indexes: __m128i) -> Self {
1868        let mut index_storage: [i64; 4] = [0; 4];
1869        _mm_storeu_si128(index_storage.as_mut_ptr() as *mut __m128i, indexes);
1870
1871        _mm_loadu_pd(ptr.offset(index_storage[0] as isize) as *const f64)
1872    }
1873}
1874
1875pub trait AvxArray<T: AvxNum>: Deref {
1876    unsafe fn load_complex(&self, index: usize) -> T::VectorType;
1877    unsafe fn load_partial1_complex(
1878        &self,
1879        index: usize,
1880    ) -> <T::VectorType as AvxVector256>::HalfVector;
1881    unsafe fn load_partial2_complex(
1882        &self,
1883        index: usize,
1884    ) -> <T::VectorType as AvxVector256>::HalfVector;
1885    unsafe fn load_partial3_complex(&self, index: usize) -> T::VectorType;
1886
1887    // some avx operations need bespoke one-off things that don't fit into the methods above, so we should provide an escape hatch for them
1888    fn input_ptr(&self) -> *const Complex<T>;
1889}
1890pub trait AvxArrayMut<T: AvxNum>: AvxArray<T> + DerefMut {
1891    unsafe fn store_complex(&mut self, data: T::VectorType, index: usize);
1892    unsafe fn store_partial1_complex(
1893        &mut self,
1894        data: <T::VectorType as AvxVector256>::HalfVector,
1895        index: usize,
1896    );
1897    unsafe fn store_partial2_complex(
1898        &mut self,
1899        data: <T::VectorType as AvxVector256>::HalfVector,
1900        index: usize,
1901    );
1902    unsafe fn store_partial3_complex(&mut self, data: T::VectorType, index: usize);
1903
1904    // some avx operations need bespoke one-off things that don't fit into the methods above, so we should provide an escape hatch for them
1905    fn output_ptr(&mut self) -> *mut Complex<T>;
1906}
1907
1908impl<T: AvxNum> AvxArray<T> for &[Complex<T>] {
1909    #[inline(always)]
1910    unsafe fn load_complex(&self, index: usize) -> T::VectorType {
1911        debug_assert!(self.len() >= index + T::VectorType::COMPLEX_PER_VECTOR);
1912        T::VectorType::load_complex(self.as_ptr().add(index))
1913    }
1914    #[inline(always)]
1915    unsafe fn load_partial1_complex(
1916        &self,
1917        index: usize,
1918    ) -> <T::VectorType as AvxVector256>::HalfVector {
1919        debug_assert!(self.len() >= index + 1);
1920        T::VectorType::load_partial1_complex(self.as_ptr().add(index))
1921    }
1922    #[inline(always)]
1923    unsafe fn load_partial2_complex(
1924        &self,
1925        index: usize,
1926    ) -> <T::VectorType as AvxVector256>::HalfVector {
1927        debug_assert!(self.len() >= index + 2);
1928        T::VectorType::load_partial2_complex(self.as_ptr().add(index))
1929    }
1930    #[inline(always)]
1931    unsafe fn load_partial3_complex(&self, index: usize) -> T::VectorType {
1932        debug_assert!(self.len() >= index + 3);
1933        T::VectorType::load_partial3_complex(self.as_ptr().add(index))
1934    }
1935    #[inline(always)]
1936    fn input_ptr(&self) -> *const Complex<T> {
1937        self.as_ptr()
1938    }
1939}
1940impl<T: AvxNum> AvxArray<T> for &mut [Complex<T>] {
1941    #[inline(always)]
1942    unsafe fn load_complex(&self, index: usize) -> T::VectorType {
1943        debug_assert!(self.len() >= index + T::VectorType::COMPLEX_PER_VECTOR);
1944        T::VectorType::load_complex(self.as_ptr().add(index))
1945    }
1946    #[inline(always)]
1947    unsafe fn load_partial1_complex(
1948        &self,
1949        index: usize,
1950    ) -> <T::VectorType as AvxVector256>::HalfVector {
1951        debug_assert!(self.len() >= index + 1);
1952        T::VectorType::load_partial1_complex(self.as_ptr().add(index))
1953    }
1954    #[inline(always)]
1955    unsafe fn load_partial2_complex(
1956        &self,
1957        index: usize,
1958    ) -> <T::VectorType as AvxVector256>::HalfVector {
1959        debug_assert!(self.len() >= index + 2);
1960        T::VectorType::load_partial2_complex(self.as_ptr().add(index))
1961    }
1962    #[inline(always)]
1963    unsafe fn load_partial3_complex(&self, index: usize) -> T::VectorType {
1964        debug_assert!(self.len() >= index + 3);
1965        T::VectorType::load_partial3_complex(self.as_ptr().add(index))
1966    }
1967    #[inline(always)]
1968    fn input_ptr(&self) -> *const Complex<T> {
1969        self.as_ptr()
1970    }
1971}
1972impl<'a, T: AvxNum> AvxArray<T> for DoubleBuf<'a, T>
1973where
1974    &'a [Complex<T>]: AvxArray<T>,
1975{
1976    #[inline(always)]
1977    unsafe fn load_complex(&self, index: usize) -> T::VectorType {
1978        self.input.load_complex(index)
1979    }
1980    #[inline(always)]
1981    unsafe fn load_partial1_complex(
1982        &self,
1983        index: usize,
1984    ) -> <T::VectorType as AvxVector256>::HalfVector {
1985        self.input.load_partial1_complex(index)
1986    }
1987    #[inline(always)]
1988    unsafe fn load_partial2_complex(
1989        &self,
1990        index: usize,
1991    ) -> <T::VectorType as AvxVector256>::HalfVector {
1992        self.input.load_partial2_complex(index)
1993    }
1994    #[inline(always)]
1995    unsafe fn load_partial3_complex(&self, index: usize) -> T::VectorType {
1996        self.input.load_partial3_complex(index)
1997    }
1998    #[inline(always)]
1999    fn input_ptr(&self) -> *const Complex<T> {
2000        self.input.input_ptr()
2001    }
2002}
2003
2004impl<T: AvxNum> AvxArrayMut<T> for &mut [Complex<T>] {
2005    #[inline(always)]
2006    unsafe fn store_complex(&mut self, data: T::VectorType, index: usize) {
2007        debug_assert!(self.len() >= index + T::VectorType::COMPLEX_PER_VECTOR);
2008        T::VectorType::store_complex(self.as_mut_ptr().add(index), data);
2009    }
2010    #[inline(always)]
2011    unsafe fn store_partial1_complex(
2012        &mut self,
2013        data: <T::VectorType as AvxVector256>::HalfVector,
2014        index: usize,
2015    ) {
2016        debug_assert!(self.len() >= index + 1);
2017        T::VectorType::store_partial1_complex(self.as_mut_ptr().add(index), data);
2018    }
2019    #[inline(always)]
2020    unsafe fn store_partial2_complex(
2021        &mut self,
2022        data: <T::VectorType as AvxVector256>::HalfVector,
2023        index: usize,
2024    ) {
2025        debug_assert!(self.len() >= index + 2);
2026        T::VectorType::store_partial2_complex(self.as_mut_ptr().add(index), data);
2027    }
2028    #[inline(always)]
2029    unsafe fn store_partial3_complex(&mut self, data: T::VectorType, index: usize) {
2030        debug_assert!(self.len() >= index + 3);
2031        T::VectorType::store_partial3_complex(self.as_mut_ptr().add(index), data);
2032    }
2033    #[inline(always)]
2034    fn output_ptr(&mut self) -> *mut Complex<T> {
2035        self.as_mut_ptr()
2036    }
2037}
2038impl<'a, T: AvxNum> AvxArrayMut<T> for DoubleBuf<'a, T>
2039where
2040    Self: AvxArray<T>,
2041    &'a mut [Complex<T>]: AvxArrayMut<T>,
2042{
2043    #[inline(always)]
2044    unsafe fn store_complex(&mut self, data: T::VectorType, index: usize) {
2045        self.output.store_complex(data, index);
2046    }
2047    #[inline(always)]
2048    unsafe fn store_partial1_complex(
2049        &mut self,
2050        data: <T::VectorType as AvxVector256>::HalfVector,
2051        index: usize,
2052    ) {
2053        self.output.store_partial1_complex(data, index);
2054    }
2055    #[inline(always)]
2056    unsafe fn store_partial2_complex(
2057        &mut self,
2058        data: <T::VectorType as AvxVector256>::HalfVector,
2059        index: usize,
2060    ) {
2061        self.output.store_partial2_complex(data, index);
2062    }
2063    #[inline(always)]
2064    unsafe fn store_partial3_complex(&mut self, data: T::VectorType, index: usize) {
2065        self.output.store_partial3_complex(data, index);
2066    }
2067    #[inline(always)]
2068    fn output_ptr(&mut self) -> *mut Complex<T> {
2069        self.output.output_ptr()
2070    }
2071}
2072
2073// A custom butterfly-16 function that calls a lambda to load/store data instead of taking an array
2074// This is particularly useful for butterfly 16, because the whole problem doesn't fit into registers, and the compiler isn't smart enough to only load data when it's needed
2075// So the version that takes an array ends up loading data and immediately re-storing it on the stack. By lazily loading and storing exactly when we need to, we can avoid some data reshuffling
2076macro_rules! column_butterfly16_loadfn{
2077    ($load_expr: expr, $store_expr: expr, $twiddles: expr, $rotation: expr) => (
2078        // Size-4 FFTs down the columns
2079        let input1 = [$load_expr(1), $load_expr(5), $load_expr(9), $load_expr(13)];
2080        let mut mid1 = AvxVector::column_butterfly4(input1, $rotation);
2081
2082        mid1[1] = AvxVector::mul_complex(mid1[1], $twiddles[0]);
2083        mid1[2] = avx_vector::apply_butterfly8_twiddle1(mid1[2], $rotation);
2084        mid1[3] = AvxVector::mul_complex(mid1[3], $twiddles[1]);
2085
2086        let input2 = [$load_expr(2), $load_expr(6), $load_expr(10), $load_expr(14)];
2087        let mut mid2 = AvxVector::column_butterfly4(input2, $rotation);
2088
2089        mid2[1] = avx_vector::apply_butterfly8_twiddle1(mid2[1], $rotation);
2090        mid2[2] = mid2[2].rotate90($rotation);
2091        mid2[3] = avx_vector::apply_butterfly8_twiddle3(mid2[3], $rotation);
2092
2093        let input3 = [$load_expr(3), $load_expr(7), $load_expr(11), $load_expr(15)];
2094        let mut mid3 = AvxVector::column_butterfly4(input3, $rotation);
2095
2096        mid3[1] = AvxVector::mul_complex(mid3[1], $twiddles[1]);
2097        mid3[2] = avx_vector::apply_butterfly8_twiddle3(mid3[2], $rotation);
2098        mid3[3] = AvxVector::mul_complex(mid3[3], $twiddles[0].neg());
2099
2100        // do the first row last, because it doesn't need twiddles and therefore requires fewer intermediates
2101        let input0 = [$load_expr(0), $load_expr(4), $load_expr(8), $load_expr(12)];
2102        let mid0     = AvxVector::column_butterfly4(input0, $rotation);
2103
2104        // All of the data is now in the right format to just do a bunch of butterfly 8's.
2105        // Write the data out to the final output as we go so that the compiler can stop worrying about finding stack space for it
2106        for i in 0..4 {
2107            let output = AvxVector::column_butterfly4([mid0[i], mid1[i], mid2[i], mid3[i]], $rotation);
2108            $store_expr(output[0], i);
2109            $store_expr(output[1], i + 4);
2110            $store_expr(output[2], i + 8);
2111            $store_expr(output[3], i + 12);
2112        }
2113    )
2114}
2115
2116// A custom butterfly-32 function that calls a lambda to load/store data instead of taking an array
2117// This is particularly useful for butterfly 32, because the whole problem doesn't fit into registers, and the compiler isn't smart enough to only load data when it's needed
2118// So the version that takes an array ends up loading data and immediately re-storing it on the stack. By lazily loading and storing exactly when we need to, we can avoid some data reshuffling
2119macro_rules! column_butterfly32_loadfn{
2120    ($load_expr: expr, $store_expr: expr, $twiddles: expr, $rotation: expr) => (
2121        // Size-4 FFTs down the columns
2122        let input1 = [$load_expr(1), $load_expr(9), $load_expr(17), $load_expr(25)];
2123        let mut mid1     = AvxVector::column_butterfly4(input1, $rotation);
2124
2125        mid1[1] = AvxVector::mul_complex(mid1[1], $twiddles[0]);
2126        mid1[2] = AvxVector::mul_complex(mid1[2], $twiddles[1]);
2127        mid1[3] = AvxVector::mul_complex(mid1[3], $twiddles[2]);
2128
2129        let input2 = [$load_expr(2), $load_expr(10), $load_expr(18), $load_expr(26)];
2130        let mut mid2     = AvxVector::column_butterfly4(input2, $rotation);
2131
2132        mid2[1] = AvxVector::mul_complex(mid2[1], $twiddles[1]);
2133        mid2[2] = avx_vector::apply_butterfly8_twiddle1(mid2[2], $rotation);
2134        mid2[3] = AvxVector::mul_complex(mid2[3], $twiddles[4]);
2135
2136        let input3 = [$load_expr(3), $load_expr(11), $load_expr(19), $load_expr(27)];
2137        let mut mid3     = AvxVector::column_butterfly4(input3, $rotation);
2138
2139        mid3[1] = AvxVector::mul_complex(mid3[1], $twiddles[2]);
2140        mid3[2] = AvxVector::mul_complex(mid3[2], $twiddles[4]);
2141        mid3[3] = AvxVector::mul_complex(mid3[3], $twiddles[0].rotate90($rotation));
2142
2143        let input4 = [$load_expr(4), $load_expr(12), $load_expr(20), $load_expr(28)];
2144        let mut mid4     = AvxVector::column_butterfly4(input4, $rotation);
2145
2146        mid4[1] = avx_vector::apply_butterfly8_twiddle1(mid4[1], $rotation);
2147        mid4[2] = mid4[2].rotate90($rotation);
2148        mid4[3] = avx_vector::apply_butterfly8_twiddle3(mid4[3], $rotation);
2149
2150        let input5 = [$load_expr(5), $load_expr(13), $load_expr(21), $load_expr(29)];
2151        let mut mid5     = AvxVector::column_butterfly4(input5, $rotation);
2152
2153        mid5[1] = AvxVector::mul_complex(mid5[1], $twiddles[3]);
2154        mid5[2] = AvxVector::mul_complex(mid5[2], $twiddles[1].rotate90($rotation));
2155        mid5[3] = AvxVector::mul_complex(mid5[3], $twiddles[5].rotate90($rotation));
2156
2157        let input6 = [$load_expr(6), $load_expr(14), $load_expr(22), $load_expr(30)];
2158        let mut mid6     = AvxVector::column_butterfly4(input6, $rotation);
2159
2160        mid6[1] = AvxVector::mul_complex(mid6[1], $twiddles[4]);
2161        mid6[2] = avx_vector::apply_butterfly8_twiddle3(mid6[2], $rotation);
2162        mid6[3] = AvxVector::mul_complex(mid6[3], $twiddles[1].neg());
2163
2164        let input7 = [$load_expr(7), $load_expr(15), $load_expr(23), $load_expr(31)];
2165        let mut mid7     = AvxVector::column_butterfly4(input7, $rotation);
2166
2167        mid7[1] = AvxVector::mul_complex(mid7[1], $twiddles[5]);
2168        mid7[2] = AvxVector::mul_complex(mid7[2], $twiddles[4].rotate90($rotation));
2169        mid7[3] = AvxVector::mul_complex(mid7[3], $twiddles[3].neg());
2170
2171        let input0 = [$load_expr(0), $load_expr(8), $load_expr(16), $load_expr(24)];
2172        let mid0     = AvxVector::column_butterfly4(input0, $rotation);
2173
2174        // All of the data is now in the right format to just do a bunch of butterfly 8's in a loop.
2175        // Write the data out to the final output as we go so that the compiler can stop worrying about finding stack space for it
2176        for i in 0..4 {
2177            let output = AvxVector::column_butterfly8([mid0[i], mid1[i], mid2[i], mid3[i], mid4[i], mid5[i], mid6[i], mid7[i]], $rotation);
2178            $store_expr(output[0], i);
2179            $store_expr(output[1], i + 4);
2180            $store_expr(output[2], i + 8);
2181            $store_expr(output[3], i + 12);
2182            $store_expr(output[4], i + 16);
2183            $store_expr(output[5], i + 20);
2184            $store_expr(output[6], i + 24);
2185            $store_expr(output[7], i + 28);
2186        }
2187    )
2188}
2189
2190/// Multiply the complex numbers in `left` by the complex numbers in `right`.
2191/// This is exactly the same as `mul_complex` in `AvxVector`, but this implementation also conjugates the `left` input before multiplying
2192#[inline(always)]
2193unsafe fn mul_complex_conjugated<V: AvxVector>(left: V, right: V) -> V {
2194    // Extract the real and imaginary components from left into 2 separate registers
2195    let (left_real, left_imag) = V::duplicate_complex_components(left);
2196
2197    // create a shuffled version of right where the imaginary values are swapped with the reals
2198    let right_shuffled = V::swap_complex_components(right);
2199
2200    // multiply our duplicated imaginary left vector by our shuffled right vector. that will give us the right side of the traditional complex multiplication formula
2201    let output_right = V::mul(left_imag, right_shuffled);
2202
2203    // use a FMA instruction to multiply together left side of the complex multiplication formula, then alternatingly add and subtract the left side from the right
2204    // By using subadd instead of addsub, we can conjugate the left side for free.
2205    V::fmsubadd(left_real, right, output_right)
2206}
2207
2208// compute buffer[i] = buffer[i].conj() * multiplier[i] pairwise complex multiplication for each element.
2209#[target_feature(enable = "avx", enable = "fma")]
2210pub unsafe fn pairwise_complex_mul_assign_conjugated<T: AvxNum>(
2211    mut buffer: &mut [Complex<T>],
2212    multiplier: &[T::VectorType],
2213) {
2214    assert!(multiplier.len() * T::VectorType::COMPLEX_PER_VECTOR >= buffer.len()); // Assert to convince the compiler to omit bounds checks inside the loop
2215
2216    for (i, mut buffer_chunk) in buffer
2217        .chunks_exact_mut(T::VectorType::COMPLEX_PER_VECTOR)
2218        .enumerate()
2219    {
2220        let left = buffer_chunk.load_complex(0);
2221
2222        // Do a complex multiplication between `left` and `right`
2223        let product = mul_complex_conjugated(left, multiplier[i]);
2224
2225        // Store the result
2226        buffer_chunk.store_complex(product, 0);
2227    }
2228
2229    // Process the remainder, if there is one
2230    let remainder_count = buffer.len() % T::VectorType::COMPLEX_PER_VECTOR;
2231    if remainder_count > 0 {
2232        let remainder_index = buffer.len() - remainder_count;
2233        let remainder_multiplier = multiplier.last().unwrap();
2234        match remainder_count {
2235            1 => {
2236                let left = buffer.load_partial1_complex(remainder_index);
2237                let product = mul_complex_conjugated(left, remainder_multiplier.lo());
2238                buffer.store_partial1_complex(product, remainder_index);
2239            }
2240            2 => {
2241                let left = buffer.load_partial2_complex(remainder_index);
2242                let product = mul_complex_conjugated(left, remainder_multiplier.lo());
2243                buffer.store_partial2_complex(product, remainder_index);
2244            }
2245            3 => {
2246                let left = buffer.load_partial3_complex(remainder_index);
2247                let product = mul_complex_conjugated(left, *remainder_multiplier);
2248                buffer.store_partial3_complex(product, remainder_index);
2249            }
2250            _ => unreachable!(),
2251        }
2252    }
2253}
2254
2255// compute output[i] = input[i].conj() * multiplier[i] pairwise complex multiplication for each element.
2256#[target_feature(enable = "avx", enable = "fma")]
2257pub unsafe fn pairwise_complex_mul_conjugated<T: AvxNum>(
2258    input: &[Complex<T>],
2259    mut output: &mut [Complex<T>],
2260    multiplier: &[T::VectorType],
2261) {
2262    assert!(
2263        multiplier.len() * T::VectorType::COMPLEX_PER_VECTOR >= input.len(),
2264        "multiplier len = {}, input len = {}",
2265        multiplier.len(),
2266        input.len()
2267    ); // Assert to convince the compiler to omit bounds checks inside the loop
2268    assert!(input.len() == output.len()); // Assert to convince the compiler to omit bounds checks inside the loop
2269    let main_loop_count = input.len() / T::VectorType::COMPLEX_PER_VECTOR;
2270    let remainder_count = input.len() % T::VectorType::COMPLEX_PER_VECTOR;
2271
2272    for (i, m) in (&multiplier[..main_loop_count]).iter().enumerate() {
2273        let left = input.load_complex(i * T::VectorType::COMPLEX_PER_VECTOR);
2274
2275        // Do a complex multiplication between `left` and `right`
2276        let product = mul_complex_conjugated(left, *m);
2277
2278        // Store the result
2279        output.store_complex(product, i * T::VectorType::COMPLEX_PER_VECTOR);
2280    }
2281
2282    // Process the remainder, if there is one
2283    if remainder_count > 0 {
2284        let remainder_index = input.len() - remainder_count;
2285        let remainder_multiplier = multiplier.last().unwrap();
2286        match remainder_count {
2287            1 => {
2288                let left = input.load_partial1_complex(remainder_index);
2289                let product = mul_complex_conjugated(left, remainder_multiplier.lo());
2290                output.store_partial1_complex(product, remainder_index);
2291            }
2292            2 => {
2293                let left = input.load_partial2_complex(remainder_index);
2294                let product = mul_complex_conjugated(left, remainder_multiplier.lo());
2295                output.store_partial2_complex(product, remainder_index);
2296            }
2297            3 => {
2298                let left = input.load_partial3_complex(remainder_index);
2299                let product = mul_complex_conjugated(left, *remainder_multiplier);
2300                output.store_partial3_complex(product, remainder_index);
2301            }
2302            _ => unreachable!(),
2303        }
2304    }
2305}