tfhe_ntt/
prime64.rs

1use crate::{bit_rev, fastdiv::Div64, prime::is_prime64, roots::find_primitive_root64};
2use aligned_vec::{avec, ABox};
3
4#[allow(unused_imports)]
5use pulp::*;
6
7const RECURSION_THRESHOLD: usize = 1024;
8pub(crate) const SOLINAS_PRIME: u64 = ((1_u128 << 64) - (1_u128 << 32) + 1) as u64;
9
10mod generic_solinas;
11mod shoup;
12
13#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
14#[cfg(feature = "nightly")]
15mod less_than_50bit;
16#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
17#[cfg(feature = "nightly")]
18mod less_than_51bit;
19
20mod less_than_62bit;
21mod less_than_63bit;
22
23use self::generic_solinas::PrimeModulus;
24use crate::roots::find_root_solinas_64;
25pub use generic_solinas::Solinas;
26
27#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
28impl crate::V3 {
29    #[inline(always)]
30    fn interleave2_u64x4(self, z0z0z1z1: [u64x4; 2]) -> [u64x4; 2] {
31        let avx = self.avx;
32        [
33            cast(
34                avx._mm256_permute2f128_si256::<0b0010_0000>(cast(z0z0z1z1[0]), cast(z0z0z1z1[1])),
35            ),
36            cast(
37                avx._mm256_permute2f128_si256::<0b0011_0001>(cast(z0z0z1z1[0]), cast(z0z0z1z1[1])),
38            ),
39        ]
40    }
41
42    #[inline(always)]
43    fn permute2_u64x4(self, w: [u64; 2]) -> u64x4 {
44        let avx = self.avx;
45        let w00 = self.sse2._mm_set1_epi64x(w[0] as _);
46        let w11 = self.sse2._mm_set1_epi64x(w[1] as _);
47        cast(avx._mm256_insertf128_si256::<0b1>(avx._mm256_castsi128_si256(w00), w11))
48    }
49
50    #[inline(always)]
51    fn interleave1_u64x4(self, z0z1: [u64x4; 2]) -> [u64x4; 2] {
52        let avx = self.avx2;
53        [
54            cast(avx._mm256_unpacklo_epi64(cast(z0z1[0]), cast(z0z1[1]))),
55            cast(avx._mm256_unpackhi_epi64(cast(z0z1[0]), cast(z0z1[1]))),
56        ]
57    }
58
59    #[inline(always)]
60    fn permute1_u64x4(self, w: [u64; 4]) -> u64x4 {
61        let avx = self.avx;
62        let w0123 = pulp::cast(w);
63        let w0101 = avx._mm256_permute2f128_si256::<0b0000_0000>(w0123, w0123);
64        let w2323 = avx._mm256_permute2f128_si256::<0b0011_0011>(w0123, w0123);
65        cast(avx._mm256_castpd_si256(avx._mm256_shuffle_pd::<0b1100>(
66            avx._mm256_castsi256_pd(w0101),
67            avx._mm256_castsi256_pd(w2323),
68        )))
69    }
70
71    #[inline(always)]
72    pub fn small_mod_u64x4(self, modulus: u64x4, x: u64x4) -> u64x4 {
73        self.select_u64x4(
74            self.cmp_gt_u64x4(modulus, x),
75            x,
76            self.wrapping_sub_u64x4(x, modulus),
77        )
78    }
79}
80
81#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
82#[cfg(feature = "nightly")]
83impl crate::V4 {
84    #[inline(always)]
85    fn interleave4_u64x8(self, z0z0z0z0z1z1z1z1: [u64x8; 2]) -> [u64x8; 2] {
86        let avx = self.avx512f;
87        let idx_0 = avx._mm512_setr_epi64(0x0, 0x1, 0x2, 0x3, 0x8, 0x9, 0xa, 0xb);
88        let idx_1 = avx._mm512_setr_epi64(0x4, 0x5, 0x6, 0x7, 0xc, 0xd, 0xe, 0xf);
89        [
90            cast(avx._mm512_permutex2var_epi64(
91                cast(z0z0z0z0z1z1z1z1[0]),
92                idx_0,
93                cast(z0z0z0z0z1z1z1z1[1]),
94            )),
95            cast(avx._mm512_permutex2var_epi64(
96                cast(z0z0z0z0z1z1z1z1[0]),
97                idx_1,
98                cast(z0z0z0z0z1z1z1z1[1]),
99            )),
100        ]
101    }
102
103    #[inline(always)]
104    fn permute4_u64x8(self, w: [u64; 2]) -> u64x8 {
105        let avx = self.avx512f;
106        let w = pulp::cast(w);
107        let w01xxxxxx = avx._mm512_castsi128_si512(w);
108        let idx = avx._mm512_setr_epi64(0, 0, 0, 0, 1, 1, 1, 1);
109        cast(avx._mm512_permutexvar_epi64(idx, w01xxxxxx))
110    }
111
112    #[inline(always)]
113    fn interleave2_u64x8(self, z0z0z1z1: [u64x8; 2]) -> [u64x8; 2] {
114        let avx = self.avx512f;
115        let idx_0 = avx._mm512_setr_epi64(0x0, 0x1, 0x8, 0x9, 0x4, 0x5, 0xc, 0xd);
116        let idx_1 = avx._mm512_setr_epi64(0x2, 0x3, 0xa, 0xb, 0x6, 0x7, 0xe, 0xf);
117        [
118            cast(avx._mm512_permutex2var_epi64(cast(z0z0z1z1[0]), idx_0, cast(z0z0z1z1[1]))),
119            cast(avx._mm512_permutex2var_epi64(cast(z0z0z1z1[0]), idx_1, cast(z0z0z1z1[1]))),
120        ]
121    }
122
123    #[inline(always)]
124    fn permute2_u64x8(self, w: [u64; 4]) -> u64x8 {
125        let avx = self.avx512f;
126        let w = pulp::cast(w);
127        let w0123xxxx = avx._mm512_castsi256_si512(w);
128        let idx = avx._mm512_setr_epi64(0, 0, 2, 2, 1, 1, 3, 3);
129        cast(avx._mm512_permutexvar_epi64(idx, w0123xxxx))
130    }
131
132    #[inline(always)]
133    fn interleave1_u64x8(self, z0z1: [u64x8; 2]) -> [u64x8; 2] {
134        let avx = self.avx512f;
135        [
136            cast(avx._mm512_unpacklo_epi64(cast(z0z1[0]), cast(z0z1[1]))),
137            cast(avx._mm512_unpackhi_epi64(cast(z0z1[0]), cast(z0z1[1]))),
138        ]
139    }
140
141    #[inline(always)]
142    fn permute1_u64x8(self, w: [u64; 8]) -> u64x8 {
143        let avx = self.avx512f;
144        let w = pulp::cast(w);
145        let idx = avx._mm512_setr_epi64(0, 4, 1, 5, 2, 6, 3, 7);
146        cast(avx._mm512_permutexvar_epi64(idx, w))
147    }
148
149    #[inline(always)]
150    pub fn small_mod_u64x8(self, modulus: u64x8, x: u64x8) -> u64x8 {
151        self.select_u64x8(
152            self.cmp_gt_u64x8(modulus, x),
153            x,
154            self.wrapping_sub_u64x8(x, modulus),
155        )
156    }
157}
158
159fn init_negacyclic_twiddles(p: u64, n: usize, twid: &mut [u64], inv_twid: &mut [u64]) {
160    let div = Div64::new(p);
161
162    let w = if p == SOLINAS_PRIME {
163        // Used custom root-of-unity with Goldilocks prime
164        // Those root-of-unity enable generation of friendly twiddle will low hamming weight
165        // and enable replacement of multiplication with simple shift
166        match n {
167            32 => 8_u64,
168            64 => 2198989700608_u64,
169            128 => 14041890976876060974_u64,
170            256 => 14430643036723656017_u64,
171            512 => 4440654710286119610_u64,
172            1024 => 8816101479115663336_u64,
173            2048 => 10974926054405199669_u64,
174            4096 => 1206500561358145487_u64,
175            8192 => 10930245224889659871_u64,
176            16384 => 3333600369887534767_u64,
177            32768 => 15893793146607301539_u64,
178            _ => find_root_solinas_64(div, 2 * n as u64).unwrap(),
179        }
180    } else {
181        find_primitive_root64(div, 2 * n as u64).unwrap()
182    };
183
184    let mut k = 0;
185    let mut wk = 1u64;
186
187    let nbits = n.trailing_zeros();
188    while k < n {
189        let fwd_idx = bit_rev(nbits, k);
190
191        twid[fwd_idx] = wk;
192
193        let inv_idx = bit_rev(nbits, (n - k) % n);
194        if k == 0 {
195            inv_twid[inv_idx] = wk;
196        } else {
197            let x = p.wrapping_sub(wk);
198            inv_twid[inv_idx] = x;
199        }
200
201        wk = Div64::rem_u128(wk as u128 * w as u128, div);
202        k += 1;
203    }
204}
205
206fn init_negacyclic_twiddles_shoup(
207    p: u64,
208    n: usize,
209    max_bits: u32,
210    twid: &mut [u64],
211    twid_shoup: &mut [u64],
212    inv_twid: &mut [u64],
213    inv_twid_shoup: &mut [u64],
214) {
215    let div = Div64::new(p);
216    let w = find_primitive_root64(div, 2 * n as u64).unwrap();
217    let mut k = 0;
218    let mut wk = 1u64;
219
220    let nbits = n.trailing_zeros();
221    while k < n {
222        let fwd_idx = bit_rev(nbits, k);
223
224        let wk_shoup = Div64::div_u128((wk as u128) << max_bits, div) as u64;
225        twid[fwd_idx] = wk;
226        twid_shoup[fwd_idx] = wk_shoup;
227
228        let inv_idx = bit_rev(nbits, (n - k) % n);
229        if k == 0 {
230            inv_twid[inv_idx] = wk;
231            inv_twid_shoup[inv_idx] = wk_shoup;
232        } else {
233            let x = p.wrapping_sub(wk);
234            inv_twid[inv_idx] = x;
235            inv_twid_shoup[inv_idx] = Div64::div_u128((x as u128) << max_bits, div) as u64;
236        }
237
238        wk = Div64::rem_u128(wk as u128 * w as u128, div);
239        k += 1;
240    }
241}
242
243/// Negacyclic NTT plan for 64bit primes.
244#[derive(Clone)]
245pub struct Plan {
246    twid: ABox<[u64]>,
247    twid_shoup: ABox<[u64]>,
248    inv_twid: ABox<[u64]>,
249    inv_twid_shoup: ABox<[u64]>,
250    p: u64,
251    p_div: Div64,
252    use_ifma: bool,
253    can_use_fast_reduction_code: bool,
254
255    // used for elementwise product
256    p_barrett: u64,
257    big_q: u64,
258
259    n_inv_mod_p: u64,
260    n_inv_mod_p_shoup: u64,
261}
262
263impl core::fmt::Debug for Plan {
264    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
265        f.debug_struct("Plan")
266            .field("ntt_size", &self.ntt_size())
267            .field("modulus", &self.modulus())
268            .finish()
269    }
270}
271
272#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
273#[cfg(feature = "nightly")]
274fn mul_assign_normalize_ifma(
275    simd: crate::V4IFma,
276    lhs: &mut [u64],
277    rhs: &[u64],
278    p: u64,
279    p_barrett: u64,
280    big_q: u64,
281    n_inv_mod_p: u64,
282    n_inv_mod_p_shoup: u64,
283) {
284    simd.vectorize(
285        #[inline(always)]
286        || {
287            let lhs = pulp::as_arrays_mut::<8, _>(lhs).0;
288            let rhs = pulp::as_arrays::<8, _>(rhs).0;
289
290            let big_q_m1 = simd.splat_u64x8(big_q - 1);
291            let big_q_m1_complement = simd.splat_u64x8(52 - (big_q - 1));
292            let n_inv_mod_p = simd.splat_u64x8(n_inv_mod_p);
293            let n_inv_mod_p_shoup = simd.splat_u64x8(n_inv_mod_p_shoup);
294            let p_barrett = simd.splat_u64x8(p_barrett);
295            let neg_p = simd.splat_u64x8(p.wrapping_neg());
296            let p = simd.splat_u64x8(p);
297            let zero = simd.splat_u64x8(0);
298
299            for (lhs_, rhs) in crate::izip!(lhs, rhs) {
300                let lhs = cast(*lhs_);
301                let rhs = cast(*rhs);
302
303                // lhs × rhs
304                let (lo, hi) = simd.widening_mul_u52x8(lhs, rhs);
305                let c1 = simd.or_u64x8(
306                    simd.shr_dyn_u64x8(lo, big_q_m1),
307                    simd.shl_dyn_u64x8(hi, big_q_m1_complement),
308                );
309                let c3 = simd.widening_mul_u52x8(c1, p_barrett).1;
310                // lo - p * c3
311                let prod = simd.wrapping_mul_add_u52x8(neg_p, c3, lo);
312
313                // normalization
314                let shoup_q = simd.widening_mul_u52x8(prod, n_inv_mod_p_shoup).1;
315                let t = simd.wrapping_mul_add_u52x8(
316                    shoup_q,
317                    neg_p,
318                    simd.wrapping_mul_add_u52x8(prod, n_inv_mod_p, zero),
319                );
320
321                *lhs_ = cast(simd.small_mod_u64x8(p, t));
322            }
323        },
324    )
325}
326
327#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
328#[cfg(feature = "nightly")]
329fn mul_accumulate_ifma(
330    simd: crate::V4IFma,
331    acc: &mut [u64],
332    lhs: &[u64],
333    rhs: &[u64],
334    p: u64,
335    p_barrett: u64,
336    big_q: u64,
337) {
338    simd.vectorize(
339        #[inline(always)]
340        || {
341            let acc = pulp::as_arrays_mut::<8, _>(acc).0;
342            let lhs = pulp::as_arrays::<8, _>(lhs).0;
343            let rhs = pulp::as_arrays::<8, _>(rhs).0;
344
345            let big_q_m1 = simd.splat_u64x8(big_q - 1);
346            let big_q_m1_complement = simd.splat_u64x8(52 - (big_q - 1));
347            let p_barrett = simd.splat_u64x8(p_barrett);
348            let neg_p = simd.splat_u64x8(p.wrapping_neg());
349            let p = simd.splat_u64x8(p);
350
351            for (acc, lhs, rhs) in crate::izip!(acc, lhs, rhs) {
352                let lhs = cast(*lhs);
353                let rhs = cast(*rhs);
354
355                // lhs × rhs
356                let (lo, hi) = simd.widening_mul_u52x8(lhs, rhs);
357                let c1 = simd.or_u64x8(
358                    simd.shr_dyn_u64x8(lo, big_q_m1),
359                    simd.shl_dyn_u64x8(hi, big_q_m1_complement),
360                );
361                let c3 = simd.widening_mul_u52x8(c1, p_barrett).1;
362                // lo - p * c3
363                let prod = simd.wrapping_mul_add_u52x8(neg_p, c3, lo);
364                let prod = simd.small_mod_u64x8(p, prod);
365
366                *acc = cast(simd.small_mod_u64x8(p, simd.wrapping_add_u64x8(prod, cast(*acc))));
367            }
368        },
369    )
370}
371
372#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
373#[cfg(feature = "nightly")]
374fn mul_assign_normalize_avx512(
375    simd: crate::V4,
376    lhs: &mut [u64],
377    rhs: &[u64],
378    p: u64,
379    p_barrett: u64,
380    big_q: u64,
381    n_inv_mod_p: u64,
382    n_inv_mod_p_shoup: u64,
383) {
384    simd.vectorize(
385        #[inline(always)]
386        move || {
387            let lhs = pulp::as_arrays_mut::<8, _>(lhs).0;
388            let rhs = pulp::as_arrays::<8, _>(rhs).0;
389
390            let big_q_m1 = simd.splat_u64x8(big_q - 1);
391            let big_q_m1_complement = simd.splat_u64x8(64 - (big_q - 1));
392            let n_inv_mod_p = simd.splat_u64x8(n_inv_mod_p);
393            let n_inv_mod_p_shoup = simd.splat_u64x8(n_inv_mod_p_shoup);
394            let p_barrett = simd.splat_u64x8(p_barrett);
395            let p = simd.splat_u64x8(p);
396
397            for (lhs_, rhs) in crate::izip!(lhs, rhs) {
398                let lhs = cast(*lhs_);
399                let rhs = cast(*rhs);
400
401                // lhs × rhs
402                let (lo, hi) = simd.widening_mul_u64x8(lhs, rhs);
403                let c1 = simd.or_u64x8(
404                    simd.shr_dyn_u64x8(lo, big_q_m1),
405                    simd.shl_dyn_u64x8(hi, big_q_m1_complement),
406                );
407                let c3 = simd.widening_mul_u64x8(c1, p_barrett).1;
408                let prod = simd.wrapping_sub_u64x8(lo, simd.wrapping_mul_u64x8(p, c3));
409
410                // normalization
411                let shoup_q = simd.widening_mul_u64x8(prod, n_inv_mod_p_shoup).1;
412                let t = simd.wrapping_sub_u64x8(
413                    simd.wrapping_mul_u64x8(prod, n_inv_mod_p),
414                    simd.wrapping_mul_u64x8(shoup_q, p),
415                );
416
417                *lhs_ = cast(simd.small_mod_u64x8(p, t));
418            }
419        },
420    );
421}
422
423#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
424#[cfg(feature = "nightly")]
425fn mul_accumulate_avx512(
426    simd: crate::V4,
427    acc: &mut [u64],
428    lhs: &[u64],
429    rhs: &[u64],
430    p: u64,
431    p_barrett: u64,
432    big_q: u64,
433) {
434    simd.vectorize(
435        #[inline(always)]
436        || {
437            let acc = pulp::as_arrays_mut::<8, _>(acc).0;
438            let lhs = pulp::as_arrays::<8, _>(lhs).0;
439            let rhs = pulp::as_arrays::<8, _>(rhs).0;
440
441            let big_q_m1 = simd.splat_u64x8(big_q - 1);
442            let big_q_m1_complement = simd.splat_u64x8(64 - (big_q - 1));
443            let p_barrett = simd.splat_u64x8(p_barrett);
444            let p = simd.splat_u64x8(p);
445
446            for (acc, lhs, rhs) in crate::izip!(acc, lhs, rhs) {
447                let lhs = cast(*lhs);
448                let rhs = cast(*rhs);
449
450                // lhs × rhs
451                let (lo, hi) = simd.widening_mul_u64x8(lhs, rhs);
452                let c1 = simd.or_u64x8(
453                    simd.shr_dyn_u64x8(lo, big_q_m1),
454                    simd.shl_dyn_u64x8(hi, big_q_m1_complement),
455                );
456                let c3 = simd.widening_mul_u64x8(c1, p_barrett).1;
457                // lo - p * c3
458                let prod = simd.wrapping_sub_u64x8(lo, simd.wrapping_mul_u64x8(p, c3));
459                let prod = simd.small_mod_u64x8(p, prod);
460
461                *acc = cast(simd.small_mod_u64x8(p, simd.wrapping_add_u64x8(prod, cast(*acc))));
462            }
463        },
464    )
465}
466
467#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
468fn mul_assign_normalize_avx2(
469    simd: crate::V3,
470    lhs: &mut [u64],
471    rhs: &[u64],
472    p: u64,
473    p_barrett: u64,
474    big_q: u64,
475    n_inv_mod_p: u64,
476    n_inv_mod_p_shoup: u64,
477) {
478    simd.vectorize(
479        #[inline(always)]
480        move || {
481            let lhs = pulp::as_arrays_mut::<4, _>(lhs).0;
482            let rhs = pulp::as_arrays::<4, _>(rhs).0;
483            let big_q_m1 = simd.splat_u64x4(big_q - 1);
484            let big_q_m1_complement = simd.splat_u64x4(64 - (big_q - 1));
485            let n_inv_mod_p = simd.splat_u64x4(n_inv_mod_p);
486            let n_inv_mod_p_shoup = simd.splat_u64x4(n_inv_mod_p_shoup);
487            let p_barrett = simd.splat_u64x4(p_barrett);
488            let p = simd.splat_u64x4(p);
489
490            for (lhs_, rhs) in crate::izip!(lhs, rhs) {
491                let lhs = cast(*lhs_);
492                let rhs = cast(*rhs);
493
494                // lhs × rhs
495                let (lo, hi) = simd.widening_mul_u64x4(lhs, rhs);
496                let c1 = simd.or_u64x4(
497                    simd.shr_dyn_u64x4(lo, big_q_m1),
498                    simd.shl_dyn_u64x4(hi, big_q_m1_complement),
499                );
500                let c3 = simd.widening_mul_u64x4(c1, p_barrett).1;
501                let prod = simd.wrapping_sub_u64x4(lo, simd.widening_mul_u64x4(p, c3).0);
502
503                // normalization
504                let shoup_q = simd.widening_mul_u64x4(prod, n_inv_mod_p_shoup).1;
505                let t = simd.wrapping_sub_u64x4(
506                    simd.widening_mul_u64x4(prod, n_inv_mod_p).0,
507                    simd.widening_mul_u64x4(shoup_q, p).0,
508                );
509
510                *lhs_ = cast(simd.small_mod_u64x4(p, t));
511            }
512        },
513    );
514}
515
516#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
517fn mul_accumulate_avx2(
518    simd: crate::V3,
519    acc: &mut [u64],
520    lhs: &[u64],
521    rhs: &[u64],
522    p: u64,
523    p_barrett: u64,
524    big_q: u64,
525) {
526    simd.vectorize(
527        #[inline(always)]
528        || {
529            let acc = pulp::as_arrays_mut::<4, _>(acc).0;
530            let lhs = pulp::as_arrays::<4, _>(lhs).0;
531            let rhs = pulp::as_arrays::<4, _>(rhs).0;
532
533            let big_q_m1 = simd.splat_u64x4(big_q - 1);
534            let big_q_m1_complement = simd.splat_u64x4(64 - (big_q - 1));
535            let p_barrett = simd.splat_u64x4(p_barrett);
536            let p = simd.splat_u64x4(p);
537
538            for (acc, lhs, rhs) in crate::izip!(acc, lhs, rhs) {
539                let lhs = cast(*lhs);
540                let rhs = cast(*rhs);
541
542                // lhs × rhs
543                let (lo, hi) = simd.widening_mul_u64x4(lhs, rhs);
544                let c1 = simd.or_u64x4(
545                    simd.shr_dyn_u64x4(lo, big_q_m1),
546                    simd.shl_dyn_u64x4(hi, big_q_m1_complement),
547                );
548                let c3 = simd.widening_mul_u64x4(c1, p_barrett).1;
549                // lo - p * c3
550                let prod = simd.wrapping_sub_u64x4(lo, simd.widening_mul_u64x4(p, c3).0);
551                let prod = simd.small_mod_u64x4(p, prod);
552
553                *acc = cast(simd.small_mod_u64x4(p, simd.wrapping_add_u64x4(prod, cast(*acc))));
554            }
555        },
556    )
557}
558
559fn mul_assign_normalize_scalar(
560    lhs: &mut [u64],
561    rhs: &[u64],
562    p: u64,
563    p_barrett: u64,
564    big_q: u64,
565    n_inv_mod_p: u64,
566    n_inv_mod_p_shoup: u64,
567) {
568    let big_q_m1 = big_q - 1;
569
570    for (lhs_, rhs) in crate::izip!(lhs, rhs) {
571        let lhs = *lhs_;
572        let rhs = *rhs;
573
574        let d = lhs as u128 * rhs as u128;
575        let c1 = (d >> big_q_m1) as u64;
576        let c3 = ((c1 as u128 * p_barrett as u128) >> 64) as u64;
577        let prod = (d as u64).wrapping_sub(p.wrapping_mul(c3));
578
579        let shoup_q = (((prod as u128) * (n_inv_mod_p_shoup as u128)) >> 64) as u64;
580        let t = u64::wrapping_sub(prod.wrapping_mul(n_inv_mod_p), shoup_q.wrapping_mul(p));
581
582        *lhs_ = t.min(t.wrapping_sub(p));
583    }
584}
585
586fn mul_accumulate_scalar(
587    acc: &mut [u64],
588    lhs: &[u64],
589    rhs: &[u64],
590    p: u64,
591    p_barrett: u64,
592    big_q: u64,
593) {
594    let big_q_m1 = big_q - 1;
595
596    for (acc, lhs, rhs) in crate::izip!(acc, lhs, rhs) {
597        let lhs = *lhs;
598        let rhs = *rhs;
599
600        let d = lhs as u128 * rhs as u128;
601        let c1 = (d >> big_q_m1) as u64;
602        let c3 = ((c1 as u128 * p_barrett as u128) >> 64) as u64;
603        let prod = (d as u64).wrapping_sub(p.wrapping_mul(c3));
604        let prod = prod.min(prod.wrapping_sub(p));
605
606        let acc_ = prod + *acc;
607        *acc = acc_.min(acc_.wrapping_sub(p));
608    }
609}
610
611#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
612#[cfg(feature = "nightly")]
613fn normalize_ifma(
614    simd: crate::V4IFma,
615    values: &mut [u64],
616    p: u64,
617    n_inv_mod_p: u64,
618    n_inv_mod_p_shoup: u64,
619) {
620    simd.vectorize(
621        #[inline(always)]
622        || {
623            let values = pulp::as_arrays_mut::<8, _>(values).0;
624
625            let n_inv_mod_p = simd.splat_u64x8(n_inv_mod_p);
626            let n_inv_mod_p_shoup = simd.splat_u64x8(n_inv_mod_p_shoup);
627            let neg_p = simd.splat_u64x8(p.wrapping_neg());
628            let p = simd.splat_u64x8(p);
629            let zero = simd.splat_u64x8(0);
630
631            for val_ in values {
632                let val = cast(*val_);
633
634                // normalization
635                let shoup_q = simd.widening_mul_u52x8(val, n_inv_mod_p_shoup).1;
636                let t = simd.wrapping_mul_add_u52x8(
637                    shoup_q,
638                    neg_p,
639                    simd.wrapping_mul_add_u52x8(val, n_inv_mod_p, zero),
640                );
641
642                *val_ = cast(simd.small_mod_u64x8(p, t));
643            }
644        },
645    )
646}
647
648#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
649#[cfg(feature = "nightly")]
650fn normalize_avx512(
651    simd: crate::V4,
652    values: &mut [u64],
653    p: u64,
654    n_inv_mod_p: u64,
655    n_inv_mod_p_shoup: u64,
656) {
657    simd.vectorize(
658        #[inline(always)]
659        move || {
660            let values = pulp::as_arrays_mut::<8, _>(values).0;
661
662            let n_inv_mod_p = simd.splat_u64x8(n_inv_mod_p);
663            let n_inv_mod_p_shoup = simd.splat_u64x8(n_inv_mod_p_shoup);
664            let p = simd.splat_u64x8(p);
665
666            for val_ in values {
667                let val = cast(*val_);
668
669                // normalization
670                let shoup_q = simd.widening_mul_u64x8(val, n_inv_mod_p_shoup).1;
671                let t = simd.wrapping_sub_u64x8(
672                    simd.wrapping_mul_u64x8(val, n_inv_mod_p),
673                    simd.wrapping_mul_u64x8(shoup_q, p),
674                );
675
676                *val_ = cast(simd.small_mod_u64x8(p, t));
677            }
678        },
679    );
680}
681
682#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
683fn normalize_avx2(
684    simd: crate::V3,
685    values: &mut [u64],
686    p: u64,
687    n_inv_mod_p: u64,
688    n_inv_mod_p_shoup: u64,
689) {
690    simd.vectorize(
691        #[inline(always)]
692        move || {
693            let values = pulp::as_arrays_mut::<4, _>(values).0;
694
695            let n_inv_mod_p = simd.splat_u64x4(n_inv_mod_p);
696            let n_inv_mod_p_shoup = simd.splat_u64x4(n_inv_mod_p_shoup);
697            let p = simd.splat_u64x4(p);
698
699            for val_ in values {
700                let val = cast(*val_);
701
702                // normalization
703                let shoup_q = simd.widening_mul_u64x4(val, n_inv_mod_p_shoup).1;
704                let t = simd.wrapping_sub_u64x4(
705                    simd.widening_mul_u64x4(val, n_inv_mod_p).0,
706                    simd.widening_mul_u64x4(shoup_q, p).0,
707                );
708
709                *val_ = cast(simd.small_mod_u64x4(p, t));
710            }
711        },
712    );
713}
714
715fn normalize_scalar(values: &mut [u64], p: u64, n_inv_mod_p: u64, n_inv_mod_p_shoup: u64) {
716    for val_ in values {
717        let val = *val_;
718
719        let shoup_q = (((val as u128) * (n_inv_mod_p_shoup as u128)) >> 64) as u64;
720        let t = u64::wrapping_sub(val.wrapping_mul(n_inv_mod_p), shoup_q.wrapping_mul(p));
721
722        *val_ = t.min(t.wrapping_sub(p));
723    }
724}
725
726struct BarrettInit64 {
727    bits: u32,
728    big_q: u64,
729    p_barrett: u64,
730    requires_single_reduction_step: bool,
731}
732
733impl BarrettInit64 {
734    pub fn new(modulus: u64, bits: u32) -> Self {
735        let big_q = modulus.ilog2() + 1;
736        let big_l = big_q + bits - 1;
737        let m_as_u128: u128 = modulus.into();
738        let two_to_the_l = 1u128 << big_l; // Equivalent to 2^{2k} from the zk security blog
739        let (p_barrett, beta) = (
740            (two_to_the_l / m_as_u128) as u64,
741            (two_to_the_l % m_as_u128),
742        );
743
744        // Check that the chosen prime will only trigger a single barrett reduction step with
745        // our implementation. If two reductions are needed there can be cases where it is not
746        // possible to decide whether a reduction is required yielding wrong results.
747        // Formula derived with https://blog.zksecurity.xyz/posts/barrett-tighter-bound/
748        let single_reduction_threshold = m_as_u128 - (1 << (big_q - 1));
749
750        let requires_single_reduction_step = beta <= single_reduction_threshold;
751
752        Self {
753            bits,
754            big_q: big_q.into(),
755            p_barrett,
756            requires_single_reduction_step,
757        }
758    }
759}
760
761impl Plan {
762    /// Returns a negacyclic NTT plan for the given polynomial size and modulus, or `None` if no
763    /// suitable roots of unity can be found for the wanted parameters.
764    pub fn try_new(polynomial_size: usize, modulus: u64) -> Option<Self> {
765        let p_div = Div64::new(modulus);
766        // 16 = 8x2 = max_register_size * ntt_radix,
767        // as SIMD registers can contain at most 8*u64
768        // and the implementation assumes that SIMD registers are full
769        if polynomial_size < 16
770            || !polynomial_size.is_power_of_two()
771            || !is_prime64(modulus)
772            || find_primitive_root64(p_div, 2 * polynomial_size as u64).is_none()
773        {
774            None
775        } else {
776            let ifma_instructions_available = {
777                #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
778                #[cfg(feature = "nightly")]
779                let has_ifma = crate::V4IFma::try_new().is_some();
780                #[cfg(not(all(
781                    any(target_arch = "x86", target_arch = "x86_64"),
782                    feature = "nightly",
783                )))]
784                let has_ifma = false;
785
786                has_ifma
787            };
788
789            // See prime32 for the logic behind the checks performed here
790            // They avoid overflows and allow the use of fast code paths.
791            let init_ifma = if ifma_instructions_available && modulus < (1u64 << 52) {
792                let init_less_than_52_bits = BarrettInit64::new(modulus, 52);
793                if (modulus < 1501199875790166)
794                    || (init_less_than_52_bits.requires_single_reduction_step
795                        && modulus < (1 << 51))
796                {
797                    // If we comply with the 52 bits code requirements return the init params
798                    Some(init_less_than_52_bits)
799                } else {
800                    // Otherwise we will need a 64 bits fallback
801                    None
802                }
803            } else {
804                None
805            };
806
807            let BarrettInit64 {
808                bits,
809                big_q,
810                p_barrett,
811                requires_single_reduction_step,
812            } = init_ifma.unwrap_or_else(|| BarrettInit64::new(modulus, 64));
813
814            let use_ifma = bits == 52;
815            let can_use_fast_reduction_code = use_ifma
816                || ((modulus < 6148914691236517206)
817                    || (requires_single_reduction_step && (modulus < (1 << 63))));
818
819            let mut twid = avec![0u64; polynomial_size].into_boxed_slice();
820            let mut inv_twid = avec![0u64; polynomial_size].into_boxed_slice();
821            let (mut twid_shoup, mut inv_twid_shoup) = if modulus < (1u64 << 63) {
822                (
823                    avec![0u64; polynomial_size].into_boxed_slice(),
824                    avec![0u64; polynomial_size].into_boxed_slice(),
825                )
826            } else {
827                (avec![].into_boxed_slice(), avec![].into_boxed_slice())
828            };
829
830            if modulus < (1u64 << 63) {
831                init_negacyclic_twiddles_shoup(
832                    modulus,
833                    polynomial_size,
834                    bits,
835                    &mut twid,
836                    &mut twid_shoup,
837                    &mut inv_twid,
838                    &mut inv_twid_shoup,
839                );
840            } else {
841                init_negacyclic_twiddles(modulus, polynomial_size, &mut twid, &mut inv_twid);
842            }
843
844            let n_inv_mod_p = crate::prime::exp_mod64(p_div, polynomial_size as u64, modulus - 2);
845            let n_inv_mod_p_shoup = (((n_inv_mod_p as u128) << bits) / modulus as u128) as u64;
846
847            Some(Self {
848                twid,
849                twid_shoup,
850                inv_twid_shoup,
851                inv_twid,
852                p: modulus,
853                p_div,
854                use_ifma,
855                can_use_fast_reduction_code,
856                p_barrett,
857                big_q,
858                n_inv_mod_p,
859                n_inv_mod_p_shoup,
860            })
861        }
862    }
863
864    pub(crate) fn p_div(&self) -> Div64 {
865        self.p_div
866    }
867
868    /// Returns the polynomial size of the negacyclic NTT plan.
869    #[inline]
870    pub fn ntt_size(&self) -> usize {
871        self.twid.len()
872    }
873
874    /// Returns the modulus of the negacyclic NTT plan.
875    #[inline]
876    pub fn modulus(&self) -> u64 {
877        self.p
878    }
879
880    /// Returns whether the negacyclic NTT plan uses IFMA instructions on x86.
881    #[inline]
882    pub fn use_ifma(&self) -> bool {
883        self.use_ifma
884    }
885
886    /// Returns whether the negacyclic NTT plan can use fast reduction code.
887    #[inline]
888    pub fn can_use_fast_reduction_code(&self) -> bool {
889        self.can_use_fast_reduction_code
890    }
891
892    /// Applies a forward negacyclic NTT transform in place to the given buffer.
893    ///
894    /// # Note
895    /// On entry, the buffer holds the polynomial coefficients in standard order. On exit, the
896    /// buffer holds the negacyclic NTT transform coefficients in bit reversed order.
897    pub fn fwd(&self, buf: &mut [u64]) {
898        assert_eq!(buf.len(), self.ntt_size());
899        let p = self.p;
900
901        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
902        #[cfg(feature = "nightly")]
903        if p < (1u64 << 50) {
904            if let Some(simd) = crate::V4IFma::try_new() {
905                less_than_50bit::fwd_avx512(simd, p, buf, &self.twid, &self.twid_shoup);
906                return;
907            }
908        } else if p < (1u64 << 51) {
909            if let Some(simd) = crate::V4IFma::try_new() {
910                less_than_51bit::fwd_avx512(simd, p, buf, &self.twid, &self.twid_shoup);
911                return;
912            }
913        }
914
915        if p < (1u64 << 62) {
916            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
917            {
918                #[cfg(feature = "nightly")]
919                if let Some(simd) = crate::V4::try_new() {
920                    less_than_62bit::fwd_avx512(simd, p, buf, &self.twid, &self.twid_shoup);
921                    return;
922                }
923                if let Some(simd) = crate::V3::try_new() {
924                    less_than_62bit::fwd_avx2(simd, p, buf, &self.twid, &self.twid_shoup);
925                    return;
926                }
927            }
928            less_than_62bit::fwd_scalar(p, buf, &self.twid, &self.twid_shoup);
929        } else if p < (1u64 << 63) {
930            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
931            {
932                #[cfg(feature = "nightly")]
933                if let Some(simd) = crate::V4::try_new() {
934                    less_than_63bit::fwd_avx512(simd, p, buf, &self.twid, &self.twid_shoup);
935                    return;
936                }
937                if let Some(simd) = crate::V3::try_new() {
938                    less_than_63bit::fwd_avx2(simd, p, buf, &self.twid, &self.twid_shoup);
939                    return;
940                }
941            }
942            less_than_63bit::fwd_scalar(p, buf, &self.twid, &self.twid_shoup);
943        } else if p == Solinas::P {
944            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
945            {
946                #[cfg(feature = "nightly")]
947                if let Some(simd) = crate::V4::try_new() {
948                    generic_solinas::fwd_avx512(simd, buf, Solinas, (), &self.twid);
949                    return;
950                }
951                if let Some(simd) = crate::V3::try_new() {
952                    generic_solinas::fwd_avx2(simd, buf, Solinas, (), &self.twid);
953                    return;
954                }
955            }
956            generic_solinas::fwd_scalar(buf, Solinas, (), &self.twid);
957        } else {
958            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
959            #[cfg(feature = "nightly")]
960            if let Some(simd) = crate::V4::try_new() {
961                let crate::u256 { x0, x1, x2, x3 } = self.p_div.double_reciprocal;
962                let p_div = (p, x0, x1, x2, x3);
963                generic_solinas::fwd_avx512(simd, buf, p, p_div, &self.twid);
964                return;
965            }
966            generic_solinas::fwd_scalar(buf, p, self.p_div, &self.twid);
967        }
968    }
969
970    /// Applies an inverse negacyclic NTT transform in place to the given buffer.
971    ///
972    /// # Note
973    /// On entry, the buffer holds the negacyclic NTT transform coefficients in bit reversed order.
974    /// On exit, the buffer holds the polynomial coefficients in standard order.
975    pub fn inv(&self, buf: &mut [u64]) {
976        assert_eq!(buf.len(), self.ntt_size());
977        let p = self.p;
978
979        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
980        #[cfg(feature = "nightly")]
981        if p < (1u64 << 50) {
982            if let Some(simd) = crate::V4IFma::try_new() {
983                less_than_50bit::inv_avx512(simd, p, buf, &self.inv_twid, &self.inv_twid_shoup);
984                return;
985            }
986        } else if p < (1u64 << 51) {
987            if let Some(simd) = crate::V4IFma::try_new() {
988                less_than_51bit::inv_avx512(simd, p, buf, &self.inv_twid, &self.inv_twid_shoup);
989                return;
990            }
991        }
992
993        if p < (1u64 << 62) {
994            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
995            {
996                #[cfg(feature = "nightly")]
997                if let Some(simd) = crate::V4::try_new() {
998                    less_than_62bit::inv_avx512(simd, p, buf, &self.inv_twid, &self.inv_twid_shoup);
999                    return;
1000                }
1001                if let Some(simd) = crate::V3::try_new() {
1002                    less_than_62bit::inv_avx2(simd, p, buf, &self.inv_twid, &self.inv_twid_shoup);
1003                    return;
1004                }
1005            }
1006            less_than_62bit::inv_scalar(p, buf, &self.inv_twid, &self.inv_twid_shoup);
1007        } else if p < (1u64 << 63) {
1008            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1009            {
1010                #[cfg(feature = "nightly")]
1011                if let Some(simd) = crate::V4::try_new() {
1012                    less_than_63bit::inv_avx512(simd, p, buf, &self.inv_twid, &self.inv_twid_shoup);
1013                    return;
1014                }
1015                if let Some(simd) = crate::V3::try_new() {
1016                    less_than_63bit::inv_avx2(simd, p, buf, &self.inv_twid, &self.inv_twid_shoup);
1017                    return;
1018                }
1019            }
1020            less_than_63bit::inv_scalar(p, buf, &self.inv_twid, &self.inv_twid_shoup);
1021        } else if p == Solinas::P {
1022            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1023            {
1024                #[cfg(feature = "nightly")]
1025                if let Some(simd) = crate::V4::try_new() {
1026                    generic_solinas::inv_avx512(simd, buf, Solinas, (), &self.inv_twid);
1027                    return;
1028                }
1029                if let Some(simd) = crate::V3::try_new() {
1030                    generic_solinas::inv_avx2(simd, buf, Solinas, (), &self.inv_twid);
1031                    return;
1032                }
1033            }
1034            generic_solinas::inv_scalar(buf, Solinas, (), &self.inv_twid);
1035        } else {
1036            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1037            #[cfg(feature = "nightly")]
1038            if let Some(simd) = crate::V4::try_new() {
1039                let crate::u256 { x0, x1, x2, x3 } = self.p_div.double_reciprocal;
1040                let p_div = (p, x0, x1, x2, x3);
1041                generic_solinas::inv_avx512(simd, buf, p, p_div, &self.inv_twid);
1042                return;
1043            }
1044            generic_solinas::inv_scalar(buf, p, self.p_div, &self.inv_twid);
1045        }
1046    }
1047
1048    /// Computes the elementwise product of `lhs` and `rhs`, multiplied by the inverse of the
1049    /// polynomial modulo the NTT modulus, and stores the result in `lhs`.
1050    pub fn mul_assign_normalize(&self, lhs: &mut [u64], rhs: &[u64]) {
1051        let p = self.p;
1052        let can_use_fast_reduction_code = self.can_use_fast_reduction_code;
1053
1054        if can_use_fast_reduction_code {
1055            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1056            #[cfg(feature = "nightly")]
1057            if self.use_ifma {
1058                // p < 2^51
1059                let simd = crate::V4IFma::try_new().unwrap();
1060                mul_assign_normalize_ifma(
1061                    simd,
1062                    lhs,
1063                    rhs,
1064                    p,
1065                    self.p_barrett,
1066                    self.big_q,
1067                    self.n_inv_mod_p,
1068                    self.n_inv_mod_p_shoup,
1069                );
1070                return;
1071            }
1072
1073            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1074            #[cfg(feature = "nightly")]
1075            if let Some(simd) = crate::V4::try_new() {
1076                mul_assign_normalize_avx512(
1077                    simd,
1078                    lhs,
1079                    rhs,
1080                    p,
1081                    self.p_barrett,
1082                    self.big_q,
1083                    self.n_inv_mod_p,
1084                    self.n_inv_mod_p_shoup,
1085                );
1086                return;
1087            }
1088
1089            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1090            if let Some(simd) = crate::V3::try_new() {
1091                mul_assign_normalize_avx2(
1092                    simd,
1093                    lhs,
1094                    rhs,
1095                    p,
1096                    self.p_barrett,
1097                    self.big_q,
1098                    self.n_inv_mod_p,
1099                    self.n_inv_mod_p_shoup,
1100                );
1101                return;
1102            }
1103
1104            mul_assign_normalize_scalar(
1105                lhs,
1106                rhs,
1107                p,
1108                self.p_barrett,
1109                self.big_q,
1110                self.n_inv_mod_p,
1111                self.n_inv_mod_p_shoup,
1112            );
1113        } else if p == Solinas::P {
1114            let n_inv_mod_p = self.n_inv_mod_p;
1115            for (lhs_, rhs) in crate::izip!(lhs, rhs) {
1116                let lhs = *lhs_;
1117                let rhs = *rhs;
1118                let prod = <Solinas as PrimeModulus>::mul((), lhs, rhs);
1119                let prod = <Solinas as PrimeModulus>::mul((), prod, n_inv_mod_p);
1120                *lhs_ = prod;
1121            }
1122        } else {
1123            let p_div = self.p_div;
1124            let n_inv_mod_p = self.n_inv_mod_p;
1125            for (lhs_, rhs) in crate::izip!(lhs, rhs) {
1126                let lhs = *lhs_;
1127                let rhs = *rhs;
1128                let prod = <u64 as PrimeModulus>::mul(p_div, lhs, rhs);
1129                let prod = <u64 as PrimeModulus>::mul(p_div, prod, n_inv_mod_p);
1130                *lhs_ = prod;
1131            }
1132        }
1133    }
1134
1135    /// Multiplies the values by the inverse of the polynomial modulo the NTT modulus, and stores
1136    /// the result in `values`.
1137    pub fn normalize(&self, values: &mut [u64]) {
1138        let p = self.p;
1139        let can_use_fast_reduction_code = self.can_use_fast_reduction_code;
1140
1141        if can_use_fast_reduction_code {
1142            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1143            #[cfg(feature = "nightly")]
1144            if self.use_ifma {
1145                // p < 2^51
1146                let simd = crate::V4IFma::try_new().unwrap();
1147                normalize_ifma(simd, values, p, self.n_inv_mod_p, self.n_inv_mod_p_shoup);
1148                return;
1149            }
1150
1151            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1152            #[cfg(feature = "nightly")]
1153            if let Some(simd) = crate::V4::try_new() {
1154                normalize_avx512(simd, values, p, self.n_inv_mod_p, self.n_inv_mod_p_shoup);
1155                return;
1156            }
1157
1158            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1159            if let Some(simd) = crate::V3::try_new() {
1160                normalize_avx2(simd, values, p, self.n_inv_mod_p, self.n_inv_mod_p_shoup);
1161                return;
1162            }
1163
1164            normalize_scalar(values, p, self.n_inv_mod_p, self.n_inv_mod_p_shoup);
1165        } else if p == Solinas::P {
1166            let n_inv_mod_p = self.n_inv_mod_p;
1167            for val in values {
1168                let prod = <Solinas as PrimeModulus>::mul((), *val, n_inv_mod_p);
1169                *val = prod;
1170            }
1171        } else {
1172            let n_inv_mod_p = self.n_inv_mod_p;
1173            let p_div = self.p_div;
1174            for val in values {
1175                let prod = <u64 as PrimeModulus>::mul(p_div, *val, n_inv_mod_p);
1176                *val = prod;
1177            }
1178        }
1179    }
1180
1181    /// Computes the elementwise product of `lhs` and `rhs` and accumulates the result to `acc`.
1182    pub fn mul_accumulate(&self, acc: &mut [u64], lhs: &[u64], rhs: &[u64]) {
1183        let p = self.p;
1184        let can_use_fast_reduction_code = self.can_use_fast_reduction_code;
1185
1186        if can_use_fast_reduction_code {
1187            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1188            #[cfg(feature = "nightly")]
1189            if self.use_ifma {
1190                // p < 2^51
1191                let simd = crate::V4IFma::try_new().unwrap();
1192                mul_accumulate_ifma(simd, acc, lhs, rhs, p, self.p_barrett, self.big_q);
1193                return;
1194            }
1195
1196            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1197            #[cfg(feature = "nightly")]
1198            if let Some(simd) = crate::V4::try_new() {
1199                mul_accumulate_avx512(simd, acc, lhs, rhs, p, self.p_barrett, self.big_q);
1200                return;
1201            }
1202
1203            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1204            if let Some(simd) = crate::V3::try_new() {
1205                mul_accumulate_avx2(simd, acc, lhs, rhs, p, self.p_barrett, self.big_q);
1206                return;
1207            }
1208
1209            mul_accumulate_scalar(acc, lhs, rhs, p, self.p_barrett, self.big_q);
1210        } else if p == Solinas::P {
1211            for (acc, lhs, rhs) in crate::izip!(acc, lhs, rhs) {
1212                let prod = <Solinas as PrimeModulus>::mul((), *lhs, *rhs);
1213                *acc = <Solinas as PrimeModulus>::add(Solinas, *acc, prod);
1214            }
1215        } else {
1216            let p_div = self.p_div;
1217            for (acc, lhs, rhs) in crate::izip!(acc, lhs, rhs) {
1218                let prod = <u64 as PrimeModulus>::mul(p_div, *lhs, *rhs);
1219                *acc = <u64 as PrimeModulus>::add(p, *acc, prod);
1220            }
1221        }
1222    }
1223}
1224
1225#[cfg(test)]
1226pub mod tests {
1227    use super::*;
1228    use crate::{
1229        fastdiv::Div64, prime::largest_prime_in_arithmetic_progression64,
1230        prime64::generic_solinas::PrimeModulus,
1231    };
1232    use alloc::{vec, vec::Vec};
1233    use rand::random;
1234
1235    extern crate alloc;
1236
1237    pub fn add(p: u64, a: u64, b: u64) -> u64 {
1238        let neg_b = p.wrapping_sub(b);
1239        if a >= neg_b {
1240            a - neg_b
1241        } else {
1242            a + b
1243        }
1244    }
1245
1246    pub fn sub(p: u64, a: u64, b: u64) -> u64 {
1247        let neg_b = p.wrapping_sub(b);
1248        if a >= b {
1249            a - b
1250        } else {
1251            a + neg_b
1252        }
1253    }
1254
1255    pub fn mul(p: u64, a: u64, b: u64) -> u64 {
1256        let wide = a as u128 * b as u128;
1257        if p == 0 {
1258            wide as u64
1259        } else {
1260            (wide % p as u128) as u64
1261        }
1262    }
1263
1264    pub fn negacyclic_convolution(n: usize, p: u64, lhs: &[u64], rhs: &[u64]) -> vec::Vec<u64> {
1265        let mut full_convolution = vec![0u64; 2 * n];
1266        let mut negacyclic_convolution = vec![0u64; n];
1267        for i in 0..n {
1268            for j in 0..n {
1269                full_convolution[i + j] = add(p, full_convolution[i + j], mul(p, lhs[i], rhs[j]));
1270            }
1271        }
1272        for i in 0..n {
1273            negacyclic_convolution[i] = sub(p, full_convolution[i], full_convolution[i + n]);
1274        }
1275        negacyclic_convolution
1276    }
1277
1278    pub fn random_lhs_rhs_with_negacyclic_convolution(
1279        n: usize,
1280        p: u64,
1281    ) -> (vec::Vec<u64>, vec::Vec<u64>, vec::Vec<u64>) {
1282        let mut lhs = vec![0u64; n];
1283        let mut rhs = vec![0u64; n];
1284
1285        for x in &mut lhs {
1286            *x = random();
1287            if p != 0 {
1288                *x %= p;
1289            }
1290        }
1291        for x in &mut rhs {
1292            *x = random();
1293            if p != 0 {
1294                *x %= p;
1295            }
1296        }
1297
1298        let lhs = lhs;
1299        let rhs = rhs;
1300
1301        let negacyclic_convolution = negacyclic_convolution(n, p, &lhs, &rhs);
1302        (lhs, rhs, negacyclic_convolution)
1303    }
1304
1305    #[test]
1306    fn test_product() {
1307        for n in [16, 32, 64, 128, 256, 512, 1024] {
1308            for p in [
1309                largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 49, 1 << 50).unwrap(),
1310                largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 50, 1 << 51).unwrap(),
1311                largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 61, 1 << 62).unwrap(),
1312                largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 62, 1 << 63).unwrap(),
1313                Solinas::P,
1314                largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 63, u64::MAX).unwrap(),
1315            ] {
1316                let plan = Plan::try_new(n, p).unwrap();
1317
1318                let (lhs, rhs, negacyclic_convolution) =
1319                    random_lhs_rhs_with_negacyclic_convolution(n, p);
1320
1321                let mut prod = vec![0u64; n];
1322                let mut lhs_fourier = lhs.clone();
1323                let mut rhs_fourier = rhs.clone();
1324
1325                plan.fwd(&mut lhs_fourier);
1326                plan.fwd(&mut rhs_fourier);
1327
1328                for x in &lhs_fourier {
1329                    assert!(*x < p);
1330                }
1331                for x in &rhs_fourier {
1332                    assert!(*x < p);
1333                }
1334
1335                for i in 0..n {
1336                    prod[i] =
1337                        <u64 as PrimeModulus>::mul(Div64::new(p), lhs_fourier[i], rhs_fourier[i]);
1338                }
1339                plan.inv(&mut prod);
1340
1341                plan.mul_assign_normalize(&mut lhs_fourier, &rhs_fourier);
1342                plan.inv(&mut lhs_fourier);
1343
1344                for x in &prod {
1345                    assert!(*x < p);
1346                }
1347
1348                for i in 0..n {
1349                    assert_eq!(
1350                        prod[i],
1351                        <u64 as PrimeModulus>::mul(
1352                            Div64::new(p),
1353                            negacyclic_convolution[i],
1354                            n as u64
1355                        ),
1356                    );
1357                }
1358                assert_eq!(lhs_fourier, negacyclic_convolution);
1359            }
1360        }
1361    }
1362
1363    #[test]
1364    fn test_normalize_scalar() {
1365        let p = largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 50, 1 << 63).unwrap();
1366        let p_div = Div64::new(p);
1367        let polynomial_size = 128;
1368
1369        let n_inv_mod_p = crate::prime::exp_mod64(p_div, polynomial_size as u64, p - 2);
1370        let n_inv_mod_p_shoup = (((n_inv_mod_p as u128) << 64) / p as u128) as u64;
1371
1372        let mut val = (0..polynomial_size)
1373            .map(|_| rand::random::<u64>() % p)
1374            .collect::<Vec<_>>();
1375        let mut val_target = val.clone();
1376
1377        let mul = |a: u64, b: u64| ((a as u128 * b as u128) % p as u128) as u64;
1378
1379        for val in val_target.iter_mut() {
1380            *val = mul(*val, n_inv_mod_p);
1381        }
1382
1383        normalize_scalar(&mut val, p, n_inv_mod_p, n_inv_mod_p_shoup);
1384        assert_eq!(val, val_target);
1385    }
1386
1387    #[test]
1388    fn test_mul_assign_normalize_scalar() {
1389        let p = largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 50, 1 << 63).unwrap();
1390        let p_div = Div64::new(p);
1391        let polynomial_size = 128;
1392
1393        let n_inv_mod_p = crate::prime::exp_mod64(p_div, polynomial_size as u64, p - 2);
1394        let n_inv_mod_p_shoup = (((n_inv_mod_p as u128) << 64) / p as u128) as u64;
1395        let big_q = (p.ilog2() + 1) as u64;
1396        let big_l = big_q + 63;
1397        let p_barrett = ((1u128 << big_l) / p as u128) as u64;
1398
1399        let mut lhs = (0..polynomial_size)
1400            .map(|_| rand::random::<u64>() % p)
1401            .collect::<Vec<_>>();
1402        let mut lhs_target = lhs.clone();
1403        let rhs = (0..polynomial_size)
1404            .map(|_| rand::random::<u64>() % p)
1405            .collect::<Vec<_>>();
1406
1407        let mul = |a: u64, b: u64| ((a as u128 * b as u128) % p as u128) as u64;
1408
1409        for (lhs, rhs) in lhs_target.iter_mut().zip(&rhs) {
1410            *lhs = mul(mul(*lhs, *rhs), n_inv_mod_p);
1411        }
1412
1413        mul_assign_normalize_scalar(
1414            &mut lhs,
1415            &rhs,
1416            p,
1417            p_barrett,
1418            big_q,
1419            n_inv_mod_p,
1420            n_inv_mod_p_shoup,
1421        );
1422        assert_eq!(lhs, lhs_target);
1423    }
1424
1425    #[test]
1426    fn test_mul_accumulate_scalar() {
1427        let p = largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 50, 1 << 63).unwrap();
1428        let polynomial_size = 128;
1429
1430        let big_q = (p.ilog2() + 1) as u64;
1431        let big_l = big_q + 63;
1432        let p_barrett = ((1u128 << big_l) / p as u128) as u64;
1433
1434        let mut acc = (0..polynomial_size)
1435            .map(|_| rand::random::<u64>() % p)
1436            .collect::<Vec<_>>();
1437        let mut acc_target = acc.clone();
1438        let lhs = (0..polynomial_size)
1439            .map(|_| rand::random::<u64>() % p)
1440            .collect::<Vec<_>>();
1441        let rhs = (0..polynomial_size)
1442            .map(|_| rand::random::<u64>() % p)
1443            .collect::<Vec<_>>();
1444
1445        let mul = |a: u64, b: u64| ((a as u128 * b as u128) % p as u128) as u64;
1446        let add = |a: u64, b: u64| <u64 as PrimeModulus>::add(p, a, b);
1447
1448        for (acc, lhs, rhs) in crate::izip!(&mut acc_target, &lhs, &rhs) {
1449            *acc = add(mul(*lhs, *rhs), *acc);
1450        }
1451
1452        mul_accumulate_scalar(&mut acc, &lhs, &rhs, p, p_barrett, big_q);
1453        assert_eq!(acc, acc_target);
1454    }
1455
1456    #[test]
1457    fn test_mul_accumulate() {
1458        for p in [
1459            largest_prime_in_arithmetic_progression64(1 << 16, 1, 0, 1 << 51).unwrap(),
1460            largest_prime_in_arithmetic_progression64(1 << 16, 1, 0, 1 << 61).unwrap(),
1461            largest_prime_in_arithmetic_progression64(1 << 16, 1, 0, 1 << 62).unwrap(),
1462            largest_prime_in_arithmetic_progression64(1 << 16, 1, 0, 1 << 63).unwrap(),
1463            largest_prime_in_arithmetic_progression64(1 << 16, 1, 0, u64::MAX).unwrap(),
1464        ] {
1465            let polynomial_size = 128;
1466
1467            let mut acc = (0..polynomial_size)
1468                .map(|_| rand::random::<u64>() % p)
1469                .collect::<Vec<_>>();
1470            let mut acc_target = acc.clone();
1471            let lhs = (0..polynomial_size)
1472                .map(|_| rand::random::<u64>() % p)
1473                .collect::<Vec<_>>();
1474            let rhs = (0..polynomial_size)
1475                .map(|_| rand::random::<u64>() % p)
1476                .collect::<Vec<_>>();
1477
1478            let mul = |a: u64, b: u64| ((a as u128 * b as u128) % p as u128) as u64;
1479            let add = |a: u64, b: u64| ((a as u128 + b as u128) % p as u128) as u64;
1480
1481            for (acc, lhs, rhs) in crate::izip!(&mut acc_target, &lhs, &rhs) {
1482                *acc = add(mul(*lhs, *rhs), *acc);
1483            }
1484
1485            Plan::try_new(polynomial_size, p)
1486                .unwrap()
1487                .mul_accumulate(&mut acc, &lhs, &rhs);
1488            assert_eq!(acc, acc_target);
1489        }
1490    }
1491
1492    #[test]
1493    fn test_mul_assign_normalize() {
1494        for p in [
1495            largest_prime_in_arithmetic_progression64(1 << 16, 1, 0, 1 << 51).unwrap(),
1496            largest_prime_in_arithmetic_progression64(1 << 16, 1, 0, 1 << 61).unwrap(),
1497            largest_prime_in_arithmetic_progression64(1 << 16, 1, 0, 1 << 62).unwrap(),
1498            largest_prime_in_arithmetic_progression64(1 << 16, 1, 0, 1 << 63).unwrap(),
1499            largest_prime_in_arithmetic_progression64(1 << 16, 1, 0, u64::MAX).unwrap(),
1500        ] {
1501            let polynomial_size = 128;
1502            let p_div = Div64::new(p);
1503            let n_inv_mod_p = crate::prime::exp_mod64(p_div, polynomial_size as u64, p - 2);
1504
1505            let mut lhs = (0..polynomial_size)
1506                .map(|_| rand::random::<u64>() % p)
1507                .collect::<Vec<_>>();
1508            let mut lhs_target = lhs.clone();
1509            let rhs = (0..polynomial_size)
1510                .map(|_| rand::random::<u64>() % p)
1511                .collect::<Vec<_>>();
1512
1513            let mul = |a: u64, b: u64| ((a as u128 * b as u128) % p as u128) as u64;
1514
1515            for (lhs, rhs) in lhs_target.iter_mut().zip(&rhs) {
1516                *lhs = mul(mul(*lhs, *rhs), n_inv_mod_p);
1517            }
1518
1519            Plan::try_new(polynomial_size, p)
1520                .unwrap()
1521                .mul_assign_normalize(&mut lhs, &rhs);
1522            assert_eq!(lhs, lhs_target);
1523        }
1524    }
1525
1526    #[test]
1527    fn test_normalize() {
1528        for p in [
1529            largest_prime_in_arithmetic_progression64(1 << 16, 1, 0, 1 << 51).unwrap(),
1530            largest_prime_in_arithmetic_progression64(1 << 16, 1, 0, 1 << 61).unwrap(),
1531            largest_prime_in_arithmetic_progression64(1 << 16, 1, 0, 1 << 62).unwrap(),
1532            largest_prime_in_arithmetic_progression64(1 << 16, 1, 0, 1 << 63).unwrap(),
1533            largest_prime_in_arithmetic_progression64(1 << 16, 1, 0, u64::MAX).unwrap(),
1534        ] {
1535            let polynomial_size = 128;
1536            let p_div = Div64::new(p);
1537            let n_inv_mod_p = crate::prime::exp_mod64(p_div, polynomial_size as u64, p - 2);
1538
1539            let mut val = (0..polynomial_size)
1540                .map(|_| rand::random::<u64>() % p)
1541                .collect::<Vec<_>>();
1542            let mut val_target = val.clone();
1543
1544            let mul = |a: u64, b: u64| ((a as u128 * b as u128) % p as u128) as u64;
1545
1546            for val in &mut val_target {
1547                *val = mul(*val, n_inv_mod_p);
1548            }
1549
1550            Plan::try_new(polynomial_size, p)
1551                .unwrap()
1552                .normalize(&mut val);
1553            assert_eq!(val, val_target);
1554        }
1555    }
1556
1557    #[test]
1558    fn test_plan_can_use_fast_reduction_code() {
1559        use crate::primes52::{P0, P1, P2, P3, P4, P5};
1560        const POLYNOMIAL_SIZE: usize = 32;
1561
1562        // First two primes are smaller than 6148914691236517206
1563        // The other ones can be used for performant code, we want those to be fast
1564        for p in [1062862849, 1431669377, P0, P1, P2, P3, P4, P5] {
1565            let plan = Plan::try_new(POLYNOMIAL_SIZE, p).unwrap();
1566
1567            assert!(plan.can_use_fast_reduction_code);
1568        }
1569    }
1570}
1571
1572#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1573#[cfg(test)]
1574mod x86_tests {
1575    use super::*;
1576    use crate::prime::largest_prime_in_arithmetic_progression64;
1577    use alloc::vec::Vec;
1578    use rand::random as rnd;
1579
1580    extern crate alloc;
1581
1582    #[test]
1583    fn test_interleaves_and_permutes_u64x4() {
1584        if let Some(simd) = crate::V3::try_new() {
1585            let a = u64x4(rnd(), rnd(), rnd(), rnd());
1586            let b = u64x4(rnd(), rnd(), rnd(), rnd());
1587
1588            assert_eq!(
1589                simd.interleave2_u64x4([a, b]),
1590                [u64x4(a.0, a.1, b.0, b.1), u64x4(a.2, a.3, b.2, b.3)],
1591            );
1592            assert_eq!(
1593                simd.interleave2_u64x4(simd.interleave2_u64x4([a, b])),
1594                [a, b],
1595            );
1596            let w = [rnd(), rnd()];
1597            assert_eq!(simd.permute2_u64x4(w), u64x4(w[0], w[0], w[1], w[1]));
1598
1599            assert_eq!(
1600                simd.interleave1_u64x4([a, b]),
1601                [u64x4(a.0, b.0, a.2, b.2), u64x4(a.1, b.1, a.3, b.3)],
1602            );
1603            assert_eq!(
1604                simd.interleave1_u64x4(simd.interleave1_u64x4([a, b])),
1605                [a, b],
1606            );
1607            let w = [rnd(), rnd(), rnd(), rnd()];
1608            assert_eq!(simd.permute1_u64x4(w), u64x4(w[0], w[2], w[1], w[3]));
1609        }
1610    }
1611
1612    #[cfg(feature = "nightly")]
1613    #[test]
1614    fn test_interleaves_and_permutes_u64x8() {
1615        if let Some(simd) = crate::V4::try_new() {
1616            let a = u64x8(rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd());
1617            let b = u64x8(rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd());
1618
1619            assert_eq!(
1620                simd.interleave4_u64x8([a, b]),
1621                [
1622                    u64x8(a.0, a.1, a.2, a.3, b.0, b.1, b.2, b.3),
1623                    u64x8(a.4, a.5, a.6, a.7, b.4, b.5, b.6, b.7),
1624                ],
1625            );
1626            assert_eq!(
1627                simd.interleave4_u64x8(simd.interleave4_u64x8([a, b])),
1628                [a, b],
1629            );
1630            let w = [rnd(), rnd()];
1631            assert_eq!(
1632                simd.permute4_u64x8(w),
1633                u64x8(w[0], w[0], w[0], w[0], w[1], w[1], w[1], w[1]),
1634            );
1635
1636            assert_eq!(
1637                simd.interleave2_u64x8([a, b]),
1638                [
1639                    u64x8(a.0, a.1, b.0, b.1, a.4, a.5, b.4, b.5),
1640                    u64x8(a.2, a.3, b.2, b.3, a.6, a.7, b.6, b.7),
1641                ],
1642            );
1643            assert_eq!(
1644                simd.interleave2_u64x8(simd.interleave2_u64x8([a, b])),
1645                [a, b],
1646            );
1647            let w = [rnd(), rnd(), rnd(), rnd()];
1648            assert_eq!(
1649                simd.permute2_u64x8(w),
1650                u64x8(w[0], w[0], w[2], w[2], w[1], w[1], w[3], w[3]),
1651            );
1652
1653            assert_eq!(
1654                simd.interleave1_u64x8([a, b]),
1655                [
1656                    u64x8(a.0, b.0, a.2, b.2, a.4, b.4, a.6, b.6),
1657                    u64x8(a.1, b.1, a.3, b.3, a.5, b.5, a.7, b.7),
1658                ],
1659            );
1660            assert_eq!(
1661                simd.interleave1_u64x8(simd.interleave1_u64x8([a, b])),
1662                [a, b],
1663            );
1664            let w = [rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd()];
1665            assert_eq!(
1666                simd.permute1_u64x8(w),
1667                u64x8(w[0], w[4], w[1], w[5], w[2], w[6], w[3], w[7]),
1668            );
1669        }
1670    }
1671
1672    #[cfg(feature = "nightly")]
1673    #[test]
1674    fn test_mul_assign_normalize_ifma() {
1675        if let Some(simd) = crate::V4IFma::try_new() {
1676            let p =
1677                largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 50, 1 << 51).unwrap();
1678            let p_div = Div64::new(p);
1679            let polynomial_size = 128;
1680
1681            let n_inv_mod_p = crate::prime::exp_mod64(p_div, polynomial_size as u64, p - 2);
1682            let n_inv_mod_p_shoup = (((n_inv_mod_p as u128) << 52) / p as u128) as u64;
1683            let big_q = (p.ilog2() + 1) as u64;
1684            let big_l = big_q + 51;
1685            let p_barrett = ((1u128 << big_l) / p as u128) as u64;
1686
1687            let mut lhs = (0..polynomial_size)
1688                .map(|_| rand::random::<u64>() % p)
1689                .collect::<Vec<_>>();
1690            let mut lhs_target = lhs.clone();
1691            let rhs = (0..polynomial_size)
1692                .map(|_| rand::random::<u64>() % p)
1693                .collect::<Vec<_>>();
1694
1695            let mul = |a: u64, b: u64| ((a as u128 * b as u128) % p as u128) as u64;
1696
1697            for (lhs, rhs) in lhs_target.iter_mut().zip(&rhs) {
1698                *lhs = mul(mul(*lhs, *rhs), n_inv_mod_p);
1699            }
1700
1701            mul_assign_normalize_ifma(
1702                simd,
1703                &mut lhs,
1704                &rhs,
1705                p,
1706                p_barrett,
1707                big_q,
1708                n_inv_mod_p,
1709                n_inv_mod_p_shoup,
1710            );
1711            assert_eq!(lhs, lhs_target);
1712        }
1713    }
1714
1715    #[cfg(feature = "nightly")]
1716    #[test]
1717    fn test_mul_assign_normalize_avx512() {
1718        if let Some(simd) = crate::V4::try_new() {
1719            let p =
1720                largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 50, 1 << 63).unwrap();
1721            let p_div = Div64::new(p);
1722            let polynomial_size = 128;
1723
1724            let n_inv_mod_p = crate::prime::exp_mod64(p_div, polynomial_size as u64, p - 2);
1725            let n_inv_mod_p_shoup = (((n_inv_mod_p as u128) << 64) / p as u128) as u64;
1726            let big_q = (p.ilog2() + 1) as u64;
1727            let big_l = big_q + 63;
1728            let p_barrett = ((1u128 << big_l) / p as u128) as u64;
1729
1730            let mut lhs = (0..polynomial_size)
1731                .map(|_| rand::random::<u64>() % p)
1732                .collect::<Vec<_>>();
1733            let mut lhs_target = lhs.clone();
1734            let rhs = (0..polynomial_size)
1735                .map(|_| rand::random::<u64>() % p)
1736                .collect::<Vec<_>>();
1737
1738            let mul = |a: u64, b: u64| ((a as u128 * b as u128) % p as u128) as u64;
1739
1740            for (lhs, rhs) in lhs_target.iter_mut().zip(&rhs) {
1741                *lhs = mul(mul(*lhs, *rhs), n_inv_mod_p);
1742            }
1743
1744            mul_assign_normalize_avx512(
1745                simd,
1746                &mut lhs,
1747                &rhs,
1748                p,
1749                p_barrett,
1750                big_q,
1751                n_inv_mod_p,
1752                n_inv_mod_p_shoup,
1753            );
1754            assert_eq!(lhs, lhs_target);
1755        }
1756    }
1757
1758    #[test]
1759    fn test_mul_assign_normalize_avx2() {
1760        if let Some(simd) = crate::V3::try_new() {
1761            let p =
1762                largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 50, 1 << 63).unwrap();
1763            let p_div = Div64::new(p);
1764            let polynomial_size = 128;
1765
1766            let n_inv_mod_p = crate::prime::exp_mod64(p_div, polynomial_size as u64, p - 2);
1767            let n_inv_mod_p_shoup = (((n_inv_mod_p as u128) << 64) / p as u128) as u64;
1768            let big_q = (p.ilog2() + 1) as u64;
1769            let big_l = big_q + 63;
1770            let p_barrett = ((1u128 << big_l) / p as u128) as u64;
1771
1772            let mut lhs = (0..polynomial_size)
1773                .map(|_| rand::random::<u64>() % p)
1774                .collect::<Vec<_>>();
1775            let mut lhs_target = lhs.clone();
1776            let rhs = (0..polynomial_size)
1777                .map(|_| rand::random::<u64>() % p)
1778                .collect::<Vec<_>>();
1779
1780            let mul = |a: u64, b: u64| ((a as u128 * b as u128) % p as u128) as u64;
1781
1782            for (lhs, rhs) in lhs_target.iter_mut().zip(&rhs) {
1783                *lhs = mul(mul(*lhs, *rhs), n_inv_mod_p);
1784            }
1785
1786            mul_assign_normalize_avx2(
1787                simd,
1788                &mut lhs,
1789                &rhs,
1790                p,
1791                p_barrett,
1792                big_q,
1793                n_inv_mod_p,
1794                n_inv_mod_p_shoup,
1795            );
1796            assert_eq!(lhs, lhs_target);
1797        }
1798    }
1799
1800    #[cfg(feature = "nightly")]
1801    #[test]
1802    fn test_mul_accumulate_ifma() {
1803        if let Some(simd) = crate::V4IFma::try_new() {
1804            let p =
1805                largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 50, 1 << 51).unwrap();
1806            let polynomial_size = 128;
1807
1808            let big_q = (p.ilog2() + 1) as u64;
1809            let big_l = big_q + 51;
1810            let p_barrett = ((1u128 << big_l) / p as u128) as u64;
1811
1812            let mut acc = (0..polynomial_size)
1813                .map(|_| rand::random::<u64>() % p)
1814                .collect::<Vec<_>>();
1815            let mut acc_target = acc.clone();
1816            let lhs = (0..polynomial_size)
1817                .map(|_| rand::random::<u64>() % p)
1818                .collect::<Vec<_>>();
1819            let rhs = (0..polynomial_size)
1820                .map(|_| rand::random::<u64>() % p)
1821                .collect::<Vec<_>>();
1822
1823            let mul = |a: u64, b: u64| ((a as u128 * b as u128) % p as u128) as u64;
1824            let add = |a: u64, b: u64| <u64 as PrimeModulus>::add(p, a, b);
1825
1826            for (acc, lhs, rhs) in crate::izip!(&mut acc_target, &lhs, &rhs) {
1827                *acc = add(mul(*lhs, *rhs), *acc);
1828            }
1829
1830            mul_accumulate_ifma(simd, &mut acc, &lhs, &rhs, p, p_barrett, big_q);
1831            assert_eq!(acc, acc_target);
1832        }
1833    }
1834
1835    #[cfg(feature = "nightly")]
1836    #[test]
1837    fn test_mul_accumulate_avx512() {
1838        if let Some(simd) = crate::V4::try_new() {
1839            let p =
1840                largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 50, 1 << 63).unwrap();
1841            let polynomial_size = 128;
1842
1843            let big_q = (p.ilog2() + 1) as u64;
1844            let big_l = big_q + 63;
1845            let p_barrett = ((1u128 << big_l) / p as u128) as u64;
1846
1847            let mut acc = (0..polynomial_size)
1848                .map(|_| rand::random::<u64>() % p)
1849                .collect::<Vec<_>>();
1850            let mut acc_target = acc.clone();
1851            let lhs = (0..polynomial_size)
1852                .map(|_| rand::random::<u64>() % p)
1853                .collect::<Vec<_>>();
1854            let rhs = (0..polynomial_size)
1855                .map(|_| rand::random::<u64>() % p)
1856                .collect::<Vec<_>>();
1857
1858            let mul = |a: u64, b: u64| ((a as u128 * b as u128) % p as u128) as u64;
1859            let add = |a: u64, b: u64| <u64 as PrimeModulus>::add(p, a, b);
1860
1861            for (acc, lhs, rhs) in crate::izip!(&mut acc_target, &lhs, &rhs) {
1862                *acc = add(mul(*lhs, *rhs), *acc);
1863            }
1864
1865            mul_accumulate_avx512(simd, &mut acc, &lhs, &rhs, p, p_barrett, big_q);
1866            assert_eq!(acc, acc_target);
1867        }
1868    }
1869
1870    #[test]
1871    fn test_mul_accumulate_avx2() {
1872        if let Some(simd) = crate::V3::try_new() {
1873            let p =
1874                largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 50, 1 << 63).unwrap();
1875            let polynomial_size = 128;
1876
1877            let big_q = (p.ilog2() + 1) as u64;
1878            let big_l = big_q + 63;
1879            let p_barrett = ((1u128 << big_l) / p as u128) as u64;
1880
1881            let mut acc = (0..polynomial_size)
1882                .map(|_| rand::random::<u64>() % p)
1883                .collect::<Vec<_>>();
1884            let mut acc_target = acc.clone();
1885            let lhs = (0..polynomial_size)
1886                .map(|_| rand::random::<u64>() % p)
1887                .collect::<Vec<_>>();
1888            let rhs = (0..polynomial_size)
1889                .map(|_| rand::random::<u64>() % p)
1890                .collect::<Vec<_>>();
1891
1892            let mul = |a: u64, b: u64| ((a as u128 * b as u128) % p as u128) as u64;
1893            let add = |a: u64, b: u64| <u64 as PrimeModulus>::add(p, a, b);
1894
1895            for (acc, lhs, rhs) in crate::izip!(&mut acc_target, &lhs, &rhs) {
1896                *acc = add(mul(*lhs, *rhs), *acc);
1897            }
1898
1899            mul_accumulate_avx2(simd, &mut acc, &lhs, &rhs, p, p_barrett, big_q);
1900            assert_eq!(acc, acc_target);
1901        }
1902    }
1903
1904    #[cfg(feature = "nightly")]
1905    #[test]
1906    fn test_normalize_ifma() {
1907        if let Some(simd) = crate::V4IFma::try_new() {
1908            let p =
1909                largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 50, 1 << 51).unwrap();
1910            let p_div = Div64::new(p);
1911            let polynomial_size = 128;
1912
1913            let n_inv_mod_p = crate::prime::exp_mod64(p_div, polynomial_size as u64, p - 2);
1914            let n_inv_mod_p_shoup = (((n_inv_mod_p as u128) << 52) / p as u128) as u64;
1915
1916            let mut val = (0..polynomial_size)
1917                .map(|_| rand::random::<u64>() % p)
1918                .collect::<Vec<_>>();
1919            let mut val_target = val.clone();
1920
1921            let mul = |a: u64, b: u64| ((a as u128 * b as u128) % p as u128) as u64;
1922
1923            for val in val_target.iter_mut() {
1924                *val = mul(*val, n_inv_mod_p);
1925            }
1926
1927            normalize_ifma(simd, &mut val, p, n_inv_mod_p, n_inv_mod_p_shoup);
1928            assert_eq!(val, val_target);
1929        }
1930    }
1931
1932    #[cfg(feature = "nightly")]
1933    #[test]
1934    fn test_normalize_avx512() {
1935        if let Some(simd) = crate::V4::try_new() {
1936            let p =
1937                largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 50, 1 << 63).unwrap();
1938            let p_div = Div64::new(p);
1939            let polynomial_size = 128;
1940
1941            let n_inv_mod_p = crate::prime::exp_mod64(p_div, polynomial_size as u64, p - 2);
1942            let n_inv_mod_p_shoup = (((n_inv_mod_p as u128) << 64) / p as u128) as u64;
1943
1944            let mut val = (0..polynomial_size)
1945                .map(|_| rand::random::<u64>() % p)
1946                .collect::<Vec<_>>();
1947            let mut val_target = val.clone();
1948
1949            let mul = |a: u64, b: u64| ((a as u128 * b as u128) % p as u128) as u64;
1950
1951            for val in val_target.iter_mut() {
1952                *val = mul(*val, n_inv_mod_p);
1953            }
1954
1955            normalize_avx512(simd, &mut val, p, n_inv_mod_p, n_inv_mod_p_shoup);
1956            assert_eq!(val, val_target);
1957        }
1958    }
1959
1960    #[test]
1961    fn test_normalize_avx2() {
1962        if let Some(simd) = crate::V3::try_new() {
1963            let p =
1964                largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 50, 1 << 63).unwrap();
1965            let p_div = Div64::new(p);
1966            let polynomial_size = 128;
1967
1968            let n_inv_mod_p = crate::prime::exp_mod64(p_div, polynomial_size as u64, p - 2);
1969            let n_inv_mod_p_shoup = (((n_inv_mod_p as u128) << 64) / p as u128) as u64;
1970
1971            let mut val = (0..polynomial_size)
1972                .map(|_| rand::random::<u64>() % p)
1973                .collect::<Vec<_>>();
1974            let mut val_target = val.clone();
1975
1976            let mul = |a: u64, b: u64| ((a as u128 * b as u128) % p as u128) as u64;
1977
1978            for val in val_target.iter_mut() {
1979                *val = mul(*val, n_inv_mod_p);
1980            }
1981
1982            normalize_avx2(simd, &mut val, p, n_inv_mod_p, n_inv_mod_p_shoup);
1983            assert_eq!(val, val_target);
1984        }
1985    }
1986
1987    #[test]
1988    fn test_plan_crash_github_11() {
1989        assert!(Plan::try_new(2048, 1024).is_none());
1990    }
1991}