tfhe_ntt/prime64/
generic_solinas.rs

1use super::RECURSION_THRESHOLD;
2use crate::fastdiv::Div64;
3use core::{fmt::Debug, iter::zip};
4
5#[allow(unused_imports)]
6use pulp::*;
7
8pub(crate) trait PrimeModulus: Debug + Copy {
9    type Div: Debug + Copy;
10
11    fn add(self, a: u64, b: u64) -> u64;
12    fn sub(self, a: u64, b: u64) -> u64;
13    fn mul(p: Self::Div, a: u64, b: u64) -> u64;
14}
15
16#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
17pub(crate) trait PrimeModulusV3: Debug + Copy {
18    type Div: Debug + Copy;
19
20    fn add(self, simd: crate::V3, a: u64x4, b: u64x4) -> u64x4;
21    fn sub(self, simd: crate::V3, a: u64x4, b: u64x4) -> u64x4;
22    fn mul(p: Self::Div, simd: crate::V3, a: u64x4, b: u64x4) -> u64x4;
23}
24
25#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
26#[cfg(feature = "nightly")]
27pub(crate) trait PrimeModulusV4: Debug + Copy {
28    type Div: Debug + Copy;
29
30    fn add(self, simd: crate::V4, a: u64x8, b: u64x8) -> u64x8;
31    fn sub(self, simd: crate::V4, a: u64x8, b: u64x8) -> u64x8;
32    fn mul(p: Self::Div, simd: crate::V4, a: u64x8, b: u64x8) -> u64x8;
33}
34
35#[derive(Copy, Clone, Debug)]
36pub struct Solinas;
37
38impl Solinas {
39    pub const P: u64 = ((1u128 << 64) - (1u128 << 32) + 1u128) as u64;
40}
41
42impl PrimeModulus for u64 {
43    type Div = Div64;
44
45    #[inline(always)]
46    fn add(self, a: u64, b: u64) -> u64 {
47        let p = self;
48        // a + b >= p
49        // implies
50        // a >= p - b
51
52        let neg_b = p - b;
53        if a >= neg_b {
54            a - neg_b
55        } else {
56            a + b
57        }
58    }
59
60    #[inline(always)]
61    fn sub(self, a: u64, b: u64) -> u64 {
62        let p = self;
63        let neg_b = p - b;
64        if a >= b {
65            a - b
66        } else {
67            a + neg_b
68        }
69    }
70
71    #[inline(always)]
72    fn mul(p: Self::Div, a: u64, b: u64) -> u64 {
73        Div64::rem_u128(a as u128 * b as u128, p)
74    }
75}
76
77impl PrimeModulus for Solinas {
78    type Div = ();
79
80    #[inline(always)]
81    fn add(self, a: u64, b: u64) -> u64 {
82        let p = Self::P;
83        let neg_b = p - b;
84        if a >= neg_b {
85            a - neg_b
86        } else {
87            a + b
88        }
89    }
90
91    #[inline(always)]
92    fn sub(self, a: u64, b: u64) -> u64 {
93        let p = Self::P;
94        let neg_b = p - b;
95        if a >= b {
96            a - b
97        } else {
98            a + neg_b
99        }
100    }
101
102    #[inline(always)]
103    fn mul(p: Self::Div, a: u64, b: u64) -> u64 {
104        let _ = p;
105        let p = Self::P;
106
107        let wide = a as u128 * b as u128;
108
109        // https://cp4space.hatsya.com/2021/09/01/an-efficient-prime-for-number-theoretic-transforms/
110        let lo = wide as u64;
111        let hi = (wide >> 64) as u64;
112        let mid = hi & 0x0000_0000_FFFF_FFFF;
113        let hi = (hi & 0xFFFF_FFFF_0000_0000) >> 32;
114
115        let mut low2 = lo.wrapping_sub(hi);
116        if hi > lo {
117            low2 = low2.wrapping_add(p);
118        }
119
120        let mut product = mid << 32;
121        product -= mid;
122
123        let mut result = low2.wrapping_add(product);
124        if (result < product) || (result >= p) {
125            result = result.wrapping_sub(p);
126        }
127        result
128    }
129}
130
131#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
132impl PrimeModulusV3 for u64 {
133    type Div = (u64, u64, u64, u64, u64);
134
135    #[inline(always)]
136    fn add(self, simd: crate::V3, a: u64x4, b: u64x4) -> u64x4 {
137        let p = simd.splat_u64x4(self);
138        let neg_b = simd.wrapping_sub_u64x4(p, b);
139        let not_a_ge_neg_b = simd.cmp_gt_u64x4(neg_b, a);
140        simd.select_u64x4(
141            not_a_ge_neg_b,
142            simd.wrapping_add_u64x4(a, b),
143            simd.wrapping_sub_u64x4(a, neg_b),
144        )
145    }
146
147    #[inline(always)]
148    fn sub(self, simd: crate::V3, a: u64x4, b: u64x4) -> u64x4 {
149        let p = simd.splat_u64x4(self);
150        let neg_b = simd.wrapping_sub_u64x4(p, b);
151        let not_a_ge_b = simd.cmp_gt_u64x4(b, a);
152        simd.select_u64x4(
153            not_a_ge_b,
154            simd.wrapping_add_u64x4(a, neg_b),
155            simd.wrapping_sub_u64x4(a, b),
156        )
157    }
158
159    #[inline(always)]
160    fn mul(p: Self::Div, simd: crate::V3, a: u64x4, b: u64x4) -> u64x4 {
161        #[inline(always)]
162        fn mul_with_carry(simd: crate::V3, l: u64x4, r: u64x4, c: u64x4) -> (u64x4, u64x4) {
163            let (lo, hi) = simd.widening_mul_u64x4(l, r);
164            let lo_plus_c = simd.wrapping_add_u64x4(lo, c);
165            let overflow = cast(simd.cmp_gt_u64x4(lo, lo_plus_c));
166            (lo_plus_c, simd.wrapping_sub_u64x4(hi, overflow))
167        }
168
169        #[inline(always)]
170        fn mul_u256_u64(
171            simd: crate::V3,
172            lhs0: u64x4,
173            lhs1: u64x4,
174            lhs2: u64x4,
175            lhs3: u64x4,
176            rhs: u64x4,
177        ) -> (u64x4, u64x4, u64x4, u64x4, u64x4) {
178            let (x0, carry) = simd.widening_mul_u64x4(lhs0, rhs);
179            let (x1, carry) = mul_with_carry(simd, lhs1, rhs, carry);
180            let (x2, carry) = mul_with_carry(simd, lhs2, rhs, carry);
181            let (x3, carry) = mul_with_carry(simd, lhs3, rhs, carry);
182            (x0, x1, x2, x3, carry)
183        }
184
185        #[inline(always)]
186        fn wrapping_mul_u256_u128(
187            simd: crate::V3,
188            lhs0: u64x4,
189            lhs1: u64x4,
190            lhs2: u64x4,
191            lhs3: u64x4,
192            rhs0: u64x4,
193            rhs1: u64x4,
194        ) -> (u64x4, u64x4, u64x4, u64x4) {
195            let (x0, x1, x2, x3, _) = mul_u256_u64(simd, lhs0, lhs1, lhs2, lhs3, rhs0);
196            let (y0, y1, y2, _, _) = mul_u256_u64(simd, lhs0, lhs1, lhs2, lhs3, rhs1);
197
198            let z0 = x0;
199
200            let z1 = simd.wrapping_add_u64x4(x1, y0);
201            let carry = cast(simd.cmp_gt_u64x4(x1, z1));
202
203            let z2 = simd.wrapping_add_u64x4(x2, y1);
204            let o0 = cast(simd.cmp_gt_u64x4(x2, z2));
205            let o1 = cast(simd.cmp_eq_u64x4(z2, carry));
206            let z2 = simd.wrapping_sub_u64x4(z2, carry);
207            let carry = simd.or_u64x4(o0, o1);
208
209            let z3 = simd.wrapping_add_u64x4(x3, y2);
210            let z3 = simd.wrapping_sub_u64x4(z3, carry);
211
212            (z0, z1, z2, z3)
213        }
214
215        let (p, p_div0, p_div1, p_div2, p_div3) = p;
216
217        let p = simd.splat_u64x4(p as _);
218        let p_div0 = simd.splat_u64x4(p_div0 as _);
219        let p_div1 = simd.splat_u64x4(p_div1 as _);
220        let p_div2 = simd.splat_u64x4(p_div2 as _);
221        let p_div3 = simd.splat_u64x4(p_div3 as _);
222
223        let (lo, hi) = simd.widening_mul_u64x4(a, b);
224        let (low_bits0, low_bits1, low_bits2, low_bits3) =
225            wrapping_mul_u256_u128(simd, p_div0, p_div1, p_div2, p_div3, lo, hi);
226
227        mul_u256_u64(simd, low_bits0, low_bits1, low_bits2, low_bits3, p).4
228    }
229}
230
231#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
232impl PrimeModulusV3 for Solinas {
233    type Div = ();
234
235    #[inline(always)]
236    fn add(self, simd: crate::V3, a: u64x4, b: u64x4) -> u64x4 {
237        let p = simd.splat_u64x4(Self::P);
238        let neg_b = simd.wrapping_sub_u64x4(p, b);
239        let not_a_ge_neg_b = simd.cmp_gt_u64x4(neg_b, a);
240        simd.select_u64x4(
241            not_a_ge_neg_b,
242            simd.wrapping_add_u64x4(a, b),
243            simd.wrapping_sub_u64x4(a, neg_b),
244        )
245    }
246
247    #[inline(always)]
248    fn sub(self, simd: crate::V3, a: u64x4, b: u64x4) -> u64x4 {
249        let p = simd.splat_u64x4(Self::P);
250        let neg_b = simd.wrapping_sub_u64x4(p, b);
251        let not_a_ge_b = simd.cmp_gt_u64x4(b, a);
252        simd.select_u64x4(
253            not_a_ge_b,
254            simd.wrapping_add_u64x4(a, neg_b),
255            simd.wrapping_sub_u64x4(a, b),
256        )
257    }
258
259    #[inline(always)]
260    fn mul(p: Self::Div, simd: crate::V3, a: u64x4, b: u64x4) -> u64x4 {
261        let _ = p;
262
263        let p = simd.splat_u64x4(Self::P as _);
264
265        // https://cp4space.hatsya.com/2021/09/01/an-efficient-prime-for-number-theoretic-transforms/
266        let (lo, hi) = simd.widening_mul_u64x4(a, b);
267        let mid = simd.and_u64x4(hi, simd.splat_u64x4(0x0000_0000_FFFF_FFFF));
268        let hi = simd.and_u64x4(hi, simd.splat_u64x4(0xFFFF_FFFF_0000_0000));
269        let hi = simd.shr_const_u64x4::<32>(hi);
270
271        let low2 = simd.wrapping_sub_u64x4(lo, hi);
272        let low2 = simd.select_u64x4(
273            simd.cmp_gt_u64x4(hi, lo),
274            simd.wrapping_add_u64x4(low2, p),
275            low2,
276        );
277
278        let product = simd.shl_const_u64x4::<32>(mid);
279        let product = simd.wrapping_sub_u64x4(product, mid);
280
281        let result = simd.wrapping_add_u64x4(low2, product);
282
283        // (result < product) || (result >= p)
284        // (result < product) || !(p > result)
285        // !(!(result < product) && (p > result))
286        let product_gt_result = simd.cmp_gt_u64x4(product, result);
287        let p_gt_result = simd.cmp_gt_u64x4(p, result);
288        let not_cond = simd.andnot_m64x4(product_gt_result, p_gt_result);
289
290        simd.select_u64x4(not_cond, result, simd.wrapping_sub_u64x4(result, p))
291    }
292}
293
294#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
295#[cfg(feature = "nightly")]
296impl PrimeModulusV4 for u64 {
297    type Div = (u64, u64, u64, u64, u64);
298
299    #[inline(always)]
300    fn add(self, simd: crate::V4, a: u64x8, b: u64x8) -> u64x8 {
301        let p = simd.splat_u64x8(self);
302        let neg_b = simd.wrapping_sub_u64x8(p, b);
303        let a_ge_neg_b = simd.cmp_ge_u64x8(a, neg_b);
304        simd.select_u64x8(
305            a_ge_neg_b,
306            simd.wrapping_sub_u64x8(a, neg_b),
307            simd.wrapping_add_u64x8(a, b),
308        )
309    }
310
311    #[inline(always)]
312    fn sub(self, simd: crate::V4, a: u64x8, b: u64x8) -> u64x8 {
313        let p = simd.splat_u64x8(self);
314        let neg_b = simd.wrapping_sub_u64x8(p, b);
315        let a_ge_b = simd.cmp_ge_u64x8(a, b);
316        simd.select_u64x8(
317            a_ge_b,
318            simd.wrapping_sub_u64x8(a, b),
319            simd.wrapping_add_u64x8(a, neg_b),
320        )
321    }
322
323    #[inline(always)]
324    fn mul(p: Self::Div, simd: crate::V4, a: u64x8, b: u64x8) -> u64x8 {
325        #[inline(always)]
326        fn mul_with_carry(simd: crate::V4, l: u64x8, r: u64x8, c: u64x8) -> (u64x8, u64x8) {
327            let (lo, hi) = simd.widening_mul_u64x8(l, r);
328            let lo_plus_c = simd.wrapping_add_u64x8(lo, c);
329            let overflow = simd.cmp_gt_u64x8(lo, lo_plus_c);
330
331            (
332                lo_plus_c,
333                simd.wrapping_sub_u64x8(hi, simd.convert_mask_b8_to_u64x8(overflow)),
334            )
335        }
336
337        #[inline(always)]
338        fn mul_u256_u64(
339            simd: crate::V4,
340            lhs0: u64x8,
341            lhs1: u64x8,
342            lhs2: u64x8,
343            lhs3: u64x8,
344            rhs: u64x8,
345        ) -> (u64x8, u64x8, u64x8, u64x8, u64x8) {
346            let (x0, carry) = simd.widening_mul_u64x8(lhs0, rhs);
347            let (x1, carry) = mul_with_carry(simd, lhs1, rhs, carry);
348            let (x2, carry) = mul_with_carry(simd, lhs2, rhs, carry);
349            let (x3, carry) = mul_with_carry(simd, lhs3, rhs, carry);
350            (x0, x1, x2, x3, carry)
351        }
352
353        #[inline(always)]
354        fn wrapping_mul_u256_u128(
355            simd: crate::V4,
356            lhs0: u64x8,
357            lhs1: u64x8,
358            lhs2: u64x8,
359            lhs3: u64x8,
360            rhs0: u64x8,
361            rhs1: u64x8,
362        ) -> (u64x8, u64x8, u64x8, u64x8) {
363            let (x0, x1, x2, x3, _) = mul_u256_u64(simd, lhs0, lhs1, lhs2, lhs3, rhs0);
364            let (y0, y1, y2, _, _) = mul_u256_u64(simd, lhs0, lhs1, lhs2, lhs3, rhs1);
365
366            let z0 = x0;
367
368            let z1 = simd.wrapping_add_u64x8(x1, y0);
369            let carry = simd.convert_mask_b8_to_u64x8(simd.cmp_gt_u64x8(x1, z1));
370
371            let z2 = simd.wrapping_add_u64x8(x2, y1);
372            let o0 = simd.cmp_gt_u64x8(x2, z2);
373            let o1 = simd.cmp_eq_u64x8(z2, carry);
374            let z2 = simd.wrapping_sub_u64x8(z2, carry);
375            let carry = simd.convert_mask_b8_to_u64x8(b8(o0.0 | o1.0));
376
377            let z3 = simd.wrapping_add_u64x8(x3, y2);
378            let z3 = simd.wrapping_sub_u64x8(z3, carry);
379
380            (z0, z1, z2, z3)
381        }
382
383        let (p, p_div0, p_div1, p_div2, p_div3) = p;
384
385        let p = simd.splat_u64x8(p);
386        let p_div0 = simd.splat_u64x8(p_div0);
387        let p_div1 = simd.splat_u64x8(p_div1);
388        let p_div2 = simd.splat_u64x8(p_div2);
389        let p_div3 = simd.splat_u64x8(p_div3);
390
391        let (lo, hi) = simd.widening_mul_u64x8(a, b);
392        let (low_bits0, low_bits1, low_bits2, low_bits3) =
393            wrapping_mul_u256_u128(simd, p_div0, p_div1, p_div2, p_div3, lo, hi);
394
395        mul_u256_u64(simd, low_bits0, low_bits1, low_bits2, low_bits3, p).4
396    }
397}
398
399#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
400#[cfg(feature = "nightly")]
401impl PrimeModulusV4 for Solinas {
402    type Div = ();
403
404    #[inline(always)]
405    fn add(self, simd: crate::V4, a: u64x8, b: u64x8) -> u64x8 {
406        PrimeModulusV4::add(Self::P, simd, a, b)
407    }
408
409    #[inline(always)]
410    fn sub(self, simd: crate::V4, a: u64x8, b: u64x8) -> u64x8 {
411        PrimeModulusV4::sub(Self::P, simd, a, b)
412    }
413
414    #[inline(always)]
415    fn mul(p: Self::Div, simd: crate::V4, a: u64x8, b: u64x8) -> u64x8 {
416        let _ = p;
417
418        let p = simd.splat_u64x8(Self::P);
419
420        // https://cp4space.hatsya.com/2021/09/01/an-efficient-prime-for-number-theoretic-transforms/
421        let (lo, hi) = simd.widening_mul_u64x8(a, b);
422        let mid = simd.and_u64x8(hi, simd.splat_u64x8(0x0000_0000_FFFF_FFFF));
423        let hi = simd.and_u64x8(hi, simd.splat_u64x8(0xFFFF_FFFF_0000_0000));
424        let hi = simd.shr_const_u64x8::<32>(hi);
425
426        let low2 = simd.wrapping_sub_u64x8(lo, hi);
427        let low2 = simd.select_u64x8(
428            simd.cmp_gt_u64x8(hi, lo),
429            simd.wrapping_add_u64x8(low2, p),
430            low2,
431        );
432
433        let product = simd.shl_const_u64x8::<32>(mid);
434        let product = simd.wrapping_sub_u64x8(product, mid);
435
436        let result = simd.wrapping_add_u64x8(low2, product);
437
438        // (result < product) || (result >= p)
439        // (result < product) || !(p > result)
440        // !(!(result < product) && (p > result))
441        let product_gt_result = simd.cmp_gt_u64x8(product, result);
442        let p_gt_result = simd.cmp_gt_u64x8(p, result);
443        let not_cond = b8(!product_gt_result.0 & p_gt_result.0);
444
445        simd.select_u64x8(not_cond, result, simd.wrapping_sub_u64x8(result, p))
446    }
447}
448
449pub(crate) fn fwd_breadth_first_scalar<P: PrimeModulus>(
450    data: &mut [u64],
451    p: P,
452    p_div: P::Div,
453    twid: &[u64],
454    recursion_depth: usize,
455    recursion_half: usize,
456) {
457    let n = data.len();
458    debug_assert!(n.is_power_of_two());
459
460    let mut t = n / 2;
461    let mut m = 1;
462    let mut w_idx = (m << recursion_depth) + recursion_half * m;
463
464    while m < n {
465        let w = &twid[w_idx..];
466
467        for (data, &w1) in zip(data.chunks_exact_mut(2 * t), w) {
468            let (z0, z1) = data.split_at_mut(t);
469
470            for (z0, z1) in zip(z0, z1) {
471                let z1w = P::mul(p_div, *z1, w1);
472
473                (*z0, *z1) = (p.add(*z0, z1w), p.sub(*z0, z1w));
474            }
475        }
476
477        t /= 2;
478        m *= 2;
479        w_idx *= 2;
480    }
481}
482
483pub(crate) fn inv_breadth_first_scalar<P: PrimeModulus>(
484    data: &mut [u64],
485    p: P,
486    p_div: P::Div,
487    inv_twid: &[u64],
488    recursion_depth: usize,
489    recursion_half: usize,
490) {
491    let n = data.len();
492    debug_assert!(n.is_power_of_two());
493
494    let mut t = 1;
495    let mut m = n;
496    let mut w_idx = (m << recursion_depth) + recursion_half * m;
497
498    while m > 1 {
499        m /= 2;
500        w_idx /= 2;
501
502        let w = &inv_twid[w_idx..];
503
504        for (data, &w1) in zip(data.chunks_exact_mut(2 * t), w) {
505            let (z0, z1) = data.split_at_mut(t);
506
507            for (z0, z1) in zip(z0, z1) {
508                (*z0, *z1) = (p.add(*z0, *z1), P::mul(p_div, p.sub(*z0, *z1), w1));
509            }
510        }
511
512        t *= 2;
513    }
514}
515
516pub(crate) fn inv_depth_first_scalar<P: PrimeModulus>(
517    data: &mut [u64],
518    p: P,
519    p_div: P::Div,
520    inv_twid: &[u64],
521    recursion_depth: usize,
522    recursion_half: usize,
523) {
524    let n = data.len();
525    debug_assert!(n.is_power_of_two());
526    if n <= RECURSION_THRESHOLD {
527        inv_breadth_first_scalar(data, p, p_div, inv_twid, recursion_depth, recursion_half);
528    } else {
529        let (data0, data1) = data.split_at_mut(n / 2);
530        inv_depth_first_scalar(
531            data0,
532            p,
533            p_div,
534            inv_twid,
535            recursion_depth + 1,
536            recursion_half * 2,
537        );
538        inv_depth_first_scalar(
539            data1,
540            p,
541            p_div,
542            inv_twid,
543            recursion_depth + 1,
544            recursion_half * 2 + 1,
545        );
546
547        let t = n / 2;
548        let m = 1;
549        let w_idx = (m << recursion_depth) + m * recursion_half;
550
551        let w = &inv_twid[w_idx..];
552
553        for (data, &w1) in zip(data.chunks_exact_mut(2 * t), w) {
554            let (z0, z1) = data.split_at_mut(t);
555
556            for (z0, z1) in zip(z0, z1) {
557                (*z0, *z1) = (p.add(*z0, *z1), P::mul(p_div, p.sub(*z0, *z1), w1));
558            }
559        }
560    }
561}
562
563#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
564pub(crate) fn fwd_breadth_first_avx2<P: PrimeModulusV3>(
565    simd: crate::V3,
566    data: &mut [u64],
567    p: P,
568    p_div: P::Div,
569    twid: &[u64],
570    recursion_depth: usize,
571    recursion_half: usize,
572) {
573    struct Impl<'a, P: PrimeModulusV3> {
574        simd: crate::V3,
575        data: &'a mut [u64],
576        p: P,
577        p_div: P::Div,
578        twid: &'a [u64],
579        recursion_depth: usize,
580        recursion_half: usize,
581    }
582    impl<P: PrimeModulusV3> pulp::NullaryFnOnce for Impl<'_, P> {
583        type Output = ();
584
585        #[inline(always)]
586        fn call(self) -> Self::Output {
587            let Self {
588                simd,
589                data,
590                p,
591                p_div,
592                twid,
593                recursion_depth,
594                recursion_half,
595            } = self;
596            let n = data.len();
597            debug_assert!(n.is_power_of_two());
598
599            let mut t = n / 2;
600            let mut m = 1;
601            let mut w_idx = (m << recursion_depth) + recursion_half * m;
602            while m < n / 4 {
603                let w = &twid[w_idx..];
604
605                for (data, &w1) in zip(data.chunks_exact_mut(2 * t), w) {
606                    let (z0, z1) = data.split_at_mut(t);
607                    let z0 = pulp::as_arrays_mut::<4, _>(z0).0;
608                    let z1 = pulp::as_arrays_mut::<4, _>(z1).0;
609                    let w1 = simd.splat_u64x4(w1);
610
611                    for (z0_, z1_) in zip(z0, z1) {
612                        let mut z0 = cast(*z0_);
613                        let mut z1 = cast(*z1_);
614                        let z1w = P::mul(p_div, simd, z1, w1);
615                        (z0, z1) = (p.add(simd, z0, z1w), p.sub(simd, z0, z1w));
616                        *z0_ = cast(z0);
617                        *z1_ = cast(z1);
618                    }
619                }
620
621                t /= 2;
622                m *= 2;
623                w_idx *= 2;
624            }
625
626            // m = n / 4
627            // t = 2
628            {
629                let w = pulp::as_arrays::<2, _>(&twid[w_idx..]).0;
630                let data = pulp::as_arrays_mut::<4, _>(data).0;
631                let data = pulp::as_arrays_mut::<2, _>(data).0;
632
633                for (z0z0z1z1, w1) in zip(data, w) {
634                    let w1 = simd.permute2_u64x4(*w1);
635                    let [mut z0, mut z1] = simd.interleave2_u64x4(cast(*z0z0z1z1));
636                    let z1w = P::mul(p_div, simd, z1, w1);
637                    (z0, z1) = (p.add(simd, z0, z1w), p.sub(simd, z0, z1w));
638                    *z0z0z1z1 = cast(simd.interleave2_u64x4([z0, z1]));
639                }
640
641                w_idx *= 2;
642            }
643
644            // m = n / 2
645            // t = 1
646            {
647                let w = pulp::as_arrays::<4, _>(&twid[w_idx..]).0;
648                let data = pulp::as_arrays_mut::<4, _>(data).0;
649                let data = pulp::as_arrays_mut::<2, _>(data).0;
650
651                for (z0z1, w1) in zip(data, w) {
652                    let w1 = simd.permute1_u64x4(*w1);
653                    let [mut z0, mut z1] = simd.interleave1_u64x4(cast(*z0z1));
654                    let z1w = P::mul(p_div, simd, z1, w1);
655                    (z0, z1) = (p.add(simd, z0, z1w), p.sub(simd, z0, z1w));
656                    *z0z1 = cast(simd.interleave1_u64x4([z0, z1]));
657                }
658            }
659        }
660    }
661    simd.vectorize(Impl {
662        simd,
663        data,
664        p,
665        p_div,
666        twid,
667        recursion_depth,
668        recursion_half,
669    });
670}
671
672#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
673pub(crate) fn inv_breadth_first_avx2<P: PrimeModulusV3>(
674    simd: crate::V3,
675    data: &mut [u64],
676    p: P,
677    p_div: P::Div,
678    inv_twid: &[u64],
679    recursion_depth: usize,
680    recursion_half: usize,
681) {
682    struct Impl<'a, P: PrimeModulusV3> {
683        simd: crate::V3,
684        data: &'a mut [u64],
685        p: P,
686        p_div: P::Div,
687        inv_twid: &'a [u64],
688        recursion_depth: usize,
689        recursion_half: usize,
690    }
691    impl<P: PrimeModulusV3> pulp::NullaryFnOnce for Impl<'_, P> {
692        type Output = ();
693
694        #[inline(always)]
695        fn call(self) -> Self::Output {
696            let Self {
697                simd,
698                data,
699                p,
700                p_div,
701                inv_twid,
702                recursion_depth,
703                recursion_half,
704            } = self;
705
706            let n = data.len();
707            debug_assert!(n.is_power_of_two());
708
709            let mut t = 1;
710            let mut m = n;
711            let mut w_idx = (m << recursion_depth) + recursion_half * m;
712
713            // m = n / 2
714            // t = 1
715            {
716                m /= 2;
717                w_idx /= 2;
718
719                let w = pulp::as_arrays::<4, _>(&inv_twid[w_idx..]).0;
720                let data = pulp::as_arrays_mut::<4, _>(data).0;
721                let data = pulp::as_arrays_mut::<2, _>(data).0;
722
723                for (z0z1, w1) in zip(data, w) {
724                    let w1 = simd.permute1_u64x4(*w1);
725                    let [mut z0, mut z1] = simd.interleave1_u64x4(cast(*z0z1));
726                    (z0, z1) = (
727                        p.add(simd, z0, z1),
728                        P::mul(p_div, simd, p.sub(simd, z0, z1), w1),
729                    );
730                    *z0z1 = cast(simd.interleave1_u64x4([z0, z1]));
731                }
732
733                t *= 2;
734            }
735
736            // m = n / 4
737            // t = 2
738            {
739                m /= 2;
740                w_idx /= 2;
741
742                let w = pulp::as_arrays::<2, _>(&inv_twid[w_idx..]).0;
743                let data = pulp::as_arrays_mut::<4, _>(data).0;
744                let data = pulp::as_arrays_mut::<2, _>(data).0;
745
746                for (z0z0z1z1, w1) in zip(data, w) {
747                    let w1 = simd.permute2_u64x4(*w1);
748                    let [mut z0, mut z1] = simd.interleave2_u64x4(cast(*z0z0z1z1));
749                    (z0, z1) = (
750                        p.add(simd, z0, z1),
751                        P::mul(p_div, simd, p.sub(simd, z0, z1), w1),
752                    );
753                    *z0z0z1z1 = cast(simd.interleave2_u64x4([z0, z1]));
754                }
755
756                t *= 2;
757            }
758
759            while m > 1 {
760                m /= 2;
761                w_idx /= 2;
762
763                let w = &inv_twid[w_idx..];
764
765                for (data, &w1) in zip(data.chunks_exact_mut(2 * t), w) {
766                    let (z0, z1) = data.split_at_mut(t);
767                    let z0 = pulp::as_arrays_mut::<4, _>(z0).0;
768                    let z1 = pulp::as_arrays_mut::<4, _>(z1).0;
769                    let w1 = simd.splat_u64x4(w1);
770
771                    for (z0_, z1_) in zip(z0, z1) {
772                        let mut z0 = cast(*z0_);
773                        let mut z1 = cast(*z1_);
774                        (z0, z1) = (
775                            p.add(simd, z0, z1),
776                            P::mul(p_div, simd, p.sub(simd, z0, z1), w1),
777                        );
778                        *z0_ = cast(z0);
779                        *z1_ = cast(z1);
780                    }
781                }
782
783                t *= 2;
784            }
785        }
786    }
787
788    simd.vectorize(Impl {
789        simd,
790        data,
791        p,
792        p_div,
793        inv_twid,
794        recursion_depth,
795        recursion_half,
796    });
797}
798
799#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
800#[cfg(feature = "nightly")]
801pub(crate) fn fwd_breadth_first_avx512<P: PrimeModulusV4>(
802    simd: crate::V4,
803    data: &mut [u64],
804    p: P,
805    p_div: P::Div,
806    twid: &[u64],
807    recursion_depth: usize,
808    recursion_half: usize,
809) {
810    struct Impl<'a, P: PrimeModulusV4> {
811        simd: crate::V4,
812        data: &'a mut [u64],
813        p: P,
814        p_div: P::Div,
815        twid: &'a [u64],
816        recursion_depth: usize,
817        recursion_half: usize,
818    }
819    impl<P: PrimeModulusV4> pulp::NullaryFnOnce for Impl<'_, P> {
820        type Output = ();
821
822        #[inline(always)]
823        fn call(self) -> Self::Output {
824            let Self {
825                simd,
826                data,
827                p,
828                p_div,
829                twid,
830                recursion_depth,
831                recursion_half,
832            } = self;
833
834            let n = data.len();
835            debug_assert!(n.is_power_of_two());
836
837            let mut t = n / 2;
838            let mut m = 1;
839            let mut w_idx = (m << recursion_depth) + recursion_half * m;
840            while m < n / 8 {
841                let w = &twid[w_idx..];
842
843                for (data, &w1) in zip(data.chunks_exact_mut(2 * t), w) {
844                    let (z0, z1) = data.split_at_mut(t);
845                    let z0 = pulp::as_arrays_mut::<8, _>(z0).0;
846                    let z1 = pulp::as_arrays_mut::<8, _>(z1).0;
847                    let w1 = simd.splat_u64x8(w1);
848
849                    for (z0_, z1_) in zip(z0, z1) {
850                        let mut z0 = cast(*z0_);
851                        let mut z1 = cast(*z1_);
852                        let z1w = P::mul(p_div, simd, z1, w1);
853                        (z0, z1) = (p.add(simd, z0, z1w), p.sub(simd, z0, z1w));
854                        *z0_ = cast(z0);
855                        *z1_ = cast(z1);
856                    }
857                }
858
859                t /= 2;
860                m *= 2;
861                w_idx *= 2;
862            }
863
864            // m = n / 8
865            // t = 4
866            {
867                let w = pulp::as_arrays::<2, _>(&twid[w_idx..]).0;
868                let data = pulp::as_arrays_mut::<8, _>(data).0;
869                let data = pulp::as_arrays_mut::<2, _>(data).0;
870
871                for (z0z0z0z0z1z1z1z1, w1) in zip(data, w) {
872                    let w1 = simd.permute4_u64x8(*w1);
873                    let [mut z0, mut z1] = simd.interleave4_u64x8(cast(*z0z0z0z0z1z1z1z1));
874                    let z1w = P::mul(p_div, simd, z1, w1);
875                    (z0, z1) = (p.add(simd, z0, z1w), p.sub(simd, z0, z1w));
876                    *z0z0z0z0z1z1z1z1 = cast(simd.interleave4_u64x8([z0, z1]));
877                }
878
879                w_idx *= 2;
880            }
881
882            // m = n / 4
883            // t = 2
884            {
885                let w = pulp::as_arrays::<4, _>(&twid[w_idx..]).0;
886                let data = pulp::as_arrays_mut::<8, _>(data).0;
887                let data = pulp::as_arrays_mut::<2, _>(data).0;
888
889                for (z0z0z1z1, w1) in zip(data, w) {
890                    let w1 = simd.permute2_u64x8(*w1);
891                    let [mut z0, mut z1] = simd.interleave2_u64x8(cast(*z0z0z1z1));
892                    let z1w = P::mul(p_div, simd, z1, w1);
893                    (z0, z1) = (p.add(simd, z0, z1w), p.sub(simd, z0, z1w));
894                    *z0z0z1z1 = cast(simd.interleave2_u64x8([z0, z1]));
895                }
896
897                w_idx *= 2;
898            }
899
900            // m = n / 2
901            // t = 1
902            {
903                let w = pulp::as_arrays::<8, _>(&twid[w_idx..]).0;
904                let data = pulp::as_arrays_mut::<8, _>(data).0;
905                let data = pulp::as_arrays_mut::<2, _>(data).0;
906
907                for (z0z1, w1) in zip(data, w) {
908                    let w1 = simd.permute1_u64x8(*w1);
909                    let [mut z0, mut z1] = simd.interleave1_u64x8(cast(*z0z1));
910                    let z1w = P::mul(p_div, simd, z1, w1);
911                    (z0, z1) = (p.add(simd, z0, z1w), p.sub(simd, z0, z1w));
912                    *z0z1 = cast(simd.interleave1_u64x8([z0, z1]));
913                }
914            }
915        }
916    }
917
918    simd.vectorize(Impl {
919        simd,
920        data,
921        p,
922        p_div,
923        twid,
924        recursion_depth,
925        recursion_half,
926    });
927}
928
929#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
930#[cfg(feature = "nightly")]
931pub(crate) fn fwd_depth_first_avx512<P: PrimeModulusV4>(
932    simd: crate::V4,
933    data: &mut [u64],
934    p: P,
935    p_div: P::Div,
936    twid: &[u64],
937    recursion_depth: usize,
938    recursion_half: usize,
939) {
940    struct Impl<'a, P: PrimeModulusV4> {
941        simd: crate::V4,
942        data: &'a mut [u64],
943        p: P,
944        p_div: P::Div,
945        twid: &'a [u64],
946        recursion_depth: usize,
947        recursion_half: usize,
948    }
949    impl<P: PrimeModulusV4> pulp::NullaryFnOnce for Impl<'_, P> {
950        type Output = ();
951
952        #[inline(always)]
953        fn call(self) -> Self::Output {
954            let Self {
955                simd,
956                data,
957                p,
958                p_div,
959                twid,
960                recursion_depth,
961                recursion_half,
962            } = self;
963
964            let n = data.len();
965            debug_assert!(n.is_power_of_two());
966
967            if n <= RECURSION_THRESHOLD {
968                fwd_breadth_first_avx512(
969                    simd,
970                    data,
971                    p,
972                    p_div,
973                    twid,
974                    recursion_depth,
975                    recursion_half,
976                );
977            } else {
978                let t = n / 2;
979                let m = 1;
980                let w_idx = (m << recursion_depth) + m * recursion_half;
981
982                let w = &twid[w_idx..];
983
984                for (data, &w1) in zip(data.chunks_exact_mut(2 * t), w) {
985                    let (z0, z1) = data.split_at_mut(t);
986                    let z0 = pulp::as_arrays_mut::<8, _>(z0).0;
987                    let z1 = pulp::as_arrays_mut::<8, _>(z1).0;
988                    let w1 = simd.splat_u64x8(w1);
989
990                    for (z0_, z1_) in zip(z0, z1) {
991                        let mut z0 = cast(*z0_);
992                        let mut z1 = cast(*z1_);
993                        let z1w = P::mul(p_div, simd, z1, w1);
994                        (z0, z1) = (p.add(simd, z0, z1w), p.sub(simd, z0, z1w));
995                        *z0_ = cast(z0);
996                        *z1_ = cast(z1);
997                    }
998                }
999
1000                let (data0, data1) = data.split_at_mut(n / 2);
1001                fwd_depth_first_avx512(
1002                    simd,
1003                    data0,
1004                    p,
1005                    p_div,
1006                    twid,
1007                    recursion_depth + 1,
1008                    recursion_half * 2,
1009                );
1010                fwd_depth_first_avx512(
1011                    simd,
1012                    data1,
1013                    p,
1014                    p_div,
1015                    twid,
1016                    recursion_depth + 1,
1017                    recursion_half * 2 + 1,
1018                );
1019            }
1020        }
1021    }
1022
1023    simd.vectorize(Impl {
1024        simd,
1025        data,
1026        p,
1027        p_div,
1028        twid,
1029        recursion_depth,
1030        recursion_half,
1031    });
1032}
1033
1034#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1035#[cfg(feature = "nightly")]
1036pub(crate) fn inv_depth_first_avx512<P: PrimeModulusV4>(
1037    simd: crate::V4,
1038    data: &mut [u64],
1039    p: P,
1040    p_div: P::Div,
1041    inv_twid: &[u64],
1042    recursion_depth: usize,
1043    recursion_half: usize,
1044) {
1045    struct Impl<'a, P: PrimeModulusV4> {
1046        simd: crate::V4,
1047        data: &'a mut [u64],
1048        p: P,
1049        p_div: P::Div,
1050        inv_twid: &'a [u64],
1051        recursion_depth: usize,
1052        recursion_half: usize,
1053    }
1054    impl<P: PrimeModulusV4> pulp::NullaryFnOnce for Impl<'_, P> {
1055        type Output = ();
1056
1057        #[inline(always)]
1058        fn call(self) -> Self::Output {
1059            let Self {
1060                simd,
1061                data,
1062                p,
1063                p_div,
1064                inv_twid,
1065                recursion_depth,
1066                recursion_half,
1067            } = self;
1068            let n = data.len();
1069            debug_assert!(n.is_power_of_two());
1070
1071            if n <= RECURSION_THRESHOLD {
1072                inv_breadth_first_avx512(
1073                    simd,
1074                    data,
1075                    p,
1076                    p_div,
1077                    inv_twid,
1078                    recursion_depth,
1079                    recursion_half,
1080                );
1081            } else {
1082                let (data0, data1) = data.split_at_mut(n / 2);
1083                inv_depth_first_avx512(
1084                    simd,
1085                    data0,
1086                    p,
1087                    p_div,
1088                    inv_twid,
1089                    recursion_depth + 1,
1090                    recursion_half * 2,
1091                );
1092                inv_depth_first_avx512(
1093                    simd,
1094                    data1,
1095                    p,
1096                    p_div,
1097                    inv_twid,
1098                    recursion_depth + 1,
1099                    recursion_half * 2 + 1,
1100                );
1101
1102                let t = n / 2;
1103                let m = 1;
1104                let w_idx = (m << recursion_depth) + m * recursion_half;
1105
1106                let w = &inv_twid[w_idx..];
1107
1108                for (data, &w1) in zip(data.chunks_exact_mut(2 * t), w) {
1109                    let (z0, z1) = data.split_at_mut(t);
1110                    let z0 = pulp::as_arrays_mut::<8, _>(z0).0;
1111                    let z1 = pulp::as_arrays_mut::<8, _>(z1).0;
1112                    let w1 = simd.splat_u64x8(w1);
1113
1114                    for (z0_, z1_) in zip(z0, z1) {
1115                        let mut z0 = cast(*z0_);
1116                        let mut z1 = cast(*z1_);
1117                        (z0, z1) = (
1118                            p.add(simd, z0, z1),
1119                            P::mul(p_div, simd, p.sub(simd, z0, z1), w1),
1120                        );
1121                        *z0_ = cast(z0);
1122                        *z1_ = cast(z1);
1123                    }
1124                }
1125            }
1126        }
1127    }
1128
1129    simd.vectorize(Impl {
1130        simd,
1131        data,
1132        p,
1133        p_div,
1134        inv_twid,
1135        recursion_depth,
1136        recursion_half,
1137    });
1138}
1139
1140#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1141pub(crate) fn inv_depth_first_avx2<P: PrimeModulusV3>(
1142    simd: crate::V3,
1143    data: &mut [u64],
1144    p: P,
1145    p_div: P::Div,
1146    inv_twid: &[u64],
1147    recursion_depth: usize,
1148    recursion_half: usize,
1149) {
1150    struct Impl<'a, P: PrimeModulusV3> {
1151        simd: crate::V3,
1152        data: &'a mut [u64],
1153        p: P,
1154        p_div: P::Div,
1155        inv_twid: &'a [u64],
1156        recursion_depth: usize,
1157        recursion_half: usize,
1158    }
1159    impl<P: PrimeModulusV3> pulp::NullaryFnOnce for Impl<'_, P> {
1160        type Output = ();
1161
1162        #[inline(always)]
1163        fn call(self) -> Self::Output {
1164            let Self {
1165                simd,
1166                data,
1167                p,
1168                p_div,
1169                inv_twid,
1170                recursion_depth,
1171                recursion_half,
1172            } = self;
1173            let n = data.len();
1174            debug_assert!(n.is_power_of_two());
1175
1176            if n <= RECURSION_THRESHOLD {
1177                inv_breadth_first_avx2(
1178                    simd,
1179                    data,
1180                    p,
1181                    p_div,
1182                    inv_twid,
1183                    recursion_depth,
1184                    recursion_half,
1185                );
1186            } else {
1187                let (data0, data1) = data.split_at_mut(n / 2);
1188                inv_depth_first_avx2(
1189                    simd,
1190                    data0,
1191                    p,
1192                    p_div,
1193                    inv_twid,
1194                    recursion_depth + 1,
1195                    recursion_half * 2,
1196                );
1197                inv_depth_first_avx2(
1198                    simd,
1199                    data1,
1200                    p,
1201                    p_div,
1202                    inv_twid,
1203                    recursion_depth + 1,
1204                    recursion_half * 2 + 1,
1205                );
1206
1207                let t = n / 2;
1208                let m = 1;
1209                let w_idx = (m << recursion_depth) + m * recursion_half;
1210
1211                let w = &inv_twid[w_idx..];
1212
1213                for (data, &w1) in zip(data.chunks_exact_mut(2 * t), w) {
1214                    let (z0, z1) = data.split_at_mut(t);
1215                    let z0 = pulp::as_arrays_mut::<4, _>(z0).0;
1216                    let z1 = pulp::as_arrays_mut::<4, _>(z1).0;
1217                    let w1 = simd.splat_u64x4(w1);
1218
1219                    for (z0_, z1_) in zip(z0, z1) {
1220                        let mut z0 = cast(*z0_);
1221                        let mut z1 = cast(*z1_);
1222                        (z0, z1) = (
1223                            p.add(simd, z0, z1),
1224                            P::mul(p_div, simd, p.sub(simd, z0, z1), w1),
1225                        );
1226                        *z0_ = cast(z0);
1227                        *z1_ = cast(z1);
1228                    }
1229                }
1230            }
1231        }
1232    }
1233    simd.vectorize(Impl {
1234        simd,
1235        data,
1236        p,
1237        p_div,
1238        inv_twid,
1239        recursion_depth,
1240        recursion_half,
1241    });
1242}
1243
1244#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1245pub(crate) fn fwd_depth_first_avx2<P: PrimeModulusV3>(
1246    simd: crate::V3,
1247    data: &mut [u64],
1248    p: P,
1249    p_div: P::Div,
1250    twid: &[u64],
1251    recursion_depth: usize,
1252    recursion_half: usize,
1253) {
1254    struct Impl<'a, P: PrimeModulusV3> {
1255        simd: crate::V3,
1256        data: &'a mut [u64],
1257        p: P,
1258        p_div: P::Div,
1259        twid: &'a [u64],
1260        recursion_depth: usize,
1261        recursion_half: usize,
1262    }
1263    impl<P: PrimeModulusV3> pulp::NullaryFnOnce for Impl<'_, P> {
1264        type Output = ();
1265
1266        #[inline(always)]
1267        fn call(self) -> Self::Output {
1268            let Self {
1269                simd,
1270                data,
1271                p,
1272                p_div,
1273                twid,
1274                recursion_depth,
1275                recursion_half,
1276            } = self;
1277            let n = data.len();
1278            debug_assert!(n.is_power_of_two());
1279
1280            if n <= RECURSION_THRESHOLD {
1281                fwd_breadth_first_avx2(simd, data, p, p_div, twid, recursion_depth, recursion_half);
1282            } else {
1283                let t = n / 2;
1284                let m = 1;
1285                let w_idx = (m << recursion_depth) + m * recursion_half;
1286
1287                let w = &twid[w_idx..];
1288
1289                for (data, &w1) in zip(data.chunks_exact_mut(2 * t), w) {
1290                    let (z0, z1) = data.split_at_mut(t);
1291                    let z0 = pulp::as_arrays_mut::<4, _>(z0).0;
1292                    let z1 = pulp::as_arrays_mut::<4, _>(z1).0;
1293                    let w1 = simd.splat_u64x4(w1);
1294
1295                    for (z0_, z1_) in zip(z0, z1) {
1296                        let mut z0 = cast(*z0_);
1297                        let mut z1 = cast(*z1_);
1298                        let z1w = P::mul(p_div, simd, z1, w1);
1299                        (z0, z1) = (p.add(simd, z0, z1w), p.sub(simd, z0, z1w));
1300                        *z0_ = cast(z0);
1301                        *z1_ = cast(z1);
1302                    }
1303                }
1304
1305                let (data0, data1) = data.split_at_mut(n / 2);
1306                fwd_depth_first_avx2(
1307                    simd,
1308                    data0,
1309                    p,
1310                    p_div,
1311                    twid,
1312                    recursion_depth + 1,
1313                    recursion_half * 2,
1314                );
1315                fwd_depth_first_avx2(
1316                    simd,
1317                    data1,
1318                    p,
1319                    p_div,
1320                    twid,
1321                    recursion_depth + 1,
1322                    recursion_half * 2 + 1,
1323                );
1324            }
1325        }
1326    }
1327    simd.vectorize(Impl {
1328        simd,
1329        data,
1330        p,
1331        p_div,
1332        twid,
1333        recursion_depth,
1334        recursion_half,
1335    });
1336}
1337
1338pub(crate) fn fwd_depth_first_scalar<P: PrimeModulus>(
1339    data: &mut [u64],
1340    p: P,
1341    p_div: P::Div,
1342    twid: &[u64],
1343    recursion_depth: usize,
1344    recursion_half: usize,
1345) {
1346    let n = data.len();
1347    debug_assert!(n.is_power_of_two());
1348
1349    if n <= RECURSION_THRESHOLD {
1350        fwd_breadth_first_scalar(data, p, p_div, twid, recursion_depth, recursion_half);
1351    } else {
1352        let t = n / 2;
1353        let m = 1;
1354        let w_idx = (m << recursion_depth) + m * recursion_half;
1355
1356        let w = &twid[w_idx..];
1357
1358        for (data, &w1) in zip(data.chunks_exact_mut(2 * t), w) {
1359            let (z0, z1) = data.split_at_mut(t);
1360
1361            for (z0, z1) in zip(z0, z1) {
1362                let z1w = P::mul(p_div, *z1, w1);
1363
1364                (*z0, *z1) = (p.add(*z0, z1w), p.sub(*z0, z1w));
1365            }
1366        }
1367
1368        let (data0, data1) = data.split_at_mut(n / 2);
1369        fwd_depth_first_scalar(
1370            data0,
1371            p,
1372            p_div,
1373            twid,
1374            recursion_depth + 1,
1375            recursion_half * 2,
1376        );
1377        fwd_depth_first_scalar(
1378            data1,
1379            p,
1380            p_div,
1381            twid,
1382            recursion_depth + 1,
1383            recursion_half * 2 + 1,
1384        );
1385    }
1386}
1387
1388#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1389#[cfg(feature = "nightly")]
1390pub(crate) fn fwd_avx512<P: PrimeModulusV4>(
1391    simd: crate::V4,
1392    data: &mut [u64],
1393    p: P,
1394    p_div: P::Div,
1395    twid: &[u64],
1396) {
1397    fwd_depth_first_avx512(simd, data, p, p_div, twid, 0, 0);
1398}
1399
1400#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1401pub(crate) fn fwd_avx2<P: PrimeModulusV3>(
1402    simd: crate::V3,
1403    data: &mut [u64],
1404    p: P,
1405    p_div: P::Div,
1406    twid: &[u64],
1407) {
1408    fwd_depth_first_avx2(simd, data, p, p_div, twid, 0, 0);
1409}
1410
1411pub(crate) fn fwd_scalar<P: PrimeModulus>(data: &mut [u64], p: P, p_div: P::Div, twid: &[u64]) {
1412    fwd_depth_first_scalar(data, p, p_div, twid, 0, 0);
1413}
1414
1415#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1416#[cfg(feature = "nightly")]
1417pub(crate) fn inv_avx512<P: PrimeModulusV4>(
1418    simd: crate::V4,
1419    data: &mut [u64],
1420    p: P,
1421    p_div: P::Div,
1422    twid: &[u64],
1423) {
1424    inv_depth_first_avx512(simd, data, p, p_div, twid, 0, 0);
1425}
1426
1427#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1428pub(crate) fn inv_avx2<P: PrimeModulusV3>(
1429    simd: crate::V3,
1430    data: &mut [u64],
1431    p: P,
1432    p_div: P::Div,
1433    twid: &[u64],
1434) {
1435    inv_depth_first_avx2(simd, data, p, p_div, twid, 0, 0);
1436}
1437
1438pub(crate) fn inv_scalar<P: PrimeModulus>(data: &mut [u64], p: P, p_div: P::Div, twid: &[u64]) {
1439    inv_depth_first_scalar(data, p, p_div, twid, 0, 0);
1440}
1441
1442#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1443#[cfg(feature = "nightly")]
1444pub(crate) fn inv_breadth_first_avx512<P: PrimeModulusV4>(
1445    simd: crate::V4,
1446    data: &mut [u64],
1447    p: P,
1448    p_div: P::Div,
1449    inv_twid: &[u64],
1450    recursion_depth: usize,
1451    recursion_half: usize,
1452) {
1453    struct Impl<'a, P: PrimeModulusV4> {
1454        simd: crate::V4,
1455        data: &'a mut [u64],
1456        p: P,
1457        p_div: P::Div,
1458        inv_twid: &'a [u64],
1459        recursion_depth: usize,
1460        recursion_half: usize,
1461    }
1462    impl<P: PrimeModulusV4> pulp::NullaryFnOnce for Impl<'_, P> {
1463        type Output = ();
1464
1465        #[inline(always)]
1466        fn call(self) -> Self::Output {
1467            let Self {
1468                simd,
1469                data,
1470                p,
1471                p_div,
1472                inv_twid,
1473                recursion_depth,
1474                recursion_half,
1475            } = self;
1476
1477            let n = data.len();
1478            debug_assert!(n.is_power_of_two());
1479
1480            let mut t = 1;
1481            let mut m = n;
1482            let mut w_idx = (m << recursion_depth) + recursion_half * m;
1483
1484            // m = n / 2
1485            // t = 1
1486            {
1487                m /= 2;
1488                w_idx /= 2;
1489
1490                let w = pulp::as_arrays::<8, _>(&inv_twid[w_idx..]).0;
1491                let data = pulp::as_arrays_mut::<8, _>(data).0;
1492                let data = pulp::as_arrays_mut::<2, _>(data).0;
1493
1494                for (z0z1, w1) in zip(data, w) {
1495                    let w1 = simd.permute1_u64x8(*w1);
1496                    let [mut z0, mut z1] = simd.interleave1_u64x8(cast(*z0z1));
1497                    (z0, z1) = (
1498                        p.add(simd, z0, z1),
1499                        P::mul(p_div, simd, p.sub(simd, z0, z1), w1),
1500                    );
1501                    *z0z1 = cast(simd.interleave1_u64x8([z0, z1]));
1502                }
1503
1504                t *= 2;
1505            }
1506
1507            // m = n / 4
1508            // t = 2
1509            {
1510                m /= 2;
1511                w_idx /= 2;
1512
1513                let w = pulp::as_arrays::<4, _>(&inv_twid[w_idx..]).0;
1514                let data = pulp::as_arrays_mut::<8, _>(data).0;
1515                let data = pulp::as_arrays_mut::<2, _>(data).0;
1516
1517                for (z0z0z1z1, w1) in zip(data, w) {
1518                    let w1 = simd.permute2_u64x8(*w1);
1519                    let [mut z0, mut z1] = simd.interleave2_u64x8(cast(*z0z0z1z1));
1520                    (z0, z1) = (
1521                        p.add(simd, z0, z1),
1522                        P::mul(p_div, simd, p.sub(simd, z0, z1), w1),
1523                    );
1524                    *z0z0z1z1 = cast(simd.interleave2_u64x8([z0, z1]));
1525                }
1526
1527                t *= 2;
1528            }
1529
1530            // m = n / 8
1531            // t = 4
1532            {
1533                m /= 2;
1534                w_idx /= 2;
1535
1536                let w = pulp::as_arrays::<2, _>(&inv_twid[w_idx..]).0;
1537                let data = pulp::as_arrays_mut::<8, _>(data).0;
1538                let data = pulp::as_arrays_mut::<2, _>(data).0;
1539
1540                for (z0z0z0z0z1z1z1z1, w1) in zip(data, w) {
1541                    let w1 = simd.permute4_u64x8(*w1);
1542                    let [mut z0, mut z1] = simd.interleave4_u64x8(cast(*z0z0z0z0z1z1z1z1));
1543                    (z0, z1) = (
1544                        p.add(simd, z0, z1),
1545                        P::mul(p_div, simd, p.sub(simd, z0, z1), w1),
1546                    );
1547                    *z0z0z0z0z1z1z1z1 = cast(simd.interleave4_u64x8([z0, z1]));
1548                }
1549
1550                t *= 2;
1551            }
1552
1553            while m > 1 {
1554                m /= 2;
1555                w_idx /= 2;
1556
1557                let w = &inv_twid[w_idx..];
1558
1559                for (data, &w1) in zip(data.chunks_exact_mut(2 * t), w) {
1560                    let (z0, z1) = data.split_at_mut(t);
1561                    let z0 = pulp::as_arrays_mut::<8, _>(z0).0;
1562                    let z1 = pulp::as_arrays_mut::<8, _>(z1).0;
1563                    let w1 = simd.splat_u64x8(w1);
1564
1565                    for (z0_, z1_) in zip(z0, z1) {
1566                        let mut z0 = cast(*z0_);
1567                        let mut z1 = cast(*z1_);
1568                        (z0, z1) = (
1569                            p.add(simd, z0, z1),
1570                            P::mul(p_div, simd, p.sub(simd, z0, z1), w1),
1571                        );
1572                        *z0_ = cast(z0);
1573                        *z1_ = cast(z1);
1574                    }
1575                }
1576
1577                t *= 2;
1578            }
1579        }
1580    }
1581
1582    simd.vectorize(Impl {
1583        simd,
1584        data,
1585        p,
1586        p_div,
1587        inv_twid,
1588        recursion_depth,
1589        recursion_half,
1590    });
1591}
1592
1593#[cfg(test)]
1594mod tests {
1595    use super::*;
1596    use crate::prime64::{
1597        init_negacyclic_twiddles,
1598        tests::{mul, random_lhs_rhs_with_negacyclic_convolution},
1599    };
1600    use alloc::vec;
1601
1602    extern crate alloc;
1603
1604    #[test]
1605    fn test_product() {
1606        for n in [2, 4, 8, 16, 32, 64, 128, 256, 512, 1024] {
1607            let p = Solinas::P;
1608
1609            let (lhs, rhs, negacyclic_convolution) =
1610                random_lhs_rhs_with_negacyclic_convolution(n, p);
1611
1612            let mut twid = vec![0u64; n];
1613            let mut inv_twid = vec![0u64; n];
1614            init_negacyclic_twiddles(p, n, &mut twid, &mut inv_twid);
1615
1616            let mut prod = vec![0u64; n];
1617            let mut lhs_fourier = lhs.clone();
1618            let mut rhs_fourier = rhs.clone();
1619
1620            fwd_breadth_first_scalar(&mut lhs_fourier, p, Div64::new(p), &twid, 0, 0);
1621            fwd_breadth_first_scalar(&mut rhs_fourier, p, Div64::new(p), &twid, 0, 0);
1622
1623            for i in 0..n {
1624                prod[i] = mul(p, lhs_fourier[i], rhs_fourier[i]);
1625            }
1626
1627            inv_breadth_first_scalar(&mut prod, p, Div64::new(p), &inv_twid, 0, 0);
1628            let result = prod;
1629
1630            for i in 0..n {
1631                assert_eq!(result[i], mul(p, negacyclic_convolution[i], n as u64));
1632            }
1633        }
1634    }
1635
1636    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1637    #[test]
1638    fn test_product_avx2() {
1639        if let Some(simd) = crate::V3::try_new() {
1640            for n in [8, 16, 32, 64, 128, 256, 512, 1024] {
1641                let p = Solinas::P;
1642
1643                let (lhs, rhs, negacyclic_convolution) =
1644                    random_lhs_rhs_with_negacyclic_convolution(n, p);
1645
1646                let mut twid = vec![0u64; n];
1647                let mut inv_twid = vec![0u64; n];
1648                init_negacyclic_twiddles(p, n, &mut twid, &mut inv_twid);
1649
1650                let mut prod = vec![0u64; n];
1651                let mut lhs_fourier = lhs.clone();
1652                let mut rhs_fourier = rhs.clone();
1653
1654                let crate::u256 { x0, x1, x2, x3 } = Div64::new(p).double_reciprocal;
1655                fwd_breadth_first_avx2(simd, &mut lhs_fourier, p, (p, x0, x1, x2, x3), &twid, 0, 0);
1656                fwd_breadth_first_avx2(simd, &mut rhs_fourier, p, (p, x0, x1, x2, x3), &twid, 0, 0);
1657
1658                for i in 0..n {
1659                    prod[i] = mul(p, lhs_fourier[i], rhs_fourier[i]);
1660                }
1661
1662                inv_breadth_first_avx2(simd, &mut prod, p, (p, x0, x1, x2, x3), &inv_twid, 0, 0);
1663                let result = prod;
1664
1665                for i in 0..n {
1666                    assert_eq!(result[i], mul(p, negacyclic_convolution[i], n as u64));
1667                }
1668            }
1669        }
1670    }
1671
1672    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1673    #[cfg(feature = "nightly")]
1674    #[test]
1675    fn test_product_avx512() {
1676        if let Some(simd) = crate::V4::try_new() {
1677            for n in [16, 32, 64, 128, 256, 512, 1024] {
1678                let p = Solinas::P;
1679
1680                let (lhs, rhs, negacyclic_convolution) =
1681                    random_lhs_rhs_with_negacyclic_convolution(n, p);
1682
1683                let mut twid = vec![0u64; n];
1684                let mut inv_twid = vec![0u64; n];
1685                init_negacyclic_twiddles(p, n, &mut twid, &mut inv_twid);
1686
1687                let mut prod = vec![0u64; n];
1688                let mut lhs_fourier = lhs.clone();
1689                let mut rhs_fourier = rhs.clone();
1690
1691                let crate::u256 { x0, x1, x2, x3 } = Div64::new(p).double_reciprocal;
1692                fwd_breadth_first_avx512(
1693                    simd,
1694                    &mut lhs_fourier,
1695                    p,
1696                    (p, x0, x1, x2, x3),
1697                    &twid,
1698                    0,
1699                    0,
1700                );
1701                fwd_breadth_first_avx512(
1702                    simd,
1703                    &mut rhs_fourier,
1704                    p,
1705                    (p, x0, x1, x2, x3),
1706                    &twid,
1707                    0,
1708                    0,
1709                );
1710
1711                for i in 0..n {
1712                    prod[i] = mul(p, lhs_fourier[i], rhs_fourier[i]);
1713                }
1714
1715                inv_breadth_first_avx512(simd, &mut prod, p, (p, x0, x1, x2, x3), &inv_twid, 0, 0);
1716                let result = prod;
1717
1718                for i in 0..n {
1719                    assert_eq!(result[i], mul(p, negacyclic_convolution[i], n as u64));
1720                }
1721            }
1722        }
1723    }
1724
1725    #[test]
1726    fn test_product_solinas() {
1727        for n in [2, 4, 8, 16, 32, 64, 128, 256, 512, 1024] {
1728            let p = Solinas::P;
1729
1730            let (lhs, rhs, negacyclic_convolution) =
1731                random_lhs_rhs_with_negacyclic_convolution(n, p);
1732
1733            let mut twid = vec![0u64; n];
1734            let mut inv_twid = vec![0u64; n];
1735            init_negacyclic_twiddles(p, n, &mut twid, &mut inv_twid);
1736
1737            let mut prod = vec![0u64; n];
1738            let mut lhs_fourier = lhs.clone();
1739            let mut rhs_fourier = rhs.clone();
1740
1741            fwd_breadth_first_scalar(&mut lhs_fourier, Solinas, (), &twid, 0, 0);
1742            fwd_breadth_first_scalar(&mut rhs_fourier, Solinas, (), &twid, 0, 0);
1743
1744            for i in 0..n {
1745                prod[i] = mul(p, lhs_fourier[i], rhs_fourier[i]);
1746            }
1747
1748            inv_breadth_first_scalar(&mut prod, Solinas, (), &inv_twid, 0, 0);
1749            let result = prod;
1750
1751            for i in 0..n {
1752                assert_eq!(result[i], mul(p, negacyclic_convolution[i], n as u64));
1753            }
1754        }
1755    }
1756
1757    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1758    #[test]
1759    fn test_product_solinas_avx2() {
1760        if let Some(simd) = crate::V3::try_new() {
1761            for n in [8, 16, 32, 64, 128, 256, 512, 1024] {
1762                let p = Solinas::P;
1763
1764                let (lhs, rhs, negacyclic_convolution) =
1765                    random_lhs_rhs_with_negacyclic_convolution(n, p);
1766
1767                let mut twid = vec![0u64; n];
1768                let mut inv_twid = vec![0u64; n];
1769                init_negacyclic_twiddles(p, n, &mut twid, &mut inv_twid);
1770
1771                let mut prod = vec![0u64; n];
1772                let mut lhs_fourier = lhs.clone();
1773                let mut rhs_fourier = rhs.clone();
1774
1775                fwd_breadth_first_avx2(simd, &mut lhs_fourier, Solinas, (), &twid, 0, 0);
1776                fwd_breadth_first_avx2(simd, &mut rhs_fourier, Solinas, (), &twid, 0, 0);
1777
1778                for i in 0..n {
1779                    prod[i] = mul(p, lhs_fourier[i], rhs_fourier[i]);
1780                }
1781
1782                inv_breadth_first_avx2(simd, &mut prod, Solinas, (), &inv_twid, 0, 0);
1783                let result = prod;
1784
1785                for i in 0..n {
1786                    assert_eq!(result[i], mul(p, negacyclic_convolution[i], n as u64));
1787                }
1788            }
1789        }
1790    }
1791
1792    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1793    #[cfg(feature = "nightly")]
1794    #[test]
1795    fn test_product_solinas_avx512() {
1796        if let Some(simd) = crate::V4::try_new() {
1797            for n in [16, 32, 64, 128, 256, 512, 1024] {
1798                let p = Solinas::P;
1799
1800                let (lhs, rhs, negacyclic_convolution) =
1801                    random_lhs_rhs_with_negacyclic_convolution(n, p);
1802
1803                let mut twid = vec![0u64; n];
1804                let mut inv_twid = vec![0u64; n];
1805                init_negacyclic_twiddles(p, n, &mut twid, &mut inv_twid);
1806
1807                let mut prod = vec![0u64; n];
1808                let mut lhs_fourier = lhs.clone();
1809                let mut rhs_fourier = rhs.clone();
1810
1811                fwd_breadth_first_avx512(simd, &mut lhs_fourier, Solinas, (), &twid, 0, 0);
1812                fwd_breadth_first_avx512(simd, &mut rhs_fourier, Solinas, (), &twid, 0, 0);
1813
1814                for i in 0..n {
1815                    prod[i] = mul(p, lhs_fourier[i], rhs_fourier[i]);
1816                }
1817
1818                inv_breadth_first_avx512(simd, &mut prod, Solinas, (), &inv_twid, 0, 0);
1819                let result = prod;
1820
1821                for i in 0..n {
1822                    assert_eq!(result[i], mul(p, negacyclic_convolution[i], n as u64));
1823                }
1824            }
1825        }
1826    }
1827}