1use crate::{
2 bit_rev,
3 fastdiv::{Div32, Div64},
4 prime::is_prime64,
5 roots::find_primitive_root64,
6};
7use aligned_vec::{avec, ABox};
8
9#[allow(unused_imports)]
10use pulp::*;
11
12const RECURSION_THRESHOLD: usize = 2048;
13
14mod generic;
15mod shoup;
16
17mod less_than_30bit;
18mod less_than_31bit;
19
20#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
21impl crate::V3 {
22 #[inline(always)]
23 fn interleave4_u32x8(self, z0z0z0z0z1z1z1z1: [u32x8; 2]) -> [u32x8; 2] {
24 let avx = self.avx;
25 [
26 cast(avx._mm256_permute2f128_si256::<0b0010_0000>(
27 cast(z0z0z0z0z1z1z1z1[0]),
28 cast(z0z0z0z0z1z1z1z1[1]),
29 )),
30 cast(avx._mm256_permute2f128_si256::<0b0011_0001>(
31 cast(z0z0z0z0z1z1z1z1[0]),
32 cast(z0z0z0z0z1z1z1z1[1]),
33 )),
34 ]
35 }
36
37 #[inline(always)]
38 fn permute4_u32x8(self, w: [u32; 2]) -> u32x8 {
39 let avx = self.avx;
40 let w0 = self.sse2._mm_set1_epi32(w[0] as i32);
41 let w1 = self.sse2._mm_set1_epi32(w[1] as i32);
42 cast(avx._mm256_insertf128_si256::<1>(avx._mm256_castsi128_si256(w0), w1))
43 }
44
45 #[inline(always)]
46 fn interleave2_u32x8(self, z0z0z1z1: [u32x8; 2]) -> [u32x8; 2] {
47 let avx = self.avx2;
48 [
49 cast(avx._mm256_unpacklo_epi64(cast(z0z0z1z1[0]), cast(z0z0z1z1[1]))),
50 cast(avx._mm256_unpackhi_epi64(cast(z0z0z1z1[0]), cast(z0z0z1z1[1]))),
51 ]
52 }
53
54 #[inline(always)]
55 fn permute2_u32x8(self, w: [u32; 4]) -> u32x8 {
56 let avx = self.avx;
57 let w0123 = pulp::cast(w);
58 let w0022 = self.sse2._mm_castps_si128(self.sse3._mm_moveldup_ps(w0123));
59 let w1133 = self.sse2._mm_castps_si128(self.sse3._mm_movehdup_ps(w0123));
60 cast(avx._mm256_insertf128_si256::<1>(avx._mm256_castsi128_si256(w0022), w1133))
61 }
62
63 #[inline(always)]
64 fn interleave1_u32x8(self, z0z0z1z1: [u32x8; 2]) -> [u32x8; 2] {
65 let avx = self.avx2;
66 let x = [
67 (avx._mm256_unpacklo_epi32(cast(z0z0z1z1[0]), cast(z0z0z1z1[1]))),
69 avx._mm256_unpackhi_epi32(cast(z0z0z1z1[0]), cast(z0z0z1z1[1])),
71 ];
72 [
73 cast(avx._mm256_unpacklo_epi64(x[0], x[1])),
75 cast(avx._mm256_unpackhi_epi64(x[0], x[1])),
77 ]
78 }
79
80 #[inline(always)]
81 fn permute1_u32x8(self, w: [u32; 8]) -> u32x8 {
82 let avx = self.avx;
83 let [w0123, w4567]: [u32x4; 2] = pulp::cast(w);
84 let w0415 = self.sse2._mm_unpacklo_epi32(cast(w0123), cast(w4567));
85 let w2637 = self.sse2._mm_unpackhi_epi32(cast(w0123), cast(w4567));
86 cast(avx._mm256_insertf128_si256::<1>(avx._mm256_castsi128_si256(w0415), w2637))
87 }
88
89 #[inline(always)]
90 pub fn small_mod_u32x8(self, modulus: u32x8, x: u32x8) -> u32x8 {
91 self.select_u32x8(
92 self.cmp_gt_u32x8(modulus, x),
93 x,
94 self.wrapping_sub_u32x8(x, modulus),
95 )
96 }
97}
98
99#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
100#[cfg(feature = "nightly")]
101impl crate::V4 {
102 #[inline(always)]
103 fn interleave8_u32x16(self, z0z0z0z0z0z0z0z0z1z1z1z1z1z1z1z1: [u32x16; 2]) -> [u32x16; 2] {
104 let avx = self.avx512f;
105 let idx_0 = avx._mm512_setr_epi32(
106 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17,
107 );
108 let idx_1 = avx._mm512_setr_epi32(
109 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f,
110 );
111 [
112 cast(avx._mm512_permutex2var_epi32(
113 cast(z0z0z0z0z0z0z0z0z1z1z1z1z1z1z1z1[0]),
114 idx_0,
115 cast(z0z0z0z0z0z0z0z0z1z1z1z1z1z1z1z1[1]),
116 )),
117 cast(avx._mm512_permutex2var_epi32(
118 cast(z0z0z0z0z0z0z0z0z1z1z1z1z1z1z1z1[0]),
119 idx_1,
120 cast(z0z0z0z0z0z0z0z0z1z1z1z1z1z1z1z1[1]),
121 )),
122 ]
123 }
124
125 #[inline(always)]
126 fn permute8_u32x16(self, w: [u32; 2]) -> u32x16 {
127 let avx = self.avx512f;
128 let w = pulp::cast(w);
129 let w01xxxxxxxxxxxxxx = avx._mm512_set1_epi64(w);
130 let idx = avx._mm512_setr_epi32(0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1);
131 cast(avx._mm512_permutexvar_epi32(idx, w01xxxxxxxxxxxxxx))
132 }
133
134 #[inline(always)]
135 fn interleave4_u32x16(self, z0z0z0z0z1z1z1z1: [u32x16; 2]) -> [u32x16; 2] {
136 let avx = self.avx512f;
137 let idx_0 = avx._mm512_setr_epi32(
138 0x0, 0x1, 0x2, 0x3, 0x10, 0x11, 0x12, 0x13, 0x8, 0x9, 0xa, 0xb, 0x18, 0x19, 0x1a, 0x1b,
139 );
140 let idx_1 = avx._mm512_setr_epi32(
141 0x4, 0x5, 0x6, 0x7, 0x14, 0x15, 0x16, 0x17, 0xc, 0xd, 0xe, 0xf, 0x1c, 0x1d, 0x1e, 0x1f,
142 );
143 [
144 cast(avx._mm512_permutex2var_epi32(
145 cast(z0z0z0z0z1z1z1z1[0]),
146 idx_0,
147 cast(z0z0z0z0z1z1z1z1[1]),
148 )),
149 cast(avx._mm512_permutex2var_epi32(
150 cast(z0z0z0z0z1z1z1z1[0]),
151 idx_1,
152 cast(z0z0z0z0z1z1z1z1[1]),
153 )),
154 ]
155 }
156
157 #[inline(always)]
158 fn permute4_u32x16(self, w: [u32; 4]) -> u32x16 {
159 let avx = self.avx512f;
160 let w = pulp::cast(w);
161 let w0123xxxxxxxxxxxx = avx._mm512_castsi128_si512(w);
162 let idx = avx._mm512_setr_epi32(0, 0, 0, 0, 2, 2, 2, 2, 1, 1, 1, 1, 3, 3, 3, 3);
163 cast(avx._mm512_permutexvar_epi32(idx, w0123xxxxxxxxxxxx))
164 }
165
166 #[inline(always)]
167 fn interleave2_u32x16(self, z0z0z1z1: [u32x16; 2]) -> [u32x16; 2] {
168 let avx = self.avx512f;
169 [
170 cast(avx._mm512_unpacklo_epi64(cast(z0z0z1z1[0]), cast(z0z0z1z1[1]))),
172 cast(avx._mm512_unpackhi_epi64(cast(z0z0z1z1[0]), cast(z0z0z1z1[1]))),
174 ]
175 }
176
177 #[inline(always)]
178 fn permute2_u32x16(self, w: [u32; 8]) -> u32x16 {
179 let avx = self.avx512f;
180 let w = pulp::cast(w);
181 let w01234567xxxxxxxx = avx._mm512_castsi256_si512(w);
182 let idx = avx._mm512_setr_epi32(0, 0, 4, 4, 1, 1, 5, 5, 2, 2, 6, 6, 3, 3, 7, 7);
183 cast(avx._mm512_permutexvar_epi32(idx, w01234567xxxxxxxx))
184 }
185
186 #[inline(always)]
187 fn interleave1_u32x16(self, z0z1: [u32x16; 2]) -> [u32x16; 2] {
188 let avx = self.avx512f;
189 let x = [
190 avx._mm512_unpacklo_epi32(cast(z0z1[0]), cast(z0z1[1])),
192 avx._mm512_unpackhi_epi32(cast(z0z1[0]), cast(z0z1[1])),
194 ];
195 [
196 cast(avx._mm512_unpacklo_epi64(x[0], x[1])),
198 cast(avx._mm512_unpackhi_epi64(x[0], x[1])),
200 ]
201 }
202
203 #[inline(always)]
204 fn permute1_u32x16(self, w: [u32; 16]) -> u32x16 {
205 let avx = self.avx512f;
206 let w = pulp::cast(w);
207 let idx = avx._mm512_setr_epi32(
208 0x0, 0x8, 0x1, 0x9, 0x2, 0xa, 0x3, 0xb, 0x4, 0xc, 0x5, 0xd, 0x6, 0xe, 0x7, 0xf,
209 );
210 cast(avx._mm512_permutexvar_epi32(idx, w))
211 }
212
213 #[inline(always)]
214 pub fn small_mod_u32x16(self, modulus: u32x16, x: u32x16) -> u32x16 {
215 self.select_u32x16(
216 self.cmp_gt_u32x16(modulus, x),
217 x,
218 self.wrapping_sub_u32x16(x, modulus),
219 )
220 }
221}
222
223fn init_negacyclic_twiddles(p: u32, n: usize, twid: &mut [u32], inv_twid: &mut [u32]) {
224 let div = Div32::new(p);
225 let w = find_primitive_root64(Div64::new(p as u64), 2 * n as u64).unwrap() as u32;
226 let mut k = 0;
227 let mut wk = 1u32;
228
229 let nbits = n.trailing_zeros();
230 while k < n {
231 let fwd_idx = bit_rev(nbits, k);
232
233 twid[fwd_idx] = wk;
234
235 let inv_idx = bit_rev(nbits, (n - k) % n);
236 if k == 0 {
237 inv_twid[inv_idx] = wk;
238 } else {
239 let x = p.wrapping_sub(wk);
240 inv_twid[inv_idx] = x;
241 }
242
243 wk = Div32::rem_u64(wk as u64 * w as u64, div);
244 k += 1;
245 }
246}
247
248fn init_negacyclic_twiddles_shoup(
249 p: u32,
250 n: usize,
251 twid: &mut [u32],
252 twid_shoup: &mut [u32],
253 inv_twid: &mut [u32],
254 inv_twid_shoup: &mut [u32],
255) {
256 let div = Div32::new(p);
257 let w = find_primitive_root64(Div64::new(p as u64), 2 * n as u64).unwrap() as u32;
258 let mut k = 0;
259 let mut wk = 1u32;
260
261 let nbits = n.trailing_zeros();
262 while k < n {
263 let fwd_idx = bit_rev(nbits, k);
264
265 let wk_shoup = Div32::div_u64((wk as u64) << 32, div) as u32;
266 twid[fwd_idx] = wk;
267 twid_shoup[fwd_idx] = wk_shoup;
268
269 let inv_idx = bit_rev(nbits, (n - k) % n);
270 if k == 0 {
271 inv_twid[inv_idx] = wk;
272 inv_twid_shoup[inv_idx] = wk_shoup;
273 } else {
274 let x = p.wrapping_sub(wk);
275 inv_twid[inv_idx] = x;
276 inv_twid_shoup[inv_idx] = Div32::div_u64((x as u64) << 32, div) as u32;
277 }
278
279 wk = Div32::rem_u64(wk as u64 * w as u64, div);
280 k += 1;
281 }
282}
283
284#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
285#[cfg(feature = "nightly")]
286fn mul_assign_normalize_avx512(
287 simd: crate::V4,
288 lhs: &mut [u32],
289 rhs: &[u32],
290 p: u32,
291 p_barrett: u32,
292 big_q: u32,
293 n_inv_mod_p: u32,
294 n_inv_mod_p_shoup: u32,
295) {
296 simd.vectorize(
297 #[inline(always)]
298 move || {
299 let lhs = pulp::as_arrays_mut::<16, _>(lhs).0;
300 let rhs = pulp::as_arrays::<16, _>(rhs).0;
301 let big_q_m1 = simd.splat_u32x16(big_q - 1);
302 let big_q_m1_complement = simd.splat_u32x16(32 - (big_q - 1));
303 let n_inv_mod_p = simd.splat_u32x16(n_inv_mod_p);
304 let n_inv_mod_p_shoup = simd.splat_u32x16(n_inv_mod_p_shoup);
305 let p_barrett = simd.splat_u32x16(p_barrett);
306 let p = simd.splat_u32x16(p);
307
308 for (lhs_, rhs) in crate::izip!(lhs, rhs) {
309 let lhs = cast(*lhs_);
310 let rhs = cast(*rhs);
311
312 let (lo, hi) = simd.widening_mul_u32x16(lhs, rhs);
314 let c1 = simd.or_u32x16(
315 simd.shr_dyn_u32x16(lo, big_q_m1),
316 simd.shl_dyn_u32x16(hi, big_q_m1_complement),
317 );
318 let c3 = simd.widening_mul_u32x16(c1, p_barrett).1;
319 let prod = simd.wrapping_sub_u32x16(lo, simd.wrapping_mul_u32x16(p, c3));
320
321 let shoup_q = simd.widening_mul_u32x16(prod, n_inv_mod_p_shoup).1;
323 let t = simd.wrapping_sub_u32x16(
324 simd.wrapping_mul_u32x16(prod, n_inv_mod_p),
325 simd.wrapping_mul_u32x16(shoup_q, p),
326 );
327
328 *lhs_ = cast(simd.small_mod_u32x16(p, t));
329 }
330 },
331 );
332}
333
334#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
335fn mul_assign_normalize_avx2(
336 simd: crate::V3,
337 lhs: &mut [u32],
338 rhs: &[u32],
339 p: u32,
340 p_barrett: u32,
341 big_q: u32,
342 n_inv_mod_p: u32,
343 n_inv_mod_p_shoup: u32,
344) {
345 simd.vectorize(
346 #[inline(always)]
347 move || {
348 let lhs = pulp::as_arrays_mut::<8, _>(lhs).0;
349 let rhs = pulp::as_arrays::<8, _>(rhs).0;
350 let big_q_m1 = simd.splat_u32x8(big_q - 1);
351 let big_q_m1_complement = simd.splat_u32x8(32 - (big_q - 1));
352 let n_inv_mod_p = simd.splat_u32x8(n_inv_mod_p);
353 let n_inv_mod_p_shoup = simd.splat_u32x8(n_inv_mod_p_shoup);
354 let p_barrett = simd.splat_u32x8(p_barrett);
355 let p = simd.splat_u32x8(p);
356
357 for (lhs_, rhs) in crate::izip!(lhs, rhs) {
358 let lhs = cast(*lhs_);
359 let rhs = cast(*rhs);
360
361 let (lo, hi) = simd.widening_mul_u32x8(lhs, rhs);
363 let c1 = simd.or_u32x8(
364 simd.shr_dyn_u32x8(lo, big_q_m1),
365 simd.shl_dyn_u32x8(hi, big_q_m1_complement),
366 );
367 let c3 = simd.widening_mul_u32x8(c1, p_barrett).1;
368 let prod = simd.wrapping_sub_u32x8(lo, simd.wrapping_mul_u32x8(p, c3));
369
370 let shoup_q = simd.widening_mul_u32x8(prod, n_inv_mod_p_shoup).1;
372 let t = simd.wrapping_sub_u32x8(
373 simd.wrapping_mul_u32x8(prod, n_inv_mod_p),
374 simd.wrapping_mul_u32x8(shoup_q, p),
375 );
376
377 *lhs_ = cast(simd.small_mod_u32x8(p, t));
378 }
379 },
380 );
381}
382
383fn mul_assign_normalize_scalar(
384 lhs: &mut [u32],
385 rhs: &[u32],
386 p: u32,
387 p_barrett: u32,
388 big_q: u32,
389 n_inv_mod_p: u32,
390 n_inv_mod_p_shoup: u32,
391) {
392 let big_q_m1 = big_q - 1;
393
394 for (lhs_, rhs) in crate::izip!(lhs, rhs) {
395 let lhs = *lhs_;
396 let rhs = *rhs;
397
398 let d = lhs as u64 * rhs as u64;
399 let c1 = (d >> big_q_m1) as u32;
400 let c3 = ((c1 as u64 * p_barrett as u64) >> 32) as u32;
401 let prod = (d as u32).wrapping_sub(p.wrapping_mul(c3));
402
403 let shoup_q = (((prod as u64) * (n_inv_mod_p_shoup as u64)) >> 32) as u32;
404 let t = u32::wrapping_sub(prod.wrapping_mul(n_inv_mod_p), shoup_q.wrapping_mul(p));
405
406 *lhs_ = t.min(t.wrapping_sub(p));
407 }
408}
409
410#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
411#[cfg(feature = "nightly")]
412fn normalize_avx512(
413 simd: crate::V4,
414 values: &mut [u32],
415 p: u32,
416 n_inv_mod_p: u32,
417 n_inv_mod_p_shoup: u32,
418) {
419 simd.vectorize(
420 #[inline(always)]
421 move || {
422 let values = pulp::as_arrays_mut::<16, _>(values).0;
423
424 let n_inv_mod_p = simd.splat_u32x16(n_inv_mod_p);
425 let n_inv_mod_p_shoup = simd.splat_u32x16(n_inv_mod_p_shoup);
426 let p = simd.splat_u32x16(p);
427
428 for val_ in values {
429 let val = cast(*val_);
430
431 let shoup_q = simd.widening_mul_u32x16(val, n_inv_mod_p_shoup).1;
433 let t = simd.wrapping_sub_u32x16(
434 simd.wrapping_mul_u32x16(val, n_inv_mod_p),
435 simd.wrapping_mul_u32x16(shoup_q, p),
436 );
437
438 *val_ = cast(simd.small_mod_u32x16(p, t));
439 }
440 },
441 );
442}
443
444#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
445fn normalize_avx2(
446 simd: crate::V3,
447 values: &mut [u32],
448 p: u32,
449 n_inv_mod_p: u32,
450 n_inv_mod_p_shoup: u32,
451) {
452 simd.vectorize(
453 #[inline(always)]
454 move || {
455 let values = pulp::as_arrays_mut::<8, _>(values).0;
456
457 let n_inv_mod_p = simd.splat_u32x8(n_inv_mod_p);
458 let n_inv_mod_p_shoup = simd.splat_u32x8(n_inv_mod_p_shoup);
459 let p = simd.splat_u32x8(p);
460
461 for val_ in values {
462 let val = cast(*val_);
463
464 let shoup_q = simd.widening_mul_u32x8(val, n_inv_mod_p_shoup).1;
466 let t = simd.wrapping_sub_u32x8(
467 simd.widening_mul_u32x8(val, n_inv_mod_p).0,
468 simd.widening_mul_u32x8(shoup_q, p).0,
469 );
470
471 *val_ = cast(simd.small_mod_u32x8(p, t));
472 }
473 },
474 );
475}
476
477fn normalize_scalar(values: &mut [u32], p: u32, n_inv_mod_p: u32, n_inv_mod_p_shoup: u32) {
478 for val_ in values {
479 let val = *val_;
480
481 let shoup_q = (((val as u64) * (n_inv_mod_p_shoup as u64)) >> 32) as u32;
482 let t = u32::wrapping_sub(val.wrapping_mul(n_inv_mod_p), shoup_q.wrapping_mul(p));
483
484 *val_ = t.min(t.wrapping_sub(p));
485 }
486}
487
488#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
489#[cfg(feature = "nightly")]
490fn mul_accumulate_avx512(
491 simd: crate::V4,
492 acc: &mut [u32],
493 lhs: &[u32],
494 rhs: &[u32],
495 p: u32,
496 p_barrett: u32,
497 big_q: u32,
498) {
499 simd.vectorize(
500 #[inline(always)]
501 || {
502 let acc = pulp::as_arrays_mut::<16, _>(acc).0;
503 let lhs = pulp::as_arrays::<16, _>(lhs).0;
504 let rhs = pulp::as_arrays::<16, _>(rhs).0;
505
506 let big_q_m1 = simd.splat_u32x16(big_q - 1);
507 let big_q_m1_complement = simd.splat_u32x16(32 - (big_q - 1));
508 let p_barrett = simd.splat_u32x16(p_barrett);
509 let p = simd.splat_u32x16(p);
510
511 for (acc, lhs, rhs) in crate::izip!(acc, lhs, rhs) {
512 let lhs = cast(*lhs);
513 let rhs = cast(*rhs);
514
515 let (lo, hi) = simd.widening_mul_u32x16(lhs, rhs);
517 let c1 = simd.or_u32x16(
518 simd.shr_dyn_u32x16(lo, big_q_m1),
519 simd.shl_dyn_u32x16(hi, big_q_m1_complement),
520 );
521 let c3 = simd.widening_mul_u32x16(c1, p_barrett).1;
522 let prod = simd.wrapping_sub_u32x16(lo, simd.wrapping_mul_u32x16(p, c3));
524 let prod = simd.small_mod_u32x16(p, prod);
525
526 *acc = cast(simd.small_mod_u32x16(p, simd.wrapping_add_u32x16(prod, cast(*acc))));
527 }
528 },
529 )
530}
531
532#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
533fn mul_accumulate_avx2(
534 simd: crate::V3,
535 acc: &mut [u32],
536 lhs: &[u32],
537 rhs: &[u32],
538 p: u32,
539 p_barrett: u32,
540 big_q: u32,
541) {
542 simd.vectorize(
543 #[inline(always)]
544 || {
545 let acc = pulp::as_arrays_mut::<8, _>(acc).0;
546 let lhs = pulp::as_arrays::<8, _>(lhs).0;
547 let rhs = pulp::as_arrays::<8, _>(rhs).0;
548
549 let big_q_m1 = simd.splat_u32x8(big_q - 1);
550 let big_q_m1_complement = simd.splat_u32x8(32 - (big_q - 1));
551 let p_barrett = simd.splat_u32x8(p_barrett);
552 let p = simd.splat_u32x8(p);
553
554 for (acc, lhs, rhs) in crate::izip!(acc, lhs, rhs) {
555 let lhs = cast(*lhs);
556 let rhs = cast(*rhs);
557
558 let (lo, hi) = simd.widening_mul_u32x8(lhs, rhs);
560 let c1 = simd.or_u32x8(
561 simd.shr_dyn_u32x8(lo, big_q_m1),
562 simd.shl_dyn_u32x8(hi, big_q_m1_complement),
563 );
564 let c3 = simd.widening_mul_u32x8(c1, p_barrett).1;
565 let prod = simd.wrapping_sub_u32x8(lo, simd.wrapping_mul_u32x8(p, c3));
567 let prod = simd.small_mod_u32x8(p, prod);
568
569 *acc = cast(simd.small_mod_u32x8(p, simd.wrapping_add_u32x8(prod, cast(*acc))));
570 }
571 },
572 )
573}
574
575fn mul_accumulate_scalar(
576 acc: &mut [u32],
577 lhs: &[u32],
578 rhs: &[u32],
579 p: u32,
580 p_barrett: u32,
581 big_q: u32,
582) {
583 let big_q_m1 = big_q - 1;
584
585 for (acc, lhs, rhs) in crate::izip!(acc, lhs, rhs) {
586 let lhs = *lhs;
587 let rhs = *rhs;
588
589 let d = lhs as u64 * rhs as u64;
590 let c1 = (d >> big_q_m1) as u32;
591 let c3 = ((c1 as u64 * p_barrett as u64) >> 32) as u32;
592 let prod = (d as u32).wrapping_sub(p.wrapping_mul(c3));
593 let prod = prod.min(prod.wrapping_sub(p));
594
595 let acc_ = prod + *acc;
596 *acc = acc_.min(acc_.wrapping_sub(p));
597 }
598}
599
600struct BarrettInit32 {
601 big_q: u32,
602 p_barrett: u32,
603 requires_single_reduction_step: bool,
604}
605
606impl BarrettInit32 {
607 pub fn new(modulus: u32) -> Self {
608 let big_q = modulus.ilog2() + 1;
609 let big_l = big_q + 31;
610 let m_as_u64: u64 = modulus.into();
611 let two_to_the_l = 1u64 << big_l; let (p_barrett, beta) = ((two_to_the_l / m_as_u64) as u32, (two_to_the_l % m_as_u64));
613
614 let single_reduction_threshold = m_as_u64 - (1 << (big_q - 1));
619
620 let requires_single_reduction_step = beta <= single_reduction_threshold;
621
622 Self {
623 big_q,
624 p_barrett,
625 requires_single_reduction_step,
626 }
627 }
628}
629
630#[derive(Clone)]
632pub struct Plan {
633 twid: ABox<[u32]>,
634 twid_shoup: ABox<[u32]>,
635 inv_twid: ABox<[u32]>,
636 inv_twid_shoup: ABox<[u32]>,
637 p: u32,
638 p_div: Div32,
639 can_use_fast_reduction_code: bool,
641
642 p_barrett: u32,
644 big_q: u32,
645
646 n_inv_mod_p: u32,
647 n_inv_mod_p_shoup: u32,
648}
649
650impl core::fmt::Debug for Plan {
651 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
652 f.debug_struct("Plan")
653 .field("ntt_size", &self.ntt_size())
654 .field("modulus", &self.modulus())
655 .finish()
656 }
657}
658
659impl Plan {
660 pub fn try_new(polynomial_size: usize, modulus: u32) -> Option<Self> {
663 let p_div = Div32::new(modulus);
664 if polynomial_size < 32
668 || !polynomial_size.is_power_of_two()
669 || !is_prime64(modulus as u64)
670 || find_primitive_root64(Div64::new(modulus as u64), 2 * polynomial_size as u64)
671 .is_none()
672 {
673 None
674 } else {
675 let mut twid = avec![0u32; polynomial_size].into_boxed_slice();
676 let mut inv_twid = avec![0u32; polynomial_size].into_boxed_slice();
677 let (mut twid_shoup, mut inv_twid_shoup) = if modulus < (1u32 << 31) {
678 (
679 avec![0u32; polynomial_size].into_boxed_slice(),
680 avec![0u32; polynomial_size].into_boxed_slice(),
681 )
682 } else {
683 (avec![].into_boxed_slice(), avec![].into_boxed_slice())
684 };
685
686 if modulus < (1u32 << 31) {
687 init_negacyclic_twiddles_shoup(
688 modulus,
689 polynomial_size,
690 &mut twid,
691 &mut twid_shoup,
692 &mut inv_twid,
693 &mut inv_twid_shoup,
694 );
695 } else {
696 init_negacyclic_twiddles(modulus, polynomial_size, &mut twid, &mut inv_twid);
697 }
698
699 let n_inv_mod_p = crate::prime::exp_mod32(p_div, polynomial_size as u32, modulus - 2);
700 let n_inv_mod_p_shoup = (((n_inv_mod_p as u64) << 32) / modulus as u64) as u32;
701
702 let BarrettInit32 {
703 big_q,
704 p_barrett,
705 requires_single_reduction_step,
706 } = BarrettInit32::new(modulus);
707
708 let can_use_fast_reduction_code =
749 (modulus < 1431655766) || (requires_single_reduction_step && modulus <= (1 << 31));
750
751 Some(Self {
752 twid,
753 twid_shoup,
754 inv_twid_shoup,
755 inv_twid,
756 p: modulus,
757 p_div,
758 can_use_fast_reduction_code,
759 n_inv_mod_p,
760 n_inv_mod_p_shoup,
761 p_barrett,
762 big_q,
763 })
764 }
765 }
766
767 pub(crate) fn p_div(&self) -> Div32 {
768 self.p_div
769 }
770
771 #[inline]
773 pub fn ntt_size(&self) -> usize {
774 self.twid.len()
775 }
776
777 #[inline]
779 pub fn modulus(&self) -> u32 {
780 self.p
781 }
782
783 #[inline]
788 pub fn can_use_fast_reduction_code(&self) -> bool {
789 self.can_use_fast_reduction_code
790 }
791
792 pub fn fwd(&self, buf: &mut [u32]) {
798 assert_eq!(buf.len(), self.ntt_size());
799 let p = self.p;
800
801 if p < (1u32 << 30) {
802 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
803 {
804 #[cfg(feature = "nightly")]
805 if let Some(simd) = crate::V4::try_new() {
806 less_than_30bit::fwd_avx512(simd, p, buf, &self.twid, &self.twid_shoup);
807 return;
808 }
809 if let Some(simd) = crate::V3::try_new() {
810 less_than_30bit::fwd_avx2(simd, p, buf, &self.twid, &self.twid_shoup);
811 return;
812 }
813 }
814 less_than_30bit::fwd_scalar(p, buf, &self.twid, &self.twid_shoup);
815 } else if p < (1u32 << 31) {
816 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
817 {
818 #[cfg(feature = "nightly")]
819 if let Some(simd) = crate::V4::try_new() {
820 less_than_31bit::fwd_avx512(simd, p, buf, &self.twid, &self.twid_shoup);
821 return;
822 }
823 if let Some(simd) = crate::V3::try_new() {
824 less_than_31bit::fwd_avx2(simd, p, buf, &self.twid, &self.twid_shoup);
825 return;
826 }
827 }
828 less_than_31bit::fwd_scalar(p, buf, &self.twid, &self.twid_shoup);
829 } else {
830 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
831 #[cfg(feature = "nightly")]
832 if let Some(simd) = crate::V4::try_new() {
833 generic::fwd_avx512(simd, buf, p, self.p_div, &self.twid);
834 return;
835 }
836 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
837 if let Some(simd) = crate::V3::try_new() {
838 generic::fwd_avx2(simd, buf, p, self.p_div, &self.twid);
839 return;
840 }
841 generic::fwd_scalar(buf, p, self.p_div, &self.twid);
842 }
843 }
844
845 pub fn inv(&self, buf: &mut [u32]) {
851 assert_eq!(buf.len(), self.ntt_size());
852 let p = self.p;
853
854 if p < (1u32 << 30) {
855 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
856 {
857 #[cfg(feature = "nightly")]
858 if let Some(simd) = crate::V4::try_new() {
859 less_than_30bit::inv_avx512(simd, p, buf, &self.inv_twid, &self.inv_twid_shoup);
860 return;
861 }
862 if let Some(simd) = crate::V3::try_new() {
863 less_than_30bit::inv_avx2(simd, p, buf, &self.inv_twid, &self.inv_twid_shoup);
864 return;
865 }
866 }
867 less_than_30bit::inv_scalar(p, buf, &self.inv_twid, &self.inv_twid_shoup);
868 } else if p < (1u32 << 31) {
869 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
870 {
871 #[cfg(feature = "nightly")]
872 if let Some(simd) = crate::V4::try_new() {
873 less_than_31bit::inv_avx512(simd, p, buf, &self.inv_twid, &self.inv_twid_shoup);
874 return;
875 }
876 if let Some(simd) = crate::V3::try_new() {
877 less_than_31bit::inv_avx2(simd, p, buf, &self.inv_twid, &self.inv_twid_shoup);
878 return;
879 }
880 }
881 less_than_31bit::inv_scalar(p, buf, &self.inv_twid, &self.inv_twid_shoup);
882 } else {
883 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
884 #[cfg(feature = "nightly")]
885 if let Some(simd) = crate::V4::try_new() {
886 generic::inv_avx512(simd, buf, p, self.p_div, &self.inv_twid);
887 return;
888 }
889 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
890 if let Some(simd) = crate::V3::try_new() {
891 generic::inv_avx2(simd, buf, p, self.p_div, &self.inv_twid);
892 return;
893 }
894 generic::inv_scalar(buf, p, self.p_div, &self.inv_twid);
895 }
896 }
897
898 pub fn mul_assign_normalize(&self, lhs: &mut [u32], rhs: &[u32]) {
901 if self.can_use_fast_reduction_code {
902 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
903 #[cfg(feature = "nightly")]
904 if let Some(simd) = crate::V4::try_new() {
905 mul_assign_normalize_avx512(
906 simd,
907 lhs,
908 rhs,
909 self.p,
910 self.p_barrett,
911 self.big_q,
912 self.n_inv_mod_p,
913 self.n_inv_mod_p_shoup,
914 );
915 return;
916 }
917 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
918 if let Some(simd) = crate::V3::try_new() {
919 mul_assign_normalize_avx2(
920 simd,
921 lhs,
922 rhs,
923 self.p,
924 self.p_barrett,
925 self.big_q,
926 self.n_inv_mod_p,
927 self.n_inv_mod_p_shoup,
928 );
929 return;
930 }
931
932 mul_assign_normalize_scalar(
933 lhs,
934 rhs,
935 self.p,
936 self.p_barrett,
937 self.big_q,
938 self.n_inv_mod_p,
939 self.n_inv_mod_p_shoup,
940 );
941 } else {
942 let p_div = self.p_div;
943 let n_inv_mod_p = self.n_inv_mod_p;
944 for (lhs_, rhs) in crate::izip!(lhs, rhs) {
945 let lhs = *lhs_;
946 let rhs = *rhs;
947 let prod = Div32::rem_u64(lhs as u64 * rhs as u64, p_div);
948 let prod = Div32::rem_u64(prod as u64 * n_inv_mod_p as u64, p_div);
949 *lhs_ = prod;
950 }
951 }
952 }
953
954 pub fn normalize(&self, values: &mut [u32]) {
957 if self.can_use_fast_reduction_code {
958 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
959 #[cfg(feature = "nightly")]
960 if let Some(simd) = crate::V4::try_new() {
961 normalize_avx512(
962 simd,
963 values,
964 self.p,
965 self.n_inv_mod_p,
966 self.n_inv_mod_p_shoup,
967 );
968 return;
969 }
970 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
971 if let Some(simd) = crate::V3::try_new() {
972 normalize_avx2(
973 simd,
974 values,
975 self.p,
976 self.n_inv_mod_p,
977 self.n_inv_mod_p_shoup,
978 );
979 return;
980 }
981 normalize_scalar(values, self.p, self.n_inv_mod_p, self.n_inv_mod_p_shoup);
982 } else {
983 let p_div = self.p_div;
984 let n_inv_mod_p = self.n_inv_mod_p;
985 for values in values {
986 let prod = Div32::rem_u64(*values as u64 * n_inv_mod_p as u64, p_div);
987 *values = prod;
988 }
989 }
990 }
991
992 pub fn mul_accumulate(&self, acc: &mut [u32], lhs: &[u32], rhs: &[u32]) {
994 if self.can_use_fast_reduction_code {
995 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
996 #[cfg(feature = "nightly")]
997 if let Some(simd) = crate::V4::try_new() {
998 mul_accumulate_avx512(simd, acc, lhs, rhs, self.p, self.p_barrett, self.big_q);
999 return;
1000 }
1001 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1002 if let Some(simd) = crate::V3::try_new() {
1003 mul_accumulate_avx2(simd, acc, lhs, rhs, self.p, self.p_barrett, self.big_q);
1004 return;
1005 }
1006 mul_accumulate_scalar(acc, lhs, rhs, self.p, self.p_barrett, self.big_q);
1007 } else {
1008 let p = self.p;
1009 let p_div = self.p_div;
1010 for (acc, lhs, rhs) in crate::izip!(acc, lhs, rhs) {
1011 let prod = generic::mul(p_div, *lhs, *rhs);
1012 *acc = generic::add(p, *acc, prod);
1013 }
1014 }
1015 }
1016}
1017
1018#[cfg(test)]
1019pub mod tests {
1020 use super::*;
1021 use crate::prime::largest_prime_in_arithmetic_progression64;
1022 use alloc::{vec, vec::Vec};
1023 use rand::random;
1024
1025 extern crate alloc;
1026
1027 pub fn add(p: u32, a: u32, b: u32) -> u32 {
1028 let neg_b = p.wrapping_sub(b);
1029 if a >= neg_b {
1030 a - neg_b
1031 } else {
1032 a + b
1033 }
1034 }
1035
1036 pub fn sub(p: u32, a: u32, b: u32) -> u32 {
1037 let neg_b = p.wrapping_sub(b);
1038 if a >= b {
1039 a - b
1040 } else {
1041 a + neg_b
1042 }
1043 }
1044
1045 pub fn mul(p: u32, a: u32, b: u32) -> u32 {
1046 let wide = a as u64 * b as u64;
1047 if p == 0 {
1048 wide as u32
1049 } else {
1050 (wide % p as u64) as u32
1051 }
1052 }
1053
1054 pub fn negacyclic_convolution(n: usize, p: u32, lhs: &[u32], rhs: &[u32]) -> vec::Vec<u32> {
1055 let mut full_convolution = vec![0u32; 2 * n];
1056 let mut negacyclic_convolution = vec![0u32; n];
1057 for i in 0..n {
1058 for j in 0..n {
1059 full_convolution[i + j] = add(p, full_convolution[i + j], mul(p, lhs[i], rhs[j]));
1060 }
1061 }
1062 for i in 0..n {
1063 negacyclic_convolution[i] = sub(p, full_convolution[i], full_convolution[i + n]);
1064 }
1065 negacyclic_convolution
1066 }
1067
1068 pub fn random_lhs_rhs_with_negacyclic_convolution(
1069 n: usize,
1070 p: u32,
1071 ) -> (vec::Vec<u32>, vec::Vec<u32>, vec::Vec<u32>) {
1072 let mut lhs = vec![0u32; n];
1073 let mut rhs = vec![0u32; n];
1074
1075 for x in &mut lhs {
1076 *x = random();
1077 if p != 0 {
1078 *x %= p;
1079 }
1080 }
1081 for x in &mut rhs {
1082 *x = random();
1083 if p != 0 {
1084 *x %= p;
1085 }
1086 }
1087
1088 let lhs = lhs;
1089 let rhs = rhs;
1090
1091 let negacyclic_convolution = negacyclic_convolution(n, p, &lhs, &rhs);
1092 (lhs, rhs, negacyclic_convolution)
1093 }
1094
1095 #[test]
1096 fn test_product() {
1097 for n in [32, 64, 128, 256, 512, 1024] {
1098 for p in [
1099 largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 29, 1 << 30).unwrap(),
1100 largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 30, 1 << 31).unwrap(),
1101 largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 31, 1 << 32).unwrap(),
1102 ] {
1103 let p = p as u32;
1104 let plan = Plan::try_new(n, p).unwrap();
1105
1106 let (lhs, rhs, negacyclic_convolution) =
1107 random_lhs_rhs_with_negacyclic_convolution(n, p);
1108
1109 let mut prod = vec![0u32; n];
1110 let mut lhs_fourier = lhs.clone();
1111 let mut rhs_fourier = rhs.clone();
1112
1113 plan.fwd(&mut lhs_fourier);
1114 plan.fwd(&mut rhs_fourier);
1115
1116 for x in &lhs_fourier {
1117 assert!(*x < p);
1118 }
1119 for x in &rhs_fourier {
1120 assert!(*x < p);
1121 }
1122
1123 for i in 0..n {
1124 prod[i] = mul(p, lhs_fourier[i], rhs_fourier[i]);
1125 }
1126 plan.inv(&mut prod);
1127
1128 plan.mul_assign_normalize(&mut lhs_fourier, &rhs_fourier);
1129 plan.inv(&mut lhs_fourier);
1130
1131 for x in &prod {
1132 assert!(*x < p);
1133 }
1134
1135 for i in 0..n {
1136 assert_eq!(prod[i], mul(p, negacyclic_convolution[i], n as u32),);
1137 }
1138 assert_eq!(lhs_fourier, negacyclic_convolution);
1139 }
1140 }
1141 }
1142
1143 #[test]
1144 fn test_mul_assign_normalize_scalar() {
1145 let p =
1146 largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 30, 1 << 31).unwrap() as u32;
1147 let p_div = Div64::new(p as u64);
1148 let polynomial_size = 128;
1149
1150 let n_inv_mod_p =
1151 crate::prime::exp_mod64(p_div, polynomial_size as u64, p as u64 - 2) as u32;
1152 let n_inv_mod_p_shoup = (((n_inv_mod_p as u128) << 32) / p as u128) as u32;
1153 let big_q = p.ilog2() + 1;
1154 let big_l = big_q + 31;
1155 let p_barrett = ((1u128 << big_l) / p as u128) as u32;
1156
1157 let mut lhs = (0..polynomial_size)
1158 .map(|_| rand::random::<u32>() % p)
1159 .collect::<Vec<_>>();
1160 let mut lhs_target = lhs.clone();
1161 let rhs = (0..polynomial_size)
1162 .map(|_| rand::random::<u32>() % p)
1163 .collect::<Vec<_>>();
1164
1165 let mul = |a: u32, b: u32| ((a as u128 * b as u128) % p as u128) as u32;
1166
1167 for (lhs, rhs) in lhs_target.iter_mut().zip(&rhs) {
1168 *lhs = mul(mul(*lhs, *rhs), n_inv_mod_p);
1169 }
1170
1171 mul_assign_normalize_scalar(
1172 &mut lhs,
1173 &rhs,
1174 p,
1175 p_barrett,
1176 big_q,
1177 n_inv_mod_p,
1178 n_inv_mod_p_shoup,
1179 );
1180 assert_eq!(lhs, lhs_target);
1181 }
1182
1183 #[test]
1184 fn test_mul_accumulate_scalar() {
1185 let p =
1186 largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 30, 1 << 31).unwrap() as u32;
1187 let polynomial_size = 128;
1188
1189 let big_q = p.ilog2() + 1;
1190 let big_l = big_q + 31;
1191 let p_barrett = ((1u128 << big_l) / p as u128) as u32;
1192
1193 let mut acc = (0..polynomial_size)
1194 .map(|_| rand::random::<u32>() % p)
1195 .collect::<Vec<_>>();
1196 let mut acc_target = acc.clone();
1197 let lhs = (0..polynomial_size)
1198 .map(|_| rand::random::<u32>() % p)
1199 .collect::<Vec<_>>();
1200 let rhs = (0..polynomial_size)
1201 .map(|_| rand::random::<u32>() % p)
1202 .collect::<Vec<_>>();
1203
1204 let mul = |a: u32, b: u32| ((a as u128 * b as u128) % p as u128) as u32;
1205 let add = |a: u32, b: u32| (a + b) % p;
1206
1207 for (acc, lhs, rhs) in crate::izip!(&mut acc_target, &lhs, &rhs) {
1208 *acc = add(mul(*lhs, *rhs), *acc);
1209 }
1210
1211 mul_accumulate_scalar(&mut acc, &lhs, &rhs, p, p_barrett, big_q);
1212 assert_eq!(acc, acc_target);
1213 }
1214
1215 #[test]
1216 fn test_normalize_scalar() {
1217 let p =
1218 largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 30, 1 << 31).unwrap() as u32;
1219 let p_div = Div64::new(p as u64);
1220 let polynomial_size = 128;
1221
1222 let n_inv_mod_p =
1223 crate::prime::exp_mod64(p_div, polynomial_size as u64, p as u64 - 2) as u32;
1224 let n_inv_mod_p_shoup = (((n_inv_mod_p as u128) << 32) / p as u128) as u32;
1225
1226 let mut val = (0..polynomial_size)
1227 .map(|_| rand::random::<u32>() % p)
1228 .collect::<Vec<_>>();
1229 let mut val_target = val.clone();
1230
1231 let mul = |a: u32, b: u32| ((a as u128 * b as u128) % p as u128) as u32;
1232
1233 for val in val_target.iter_mut() {
1234 *val = mul(*val, n_inv_mod_p);
1235 }
1236
1237 normalize_scalar(&mut val, p, n_inv_mod_p, n_inv_mod_p_shoup);
1238 assert_eq!(val, val_target);
1239 }
1240
1241 #[test]
1242 fn test_mul_accumulate() {
1243 for p in [
1244 largest_prime_in_arithmetic_progression64(1 << 16, 1, 0, 1 << 31).unwrap(),
1245 largest_prime_in_arithmetic_progression64(1 << 16, 1, 0, u32::MAX as u64).unwrap(),
1246 ] {
1247 let p = p as u32;
1248 let polynomial_size = 128;
1249
1250 let mut acc = (0..polynomial_size)
1251 .map(|_| rand::random::<u32>() % p)
1252 .collect::<Vec<_>>();
1253 let mut acc_target = acc.clone();
1254 let lhs = (0..polynomial_size)
1255 .map(|_| rand::random::<u32>() % p)
1256 .collect::<Vec<_>>();
1257 let rhs = (0..polynomial_size)
1258 .map(|_| rand::random::<u32>() % p)
1259 .collect::<Vec<_>>();
1260
1261 let mul = |a: u32, b: u32| ((a as u64 * b as u64) % p as u64) as u32;
1262 let add = |a: u32, b: u32| ((a as u64 + b as u64) % p as u64) as u32;
1263
1264 for (acc, lhs, rhs) in crate::izip!(&mut acc_target, &lhs, &rhs) {
1265 *acc = add(mul(*lhs, *rhs), *acc);
1266 }
1267
1268 Plan::try_new(polynomial_size, p)
1269 .unwrap()
1270 .mul_accumulate(&mut acc, &lhs, &rhs);
1271 assert_eq!(acc, acc_target);
1272 }
1273 }
1274
1275 #[test]
1276 fn test_mul_assign_normalize() {
1277 for p in [
1278 largest_prime_in_arithmetic_progression64(1 << 16, 1, 0, 1 << 31).unwrap(),
1279 largest_prime_in_arithmetic_progression64(1 << 16, 1, 0, u32::MAX as u64).unwrap(),
1280 ] {
1281 let p = p as u32;
1282 let polynomial_size = 128;
1283 let p_div = Div64::new(p as u64);
1284 let n_inv_mod_p =
1285 crate::prime::exp_mod64(p_div, polynomial_size as u64, p as u64 - 2) as u32;
1286
1287 let mut lhs = (0..polynomial_size)
1288 .map(|_| rand::random::<u32>() % p)
1289 .collect::<Vec<_>>();
1290 let mut lhs_target = lhs.clone();
1291 let rhs = (0..polynomial_size)
1292 .map(|_| rand::random::<u32>() % p)
1293 .collect::<Vec<_>>();
1294
1295 let mul = |a: u32, b: u32| ((a as u64 * b as u64) % p as u64) as u32;
1296
1297 for (lhs, rhs) in lhs_target.iter_mut().zip(&rhs) {
1298 *lhs = mul(mul(*lhs, *rhs), n_inv_mod_p);
1299 }
1300
1301 Plan::try_new(polynomial_size, p)
1302 .unwrap()
1303 .mul_assign_normalize(&mut lhs, &rhs);
1304 assert_eq!(lhs, lhs_target);
1305 }
1306 }
1307
1308 #[test]
1309 fn test_normalize() {
1310 for p in [
1311 largest_prime_in_arithmetic_progression64(1 << 16, 1, 0, 1 << 31).unwrap(),
1312 largest_prime_in_arithmetic_progression64(1 << 16, 1, 0, u32::MAX as u64).unwrap(),
1313 ] {
1314 let p = p as u32;
1315 let polynomial_size = 128;
1316 let p_div = Div64::new(p as u64);
1317 let n_inv_mod_p =
1318 crate::prime::exp_mod64(p_div, polynomial_size as u64, p as u64 - 2) as u32;
1319
1320 let mut val = (0..polynomial_size)
1321 .map(|_| rand::random::<u32>() % p)
1322 .collect::<Vec<_>>();
1323 let mut val_target = val.clone();
1324
1325 let mul = |a: u32, b: u32| ((a as u64 * b as u64) % p as u64) as u32;
1326
1327 for val in &mut val_target {
1328 *val = mul(*val, n_inv_mod_p);
1329 }
1330
1331 Plan::try_new(polynomial_size, p)
1332 .unwrap()
1333 .normalize(&mut val);
1334 assert_eq!(val, val_target);
1335 }
1336 }
1337
1338 #[test]
1339 fn test_plan_can_use_fast_reduction_code() {
1340 use crate::primes32::{P0, P1, P2, P3, P4, P5, P6, P7, P8, P9};
1341 const POLYNOMIAL_SIZE: usize = 32;
1342
1343 for p in [
1347 1062862849, 1431669377, P0, P1, P2, P3, P4, P5, P6, P7, P8, P9,
1348 ] {
1349 let plan = Plan::try_new(POLYNOMIAL_SIZE, p).unwrap();
1350
1351 assert!(plan.can_use_fast_reduction_code);
1352 }
1353
1354 let plan = Plan::try_new(POLYNOMIAL_SIZE, 0x7fe0_1001).unwrap();
1357 assert!(!plan.can_use_fast_reduction_code);
1358 }
1359
1360 #[test]
1361 fn test_barret_invalid_reduction_non_regression() {
1362 const POLYNOMIAL_SIZE: usize = 32;
1363
1364 let p: u32 = 0x7fe0_1001;
1365 let plan = Plan::try_new(POLYNOMIAL_SIZE, p).unwrap();
1366
1367 let value = 0x6e63593a;
1368 let mut acc = [0u32; POLYNOMIAL_SIZE];
1369 let input: [u32; POLYNOMIAL_SIZE] =
1371 core::array::from_fn(|i| if i == 0 { value } else { 0 });
1372
1373 plan.mul_accumulate(&mut acc, &input, &input);
1374
1375 let expected = (u64::from(value) * u64::from(value) % u64::from(p)) as u32;
1376 assert_eq!(acc[0], expected);
1377 }
1378
1379 #[test]
1380 fn test_barrett_invalid_reduction_normalize_non_regression() {
1381 const POLYNOMIAL_SIZE: usize = 32;
1382
1383 let p: u32 = 0x7fe0_1001;
1384 let plan = Plan::try_new(POLYNOMIAL_SIZE, p).unwrap();
1385
1386 let value = 0x6e63593a;
1387 let input: [u32; POLYNOMIAL_SIZE] =
1389 core::array::from_fn(|i| if i == 0 { value } else { 0 });
1390 let mut acc = input;
1391
1392 plan.mul_assign_normalize(&mut acc, &input);
1393
1394 let expected = (u128::from(value) * u128::from(value) * u128::from(plan.n_inv_mod_p)
1395 % u128::from(p)) as u32;
1396 assert_eq!(acc[0], expected);
1397 }
1398}
1399
1400#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1401#[cfg(test)]
1402mod x86_tests {
1403 use super::*;
1404 use crate::prime::largest_prime_in_arithmetic_progression64;
1405 use alloc::vec::Vec;
1406 use rand::random as rnd;
1407
1408 extern crate alloc;
1409
1410 #[test]
1411 fn test_interleaves_and_permutes_u32x8() {
1412 if let Some(simd) = crate::V3::try_new() {
1413 let a = u32x8(rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd());
1414 let b = u32x8(rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd());
1415
1416 assert_eq!(
1417 simd.interleave4_u32x8([a, b]),
1418 [
1419 u32x8(a.0, a.1, a.2, a.3, b.0, b.1, b.2, b.3),
1420 u32x8(a.4, a.5, a.6, a.7, b.4, b.5, b.6, b.7),
1421 ],
1422 );
1423 assert_eq!(
1424 simd.interleave4_u32x8(simd.interleave4_u32x8([a, b])),
1425 [a, b],
1426 );
1427 let w = [rnd(), rnd()];
1428 assert_eq!(
1429 simd.permute4_u32x8(w),
1430 u32x8(w[0], w[0], w[0], w[0], w[1], w[1], w[1], w[1]),
1431 );
1432
1433 assert_eq!(
1434 simd.interleave2_u32x8([a, b]),
1435 [
1436 u32x8(a.0, a.1, b.0, b.1, a.4, a.5, b.4, b.5),
1437 u32x8(a.2, a.3, b.2, b.3, a.6, a.7, b.6, b.7),
1438 ],
1439 );
1440 assert_eq!(
1441 simd.interleave2_u32x8(simd.interleave2_u32x8([a, b])),
1442 [a, b],
1443 );
1444 let w = [rnd(), rnd(), rnd(), rnd()];
1445 assert_eq!(
1446 simd.permute2_u32x8(w),
1447 u32x8(w[0], w[0], w[2], w[2], w[1], w[1], w[3], w[3]),
1448 );
1449
1450 assert_eq!(
1451 simd.interleave1_u32x8([a, b]),
1452 [
1453 u32x8(a.0, b.0, a.2, b.2, a.4, b.4, a.6, b.6),
1454 u32x8(a.1, b.1, a.3, b.3, a.5, b.5, a.7, b.7),
1455 ],
1456 );
1457 assert_eq!(
1458 simd.interleave1_u32x8(simd.interleave1_u32x8([a, b])),
1459 [a, b],
1460 );
1461 let w = [rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd()];
1462 assert_eq!(
1463 simd.permute1_u32x8(w),
1464 u32x8(w[0], w[4], w[1], w[5], w[2], w[6], w[3], w[7]),
1465 );
1466 }
1467 }
1468
1469 #[cfg(feature = "nightly")]
1470 #[test]
1471 fn test_interleaves_and_permutes_u32x16() {
1472 if let Some(simd) = crate::V4::try_new() {
1473 #[rustfmt::skip]
1474 let a = u32x16(rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd());
1475 #[rustfmt::skip]
1476 let b = u32x16(rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd());
1477
1478 assert_eq!(
1479 simd.interleave8_u32x16([a, b]),
1480 [
1481 u32x16(
1482 a.0, a.1, a.2, a.3, a.4, a.5, a.6, a.7, b.0, b.1, b.2, b.3, b.4, b.5, b.6,
1483 b.7,
1484 ),
1485 u32x16(
1486 a.8, a.9, a.10, a.11, a.12, a.13, a.14, a.15, b.8, b.9, b.10, b.11, b.12,
1487 b.13, b.14, b.15,
1488 ),
1489 ],
1490 );
1491 assert_eq!(
1492 simd.interleave8_u32x16(simd.interleave8_u32x16([a, b])),
1493 [a, b],
1494 );
1495 let w = [rnd(), rnd()];
1496 assert_eq!(
1497 simd.permute8_u32x16(w),
1498 u32x16(
1499 w[0], w[0], w[0], w[0], w[0], w[0], w[0], w[0], w[1], w[1], w[1], w[1], w[1],
1500 w[1], w[1], w[1],
1501 ),
1502 );
1503
1504 assert_eq!(
1505 simd.interleave4_u32x16([a, b]),
1506 [
1507 u32x16(
1508 a.0, a.1, a.2, a.3, b.0, b.1, b.2, b.3, a.8, a.9, a.10, a.11, b.8, b.9,
1509 b.10, b.11,
1510 ),
1511 u32x16(
1512 a.4, a.5, a.6, a.7, b.4, b.5, b.6, b.7, a.12, a.13, a.14, a.15, b.12, b.13,
1513 b.14, b.15,
1514 ),
1515 ],
1516 );
1517 assert_eq!(
1518 simd.interleave4_u32x16(simd.interleave4_u32x16([a, b])),
1519 [a, b],
1520 );
1521 let w = [rnd(), rnd(), rnd(), rnd()];
1522 assert_eq!(
1523 simd.permute4_u32x16(w),
1524 u32x16(
1525 w[0], w[0], w[0], w[0], w[2], w[2], w[2], w[2], w[1], w[1], w[1], w[1], w[3],
1526 w[3], w[3], w[3],
1527 ),
1528 );
1529
1530 assert_eq!(
1531 simd.interleave2_u32x16([a, b]),
1532 [
1533 u32x16(
1534 a.0, a.1, b.0, b.1, a.4, a.5, b.4, b.5, a.8, a.9, b.8, b.9, a.12, a.13,
1535 b.12, b.13,
1536 ),
1537 u32x16(
1538 a.2, a.3, b.2, b.3, a.6, a.7, b.6, b.7, a.10, a.11, b.10, b.11, a.14, a.15,
1539 b.14, b.15,
1540 ),
1541 ],
1542 );
1543 assert_eq!(
1544 simd.interleave2_u32x16(simd.interleave2_u32x16([a, b])),
1545 [a, b],
1546 );
1547 let w = [rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd()];
1548 assert_eq!(
1549 simd.permute2_u32x16(w),
1550 u32x16(
1551 w[0], w[0], w[4], w[4], w[1], w[1], w[5], w[5], w[2], w[2], w[6], w[6], w[3],
1552 w[3], w[7], w[7],
1553 ),
1554 );
1555
1556 assert_eq!(
1557 simd.interleave1_u32x16([a, b]),
1558 [
1559 u32x16(
1560 a.0, b.0, a.2, b.2, a.4, b.4, a.6, b.6, a.8, b.8, a.10, b.10, a.12, b.12,
1561 a.14, b.14,
1562 ),
1563 u32x16(
1564 a.1, b.1, a.3, b.3, a.5, b.5, a.7, b.7, a.9, b.9, a.11, b.11, a.13, b.13,
1565 a.15, b.15,
1566 ),
1567 ],
1568 );
1569 assert_eq!(
1570 simd.interleave1_u32x16(simd.interleave1_u32x16([a, b])),
1571 [a, b],
1572 );
1573 #[rustfmt::skip]
1574 let w = [rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd()];
1575 assert_eq!(
1576 simd.permute1_u32x16(w),
1577 u32x16(
1578 w[0], w[8], w[1], w[9], w[2], w[10], w[3], w[11], w[4], w[12], w[5], w[13],
1579 w[6], w[14], w[7], w[15],
1580 ),
1581 );
1582 }
1583 }
1584
1585 #[cfg(feature = "nightly")]
1586 #[test]
1587 fn test_mul_assign_normalize_avx512() {
1588 if let Some(simd) = crate::V4::try_new() {
1589 let p = largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 30, 1 << 31).unwrap()
1590 as u32;
1591 let p_div = Div64::new(p as u64);
1592 let polynomial_size = 128;
1593
1594 let n_inv_mod_p =
1595 crate::prime::exp_mod64(p_div, polynomial_size as u64, p as u64 - 2) as u32;
1596 let n_inv_mod_p_shoup = (((n_inv_mod_p as u128) << 32) / p as u128) as u32;
1597 let big_q = p.ilog2() + 1;
1598 let big_l = big_q + 31;
1599 let p_barrett = ((1u128 << big_l) / p as u128) as u32;
1600
1601 let mut lhs = (0..polynomial_size)
1602 .map(|_| rand::random::<u32>() % p)
1603 .collect::<Vec<_>>();
1604 let mut lhs_target = lhs.clone();
1605 let rhs = (0..polynomial_size)
1606 .map(|_| rand::random::<u32>() % p)
1607 .collect::<Vec<_>>();
1608
1609 let mul = |a: u32, b: u32| ((a as u128 * b as u128) % p as u128) as u32;
1610
1611 for (lhs, rhs) in lhs_target.iter_mut().zip(&rhs) {
1612 *lhs = mul(mul(*lhs, *rhs), n_inv_mod_p);
1613 }
1614
1615 mul_assign_normalize_avx512(
1616 simd,
1617 &mut lhs,
1618 &rhs,
1619 p,
1620 p_barrett,
1621 big_q,
1622 n_inv_mod_p,
1623 n_inv_mod_p_shoup,
1624 );
1625 assert_eq!(lhs, lhs_target);
1626 }
1627 }
1628
1629 #[test]
1630 fn test_mul_assign_normalize_avx2() {
1631 if let Some(simd) = crate::V3::try_new() {
1632 let p = largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 30, 1 << 31).unwrap()
1633 as u32;
1634 let p_div = Div64::new(p as u64);
1635 let polynomial_size = 128;
1636
1637 let n_inv_mod_p =
1638 crate::prime::exp_mod64(p_div, polynomial_size as u64, p as u64 - 2) as u32;
1639 let n_inv_mod_p_shoup = (((n_inv_mod_p as u128) << 32) / p as u128) as u32;
1640 let big_q = p.ilog2() + 1;
1641 let big_l = big_q + 31;
1642 let p_barrett = ((1u128 << big_l) / p as u128) as u32;
1643
1644 let mut lhs = (0..polynomial_size)
1645 .map(|_| rand::random::<u32>() % p)
1646 .collect::<Vec<_>>();
1647 let mut lhs_target = lhs.clone();
1648 let rhs = (0..polynomial_size)
1649 .map(|_| rand::random::<u32>() % p)
1650 .collect::<Vec<_>>();
1651
1652 let mul = |a: u32, b: u32| ((a as u128 * b as u128) % p as u128) as u32;
1653
1654 for (lhs, rhs) in lhs_target.iter_mut().zip(&rhs) {
1655 *lhs = mul(mul(*lhs, *rhs), n_inv_mod_p);
1656 }
1657
1658 mul_assign_normalize_avx2(
1659 simd,
1660 &mut lhs,
1661 &rhs,
1662 p,
1663 p_barrett,
1664 big_q,
1665 n_inv_mod_p,
1666 n_inv_mod_p_shoup,
1667 );
1668 assert_eq!(lhs, lhs_target);
1669 }
1670 }
1671
1672 #[cfg(feature = "nightly")]
1673 #[test]
1674 fn test_mul_accumulate_avx512() {
1675 if let Some(simd) = crate::V4::try_new() {
1676 let p = largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 30, 1 << 31).unwrap()
1677 as u32;
1678 let polynomial_size = 128;
1679
1680 let big_q = p.ilog2() + 1;
1681 let big_l = big_q + 31;
1682 let p_barrett = ((1u128 << big_l) / p as u128) as u32;
1683
1684 let mut acc = (0..polynomial_size)
1685 .map(|_| rand::random::<u32>() % p)
1686 .collect::<Vec<_>>();
1687 let mut acc_target = acc.clone();
1688 let lhs = (0..polynomial_size)
1689 .map(|_| rand::random::<u32>() % p)
1690 .collect::<Vec<_>>();
1691 let rhs = (0..polynomial_size)
1692 .map(|_| rand::random::<u32>() % p)
1693 .collect::<Vec<_>>();
1694
1695 let mul = |a: u32, b: u32| ((a as u128 * b as u128) % p as u128) as u32;
1696 let add = |a: u32, b: u32| (a + b) % p;
1697
1698 for (acc, lhs, rhs) in crate::izip!(&mut acc_target, &lhs, &rhs) {
1699 *acc = add(mul(*lhs, *rhs), *acc);
1700 }
1701
1702 mul_accumulate_avx512(simd, &mut acc, &lhs, &rhs, p, p_barrett, big_q);
1703 assert_eq!(acc, acc_target);
1704 }
1705 }
1706
1707 #[test]
1708 fn test_mul_accumulate_avx2() {
1709 if let Some(simd) = crate::V3::try_new() {
1710 let p = largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 30, 1 << 31).unwrap()
1711 as u32;
1712 let polynomial_size = 128;
1713
1714 let big_q = p.ilog2() + 1;
1715 let big_l = big_q + 31;
1716 let p_barrett = ((1u128 << big_l) / p as u128) as u32;
1717
1718 let mut acc = (0..polynomial_size)
1719 .map(|_| rand::random::<u32>() % p)
1720 .collect::<Vec<_>>();
1721 let mut acc_target = acc.clone();
1722 let lhs = (0..polynomial_size)
1723 .map(|_| rand::random::<u32>() % p)
1724 .collect::<Vec<_>>();
1725 let rhs = (0..polynomial_size)
1726 .map(|_| rand::random::<u32>() % p)
1727 .collect::<Vec<_>>();
1728
1729 let mul = |a: u32, b: u32| ((a as u128 * b as u128) % p as u128) as u32;
1730 let add = |a: u32, b: u32| (a + b) % p;
1731
1732 for (acc, lhs, rhs) in crate::izip!(&mut acc_target, &lhs, &rhs) {
1733 *acc = add(mul(*lhs, *rhs), *acc);
1734 }
1735
1736 mul_accumulate_avx2(simd, &mut acc, &lhs, &rhs, p, p_barrett, big_q);
1737 assert_eq!(acc, acc_target);
1738 }
1739 }
1740
1741 #[cfg(feature = "nightly")]
1742 #[test]
1743 fn test_normalize_avx512() {
1744 if let Some(simd) = crate::V4::try_new() {
1745 let p = largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 30, 1 << 31).unwrap()
1746 as u32;
1747 let p_div = Div64::new(p as u64);
1748 let polynomial_size = 128;
1749
1750 let n_inv_mod_p =
1751 crate::prime::exp_mod64(p_div, polynomial_size as u64, p as u64 - 2) as u32;
1752 let n_inv_mod_p_shoup = (((n_inv_mod_p as u128) << 32) / p as u128) as u32;
1753
1754 let mut val = (0..polynomial_size)
1755 .map(|_| rand::random::<u32>() % p)
1756 .collect::<Vec<_>>();
1757 let mut val_target = val.clone();
1758
1759 let mul = |a: u32, b: u32| ((a as u128 * b as u128) % p as u128) as u32;
1760
1761 for val in val_target.iter_mut() {
1762 *val = mul(*val, n_inv_mod_p);
1763 }
1764
1765 normalize_avx512(simd, &mut val, p, n_inv_mod_p, n_inv_mod_p_shoup);
1766 assert_eq!(val, val_target);
1767 }
1768 }
1769
1770 #[test]
1771 fn test_normalize_avx2() {
1772 if let Some(simd) = crate::V3::try_new() {
1773 let p = largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 30, 1 << 31).unwrap()
1774 as u32;
1775 let p_div = Div64::new(p as u64);
1776 let polynomial_size = 128;
1777
1778 let n_inv_mod_p =
1779 crate::prime::exp_mod64(p_div, polynomial_size as u64, p as u64 - 2) as u32;
1780 let n_inv_mod_p_shoup = (((n_inv_mod_p as u128) << 32) / p as u128) as u32;
1781
1782 let mut val = (0..polynomial_size)
1783 .map(|_| rand::random::<u32>() % p)
1784 .collect::<Vec<_>>();
1785 let mut val_target = val.clone();
1786
1787 let mul = |a: u32, b: u32| ((a as u128 * b as u128) % p as u128) as u32;
1788
1789 for val in val_target.iter_mut() {
1790 *val = mul(*val, n_inv_mod_p);
1791 }
1792
1793 normalize_avx2(simd, &mut val, p, n_inv_mod_p, n_inv_mod_p_shoup);
1794 assert_eq!(val, val_target);
1795 }
1796 }
1797}