tfhe_ntt/
prime32.rs

1use crate::{
2    bit_rev,
3    fastdiv::{Div32, Div64},
4    prime::is_prime64,
5    roots::find_primitive_root64,
6};
7use aligned_vec::{avec, ABox};
8
9#[allow(unused_imports)]
10use pulp::*;
11
12const RECURSION_THRESHOLD: usize = 2048;
13
14mod generic;
15mod shoup;
16
17mod less_than_30bit;
18mod less_than_31bit;
19
20#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
21impl crate::V3 {
22    #[inline(always)]
23    fn interleave4_u32x8(self, z0z0z0z0z1z1z1z1: [u32x8; 2]) -> [u32x8; 2] {
24        let avx = self.avx;
25        [
26            cast(avx._mm256_permute2f128_si256::<0b0010_0000>(
27                cast(z0z0z0z0z1z1z1z1[0]),
28                cast(z0z0z0z0z1z1z1z1[1]),
29            )),
30            cast(avx._mm256_permute2f128_si256::<0b0011_0001>(
31                cast(z0z0z0z0z1z1z1z1[0]),
32                cast(z0z0z0z0z1z1z1z1[1]),
33            )),
34        ]
35    }
36
37    #[inline(always)]
38    fn permute4_u32x8(self, w: [u32; 2]) -> u32x8 {
39        let avx = self.avx;
40        let w0 = self.sse2._mm_set1_epi32(w[0] as i32);
41        let w1 = self.sse2._mm_set1_epi32(w[1] as i32);
42        cast(avx._mm256_insertf128_si256::<1>(avx._mm256_castsi128_si256(w0), w1))
43    }
44
45    #[inline(always)]
46    fn interleave2_u32x8(self, z0z0z1z1: [u32x8; 2]) -> [u32x8; 2] {
47        let avx = self.avx2;
48        [
49            cast(avx._mm256_unpacklo_epi64(cast(z0z0z1z1[0]), cast(z0z0z1z1[1]))),
50            cast(avx._mm256_unpackhi_epi64(cast(z0z0z1z1[0]), cast(z0z0z1z1[1]))),
51        ]
52    }
53
54    #[inline(always)]
55    fn permute2_u32x8(self, w: [u32; 4]) -> u32x8 {
56        let avx = self.avx;
57        let w0123 = pulp::cast(w);
58        let w0022 = self.sse2._mm_castps_si128(self.sse3._mm_moveldup_ps(w0123));
59        let w1133 = self.sse2._mm_castps_si128(self.sse3._mm_movehdup_ps(w0123));
60        cast(avx._mm256_insertf128_si256::<1>(avx._mm256_castsi128_si256(w0022), w1133))
61    }
62
63    #[inline(always)]
64    fn interleave1_u32x8(self, z0z0z1z1: [u32x8; 2]) -> [u32x8; 2] {
65        let avx = self.avx2;
66        let x = [
67            // 00 10 01 11 04 14 05 15 08 18 09 19 0c 1c 0d 1d
68            (avx._mm256_unpacklo_epi32(cast(z0z0z1z1[0]), cast(z0z0z1z1[1]))),
69            // 02 12 03 13 06 16 07 17 0a 1a 0b 1b 0e 1e 0f 1f
70            avx._mm256_unpackhi_epi32(cast(z0z0z1z1[0]), cast(z0z0z1z1[1])),
71        ];
72        [
73            // 00 10 02 12 04 14 06 16 08 18 0a 1a 0c 1c 0c 1c
74            cast(avx._mm256_unpacklo_epi64(x[0], x[1])),
75            // 01 11 03 13 05 15 07 17 09 19 0b 1b 0d 1d 0f 1f
76            cast(avx._mm256_unpackhi_epi64(x[0], x[1])),
77        ]
78    }
79
80    #[inline(always)]
81    fn permute1_u32x8(self, w: [u32; 8]) -> u32x8 {
82        let avx = self.avx;
83        let [w0123, w4567]: [u32x4; 2] = pulp::cast(w);
84        let w0415 = self.sse2._mm_unpacklo_epi32(cast(w0123), cast(w4567));
85        let w2637 = self.sse2._mm_unpackhi_epi32(cast(w0123), cast(w4567));
86        cast(avx._mm256_insertf128_si256::<1>(avx._mm256_castsi128_si256(w0415), w2637))
87    }
88
89    #[inline(always)]
90    pub fn small_mod_u32x8(self, modulus: u32x8, x: u32x8) -> u32x8 {
91        self.select_u32x8(
92            self.cmp_gt_u32x8(modulus, x),
93            x,
94            self.wrapping_sub_u32x8(x, modulus),
95        )
96    }
97}
98
99#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
100#[cfg(feature = "nightly")]
101impl crate::V4 {
102    #[inline(always)]
103    fn interleave8_u32x16(self, z0z0z0z0z0z0z0z0z1z1z1z1z1z1z1z1: [u32x16; 2]) -> [u32x16; 2] {
104        let avx = self.avx512f;
105        let idx_0 = avx._mm512_setr_epi32(
106            0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17,
107        );
108        let idx_1 = avx._mm512_setr_epi32(
109            0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f,
110        );
111        [
112            cast(avx._mm512_permutex2var_epi32(
113                cast(z0z0z0z0z0z0z0z0z1z1z1z1z1z1z1z1[0]),
114                idx_0,
115                cast(z0z0z0z0z0z0z0z0z1z1z1z1z1z1z1z1[1]),
116            )),
117            cast(avx._mm512_permutex2var_epi32(
118                cast(z0z0z0z0z0z0z0z0z1z1z1z1z1z1z1z1[0]),
119                idx_1,
120                cast(z0z0z0z0z0z0z0z0z1z1z1z1z1z1z1z1[1]),
121            )),
122        ]
123    }
124
125    #[inline(always)]
126    fn permute8_u32x16(self, w: [u32; 2]) -> u32x16 {
127        let avx = self.avx512f;
128        let w = pulp::cast(w);
129        let w01xxxxxxxxxxxxxx = avx._mm512_set1_epi64(w);
130        let idx = avx._mm512_setr_epi32(0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1);
131        cast(avx._mm512_permutexvar_epi32(idx, w01xxxxxxxxxxxxxx))
132    }
133
134    #[inline(always)]
135    fn interleave4_u32x16(self, z0z0z0z0z1z1z1z1: [u32x16; 2]) -> [u32x16; 2] {
136        let avx = self.avx512f;
137        let idx_0 = avx._mm512_setr_epi32(
138            0x0, 0x1, 0x2, 0x3, 0x10, 0x11, 0x12, 0x13, 0x8, 0x9, 0xa, 0xb, 0x18, 0x19, 0x1a, 0x1b,
139        );
140        let idx_1 = avx._mm512_setr_epi32(
141            0x4, 0x5, 0x6, 0x7, 0x14, 0x15, 0x16, 0x17, 0xc, 0xd, 0xe, 0xf, 0x1c, 0x1d, 0x1e, 0x1f,
142        );
143        [
144            cast(avx._mm512_permutex2var_epi32(
145                cast(z0z0z0z0z1z1z1z1[0]),
146                idx_0,
147                cast(z0z0z0z0z1z1z1z1[1]),
148            )),
149            cast(avx._mm512_permutex2var_epi32(
150                cast(z0z0z0z0z1z1z1z1[0]),
151                idx_1,
152                cast(z0z0z0z0z1z1z1z1[1]),
153            )),
154        ]
155    }
156
157    #[inline(always)]
158    fn permute4_u32x16(self, w: [u32; 4]) -> u32x16 {
159        let avx = self.avx512f;
160        let w = pulp::cast(w);
161        let w0123xxxxxxxxxxxx = avx._mm512_castsi128_si512(w);
162        let idx = avx._mm512_setr_epi32(0, 0, 0, 0, 2, 2, 2, 2, 1, 1, 1, 1, 3, 3, 3, 3);
163        cast(avx._mm512_permutexvar_epi32(idx, w0123xxxxxxxxxxxx))
164    }
165
166    #[inline(always)]
167    fn interleave2_u32x16(self, z0z0z1z1: [u32x16; 2]) -> [u32x16; 2] {
168        let avx = self.avx512f;
169        [
170            // 00 01 10 11 04 05 14 15 08 09 18 19 0c 0d 1c 1d
171            cast(avx._mm512_unpacklo_epi64(cast(z0z0z1z1[0]), cast(z0z0z1z1[1]))),
172            // 02 03 12 13 06 07 16 17 0a 0b 1a 1b 0e 0f 1e 1f
173            cast(avx._mm512_unpackhi_epi64(cast(z0z0z1z1[0]), cast(z0z0z1z1[1]))),
174        ]
175    }
176
177    #[inline(always)]
178    fn permute2_u32x16(self, w: [u32; 8]) -> u32x16 {
179        let avx = self.avx512f;
180        let w = pulp::cast(w);
181        let w01234567xxxxxxxx = avx._mm512_castsi256_si512(w);
182        let idx = avx._mm512_setr_epi32(0, 0, 4, 4, 1, 1, 5, 5, 2, 2, 6, 6, 3, 3, 7, 7);
183        cast(avx._mm512_permutexvar_epi32(idx, w01234567xxxxxxxx))
184    }
185
186    #[inline(always)]
187    fn interleave1_u32x16(self, z0z1: [u32x16; 2]) -> [u32x16; 2] {
188        let avx = self.avx512f;
189        let x = [
190            // 00 10 01 11 04 14 05 15 08 18 09 19 0c 1c 0d 1d
191            avx._mm512_unpacklo_epi32(cast(z0z1[0]), cast(z0z1[1])),
192            // 02 12 03 13 06 16 07 17 0a 1a 0b 1b 0e 1e 0f 1f
193            avx._mm512_unpackhi_epi32(cast(z0z1[0]), cast(z0z1[1])),
194        ];
195        [
196            // 00 10 02 12 04 14 06 16 08 18 0a 1a 0c 1c 0c 1c
197            cast(avx._mm512_unpacklo_epi64(x[0], x[1])),
198            // 01 11 03 13 05 15 07 17 09 19 0b 1b 0d 1d 0f 1f
199            cast(avx._mm512_unpackhi_epi64(x[0], x[1])),
200        ]
201    }
202
203    #[inline(always)]
204    fn permute1_u32x16(self, w: [u32; 16]) -> u32x16 {
205        let avx = self.avx512f;
206        let w = pulp::cast(w);
207        let idx = avx._mm512_setr_epi32(
208            0x0, 0x8, 0x1, 0x9, 0x2, 0xa, 0x3, 0xb, 0x4, 0xc, 0x5, 0xd, 0x6, 0xe, 0x7, 0xf,
209        );
210        cast(avx._mm512_permutexvar_epi32(idx, w))
211    }
212
213    #[inline(always)]
214    pub fn small_mod_u32x16(self, modulus: u32x16, x: u32x16) -> u32x16 {
215        self.select_u32x16(
216            self.cmp_gt_u32x16(modulus, x),
217            x,
218            self.wrapping_sub_u32x16(x, modulus),
219        )
220    }
221}
222
223fn init_negacyclic_twiddles(p: u32, n: usize, twid: &mut [u32], inv_twid: &mut [u32]) {
224    let div = Div32::new(p);
225    let w = find_primitive_root64(Div64::new(p as u64), 2 * n as u64).unwrap() as u32;
226    let mut k = 0;
227    let mut wk = 1u32;
228
229    let nbits = n.trailing_zeros();
230    while k < n {
231        let fwd_idx = bit_rev(nbits, k);
232
233        twid[fwd_idx] = wk;
234
235        let inv_idx = bit_rev(nbits, (n - k) % n);
236        if k == 0 {
237            inv_twid[inv_idx] = wk;
238        } else {
239            let x = p.wrapping_sub(wk);
240            inv_twid[inv_idx] = x;
241        }
242
243        wk = Div32::rem_u64(wk as u64 * w as u64, div);
244        k += 1;
245    }
246}
247
248fn init_negacyclic_twiddles_shoup(
249    p: u32,
250    n: usize,
251    twid: &mut [u32],
252    twid_shoup: &mut [u32],
253    inv_twid: &mut [u32],
254    inv_twid_shoup: &mut [u32],
255) {
256    let div = Div32::new(p);
257    let w = find_primitive_root64(Div64::new(p as u64), 2 * n as u64).unwrap() as u32;
258    let mut k = 0;
259    let mut wk = 1u32;
260
261    let nbits = n.trailing_zeros();
262    while k < n {
263        let fwd_idx = bit_rev(nbits, k);
264
265        let wk_shoup = Div32::div_u64((wk as u64) << 32, div) as u32;
266        twid[fwd_idx] = wk;
267        twid_shoup[fwd_idx] = wk_shoup;
268
269        let inv_idx = bit_rev(nbits, (n - k) % n);
270        if k == 0 {
271            inv_twid[inv_idx] = wk;
272            inv_twid_shoup[inv_idx] = wk_shoup;
273        } else {
274            let x = p.wrapping_sub(wk);
275            inv_twid[inv_idx] = x;
276            inv_twid_shoup[inv_idx] = Div32::div_u64((x as u64) << 32, div) as u32;
277        }
278
279        wk = Div32::rem_u64(wk as u64 * w as u64, div);
280        k += 1;
281    }
282}
283
284#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
285#[cfg(feature = "nightly")]
286fn mul_assign_normalize_avx512(
287    simd: crate::V4,
288    lhs: &mut [u32],
289    rhs: &[u32],
290    p: u32,
291    p_barrett: u32,
292    big_q: u32,
293    n_inv_mod_p: u32,
294    n_inv_mod_p_shoup: u32,
295) {
296    simd.vectorize(
297        #[inline(always)]
298        move || {
299            let lhs = pulp::as_arrays_mut::<16, _>(lhs).0;
300            let rhs = pulp::as_arrays::<16, _>(rhs).0;
301            let big_q_m1 = simd.splat_u32x16(big_q - 1);
302            let big_q_m1_complement = simd.splat_u32x16(32 - (big_q - 1));
303            let n_inv_mod_p = simd.splat_u32x16(n_inv_mod_p);
304            let n_inv_mod_p_shoup = simd.splat_u32x16(n_inv_mod_p_shoup);
305            let p_barrett = simd.splat_u32x16(p_barrett);
306            let p = simd.splat_u32x16(p);
307
308            for (lhs_, rhs) in crate::izip!(lhs, rhs) {
309                let lhs = cast(*lhs_);
310                let rhs = cast(*rhs);
311
312                // lhs × rhs
313                let (lo, hi) = simd.widening_mul_u32x16(lhs, rhs);
314                let c1 = simd.or_u32x16(
315                    simd.shr_dyn_u32x16(lo, big_q_m1),
316                    simd.shl_dyn_u32x16(hi, big_q_m1_complement),
317                );
318                let c3 = simd.widening_mul_u32x16(c1, p_barrett).1;
319                let prod = simd.wrapping_sub_u32x16(lo, simd.wrapping_mul_u32x16(p, c3));
320
321                // normalization
322                let shoup_q = simd.widening_mul_u32x16(prod, n_inv_mod_p_shoup).1;
323                let t = simd.wrapping_sub_u32x16(
324                    simd.wrapping_mul_u32x16(prod, n_inv_mod_p),
325                    simd.wrapping_mul_u32x16(shoup_q, p),
326                );
327
328                *lhs_ = cast(simd.small_mod_u32x16(p, t));
329            }
330        },
331    );
332}
333
334#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
335fn mul_assign_normalize_avx2(
336    simd: crate::V3,
337    lhs: &mut [u32],
338    rhs: &[u32],
339    p: u32,
340    p_barrett: u32,
341    big_q: u32,
342    n_inv_mod_p: u32,
343    n_inv_mod_p_shoup: u32,
344) {
345    simd.vectorize(
346        #[inline(always)]
347        move || {
348            let lhs = pulp::as_arrays_mut::<8, _>(lhs).0;
349            let rhs = pulp::as_arrays::<8, _>(rhs).0;
350            let big_q_m1 = simd.splat_u32x8(big_q - 1);
351            let big_q_m1_complement = simd.splat_u32x8(32 - (big_q - 1));
352            let n_inv_mod_p = simd.splat_u32x8(n_inv_mod_p);
353            let n_inv_mod_p_shoup = simd.splat_u32x8(n_inv_mod_p_shoup);
354            let p_barrett = simd.splat_u32x8(p_barrett);
355            let p = simd.splat_u32x8(p);
356
357            for (lhs_, rhs) in crate::izip!(lhs, rhs) {
358                let lhs = cast(*lhs_);
359                let rhs = cast(*rhs);
360
361                // lhs × rhs
362                let (lo, hi) = simd.widening_mul_u32x8(lhs, rhs);
363                let c1 = simd.or_u32x8(
364                    simd.shr_dyn_u32x8(lo, big_q_m1),
365                    simd.shl_dyn_u32x8(hi, big_q_m1_complement),
366                );
367                let c3 = simd.widening_mul_u32x8(c1, p_barrett).1;
368                let prod = simd.wrapping_sub_u32x8(lo, simd.wrapping_mul_u32x8(p, c3));
369
370                // normalization
371                let shoup_q = simd.widening_mul_u32x8(prod, n_inv_mod_p_shoup).1;
372                let t = simd.wrapping_sub_u32x8(
373                    simd.wrapping_mul_u32x8(prod, n_inv_mod_p),
374                    simd.wrapping_mul_u32x8(shoup_q, p),
375                );
376
377                *lhs_ = cast(simd.small_mod_u32x8(p, t));
378            }
379        },
380    );
381}
382
383fn mul_assign_normalize_scalar(
384    lhs: &mut [u32],
385    rhs: &[u32],
386    p: u32,
387    p_barrett: u32,
388    big_q: u32,
389    n_inv_mod_p: u32,
390    n_inv_mod_p_shoup: u32,
391) {
392    let big_q_m1 = big_q - 1;
393
394    for (lhs_, rhs) in crate::izip!(lhs, rhs) {
395        let lhs = *lhs_;
396        let rhs = *rhs;
397
398        let d = lhs as u64 * rhs as u64;
399        let c1 = (d >> big_q_m1) as u32;
400        let c3 = ((c1 as u64 * p_barrett as u64) >> 32) as u32;
401        let prod = (d as u32).wrapping_sub(p.wrapping_mul(c3));
402
403        let shoup_q = (((prod as u64) * (n_inv_mod_p_shoup as u64)) >> 32) as u32;
404        let t = u32::wrapping_sub(prod.wrapping_mul(n_inv_mod_p), shoup_q.wrapping_mul(p));
405
406        *lhs_ = t.min(t.wrapping_sub(p));
407    }
408}
409
410#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
411#[cfg(feature = "nightly")]
412fn normalize_avx512(
413    simd: crate::V4,
414    values: &mut [u32],
415    p: u32,
416    n_inv_mod_p: u32,
417    n_inv_mod_p_shoup: u32,
418) {
419    simd.vectorize(
420        #[inline(always)]
421        move || {
422            let values = pulp::as_arrays_mut::<16, _>(values).0;
423
424            let n_inv_mod_p = simd.splat_u32x16(n_inv_mod_p);
425            let n_inv_mod_p_shoup = simd.splat_u32x16(n_inv_mod_p_shoup);
426            let p = simd.splat_u32x16(p);
427
428            for val_ in values {
429                let val = cast(*val_);
430
431                // normalization
432                let shoup_q = simd.widening_mul_u32x16(val, n_inv_mod_p_shoup).1;
433                let t = simd.wrapping_sub_u32x16(
434                    simd.wrapping_mul_u32x16(val, n_inv_mod_p),
435                    simd.wrapping_mul_u32x16(shoup_q, p),
436                );
437
438                *val_ = cast(simd.small_mod_u32x16(p, t));
439            }
440        },
441    );
442}
443
444#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
445fn normalize_avx2(
446    simd: crate::V3,
447    values: &mut [u32],
448    p: u32,
449    n_inv_mod_p: u32,
450    n_inv_mod_p_shoup: u32,
451) {
452    simd.vectorize(
453        #[inline(always)]
454        move || {
455            let values = pulp::as_arrays_mut::<8, _>(values).0;
456
457            let n_inv_mod_p = simd.splat_u32x8(n_inv_mod_p);
458            let n_inv_mod_p_shoup = simd.splat_u32x8(n_inv_mod_p_shoup);
459            let p = simd.splat_u32x8(p);
460
461            for val_ in values {
462                let val = cast(*val_);
463
464                // normalization
465                let shoup_q = simd.widening_mul_u32x8(val, n_inv_mod_p_shoup).1;
466                let t = simd.wrapping_sub_u32x8(
467                    simd.widening_mul_u32x8(val, n_inv_mod_p).0,
468                    simd.widening_mul_u32x8(shoup_q, p).0,
469                );
470
471                *val_ = cast(simd.small_mod_u32x8(p, t));
472            }
473        },
474    );
475}
476
477fn normalize_scalar(values: &mut [u32], p: u32, n_inv_mod_p: u32, n_inv_mod_p_shoup: u32) {
478    for val_ in values {
479        let val = *val_;
480
481        let shoup_q = (((val as u64) * (n_inv_mod_p_shoup as u64)) >> 32) as u32;
482        let t = u32::wrapping_sub(val.wrapping_mul(n_inv_mod_p), shoup_q.wrapping_mul(p));
483
484        *val_ = t.min(t.wrapping_sub(p));
485    }
486}
487
488#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
489#[cfg(feature = "nightly")]
490fn mul_accumulate_avx512(
491    simd: crate::V4,
492    acc: &mut [u32],
493    lhs: &[u32],
494    rhs: &[u32],
495    p: u32,
496    p_barrett: u32,
497    big_q: u32,
498) {
499    simd.vectorize(
500        #[inline(always)]
501        || {
502            let acc = pulp::as_arrays_mut::<16, _>(acc).0;
503            let lhs = pulp::as_arrays::<16, _>(lhs).0;
504            let rhs = pulp::as_arrays::<16, _>(rhs).0;
505
506            let big_q_m1 = simd.splat_u32x16(big_q - 1);
507            let big_q_m1_complement = simd.splat_u32x16(32 - (big_q - 1));
508            let p_barrett = simd.splat_u32x16(p_barrett);
509            let p = simd.splat_u32x16(p);
510
511            for (acc, lhs, rhs) in crate::izip!(acc, lhs, rhs) {
512                let lhs = cast(*lhs);
513                let rhs = cast(*rhs);
514
515                // lhs × rhs
516                let (lo, hi) = simd.widening_mul_u32x16(lhs, rhs);
517                let c1 = simd.or_u32x16(
518                    simd.shr_dyn_u32x16(lo, big_q_m1),
519                    simd.shl_dyn_u32x16(hi, big_q_m1_complement),
520                );
521                let c3 = simd.widening_mul_u32x16(c1, p_barrett).1;
522                // lo - p * c3
523                let prod = simd.wrapping_sub_u32x16(lo, simd.wrapping_mul_u32x16(p, c3));
524                let prod = simd.small_mod_u32x16(p, prod);
525
526                *acc = cast(simd.small_mod_u32x16(p, simd.wrapping_add_u32x16(prod, cast(*acc))));
527            }
528        },
529    )
530}
531
532#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
533fn mul_accumulate_avx2(
534    simd: crate::V3,
535    acc: &mut [u32],
536    lhs: &[u32],
537    rhs: &[u32],
538    p: u32,
539    p_barrett: u32,
540    big_q: u32,
541) {
542    simd.vectorize(
543        #[inline(always)]
544        || {
545            let acc = pulp::as_arrays_mut::<8, _>(acc).0;
546            let lhs = pulp::as_arrays::<8, _>(lhs).0;
547            let rhs = pulp::as_arrays::<8, _>(rhs).0;
548
549            let big_q_m1 = simd.splat_u32x8(big_q - 1);
550            let big_q_m1_complement = simd.splat_u32x8(32 - (big_q - 1));
551            let p_barrett = simd.splat_u32x8(p_barrett);
552            let p = simd.splat_u32x8(p);
553
554            for (acc, lhs, rhs) in crate::izip!(acc, lhs, rhs) {
555                let lhs = cast(*lhs);
556                let rhs = cast(*rhs);
557
558                // lhs × rhs
559                let (lo, hi) = simd.widening_mul_u32x8(lhs, rhs);
560                let c1 = simd.or_u32x8(
561                    simd.shr_dyn_u32x8(lo, big_q_m1),
562                    simd.shl_dyn_u32x8(hi, big_q_m1_complement),
563                );
564                let c3 = simd.widening_mul_u32x8(c1, p_barrett).1;
565                // lo - p * c3
566                let prod = simd.wrapping_sub_u32x8(lo, simd.wrapping_mul_u32x8(p, c3));
567                let prod = simd.small_mod_u32x8(p, prod);
568
569                *acc = cast(simd.small_mod_u32x8(p, simd.wrapping_add_u32x8(prod, cast(*acc))));
570            }
571        },
572    )
573}
574
575fn mul_accumulate_scalar(
576    acc: &mut [u32],
577    lhs: &[u32],
578    rhs: &[u32],
579    p: u32,
580    p_barrett: u32,
581    big_q: u32,
582) {
583    let big_q_m1 = big_q - 1;
584
585    for (acc, lhs, rhs) in crate::izip!(acc, lhs, rhs) {
586        let lhs = *lhs;
587        let rhs = *rhs;
588
589        let d = lhs as u64 * rhs as u64;
590        let c1 = (d >> big_q_m1) as u32;
591        let c3 = ((c1 as u64 * p_barrett as u64) >> 32) as u32;
592        let prod = (d as u32).wrapping_sub(p.wrapping_mul(c3));
593        let prod = prod.min(prod.wrapping_sub(p));
594
595        let acc_ = prod + *acc;
596        *acc = acc_.min(acc_.wrapping_sub(p));
597    }
598}
599
600struct BarrettInit32 {
601    big_q: u32,
602    p_barrett: u32,
603    requires_single_reduction_step: bool,
604}
605
606impl BarrettInit32 {
607    pub fn new(modulus: u32) -> Self {
608        let big_q = modulus.ilog2() + 1;
609        let big_l = big_q + 31;
610        let m_as_u64: u64 = modulus.into();
611        let two_to_the_l = 1u64 << big_l; // Equivalent to 2^{2k} from the zk security blog
612        let (p_barrett, beta) = ((two_to_the_l / m_as_u64) as u32, (two_to_the_l % m_as_u64));
613
614        // Check that the chosen prime will only trigger a single barrett reduction step with
615        // our implementation. If two reductions are needed there can be cases where it is not
616        // possible to decide whether a reduction is required yielding wrong results.
617        // Formula derived with https://blog.zksecurity.xyz/posts/barrett-tighter-bound/
618        let single_reduction_threshold = m_as_u64 - (1 << (big_q - 1));
619
620        let requires_single_reduction_step = beta <= single_reduction_threshold;
621
622        Self {
623            big_q,
624            p_barrett,
625            requires_single_reduction_step,
626        }
627    }
628}
629
630/// Negacyclic NTT plan for 32bit primes.
631#[derive(Clone)]
632pub struct Plan {
633    twid: ABox<[u32]>,
634    twid_shoup: ABox<[u32]>,
635    inv_twid: ABox<[u32]>,
636    inv_twid_shoup: ABox<[u32]>,
637    p: u32,
638    p_div: Div32,
639    /// If true can use non generic code and optimized reduction algorithms.
640    can_use_fast_reduction_code: bool,
641
642    // used for elementwise product
643    p_barrett: u32,
644    big_q: u32,
645
646    n_inv_mod_p: u32,
647    n_inv_mod_p_shoup: u32,
648}
649
650impl core::fmt::Debug for Plan {
651    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
652        f.debug_struct("Plan")
653            .field("ntt_size", &self.ntt_size())
654            .field("modulus", &self.modulus())
655            .finish()
656    }
657}
658
659impl Plan {
660    /// Returns a negacyclic NTT plan for the given polynomial size and modulus, or `None` if no
661    /// suitable roots of unity can be found for the wanted parameters.
662    pub fn try_new(polynomial_size: usize, modulus: u32) -> Option<Self> {
663        let p_div = Div32::new(modulus);
664        // 32 = 16x2 = max_register_size * ntt_radix,
665        // as SIMD registers can contain at most 16*u32
666        // and the implementation assumes that SIMD registers are full
667        if polynomial_size < 32
668            || !polynomial_size.is_power_of_two()
669            || !is_prime64(modulus as u64)
670            || find_primitive_root64(Div64::new(modulus as u64), 2 * polynomial_size as u64)
671                .is_none()
672        {
673            None
674        } else {
675            let mut twid = avec![0u32; polynomial_size].into_boxed_slice();
676            let mut inv_twid = avec![0u32; polynomial_size].into_boxed_slice();
677            let (mut twid_shoup, mut inv_twid_shoup) = if modulus < (1u32 << 31) {
678                (
679                    avec![0u32; polynomial_size].into_boxed_slice(),
680                    avec![0u32; polynomial_size].into_boxed_slice(),
681                )
682            } else {
683                (avec![].into_boxed_slice(), avec![].into_boxed_slice())
684            };
685
686            if modulus < (1u32 << 31) {
687                init_negacyclic_twiddles_shoup(
688                    modulus,
689                    polynomial_size,
690                    &mut twid,
691                    &mut twid_shoup,
692                    &mut inv_twid,
693                    &mut inv_twid_shoup,
694                );
695            } else {
696                init_negacyclic_twiddles(modulus, polynomial_size, &mut twid, &mut inv_twid);
697            }
698
699            let n_inv_mod_p = crate::prime::exp_mod32(p_div, polynomial_size as u32, modulus - 2);
700            let n_inv_mod_p_shoup = (((n_inv_mod_p as u64) << 32) / modulus as u64) as u32;
701
702            let BarrettInit32 {
703                big_q,
704                p_barrett,
705                requires_single_reduction_step,
706            } = BarrettInit32::new(modulus);
707
708            // The mul_accumulate_scalar code (and derived SIMD code) can have an overflow issue due
709            // to how Barrett reduction works. It can also overflow during the following
710            // accumulations. We'll derive the requirements for the mul_accumulate_scalar code to be
711            // correct and use that to determine whether a given prime is a "fast" prime for our
712            // code.
713            //
714            // You can check https://blog.zksecurity.xyz/posts/barrett-tighter-bound/ for a readable
715            // Barrett reduction analysis/breakdown.
716            //
717            // Barrett reduction gives an approximate value q_approx of the quotient q we are
718            // looking to compute the reduction
719            // which can differ from q by at most 2: q_approx € [q - 2, q - 1, q], given we subtract
720            // q * p from our result to start the reduction we have:
721            // prod = to_reduce - q * p
722            // prod = true_prod + c * p with c € {0, 1, 2}
723            // We need to make sure that prod does not overflow 32 bits
724            // We have true_prod which is reduced mod p, so true_prod <= p - 1
725            // prod <= 2^32 - 1 <=>
726            // true_prod + 2 * p <= 2^32 - 1 <=>
727            // 3 * p - 1 <= 2^32 - 1 <=>
728            // p <= 2^32 / 3 <= 1431655765.3333333 < 1431655766
729            //
730            // After the computation of prod a first reduction step is performed, meaning we now
731            // have:
732            // prod = true_prod + c * p with c € {0, 1}
733            // We are accumulating in acc which is already reduced, so acc <= p - 1
734            // The accumulation yields:
735            // We need acc + prod <= 2^32 - 1 <=>
736            // (p - 1) + (p - 1) + p  <= 2^32 - 1 <=>
737            // 3p - 2 <= 2^32 - 1 <=>
738            // 3p <= 2^32 + 1 <=>
739            // p <= (2^32 + 1) / 3 <= 1431655765.6666667 < 1431655766
740            //
741            // It is the same criterion.
742            // Now for cases where moduli are known to yield a Barrett reduction with a single step
743            // of reduction required (see blog post again) the conditions become:
744            // true_prod + p <= 2^32 - 1 <=>
745            // 2p - 1 <= 2^32 - 1 <=>
746            // p <= 2^31
747
748            let can_use_fast_reduction_code =
749                (modulus < 1431655766) || (requires_single_reduction_step && modulus <= (1 << 31));
750
751            Some(Self {
752                twid,
753                twid_shoup,
754                inv_twid_shoup,
755                inv_twid,
756                p: modulus,
757                p_div,
758                can_use_fast_reduction_code,
759                n_inv_mod_p,
760                n_inv_mod_p_shoup,
761                p_barrett,
762                big_q,
763            })
764        }
765    }
766
767    pub(crate) fn p_div(&self) -> Div32 {
768        self.p_div
769    }
770
771    /// Returns the polynomial size of the negacyclic NTT plan.
772    #[inline]
773    pub fn ntt_size(&self) -> usize {
774        self.twid.len()
775    }
776
777    /// Returns the modulus of the negacyclic NTT plan.
778    #[inline]
779    pub fn modulus(&self) -> u32 {
780        self.p
781    }
782
783    /// Returns whether the negacyclic NTT plan can use fast reduction code.
784    ///
785    /// To avoid correctness issues linked to overflows the code has to make performance sacrifices
786    /// for certain primes and will not yield the best performance possible for them.
787    #[inline]
788    pub fn can_use_fast_reduction_code(&self) -> bool {
789        self.can_use_fast_reduction_code
790    }
791
792    /// Applies a forward negacyclic NTT transform in place to the given buffer.
793    ///
794    /// # Note
795    /// On entry, the buffer holds the polynomial coefficients in standard order. On exit, the
796    /// buffer holds the negacyclic NTT transform coefficients in bit reversed order.
797    pub fn fwd(&self, buf: &mut [u32]) {
798        assert_eq!(buf.len(), self.ntt_size());
799        let p = self.p;
800
801        if p < (1u32 << 30) {
802            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
803            {
804                #[cfg(feature = "nightly")]
805                if let Some(simd) = crate::V4::try_new() {
806                    less_than_30bit::fwd_avx512(simd, p, buf, &self.twid, &self.twid_shoup);
807                    return;
808                }
809                if let Some(simd) = crate::V3::try_new() {
810                    less_than_30bit::fwd_avx2(simd, p, buf, &self.twid, &self.twid_shoup);
811                    return;
812                }
813            }
814            less_than_30bit::fwd_scalar(p, buf, &self.twid, &self.twid_shoup);
815        } else if p < (1u32 << 31) {
816            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
817            {
818                #[cfg(feature = "nightly")]
819                if let Some(simd) = crate::V4::try_new() {
820                    less_than_31bit::fwd_avx512(simd, p, buf, &self.twid, &self.twid_shoup);
821                    return;
822                }
823                if let Some(simd) = crate::V3::try_new() {
824                    less_than_31bit::fwd_avx2(simd, p, buf, &self.twid, &self.twid_shoup);
825                    return;
826                }
827            }
828            less_than_31bit::fwd_scalar(p, buf, &self.twid, &self.twid_shoup);
829        } else {
830            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
831            #[cfg(feature = "nightly")]
832            if let Some(simd) = crate::V4::try_new() {
833                generic::fwd_avx512(simd, buf, p, self.p_div, &self.twid);
834                return;
835            }
836            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
837            if let Some(simd) = crate::V3::try_new() {
838                generic::fwd_avx2(simd, buf, p, self.p_div, &self.twid);
839                return;
840            }
841            generic::fwd_scalar(buf, p, self.p_div, &self.twid);
842        }
843    }
844
845    /// Applies an inverse negacyclic NTT transform in place to the given buffer.
846    ///
847    /// # Note
848    /// On entry, the buffer holds the negacyclic NTT transform coefficients in bit reversed order.
849    /// On exit, the buffer holds the polynomial coefficients in standard order.
850    pub fn inv(&self, buf: &mut [u32]) {
851        assert_eq!(buf.len(), self.ntt_size());
852        let p = self.p;
853
854        if p < (1u32 << 30) {
855            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
856            {
857                #[cfg(feature = "nightly")]
858                if let Some(simd) = crate::V4::try_new() {
859                    less_than_30bit::inv_avx512(simd, p, buf, &self.inv_twid, &self.inv_twid_shoup);
860                    return;
861                }
862                if let Some(simd) = crate::V3::try_new() {
863                    less_than_30bit::inv_avx2(simd, p, buf, &self.inv_twid, &self.inv_twid_shoup);
864                    return;
865                }
866            }
867            less_than_30bit::inv_scalar(p, buf, &self.inv_twid, &self.inv_twid_shoup);
868        } else if p < (1u32 << 31) {
869            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
870            {
871                #[cfg(feature = "nightly")]
872                if let Some(simd) = crate::V4::try_new() {
873                    less_than_31bit::inv_avx512(simd, p, buf, &self.inv_twid, &self.inv_twid_shoup);
874                    return;
875                }
876                if let Some(simd) = crate::V3::try_new() {
877                    less_than_31bit::inv_avx2(simd, p, buf, &self.inv_twid, &self.inv_twid_shoup);
878                    return;
879                }
880            }
881            less_than_31bit::inv_scalar(p, buf, &self.inv_twid, &self.inv_twid_shoup);
882        } else {
883            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
884            #[cfg(feature = "nightly")]
885            if let Some(simd) = crate::V4::try_new() {
886                generic::inv_avx512(simd, buf, p, self.p_div, &self.inv_twid);
887                return;
888            }
889            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
890            if let Some(simd) = crate::V3::try_new() {
891                generic::inv_avx2(simd, buf, p, self.p_div, &self.inv_twid);
892                return;
893            }
894            generic::inv_scalar(buf, p, self.p_div, &self.inv_twid);
895        }
896    }
897
898    /// Computes the elementwise product of `lhs` and `rhs`, multiplied by the inverse of the
899    /// polynomial modulo the NTT modulus, and stores the result in `lhs`.
900    pub fn mul_assign_normalize(&self, lhs: &mut [u32], rhs: &[u32]) {
901        if self.can_use_fast_reduction_code {
902            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
903            #[cfg(feature = "nightly")]
904            if let Some(simd) = crate::V4::try_new() {
905                mul_assign_normalize_avx512(
906                    simd,
907                    lhs,
908                    rhs,
909                    self.p,
910                    self.p_barrett,
911                    self.big_q,
912                    self.n_inv_mod_p,
913                    self.n_inv_mod_p_shoup,
914                );
915                return;
916            }
917            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
918            if let Some(simd) = crate::V3::try_new() {
919                mul_assign_normalize_avx2(
920                    simd,
921                    lhs,
922                    rhs,
923                    self.p,
924                    self.p_barrett,
925                    self.big_q,
926                    self.n_inv_mod_p,
927                    self.n_inv_mod_p_shoup,
928                );
929                return;
930            }
931
932            mul_assign_normalize_scalar(
933                lhs,
934                rhs,
935                self.p,
936                self.p_barrett,
937                self.big_q,
938                self.n_inv_mod_p,
939                self.n_inv_mod_p_shoup,
940            );
941        } else {
942            let p_div = self.p_div;
943            let n_inv_mod_p = self.n_inv_mod_p;
944            for (lhs_, rhs) in crate::izip!(lhs, rhs) {
945                let lhs = *lhs_;
946                let rhs = *rhs;
947                let prod = Div32::rem_u64(lhs as u64 * rhs as u64, p_div);
948                let prod = Div32::rem_u64(prod as u64 * n_inv_mod_p as u64, p_div);
949                *lhs_ = prod;
950            }
951        }
952    }
953
954    /// Multiplies the values by the inverse of the polynomial modulo the NTT modulus, and stores
955    /// the result in `values`.
956    pub fn normalize(&self, values: &mut [u32]) {
957        if self.can_use_fast_reduction_code {
958            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
959            #[cfg(feature = "nightly")]
960            if let Some(simd) = crate::V4::try_new() {
961                normalize_avx512(
962                    simd,
963                    values,
964                    self.p,
965                    self.n_inv_mod_p,
966                    self.n_inv_mod_p_shoup,
967                );
968                return;
969            }
970            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
971            if let Some(simd) = crate::V3::try_new() {
972                normalize_avx2(
973                    simd,
974                    values,
975                    self.p,
976                    self.n_inv_mod_p,
977                    self.n_inv_mod_p_shoup,
978                );
979                return;
980            }
981            normalize_scalar(values, self.p, self.n_inv_mod_p, self.n_inv_mod_p_shoup);
982        } else {
983            let p_div = self.p_div;
984            let n_inv_mod_p = self.n_inv_mod_p;
985            for values in values {
986                let prod = Div32::rem_u64(*values as u64 * n_inv_mod_p as u64, p_div);
987                *values = prod;
988            }
989        }
990    }
991
992    /// Computes the elementwise product of `lhs` and `rhs` and accumulates the result to `acc`.
993    pub fn mul_accumulate(&self, acc: &mut [u32], lhs: &[u32], rhs: &[u32]) {
994        if self.can_use_fast_reduction_code {
995            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
996            #[cfg(feature = "nightly")]
997            if let Some(simd) = crate::V4::try_new() {
998                mul_accumulate_avx512(simd, acc, lhs, rhs, self.p, self.p_barrett, self.big_q);
999                return;
1000            }
1001            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1002            if let Some(simd) = crate::V3::try_new() {
1003                mul_accumulate_avx2(simd, acc, lhs, rhs, self.p, self.p_barrett, self.big_q);
1004                return;
1005            }
1006            mul_accumulate_scalar(acc, lhs, rhs, self.p, self.p_barrett, self.big_q);
1007        } else {
1008            let p = self.p;
1009            let p_div = self.p_div;
1010            for (acc, lhs, rhs) in crate::izip!(acc, lhs, rhs) {
1011                let prod = generic::mul(p_div, *lhs, *rhs);
1012                *acc = generic::add(p, *acc, prod);
1013            }
1014        }
1015    }
1016}
1017
1018#[cfg(test)]
1019pub mod tests {
1020    use super::*;
1021    use crate::prime::largest_prime_in_arithmetic_progression64;
1022    use alloc::{vec, vec::Vec};
1023    use rand::random;
1024
1025    extern crate alloc;
1026
1027    pub fn add(p: u32, a: u32, b: u32) -> u32 {
1028        let neg_b = p.wrapping_sub(b);
1029        if a >= neg_b {
1030            a - neg_b
1031        } else {
1032            a + b
1033        }
1034    }
1035
1036    pub fn sub(p: u32, a: u32, b: u32) -> u32 {
1037        let neg_b = p.wrapping_sub(b);
1038        if a >= b {
1039            a - b
1040        } else {
1041            a + neg_b
1042        }
1043    }
1044
1045    pub fn mul(p: u32, a: u32, b: u32) -> u32 {
1046        let wide = a as u64 * b as u64;
1047        if p == 0 {
1048            wide as u32
1049        } else {
1050            (wide % p as u64) as u32
1051        }
1052    }
1053
1054    pub fn negacyclic_convolution(n: usize, p: u32, lhs: &[u32], rhs: &[u32]) -> vec::Vec<u32> {
1055        let mut full_convolution = vec![0u32; 2 * n];
1056        let mut negacyclic_convolution = vec![0u32; n];
1057        for i in 0..n {
1058            for j in 0..n {
1059                full_convolution[i + j] = add(p, full_convolution[i + j], mul(p, lhs[i], rhs[j]));
1060            }
1061        }
1062        for i in 0..n {
1063            negacyclic_convolution[i] = sub(p, full_convolution[i], full_convolution[i + n]);
1064        }
1065        negacyclic_convolution
1066    }
1067
1068    pub fn random_lhs_rhs_with_negacyclic_convolution(
1069        n: usize,
1070        p: u32,
1071    ) -> (vec::Vec<u32>, vec::Vec<u32>, vec::Vec<u32>) {
1072        let mut lhs = vec![0u32; n];
1073        let mut rhs = vec![0u32; n];
1074
1075        for x in &mut lhs {
1076            *x = random();
1077            if p != 0 {
1078                *x %= p;
1079            }
1080        }
1081        for x in &mut rhs {
1082            *x = random();
1083            if p != 0 {
1084                *x %= p;
1085            }
1086        }
1087
1088        let lhs = lhs;
1089        let rhs = rhs;
1090
1091        let negacyclic_convolution = negacyclic_convolution(n, p, &lhs, &rhs);
1092        (lhs, rhs, negacyclic_convolution)
1093    }
1094
1095    #[test]
1096    fn test_product() {
1097        for n in [32, 64, 128, 256, 512, 1024] {
1098            for p in [
1099                largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 29, 1 << 30).unwrap(),
1100                largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 30, 1 << 31).unwrap(),
1101                largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 31, 1 << 32).unwrap(),
1102            ] {
1103                let p = p as u32;
1104                let plan = Plan::try_new(n, p).unwrap();
1105
1106                let (lhs, rhs, negacyclic_convolution) =
1107                    random_lhs_rhs_with_negacyclic_convolution(n, p);
1108
1109                let mut prod = vec![0u32; n];
1110                let mut lhs_fourier = lhs.clone();
1111                let mut rhs_fourier = rhs.clone();
1112
1113                plan.fwd(&mut lhs_fourier);
1114                plan.fwd(&mut rhs_fourier);
1115
1116                for x in &lhs_fourier {
1117                    assert!(*x < p);
1118                }
1119                for x in &rhs_fourier {
1120                    assert!(*x < p);
1121                }
1122
1123                for i in 0..n {
1124                    prod[i] = mul(p, lhs_fourier[i], rhs_fourier[i]);
1125                }
1126                plan.inv(&mut prod);
1127
1128                plan.mul_assign_normalize(&mut lhs_fourier, &rhs_fourier);
1129                plan.inv(&mut lhs_fourier);
1130
1131                for x in &prod {
1132                    assert!(*x < p);
1133                }
1134
1135                for i in 0..n {
1136                    assert_eq!(prod[i], mul(p, negacyclic_convolution[i], n as u32),);
1137                }
1138                assert_eq!(lhs_fourier, negacyclic_convolution);
1139            }
1140        }
1141    }
1142
1143    #[test]
1144    fn test_mul_assign_normalize_scalar() {
1145        let p =
1146            largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 30, 1 << 31).unwrap() as u32;
1147        let p_div = Div64::new(p as u64);
1148        let polynomial_size = 128;
1149
1150        let n_inv_mod_p =
1151            crate::prime::exp_mod64(p_div, polynomial_size as u64, p as u64 - 2) as u32;
1152        let n_inv_mod_p_shoup = (((n_inv_mod_p as u128) << 32) / p as u128) as u32;
1153        let big_q = p.ilog2() + 1;
1154        let big_l = big_q + 31;
1155        let p_barrett = ((1u128 << big_l) / p as u128) as u32;
1156
1157        let mut lhs = (0..polynomial_size)
1158            .map(|_| rand::random::<u32>() % p)
1159            .collect::<Vec<_>>();
1160        let mut lhs_target = lhs.clone();
1161        let rhs = (0..polynomial_size)
1162            .map(|_| rand::random::<u32>() % p)
1163            .collect::<Vec<_>>();
1164
1165        let mul = |a: u32, b: u32| ((a as u128 * b as u128) % p as u128) as u32;
1166
1167        for (lhs, rhs) in lhs_target.iter_mut().zip(&rhs) {
1168            *lhs = mul(mul(*lhs, *rhs), n_inv_mod_p);
1169        }
1170
1171        mul_assign_normalize_scalar(
1172            &mut lhs,
1173            &rhs,
1174            p,
1175            p_barrett,
1176            big_q,
1177            n_inv_mod_p,
1178            n_inv_mod_p_shoup,
1179        );
1180        assert_eq!(lhs, lhs_target);
1181    }
1182
1183    #[test]
1184    fn test_mul_accumulate_scalar() {
1185        let p =
1186            largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 30, 1 << 31).unwrap() as u32;
1187        let polynomial_size = 128;
1188
1189        let big_q = p.ilog2() + 1;
1190        let big_l = big_q + 31;
1191        let p_barrett = ((1u128 << big_l) / p as u128) as u32;
1192
1193        let mut acc = (0..polynomial_size)
1194            .map(|_| rand::random::<u32>() % p)
1195            .collect::<Vec<_>>();
1196        let mut acc_target = acc.clone();
1197        let lhs = (0..polynomial_size)
1198            .map(|_| rand::random::<u32>() % p)
1199            .collect::<Vec<_>>();
1200        let rhs = (0..polynomial_size)
1201            .map(|_| rand::random::<u32>() % p)
1202            .collect::<Vec<_>>();
1203
1204        let mul = |a: u32, b: u32| ((a as u128 * b as u128) % p as u128) as u32;
1205        let add = |a: u32, b: u32| (a + b) % p;
1206
1207        for (acc, lhs, rhs) in crate::izip!(&mut acc_target, &lhs, &rhs) {
1208            *acc = add(mul(*lhs, *rhs), *acc);
1209        }
1210
1211        mul_accumulate_scalar(&mut acc, &lhs, &rhs, p, p_barrett, big_q);
1212        assert_eq!(acc, acc_target);
1213    }
1214
1215    #[test]
1216    fn test_normalize_scalar() {
1217        let p =
1218            largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 30, 1 << 31).unwrap() as u32;
1219        let p_div = Div64::new(p as u64);
1220        let polynomial_size = 128;
1221
1222        let n_inv_mod_p =
1223            crate::prime::exp_mod64(p_div, polynomial_size as u64, p as u64 - 2) as u32;
1224        let n_inv_mod_p_shoup = (((n_inv_mod_p as u128) << 32) / p as u128) as u32;
1225
1226        let mut val = (0..polynomial_size)
1227            .map(|_| rand::random::<u32>() % p)
1228            .collect::<Vec<_>>();
1229        let mut val_target = val.clone();
1230
1231        let mul = |a: u32, b: u32| ((a as u128 * b as u128) % p as u128) as u32;
1232
1233        for val in val_target.iter_mut() {
1234            *val = mul(*val, n_inv_mod_p);
1235        }
1236
1237        normalize_scalar(&mut val, p, n_inv_mod_p, n_inv_mod_p_shoup);
1238        assert_eq!(val, val_target);
1239    }
1240
1241    #[test]
1242    fn test_mul_accumulate() {
1243        for p in [
1244            largest_prime_in_arithmetic_progression64(1 << 16, 1, 0, 1 << 31).unwrap(),
1245            largest_prime_in_arithmetic_progression64(1 << 16, 1, 0, u32::MAX as u64).unwrap(),
1246        ] {
1247            let p = p as u32;
1248            let polynomial_size = 128;
1249
1250            let mut acc = (0..polynomial_size)
1251                .map(|_| rand::random::<u32>() % p)
1252                .collect::<Vec<_>>();
1253            let mut acc_target = acc.clone();
1254            let lhs = (0..polynomial_size)
1255                .map(|_| rand::random::<u32>() % p)
1256                .collect::<Vec<_>>();
1257            let rhs = (0..polynomial_size)
1258                .map(|_| rand::random::<u32>() % p)
1259                .collect::<Vec<_>>();
1260
1261            let mul = |a: u32, b: u32| ((a as u64 * b as u64) % p as u64) as u32;
1262            let add = |a: u32, b: u32| ((a as u64 + b as u64) % p as u64) as u32;
1263
1264            for (acc, lhs, rhs) in crate::izip!(&mut acc_target, &lhs, &rhs) {
1265                *acc = add(mul(*lhs, *rhs), *acc);
1266            }
1267
1268            Plan::try_new(polynomial_size, p)
1269                .unwrap()
1270                .mul_accumulate(&mut acc, &lhs, &rhs);
1271            assert_eq!(acc, acc_target);
1272        }
1273    }
1274
1275    #[test]
1276    fn test_mul_assign_normalize() {
1277        for p in [
1278            largest_prime_in_arithmetic_progression64(1 << 16, 1, 0, 1 << 31).unwrap(),
1279            largest_prime_in_arithmetic_progression64(1 << 16, 1, 0, u32::MAX as u64).unwrap(),
1280        ] {
1281            let p = p as u32;
1282            let polynomial_size = 128;
1283            let p_div = Div64::new(p as u64);
1284            let n_inv_mod_p =
1285                crate::prime::exp_mod64(p_div, polynomial_size as u64, p as u64 - 2) as u32;
1286
1287            let mut lhs = (0..polynomial_size)
1288                .map(|_| rand::random::<u32>() % p)
1289                .collect::<Vec<_>>();
1290            let mut lhs_target = lhs.clone();
1291            let rhs = (0..polynomial_size)
1292                .map(|_| rand::random::<u32>() % p)
1293                .collect::<Vec<_>>();
1294
1295            let mul = |a: u32, b: u32| ((a as u64 * b as u64) % p as u64) as u32;
1296
1297            for (lhs, rhs) in lhs_target.iter_mut().zip(&rhs) {
1298                *lhs = mul(mul(*lhs, *rhs), n_inv_mod_p);
1299            }
1300
1301            Plan::try_new(polynomial_size, p)
1302                .unwrap()
1303                .mul_assign_normalize(&mut lhs, &rhs);
1304            assert_eq!(lhs, lhs_target);
1305        }
1306    }
1307
1308    #[test]
1309    fn test_normalize() {
1310        for p in [
1311            largest_prime_in_arithmetic_progression64(1 << 16, 1, 0, 1 << 31).unwrap(),
1312            largest_prime_in_arithmetic_progression64(1 << 16, 1, 0, u32::MAX as u64).unwrap(),
1313        ] {
1314            let p = p as u32;
1315            let polynomial_size = 128;
1316            let p_div = Div64::new(p as u64);
1317            let n_inv_mod_p =
1318                crate::prime::exp_mod64(p_div, polynomial_size as u64, p as u64 - 2) as u32;
1319
1320            let mut val = (0..polynomial_size)
1321                .map(|_| rand::random::<u32>() % p)
1322                .collect::<Vec<_>>();
1323            let mut val_target = val.clone();
1324
1325            let mul = |a: u32, b: u32| ((a as u64 * b as u64) % p as u64) as u32;
1326
1327            for val in &mut val_target {
1328                *val = mul(*val, n_inv_mod_p);
1329            }
1330
1331            Plan::try_new(polynomial_size, p)
1332                .unwrap()
1333                .normalize(&mut val);
1334            assert_eq!(val, val_target);
1335        }
1336    }
1337
1338    #[test]
1339    fn test_plan_can_use_fast_reduction_code() {
1340        use crate::primes32::{P0, P1, P2, P3, P4, P5, P6, P7, P8, P9};
1341        const POLYNOMIAL_SIZE: usize = 32;
1342
1343        // First prime is smaller than 1431655766
1344        // Second is larger, but satisfies the single reduction condition for Barrett
1345        // The other ones can be used for performant code, we want those to be fast
1346        for p in [
1347            1062862849, 1431669377, P0, P1, P2, P3, P4, P5, P6, P7, P8, P9,
1348        ] {
1349            let plan = Plan::try_new(POLYNOMIAL_SIZE, p).unwrap();
1350
1351            assert!(plan.can_use_fast_reduction_code);
1352        }
1353
1354        // Prime is bigger than threshold and does not satisfy the single reduction condition for
1355        // Barrett
1356        let plan = Plan::try_new(POLYNOMIAL_SIZE, 0x7fe0_1001).unwrap();
1357        assert!(!plan.can_use_fast_reduction_code);
1358    }
1359
1360    #[test]
1361    fn test_barret_invalid_reduction_non_regression() {
1362        const POLYNOMIAL_SIZE: usize = 32;
1363
1364        let p: u32 = 0x7fe0_1001;
1365        let plan = Plan::try_new(POLYNOMIAL_SIZE, p).unwrap();
1366
1367        let value = 0x6e63593a;
1368        let mut acc = [0u32; POLYNOMIAL_SIZE];
1369        // Essentially = [value, 0, 0, ...]
1370        let input: [u32; POLYNOMIAL_SIZE] =
1371            core::array::from_fn(|i| if i == 0 { value } else { 0 });
1372
1373        plan.mul_accumulate(&mut acc, &input, &input);
1374
1375        let expected = (u64::from(value) * u64::from(value) % u64::from(p)) as u32;
1376        assert_eq!(acc[0], expected);
1377    }
1378
1379    #[test]
1380    fn test_barrett_invalid_reduction_normalize_non_regression() {
1381        const POLYNOMIAL_SIZE: usize = 32;
1382
1383        let p: u32 = 0x7fe0_1001;
1384        let plan = Plan::try_new(POLYNOMIAL_SIZE, p).unwrap();
1385
1386        let value = 0x6e63593a;
1387        // Essentially = [value, 0, 0, ...]
1388        let input: [u32; POLYNOMIAL_SIZE] =
1389            core::array::from_fn(|i| if i == 0 { value } else { 0 });
1390        let mut acc = input;
1391
1392        plan.mul_assign_normalize(&mut acc, &input);
1393
1394        let expected = (u128::from(value) * u128::from(value) * u128::from(plan.n_inv_mod_p)
1395            % u128::from(p)) as u32;
1396        assert_eq!(acc[0], expected);
1397    }
1398}
1399
1400#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1401#[cfg(test)]
1402mod x86_tests {
1403    use super::*;
1404    use crate::prime::largest_prime_in_arithmetic_progression64;
1405    use alloc::vec::Vec;
1406    use rand::random as rnd;
1407
1408    extern crate alloc;
1409
1410    #[test]
1411    fn test_interleaves_and_permutes_u32x8() {
1412        if let Some(simd) = crate::V3::try_new() {
1413            let a = u32x8(rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd());
1414            let b = u32x8(rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd());
1415
1416            assert_eq!(
1417                simd.interleave4_u32x8([a, b]),
1418                [
1419                    u32x8(a.0, a.1, a.2, a.3, b.0, b.1, b.2, b.3),
1420                    u32x8(a.4, a.5, a.6, a.7, b.4, b.5, b.6, b.7),
1421                ],
1422            );
1423            assert_eq!(
1424                simd.interleave4_u32x8(simd.interleave4_u32x8([a, b])),
1425                [a, b],
1426            );
1427            let w = [rnd(), rnd()];
1428            assert_eq!(
1429                simd.permute4_u32x8(w),
1430                u32x8(w[0], w[0], w[0], w[0], w[1], w[1], w[1], w[1]),
1431            );
1432
1433            assert_eq!(
1434                simd.interleave2_u32x8([a, b]),
1435                [
1436                    u32x8(a.0, a.1, b.0, b.1, a.4, a.5, b.4, b.5),
1437                    u32x8(a.2, a.3, b.2, b.3, a.6, a.7, b.6, b.7),
1438                ],
1439            );
1440            assert_eq!(
1441                simd.interleave2_u32x8(simd.interleave2_u32x8([a, b])),
1442                [a, b],
1443            );
1444            let w = [rnd(), rnd(), rnd(), rnd()];
1445            assert_eq!(
1446                simd.permute2_u32x8(w),
1447                u32x8(w[0], w[0], w[2], w[2], w[1], w[1], w[3], w[3]),
1448            );
1449
1450            assert_eq!(
1451                simd.interleave1_u32x8([a, b]),
1452                [
1453                    u32x8(a.0, b.0, a.2, b.2, a.4, b.4, a.6, b.6),
1454                    u32x8(a.1, b.1, a.3, b.3, a.5, b.5, a.7, b.7),
1455                ],
1456            );
1457            assert_eq!(
1458                simd.interleave1_u32x8(simd.interleave1_u32x8([a, b])),
1459                [a, b],
1460            );
1461            let w = [rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd()];
1462            assert_eq!(
1463                simd.permute1_u32x8(w),
1464                u32x8(w[0], w[4], w[1], w[5], w[2], w[6], w[3], w[7]),
1465            );
1466        }
1467    }
1468
1469    #[cfg(feature = "nightly")]
1470    #[test]
1471    fn test_interleaves_and_permutes_u32x16() {
1472        if let Some(simd) = crate::V4::try_new() {
1473            #[rustfmt::skip]
1474            let a = u32x16(rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd());
1475            #[rustfmt::skip]
1476            let b = u32x16(rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd());
1477
1478            assert_eq!(
1479                simd.interleave8_u32x16([a, b]),
1480                [
1481                    u32x16(
1482                        a.0, a.1, a.2, a.3, a.4, a.5, a.6, a.7, b.0, b.1, b.2, b.3, b.4, b.5, b.6,
1483                        b.7,
1484                    ),
1485                    u32x16(
1486                        a.8, a.9, a.10, a.11, a.12, a.13, a.14, a.15, b.8, b.9, b.10, b.11, b.12,
1487                        b.13, b.14, b.15,
1488                    ),
1489                ],
1490            );
1491            assert_eq!(
1492                simd.interleave8_u32x16(simd.interleave8_u32x16([a, b])),
1493                [a, b],
1494            );
1495            let w = [rnd(), rnd()];
1496            assert_eq!(
1497                simd.permute8_u32x16(w),
1498                u32x16(
1499                    w[0], w[0], w[0], w[0], w[0], w[0], w[0], w[0], w[1], w[1], w[1], w[1], w[1],
1500                    w[1], w[1], w[1],
1501                ),
1502            );
1503
1504            assert_eq!(
1505                simd.interleave4_u32x16([a, b]),
1506                [
1507                    u32x16(
1508                        a.0, a.1, a.2, a.3, b.0, b.1, b.2, b.3, a.8, a.9, a.10, a.11, b.8, b.9,
1509                        b.10, b.11,
1510                    ),
1511                    u32x16(
1512                        a.4, a.5, a.6, a.7, b.4, b.5, b.6, b.7, a.12, a.13, a.14, a.15, b.12, b.13,
1513                        b.14, b.15,
1514                    ),
1515                ],
1516            );
1517            assert_eq!(
1518                simd.interleave4_u32x16(simd.interleave4_u32x16([a, b])),
1519                [a, b],
1520            );
1521            let w = [rnd(), rnd(), rnd(), rnd()];
1522            assert_eq!(
1523                simd.permute4_u32x16(w),
1524                u32x16(
1525                    w[0], w[0], w[0], w[0], w[2], w[2], w[2], w[2], w[1], w[1], w[1], w[1], w[3],
1526                    w[3], w[3], w[3],
1527                ),
1528            );
1529
1530            assert_eq!(
1531                simd.interleave2_u32x16([a, b]),
1532                [
1533                    u32x16(
1534                        a.0, a.1, b.0, b.1, a.4, a.5, b.4, b.5, a.8, a.9, b.8, b.9, a.12, a.13,
1535                        b.12, b.13,
1536                    ),
1537                    u32x16(
1538                        a.2, a.3, b.2, b.3, a.6, a.7, b.6, b.7, a.10, a.11, b.10, b.11, a.14, a.15,
1539                        b.14, b.15,
1540                    ),
1541                ],
1542            );
1543            assert_eq!(
1544                simd.interleave2_u32x16(simd.interleave2_u32x16([a, b])),
1545                [a, b],
1546            );
1547            let w = [rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd()];
1548            assert_eq!(
1549                simd.permute2_u32x16(w),
1550                u32x16(
1551                    w[0], w[0], w[4], w[4], w[1], w[1], w[5], w[5], w[2], w[2], w[6], w[6], w[3],
1552                    w[3], w[7], w[7],
1553                ),
1554            );
1555
1556            assert_eq!(
1557                simd.interleave1_u32x16([a, b]),
1558                [
1559                    u32x16(
1560                        a.0, b.0, a.2, b.2, a.4, b.4, a.6, b.6, a.8, b.8, a.10, b.10, a.12, b.12,
1561                        a.14, b.14,
1562                    ),
1563                    u32x16(
1564                        a.1, b.1, a.3, b.3, a.5, b.5, a.7, b.7, a.9, b.9, a.11, b.11, a.13, b.13,
1565                        a.15, b.15,
1566                    ),
1567                ],
1568            );
1569            assert_eq!(
1570                simd.interleave1_u32x16(simd.interleave1_u32x16([a, b])),
1571                [a, b],
1572            );
1573            #[rustfmt::skip]
1574            let w = [rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd()];
1575            assert_eq!(
1576                simd.permute1_u32x16(w),
1577                u32x16(
1578                    w[0], w[8], w[1], w[9], w[2], w[10], w[3], w[11], w[4], w[12], w[5], w[13],
1579                    w[6], w[14], w[7], w[15],
1580                ),
1581            );
1582        }
1583    }
1584
1585    #[cfg(feature = "nightly")]
1586    #[test]
1587    fn test_mul_assign_normalize_avx512() {
1588        if let Some(simd) = crate::V4::try_new() {
1589            let p = largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 30, 1 << 31).unwrap()
1590                as u32;
1591            let p_div = Div64::new(p as u64);
1592            let polynomial_size = 128;
1593
1594            let n_inv_mod_p =
1595                crate::prime::exp_mod64(p_div, polynomial_size as u64, p as u64 - 2) as u32;
1596            let n_inv_mod_p_shoup = (((n_inv_mod_p as u128) << 32) / p as u128) as u32;
1597            let big_q = p.ilog2() + 1;
1598            let big_l = big_q + 31;
1599            let p_barrett = ((1u128 << big_l) / p as u128) as u32;
1600
1601            let mut lhs = (0..polynomial_size)
1602                .map(|_| rand::random::<u32>() % p)
1603                .collect::<Vec<_>>();
1604            let mut lhs_target = lhs.clone();
1605            let rhs = (0..polynomial_size)
1606                .map(|_| rand::random::<u32>() % p)
1607                .collect::<Vec<_>>();
1608
1609            let mul = |a: u32, b: u32| ((a as u128 * b as u128) % p as u128) as u32;
1610
1611            for (lhs, rhs) in lhs_target.iter_mut().zip(&rhs) {
1612                *lhs = mul(mul(*lhs, *rhs), n_inv_mod_p);
1613            }
1614
1615            mul_assign_normalize_avx512(
1616                simd,
1617                &mut lhs,
1618                &rhs,
1619                p,
1620                p_barrett,
1621                big_q,
1622                n_inv_mod_p,
1623                n_inv_mod_p_shoup,
1624            );
1625            assert_eq!(lhs, lhs_target);
1626        }
1627    }
1628
1629    #[test]
1630    fn test_mul_assign_normalize_avx2() {
1631        if let Some(simd) = crate::V3::try_new() {
1632            let p = largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 30, 1 << 31).unwrap()
1633                as u32;
1634            let p_div = Div64::new(p as u64);
1635            let polynomial_size = 128;
1636
1637            let n_inv_mod_p =
1638                crate::prime::exp_mod64(p_div, polynomial_size as u64, p as u64 - 2) as u32;
1639            let n_inv_mod_p_shoup = (((n_inv_mod_p as u128) << 32) / p as u128) as u32;
1640            let big_q = p.ilog2() + 1;
1641            let big_l = big_q + 31;
1642            let p_barrett = ((1u128 << big_l) / p as u128) as u32;
1643
1644            let mut lhs = (0..polynomial_size)
1645                .map(|_| rand::random::<u32>() % p)
1646                .collect::<Vec<_>>();
1647            let mut lhs_target = lhs.clone();
1648            let rhs = (0..polynomial_size)
1649                .map(|_| rand::random::<u32>() % p)
1650                .collect::<Vec<_>>();
1651
1652            let mul = |a: u32, b: u32| ((a as u128 * b as u128) % p as u128) as u32;
1653
1654            for (lhs, rhs) in lhs_target.iter_mut().zip(&rhs) {
1655                *lhs = mul(mul(*lhs, *rhs), n_inv_mod_p);
1656            }
1657
1658            mul_assign_normalize_avx2(
1659                simd,
1660                &mut lhs,
1661                &rhs,
1662                p,
1663                p_barrett,
1664                big_q,
1665                n_inv_mod_p,
1666                n_inv_mod_p_shoup,
1667            );
1668            assert_eq!(lhs, lhs_target);
1669        }
1670    }
1671
1672    #[cfg(feature = "nightly")]
1673    #[test]
1674    fn test_mul_accumulate_avx512() {
1675        if let Some(simd) = crate::V4::try_new() {
1676            let p = largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 30, 1 << 31).unwrap()
1677                as u32;
1678            let polynomial_size = 128;
1679
1680            let big_q = p.ilog2() + 1;
1681            let big_l = big_q + 31;
1682            let p_barrett = ((1u128 << big_l) / p as u128) as u32;
1683
1684            let mut acc = (0..polynomial_size)
1685                .map(|_| rand::random::<u32>() % p)
1686                .collect::<Vec<_>>();
1687            let mut acc_target = acc.clone();
1688            let lhs = (0..polynomial_size)
1689                .map(|_| rand::random::<u32>() % p)
1690                .collect::<Vec<_>>();
1691            let rhs = (0..polynomial_size)
1692                .map(|_| rand::random::<u32>() % p)
1693                .collect::<Vec<_>>();
1694
1695            let mul = |a: u32, b: u32| ((a as u128 * b as u128) % p as u128) as u32;
1696            let add = |a: u32, b: u32| (a + b) % p;
1697
1698            for (acc, lhs, rhs) in crate::izip!(&mut acc_target, &lhs, &rhs) {
1699                *acc = add(mul(*lhs, *rhs), *acc);
1700            }
1701
1702            mul_accumulate_avx512(simd, &mut acc, &lhs, &rhs, p, p_barrett, big_q);
1703            assert_eq!(acc, acc_target);
1704        }
1705    }
1706
1707    #[test]
1708    fn test_mul_accumulate_avx2() {
1709        if let Some(simd) = crate::V3::try_new() {
1710            let p = largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 30, 1 << 31).unwrap()
1711                as u32;
1712            let polynomial_size = 128;
1713
1714            let big_q = p.ilog2() + 1;
1715            let big_l = big_q + 31;
1716            let p_barrett = ((1u128 << big_l) / p as u128) as u32;
1717
1718            let mut acc = (0..polynomial_size)
1719                .map(|_| rand::random::<u32>() % p)
1720                .collect::<Vec<_>>();
1721            let mut acc_target = acc.clone();
1722            let lhs = (0..polynomial_size)
1723                .map(|_| rand::random::<u32>() % p)
1724                .collect::<Vec<_>>();
1725            let rhs = (0..polynomial_size)
1726                .map(|_| rand::random::<u32>() % p)
1727                .collect::<Vec<_>>();
1728
1729            let mul = |a: u32, b: u32| ((a as u128 * b as u128) % p as u128) as u32;
1730            let add = |a: u32, b: u32| (a + b) % p;
1731
1732            for (acc, lhs, rhs) in crate::izip!(&mut acc_target, &lhs, &rhs) {
1733                *acc = add(mul(*lhs, *rhs), *acc);
1734            }
1735
1736            mul_accumulate_avx2(simd, &mut acc, &lhs, &rhs, p, p_barrett, big_q);
1737            assert_eq!(acc, acc_target);
1738        }
1739    }
1740
1741    #[cfg(feature = "nightly")]
1742    #[test]
1743    fn test_normalize_avx512() {
1744        if let Some(simd) = crate::V4::try_new() {
1745            let p = largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 30, 1 << 31).unwrap()
1746                as u32;
1747            let p_div = Div64::new(p as u64);
1748            let polynomial_size = 128;
1749
1750            let n_inv_mod_p =
1751                crate::prime::exp_mod64(p_div, polynomial_size as u64, p as u64 - 2) as u32;
1752            let n_inv_mod_p_shoup = (((n_inv_mod_p as u128) << 32) / p as u128) as u32;
1753
1754            let mut val = (0..polynomial_size)
1755                .map(|_| rand::random::<u32>() % p)
1756                .collect::<Vec<_>>();
1757            let mut val_target = val.clone();
1758
1759            let mul = |a: u32, b: u32| ((a as u128 * b as u128) % p as u128) as u32;
1760
1761            for val in val_target.iter_mut() {
1762                *val = mul(*val, n_inv_mod_p);
1763            }
1764
1765            normalize_avx512(simd, &mut val, p, n_inv_mod_p, n_inv_mod_p_shoup);
1766            assert_eq!(val, val_target);
1767        }
1768    }
1769
1770    #[test]
1771    fn test_normalize_avx2() {
1772        if let Some(simd) = crate::V3::try_new() {
1773            let p = largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 30, 1 << 31).unwrap()
1774                as u32;
1775            let p_div = Div64::new(p as u64);
1776            let polynomial_size = 128;
1777
1778            let n_inv_mod_p =
1779                crate::prime::exp_mod64(p_div, polynomial_size as u64, p as u64 - 2) as u32;
1780            let n_inv_mod_p_shoup = (((n_inv_mod_p as u128) << 32) / p as u128) as u32;
1781
1782            let mut val = (0..polynomial_size)
1783                .map(|_| rand::random::<u32>() % p)
1784                .collect::<Vec<_>>();
1785            let mut val_target = val.clone();
1786
1787            let mul = |a: u32, b: u32| ((a as u128 * b as u128) % p as u128) as u32;
1788
1789            for val in val_target.iter_mut() {
1790                *val = mul(*val, n_inv_mod_p);
1791            }
1792
1793            normalize_avx2(simd, &mut val, p, n_inv_mod_p, n_inv_mod_p_shoup);
1794            assert_eq!(val, val_target);
1795        }
1796    }
1797}