1use crate::{
2 fastdiv::{Div32, Div64},
3 izip, prime32, prime64,
4};
5
6extern crate alloc;
8type Box<T> = alloc::boxed::Box<T>;
9
10#[derive(Copy, Clone, Debug, PartialEq, Eq)]
11pub enum FwdMode {
12 Generic,
13 Bounded(u64),
14}
15
16#[derive(Copy, Clone, Debug, PartialEq, Eq)]
17pub enum InvMode {
18 Replace,
19 Accumulate,
20}
21
22fn modular_inv_u32(modulus: Div32, n: u32) -> u32 {
23 let modulus_div = modulus;
24 let modulus = modulus.divisor();
25
26 let mut old_r = Div32::rem(n, modulus_div);
27 let mut r = modulus;
28
29 let mut old_s = 1u32;
30 let mut s = 0u32;
31
32 while r != 0 {
33 let q = old_r / r;
34 (old_r, r) = (r, old_r - q * r);
35 (old_s, s) = (
36 s,
37 sub_mod_u32(modulus, old_s, mul_mod_u32(modulus_div, q, s)),
38 );
39 }
40
41 old_s
42}
43
44fn modular_inv_u64(modulus: Div64, n: u64) -> u64 {
45 let modulus_div = modulus;
46 let modulus = modulus.divisor();
47
48 let mut old_r = Div64::rem(n, modulus_div);
49 let mut r = modulus;
50
51 let mut old_s = 1u64;
52 let mut s = 0u64;
53
54 while r != 0 {
55 let q = old_r / r;
56 (old_r, r) = (r, old_r - q * r);
57 (old_s, s) = (
58 s,
59 sub_mod_u64(modulus, old_s, mul_mod_u64(modulus_div, q, s)),
60 );
61 }
62
63 old_s
64}
65
66#[inline]
67fn sub_mod_u64(modulus: u64, a: u64, b: u64) -> u64 {
68 if a >= b {
69 a - b
70 } else {
71 a.wrapping_sub(b).wrapping_add(modulus)
72 }
73}
74
75#[inline]
76fn sub_mod_u32(modulus: u32, a: u32, b: u32) -> u32 {
77 if a >= b {
78 a - b
79 } else {
80 a.wrapping_sub(b).wrapping_add(modulus)
81 }
82}
83
84#[inline]
85fn add_mod_u64(modulus: u64, a: u64, b: u64) -> u64 {
86 let (sum, overflow) = a.overflowing_add(b);
87 if sum >= modulus || overflow {
88 sum.wrapping_sub(modulus)
89 } else {
90 sum
91 }
92}
93
94#[inline]
95fn add_mod_u64_less_than_2_63(modulus: u64, a: u64, b: u64) -> u64 {
96 debug_assert!(modulus < 1 << 63);
97
98 let sum = a + b;
99 if sum >= modulus {
100 sum - modulus
101 } else {
102 sum
103 }
104}
105
106#[inline]
107fn add_mod_u32(modulus: u32, a: u32, b: u32) -> u32 {
108 let (sum, overflow) = a.overflowing_add(b);
109 if sum >= modulus || overflow {
110 sum.wrapping_sub(modulus)
111 } else {
112 sum
113 }
114}
115
116#[inline]
117fn mul_mod_u64(modulus: Div64, a: u64, b: u64) -> u64 {
118 Div64::rem_u128(a as u128 * b as u128, modulus)
119}
120
121#[inline]
122fn mul_mod_u32(modulus: Div32, a: u32, b: u32) -> u32 {
123 Div32::rem_u64(a as u64 * b as u64, modulus)
124}
125
126#[inline]
127fn shoup_mul_mod_u32(modulus: u32, a: u32, b: u32, b_shoup: u32) -> u32 {
128 debug_assert!(modulus < 1 << 31);
129 let q = ((a as u64 * b_shoup as u64) >> 32) as u32;
130 let mut r = u32::wrapping_sub(b.wrapping_mul(a), q.wrapping_mul(modulus));
131 if r >= modulus {
132 r -= modulus
133 }
134 r
135}
136
137#[derive(Clone, Debug)]
139pub struct Plan {
140 polynomial_size: usize,
141 modulus: u64,
142 modular_inverses: Box<[u64]>,
143 plan_32: Box<[prime32::Plan]>,
144 plan_64: Box<[prime64::Plan]>,
145 div_32: Box<[Div32]>,
146 div_64: Box<[Div64]>,
147}
148
149impl Plan {
150 pub fn try_new(
154 polynomial_size: usize,
155 modulus: u64,
156 factors: impl AsRef<[u64]>,
157 ) -> Option<Self> {
158 fn try_new_impl(polynomial_size: usize, modulus: u64, primes: &mut [u64]) -> Option<Plan> {
159 if polynomial_size % 2 != 0 {
160 return None;
161 }
162
163 primes.sort_unstable();
165
166 let mut prev = 0;
167 for &factor in &*primes {
168 if factor == prev {
169 return None;
170 }
171 prev = factor;
172 }
173
174 let start = primes.partition_point(|&modulus| modulus == 1);
175 let primes = &primes[start..];
176
177 if primes
178 .iter()
179 .try_fold(1u64, |prod, &modulus| prod.checked_mul(modulus))
180 != Some(modulus)
181 {
182 return None;
183 };
184
185 let mid = primes.partition_point(|&modulus| modulus < (1u64 << 32));
186 let (primes_32, primes_64) = primes.split_at(mid);
187
188 let plan_32 = primes_32
189 .iter()
190 .map(|&modulus| prime32::Plan::try_new(polynomial_size, modulus as u32))
191 .collect::<Option<Box<[_]>>>()?;
192
193 let plan_64 = primes_64
194 .iter()
195 .map(|&modulus| prime64::Plan::try_new(polynomial_size, modulus))
196 .collect::<Option<Box<[_]>>>()?;
197
198 let div_32 = plan_32
199 .iter()
200 .map(prime32::Plan::p_div)
201 .collect::<Box<[_]>>();
202 let div_64 = plan_64
203 .iter()
204 .map(prime64::Plan::p_div)
205 .collect::<Box<[_]>>();
206
207 let len = primes.len();
208
209 let mut modular_inverses = alloc::vec![0u64; (len * (len - 1)) / 2].into_boxed_slice();
210 let mut offset = 0;
211 for (j, pj) in plan_32.iter().map(prime32::Plan::p_div).enumerate() {
212 for (inv, &pi) in modular_inverses[offset..][..j]
213 .iter_mut()
214 .zip(&primes_32[..j])
215 {
216 *inv = modular_inv_u32(pj, pi as u32) as u64;
217 }
218 offset += j;
219 }
220
221 let count_32 = plan_32.len();
222 for (j, pj) in plan_64.iter().map(prime64::Plan::p_div).enumerate() {
223 let j = j + count_32;
224
225 for (inv, &pi) in modular_inverses[offset..][..j].iter_mut().zip(&primes[..j]) {
226 *inv = modular_inv_u64(pj, pi);
227 }
228 offset += j;
229 }
230
231 Some(Plan {
232 polynomial_size,
233 modulus,
234 modular_inverses,
235 plan_32,
236 plan_64,
237 div_32,
238 div_64,
239 })
240 }
241
242 try_new_impl(
243 polynomial_size,
244 modulus,
245 &mut factors.as_ref().iter().copied().collect::<Box<[_]>>(),
246 )
247 }
248
249 #[inline]
251 pub fn ntt_size(&self) -> usize {
252 self.polynomial_size
253 }
254
255 #[inline]
257 pub fn modulus(&self) -> u64 {
258 self.modulus
259 }
260
261 fn ntt_domain_len_u32(&self) -> usize {
262 (self.polynomial_size / 2) * self.plan_32.len()
263 }
264 fn ntt_domain_len_u64(&self) -> usize {
265 self.polynomial_size * self.plan_64.len()
266 }
267
268 pub fn ntt_domain_len(&self) -> usize {
269 self.ntt_domain_len_u32() + self.ntt_domain_len_u64()
270 }
271
272 #[track_caller]
273 pub fn fwd(&self, ntt: &mut [u64], standard: &[u64], mode: FwdMode) {
274 assert_eq!(standard.len(), self.ntt_size());
275 assert_eq!(ntt.len(), self.ntt_domain_len());
276
277 let (ntt_32, ntt_64) = ntt.split_at_mut(self.ntt_domain_len_u32());
278 let ntt_32: &mut [u32] = bytemuck::cast_slice_mut(ntt_32);
279
280 if self.plan_32.is_empty() && self.plan_64.len() == 1 {
282 ntt_64.copy_from_slice(standard);
283 self.plan_64[0].fwd(ntt_64);
284 return;
285 }
286 if self.plan_32.len() == 1 && self.plan_64.is_empty() {
287 for (ntt, &standard) in ntt_32.iter_mut().zip(standard) {
288 *ntt = standard as u32;
289 }
290 self.plan_32[0].fwd(ntt_32);
291 return;
292 }
293
294 if self.plan_32.len() == 2 && self.plan_64.is_empty() {
295 let (ntt0, ntt1) = ntt_32.split_at_mut(self.ntt_size());
296 let p0_div = self.plan_32[0].p_div();
297 let p1_div = self.plan_32[1].p_div();
298 let p0 = self.plan_32[0].modulus();
299 let p1 = self.plan_32[1].modulus();
300 let p = self.modulus();
301 let p_u32 = p as u32;
302
303 match mode {
304 FwdMode::Bounded(bound) if bound < p0 as u64 && bound < p1 as u64 => {
305 for ((ntt0, ntt1), &standard) in
306 ntt0.iter_mut().zip(ntt1.iter_mut()).zip(standard)
307 {
308 let positive = standard < p / 2;
309 let standard = standard as u32;
310 let complement = p_u32.wrapping_sub(standard);
311 *ntt0 = if positive {
312 standard
313 } else {
314 p0.wrapping_sub(complement)
315 };
316 *ntt1 = if positive {
317 standard
318 } else {
319 p1.wrapping_sub(complement)
320 };
321 }
322 }
323 _ => {
324 for ((ntt0, ntt1), &standard) in
325 ntt0.iter_mut().zip(ntt1.iter_mut()).zip(standard)
326 {
327 *ntt0 = Div32::rem_u64(standard, p0_div);
328 *ntt1 = Div32::rem_u64(standard, p1_div);
329 }
330 }
331 }
332
333 self.plan_32[0].fwd(ntt0);
334 self.plan_32[1].fwd(ntt1);
335
336 return;
337 }
338
339 for (ntt, plan) in ntt_32.chunks_exact_mut(self.ntt_size()).zip(&*self.plan_32) {
340 let modulus = plan.p_div();
341
342 for (ntt, &standard) in ntt.iter_mut().zip(standard) {
343 *ntt = Div32::rem_u64(standard, modulus);
344 }
345
346 plan.fwd(ntt);
347 }
348
349 for (ntt, plan) in ntt_64.chunks_exact_mut(self.ntt_size()).zip(&*self.plan_64) {
350 let modulus = plan.p_div();
351 for (ntt, &standard) in ntt.iter_mut().zip(standard) {
352 *ntt = Div64::rem(standard, modulus);
353 }
354
355 plan.fwd(ntt);
356 }
357 }
358
359 #[track_caller]
360 pub fn inv(&self, standard: &mut [u64], ntt: &mut [u64], mode: InvMode) {
361 assert_eq!(standard.len(), self.ntt_size());
362 assert_eq!(ntt.len(), self.ntt_domain_len());
363
364 let (ntt_32, ntt_64) = ntt.split_at_mut(self.ntt_domain_len_u32());
365 let ntt_32: &mut [u32] = bytemuck::cast_slice_mut(ntt_32);
366
367 for (ntt, plan) in ntt_32.chunks_exact_mut(self.ntt_size()).zip(&*self.plan_32) {
368 plan.inv(ntt);
369 }
370 for (ntt, plan) in ntt_64.chunks_exact_mut(self.ntt_size()).zip(&*self.plan_64) {
371 plan.inv(ntt);
372 }
373
374 let ntt_32 = &*ntt_32;
375 let ntt_64 = &*ntt_64;
376
377 if self.plan_32.is_empty() && self.plan_64.is_empty() {
379 match mode {
380 InvMode::Replace => standard.fill(0),
381 InvMode::Accumulate => {}
382 }
383 return;
384 }
385
386 if self.plan_32.is_empty() && self.plan_64.len() == 1 {
387 match mode {
388 InvMode::Replace => standard.copy_from_slice(ntt_64),
389 InvMode::Accumulate => {
390 let p = self.plan_64[0].modulus();
391
392 for (standard, &ntt) in standard.iter_mut().zip(ntt_64) {
393 *standard = add_mod_u64(p, *standard, ntt);
394 }
395 }
396 }
397 return;
398 }
399 if self.plan_32.len() == 1 && self.plan_64.is_empty() {
400 match mode {
401 InvMode::Replace => {
402 for (standard, &ntt) in standard.iter_mut().zip(ntt_32) {
403 *standard = ntt as u64;
404 }
405 }
406 InvMode::Accumulate => {
407 let p = self.plan_32[0].modulus();
408
409 for (standard, &ntt) in standard.iter_mut().zip(ntt_32) {
410 *standard = add_mod_u32(p, *standard as u32, ntt) as u64;
411 }
412 }
413 }
414 return;
415 }
416
417 if self.plan_32.len() == 2 && self.plan_64.is_empty() {
420 let (ntt0, ntt1) = ntt_32.split_at(self.ntt_size());
421 let p0 = self.plan_32[0].modulus();
422 let p1 = self.plan_32[1].modulus();
423 let p = self.modulus();
424 let p1_div = self.plan_32[1].p_div();
425
426 let inv = self.modular_inverses[0] as u32;
427
428 if p1 < 1 << 31 {
429 let inv_shoup = Div32::div_u64((inv as u64) << 32, p1_div) as u32;
430 match mode {
431 InvMode::Replace => {
432 for (standard, &ntt0, &ntt1) in izip!(standard.iter_mut(), ntt0, ntt1) {
433 let u0 = ntt0;
434 let u1 = ntt1;
435
436 let v0 = u0;
437
438 let diff = sub_mod_u32(p1, u1, v0);
439 let v1 = shoup_mul_mod_u32(p1, diff, inv, inv_shoup);
440
441 *standard = v0 as u64 + (v1 as u64 * p0 as u64);
442 }
443 }
444 InvMode::Accumulate => {
447 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
448 {
449 #[cfg(feature = "nightly")]
450 if let Some(simd) = pulp::x86::V4::try_new() {
451 struct Impl<'a> {
452 simd: pulp::x86::V4,
453 standard: &'a mut [u64],
454 ntt0: &'a [u32],
455 ntt1: &'a [u32],
456 p: u64,
457 p0: u32,
458 p1: u32,
459 inv: u32,
460 inv_shoup: u32,
461 }
462
463 impl pulp::NullaryFnOnce for Impl<'_> {
464 type Output = ();
465
466 #[inline(always)]
467 fn call(self) -> Self::Output {
468 let Self {
469 simd,
470 standard,
471 ntt0,
472 ntt1,
473 p,
474 p0,
475 p1,
476 inv,
477 inv_shoup,
478 } = self;
479
480 {
481 let standard = pulp::as_arrays_mut::<8, _>(standard).0;
482 let ntt0 = pulp::as_arrays::<8, _>(ntt0).0;
483 let ntt1 = pulp::as_arrays::<8, _>(ntt1).0;
484
485 let standard: &mut [pulp::u64x8] =
486 bytemuck::cast_slice_mut(standard);
487 let ntt0: &[pulp::u32x8] = bytemuck::cast_slice(ntt0);
488 let ntt1: &[pulp::u32x8] = bytemuck::cast_slice(ntt1);
489
490 let p1_u32 = simd.splat_u32x8(p1);
491 let p1_u64 = simd.convert_u32x8_to_u64x8(p1_u32);
492 let p0 =
493 simd.convert_u32x8_to_u64x8(simd.splat_u32x8(p0));
494 let p = simd.splat_u64x8(p);
495 let inv =
496 simd.convert_u32x8_to_u64x8(simd.splat_u32x8(inv));
497 let inv_shoup = simd.convert_u32x8_to_u64x8(
498 simd.splat_u32x8(inv_shoup),
499 );
500
501 for (standard, &ntt0, &ntt1) in
502 izip!(standard.iter_mut(), ntt0, ntt1)
503 {
504 let u0 = ntt0;
505 let u1 = ntt1;
506
507 let v0 = u0;
508
509 let diff = simd.wrapping_sub_u32x8(u1, v0);
510 let diff = simd.min_u32x8(
511 diff,
512 simd.wrapping_add_u32x8(diff, p1_u32),
513 );
514 let diff = simd.convert_u32x8_to_u64x8(diff);
515
516 let v1: pulp::u64x8 = {
517 let a = diff;
519 let b = inv;
520 let b_shoup = inv_shoup;
521 let modulus = p1_u64;
522
523 let q =
524 pulp::cast(simd.avx512f._mm512_mul_epu32(
525 pulp::cast(a),
526 pulp::cast(b_shoup),
527 ));
528 let q = simd.shr_const_u64x8::<32>(q);
529
530 let ab =
531 pulp::cast(simd.avx512f._mm512_mul_epu32(
532 pulp::cast(a),
533 pulp::cast(b),
534 ));
535
536 let qmod =
537 pulp::cast(simd.avx512f._mm512_mul_epu32(
538 pulp::cast(q),
539 pulp::cast(modulus),
540 ));
541
542 let r = simd.wrapping_sub_u32x16(ab, qmod);
543 let r = simd.and_u32x16(
544 r,
545 pulp::u32x16(
546 !0, 0, !0, 0, !0, 0, !0, 0, !0, 0, !0,
547 0, !0, 0, !0, 0,
548 ),
549 );
550
551 let r = simd.min_u32x16(
552 r,
553 simd.wrapping_sub_u32x16(
554 r,
555 pulp::cast(modulus),
556 ),
557 );
558 pulp::cast(r)
559 };
560
561 let v0 = simd.convert_u32x8_to_u64x8(v0);
562 let v = simd.wrapping_add_u64x8(
563 v0,
564 pulp::cast(simd.avx512f._mm512_mul_epu32(
565 pulp::cast(v1),
566 pulp::cast(p0),
567 )),
568 );
569 let sum = simd.wrapping_add_u64x8(*standard, v);
570 let smaller_than_p = simd.cmp_lt_u64x8(sum, p);
571 *standard = simd.select_u64x8(
572 smaller_than_p,
573 sum,
574 simd.wrapping_sub_u64x8(sum, p),
575 );
576 }
577 }
578 }
579 }
580
581 simd.vectorize(Impl {
582 simd,
583 standard,
584 ntt0,
585 ntt1,
586 p,
587 p0,
588 p1,
589 inv,
590 inv_shoup,
591 });
592
593 return;
594 }
595
596 if let Some(simd) = pulp::x86::V3::try_new() {
597 struct Impl<'a> {
598 simd: pulp::x86::V3,
599 standard: &'a mut [u64],
600 ntt0: &'a [u32],
601 ntt1: &'a [u32],
602 p: u64,
603 p0: u32,
604 p1: u32,
605 inv: u32,
606 inv_shoup: u32,
607 }
608
609 impl pulp::NullaryFnOnce for Impl<'_> {
610 type Output = ();
611
612 #[inline(always)]
613 fn call(self) -> Self::Output {
614 let Self {
615 simd,
616 standard,
617 ntt0,
618 ntt1,
619 p,
620 p0,
621 p1,
622 inv,
623 inv_shoup,
624 } = self;
625
626 {
627 let standard = pulp::as_arrays_mut::<4, _>(standard).0;
628 let ntt0 = pulp::as_arrays::<4, _>(ntt0).0;
629 let ntt1 = pulp::as_arrays::<4, _>(ntt1).0;
630
631 let standard: &mut [pulp::u64x4] =
632 bytemuck::cast_slice_mut(standard);
633 let ntt0: &[pulp::u32x4] = bytemuck::cast_slice(ntt0);
634 let ntt1: &[pulp::u32x4] = bytemuck::cast_slice(ntt1);
635
636 let p1_u32 = simd.splat_u32x4(p1);
637 let p1_u64 = simd.convert_u32x4_to_u64x4(p1_u32);
638 let p0 =
639 simd.convert_u32x4_to_u64x4(simd.splat_u32x4(p0));
640 let p = simd.splat_u64x4(p);
641 let inv =
642 simd.convert_u32x4_to_u64x4(simd.splat_u32x4(inv));
643 let inv_shoup = simd.convert_u32x4_to_u64x4(
644 simd.splat_u32x4(inv_shoup),
645 );
646
647 for (standard, &ntt0, &ntt1) in
648 izip!(standard.iter_mut(), ntt0, ntt1)
649 {
650 let u0 = ntt0;
651 let u1 = ntt1;
652
653 let v0 = u0;
654
655 let diff = simd.wrapping_sub_u32x4(u1, v0);
656 let diff = simd.min_u32x4(
657 diff,
658 simd.wrapping_add_u32x4(diff, p1_u32),
659 );
660 let diff = simd.convert_u32x4_to_u64x4(diff);
661
662 let v1: pulp::u64x4 = {
663 let a = diff;
665 let b = inv;
666 let b_shoup = inv_shoup;
667 let modulus = p1_u64;
668
669 let q = pulp::cast(simd.avx2._mm256_mul_epu32(
670 pulp::cast(a),
671 pulp::cast(b_shoup),
672 ));
673 let q = simd.shr_const_u64x4::<32>(q);
674
675 let ab =
676 pulp::cast(simd.avx2._mm256_mul_epu32(
677 pulp::cast(a),
678 pulp::cast(b),
679 ));
680
681 let qmod =
682 pulp::cast(simd.avx2._mm256_mul_epu32(
683 pulp::cast(q),
684 pulp::cast(modulus),
685 ));
686
687 let r = simd.wrapping_sub_u32x8(ab, qmod);
688 let r = simd.and_u32x8(
689 r,
690 pulp::u32x8(!0, 0, !0, 0, !0, 0, !0, 0),
691 );
692
693 let r = simd.min_u32x8(
694 r,
695 simd.wrapping_sub_u32x8(
696 r,
697 pulp::cast(modulus),
698 ),
699 );
700 pulp::cast(r)
701 };
702
703 let v0 = simd.convert_u32x4_to_u64x4(v0);
704 let v = simd.wrapping_add_u64x4(
705 v0,
706 pulp::cast(simd.avx2._mm256_mul_epu32(
707 pulp::cast(v1),
708 pulp::cast(p0),
709 )),
710 );
711 let sum = simd.wrapping_add_u64x4(*standard, v);
712 let smaller_than_p = simd.cmp_lt_u64x4(sum, p);
713 *standard = simd.select_u64x4(
714 smaller_than_p,
715 sum,
716 simd.wrapping_sub_u64x4(sum, p),
717 );
718 }
719 }
720 }
721 }
722
723 simd.vectorize(Impl {
724 simd,
725 standard,
726 ntt0,
727 ntt1,
728 p,
729 p0,
730 p1,
731 inv,
732 inv_shoup,
733 });
734
735 return;
736 }
737 }
738
739 for (standard, &ntt0, &ntt1) in izip!(standard.iter_mut(), ntt0, ntt1) {
740 let u0 = ntt0;
741 let u1 = ntt1;
742
743 let v0 = u0;
744
745 let diff = sub_mod_u32(p1, u1, v0);
746 let v1 = shoup_mul_mod_u32(p1, diff, inv, inv_shoup);
747
748 *standard = add_mod_u64_less_than_2_63(
749 p,
750 *standard,
751 v0 as u64 + (v1 as u64 * p0 as u64),
752 );
753 }
754 }
755 }
756 } else {
757 match mode {
758 InvMode::Replace => {
759 for (standard, &ntt0, &ntt1) in izip!(standard.iter_mut(), ntt0, ntt1) {
760 let u0 = ntt0;
761 let u1 = ntt1;
762
763 let v0 = u0;
764
765 let diff = sub_mod_u32(p1, u1, v0);
766 let v1 = mul_mod_u32(p1_div, diff, inv);
767
768 *standard = v0 as u64 + (v1 as u64 * p0 as u64);
769 }
770 }
771 InvMode::Accumulate => {
772 for (standard, &ntt0, &ntt1) in izip!(standard.iter_mut(), ntt0, ntt1) {
773 let u0 = ntt0;
774 let u1 = ntt1;
775
776 let v0 = u0;
777
778 let diff = sub_mod_u32(p1, u1, v0);
779 let v1 = mul_mod_u32(p1_div, diff, inv);
780
781 *standard =
782 add_mod_u64(p, *standard, v0 as u64 + (v1 as u64 * p0 as u64));
783 }
784 }
785 }
786 }
787
788 return;
789 }
790
791 let u_32 = &mut *alloc::vec![0u32; self.plan_32.len()];
792 let v_32 = &mut *alloc::vec![0u32; self.plan_32.len()];
793 let u_64 = &mut *alloc::vec![0u64; self.plan_64.len()];
794 let v_64 = &mut *alloc::vec![0u64; self.plan_64.len()];
795
796 let div_32 = &*self.div_32;
797 let div_64 = &*self.div_64;
798
799 let p = self.modulus();
800
801 let count_32 = self.plan_32.len();
802
803 let modular_inverses = &*self.modular_inverses;
804
805 for (idx, standard) in standard.iter_mut().enumerate() {
806 let ntt_32 = ntt_32.get(idx..).unwrap_or(&[]);
807 let ntt_64 = ntt_64.get(idx..).unwrap_or(&[]);
808
809 let ntt_32 = ntt_32.iter().step_by(self.ntt_size()).copied();
810 let ntt_64 = ntt_64.iter().step_by(self.ntt_size()).copied();
811
812 u_32.iter_mut()
813 .zip(ntt_32)
814 .for_each(|(dst, src)| *dst = src);
815 u_64.iter_mut()
816 .zip(ntt_64)
817 .for_each(|(dst, src)| *dst = src);
818
819 let u_32 = &*u_32;
820 let u_64 = &*u_64;
821
822 let mut offset = 0;
823
824 for (j, (&uj, &div_j)) in u_32.iter().zip(div_32).enumerate() {
825 let pj = div_j.divisor();
826 let mut x = uj;
827 {
828 let v = &v_32[..j];
829
830 for (&vj, &inv) in v.iter().zip(&modular_inverses[offset..][..j]) {
831 let diff = sub_mod_u32(pj, x, vj);
832 x = mul_mod_u32(div_j, diff, inv as u32);
833 }
834 offset += j;
835 }
836 v_32[j] = x;
837 }
838
839 for (j, (&uj, &div_j)) in u_64.iter().zip(div_64).enumerate() {
840 let pj = div_j.divisor();
841 let mut x = uj;
842 {
843 let v = &*v_32;
844
845 for (&vj, &inv) in v.iter().zip(&modular_inverses[offset..][..count_32]) {
846 let diff = sub_mod_u64(pj, x, vj as u64);
847 x = mul_mod_u64(div_j, diff, inv);
848 }
849 offset += count_32;
850 }
851 {
852 let v = &v_64[..j];
853
854 for (&vj, &inv) in v.iter().zip(&modular_inverses[offset..][..j]) {
855 let diff = sub_mod_u64(pj, x, vj);
856 x = mul_mod_u64(div_j, diff, inv);
857 }
858 offset += j;
859 }
860 v_64[j] = x;
861 }
862
863 let mut acc = 0u64;
864 for (&v, &p) in v_64.iter().zip(div_64).rev() {
865 let p = p.divisor();
866 acc *= p;
867 acc += v;
868 }
869 for (&v, &p) in v_32.iter().zip(div_32).rev() {
870 let p = p.divisor();
871 acc *= p as u64;
872 acc += v as u64;
873 }
874
875 match mode {
876 InvMode::Replace => *standard = acc,
877 InvMode::Accumulate => *standard = add_mod_u64(p, *standard, acc),
878 }
879 }
880 }
881
882 #[track_caller]
885 pub fn mul_assign_normalize(&self, lhs: &mut [u64], rhs: &[u64]) {
886 assert_eq!(lhs.len(), self.ntt_domain_len());
887 assert_eq!(rhs.len(), self.ntt_domain_len());
888
889 let (lhs_32, lhs_64) = lhs.split_at_mut(self.ntt_domain_len_u32());
890 let (rhs_32, rhs_64) = rhs.split_at(self.ntt_domain_len_u32());
891
892 let lhs_32: &mut [u32] = bytemuck::cast_slice_mut(lhs_32);
893 let rhs_32: &[u32] = bytemuck::cast_slice(rhs_32);
894
895 let size = self.ntt_size();
896
897 for ((lhs, rhs), plan) in lhs_32
898 .chunks_exact_mut(size)
899 .zip(rhs_32.chunks_exact(size))
900 .zip(&*self.plan_32)
901 {
902 plan.mul_assign_normalize(lhs, rhs);
903 }
904
905 for ((lhs, rhs), plan) in lhs_64
906 .chunks_exact_mut(size)
907 .zip(rhs_64.chunks_exact(size))
908 .zip(&*self.plan_64)
909 {
910 plan.mul_assign_normalize(lhs, rhs);
911 }
912 }
913
914 #[track_caller]
917 pub fn normalize(&self, values: &mut [u64]) {
918 assert_eq!(values.len(), self.ntt_domain_len());
919
920 let (values_32, values_64) = values.split_at_mut(self.ntt_domain_len_u32());
921 let values_32: &mut [u32] = bytemuck::cast_slice_mut(values_32);
922
923 let size = self.ntt_size();
924
925 for (values, plan) in values_32.chunks_exact_mut(size).zip(&*self.plan_32) {
926 plan.normalize(values);
927 }
928 for (values, plan) in values_64.chunks_exact_mut(size).zip(&*self.plan_64) {
929 plan.normalize(values);
930 }
931 }
932
933 #[track_caller]
935 pub fn mul_accumulate(&self, acc: &mut [u64], lhs: &[u64], rhs: &[u64]) {
936 assert_eq!(lhs.len(), self.ntt_domain_len());
937 assert_eq!(rhs.len(), self.ntt_domain_len());
938
939 let (acc_32, acc_64) = acc.split_at_mut(self.ntt_domain_len_u32());
940 let (lhs_32, lhs_64) = lhs.split_at(self.ntt_domain_len_u32());
941 let (rhs_32, rhs_64) = rhs.split_at(self.ntt_domain_len_u32());
942
943 let acc_32: &mut [u32] = bytemuck::cast_slice_mut(acc_32);
944 let lhs_32: &[u32] = bytemuck::cast_slice(lhs_32);
945 let rhs_32: &[u32] = bytemuck::cast_slice(rhs_32);
946
947 let size = self.ntt_size();
948
949 for (((acc, lhs), rhs), plan) in acc_32
950 .chunks_exact_mut(size)
951 .zip(lhs_32.chunks_exact(size))
952 .zip(rhs_32.chunks_exact(size))
953 .zip(&*self.plan_32)
954 {
955 plan.mul_accumulate(acc, lhs, rhs);
956 }
957
958 for (((acc, lhs), rhs), plan) in acc_64
959 .chunks_exact_mut(size)
960 .zip(lhs_64.chunks_exact(size))
961 .zip(rhs_64.chunks_exact(size))
962 .zip(&*self.plan_64)
963 {
964 plan.mul_accumulate(acc, lhs, rhs);
965 }
966 }
967}
968
969#[cfg(test)]
970mod tests {
971 use super::*;
972 use crate::prime::largest_prime_in_arithmetic_progression64;
973
974 extern crate alloc;
975
976 #[test]
977 fn test_product_u64x1() {
978 let n = 256;
979
980 let p = largest_prime_in_arithmetic_progression64(2 * n as u64, 1, 0, u64::MAX).unwrap();
981 let plan = Plan::try_new(n, p, [p]).unwrap();
982
983 let standard = &*(0..n)
984 .map(|_| rand::random::<u64>() % p)
985 .collect::<Box<[_]>>();
986 let ntt = &mut *alloc::vec![0u64; plan.ntt_domain_len()];
987 let roundtrip = &mut *alloc::vec![0u64; n];
988
989 let p_div = Div64::new(p);
990 let mul = |a, b| mul_mod_u64(p_div, a, b);
991
992 let n_inv_mod_p = modular_inv_u64(p_div, n as u64);
993 plan.fwd(ntt, standard, FwdMode::Generic);
994 plan.inv(roundtrip, ntt, InvMode::Replace);
995 for x in roundtrip.iter_mut() {
996 *x = mul(*x, n_inv_mod_p);
997 }
998
999 assert_eq!(roundtrip, standard);
1000 }
1001
1002 #[test]
1003 fn test_product_u32x1() {
1004 let n = 256;
1005
1006 let p =
1007 largest_prime_in_arithmetic_progression64(2 * n as u64, 1, 0, u32::MAX as u64).unwrap();
1008 let plan = Plan::try_new(n, p, [p]).unwrap();
1009
1010 let standard = &*(0..n)
1011 .map(|_| rand::random::<u64>() % p)
1012 .collect::<Box<[_]>>();
1013 let ntt = &mut *alloc::vec![0u64; plan.ntt_domain_len()];
1014 let roundtrip = &mut *alloc::vec![0u64; n];
1015
1016 let p_div = Div64::new(p);
1017 let mul = |a, b| mul_mod_u64(p_div, a, b);
1018
1019 let n_inv_mod_p = modular_inv_u64(p_div, n as u64);
1020 plan.fwd(ntt, standard, FwdMode::Generic);
1021 plan.inv(roundtrip, ntt, InvMode::Replace);
1022 for x in roundtrip.iter_mut() {
1023 *x = mul(*x, n_inv_mod_p);
1024 }
1025
1026 assert_eq!(roundtrip, standard);
1027 }
1028
1029 #[test]
1030 fn test_product_u32x2() {
1031 let n = 256;
1032
1033 let p0 =
1034 largest_prime_in_arithmetic_progression64(2 * n as u64, 1, 0, u32::MAX as u64).unwrap();
1035 let p1 = largest_prime_in_arithmetic_progression64(2 * n as u64, 1, 0, p0 - 1).unwrap();
1036
1037 let p = p0 * p1;
1038 let plan = Plan::try_new(n, p, [p0, p1]).unwrap();
1039
1040 let standard = &*(0..n)
1041 .map(|_| rand::random::<u64>() % p)
1042 .collect::<Box<[_]>>();
1043 for inv_mode in [InvMode::Replace, InvMode::Accumulate] {
1044 let ntt = &mut *alloc::vec![0u64; plan.ntt_domain_len()];
1045 let roundtrip = &mut *alloc::vec![0u64; n];
1046
1047 let p_div = Div64::new(p);
1048 let mul = |a, b| mul_mod_u64(p_div, a, b);
1049
1050 let n_inv_mod_p = modular_inv_u64(p_div, n as u64);
1051 plan.fwd(ntt, standard, FwdMode::Generic);
1052 plan.inv(roundtrip, ntt, inv_mode);
1053 for x in roundtrip.iter_mut() {
1054 *x = mul(*x, n_inv_mod_p);
1055 }
1056
1057 assert_eq!(roundtrip, standard);
1058 }
1059 }
1060
1061 #[test]
1062 fn test_product_u30x2() {
1063 let n = 256;
1064
1065 let p0 = largest_prime_in_arithmetic_progression64(2 * n as u64, 1, 0, 1 << 30).unwrap();
1066 let p1 = largest_prime_in_arithmetic_progression64(2 * n as u64, 1, 0, p0 - 1).unwrap();
1067
1068 let p = p0 * p1;
1069 let plan = Plan::try_new(n, p, [p0, p1]).unwrap();
1070
1071 let standard = &*(0..n)
1072 .map(|_| rand::random::<u64>() % p)
1073 .collect::<Box<[_]>>();
1074 for inv_mode in [InvMode::Replace, InvMode::Accumulate] {
1075 let ntt = &mut *alloc::vec![0u64; plan.ntt_domain_len()];
1076 let roundtrip = &mut *alloc::vec![0u64; n];
1077
1078 let p_div = Div64::new(p);
1079 let mul = |a, b| mul_mod_u64(p_div, a, b);
1080
1081 let n_inv_mod_p = modular_inv_u64(p_div, n as u64);
1082 plan.fwd(ntt, standard, FwdMode::Generic);
1083 plan.inv(roundtrip, ntt, inv_mode);
1084 for x in roundtrip.iter_mut() {
1085 *x = mul(*x, n_inv_mod_p);
1086 }
1087
1088 assert_eq!(roundtrip, standard);
1089 }
1090 }
1091
1092 #[test]
1093 fn test_product_u32x4() {
1094 let n = 256;
1095
1096 let p0 =
1097 largest_prime_in_arithmetic_progression64(2 * n as u64, 1, 0, u16::MAX as u64).unwrap();
1098 let p1 = largest_prime_in_arithmetic_progression64(2 * n as u64, 1, 0, p0 - 1).unwrap();
1099 let p2 = largest_prime_in_arithmetic_progression64(2 * n as u64, 1, 0, p1 - 1).unwrap();
1100 let p3 = largest_prime_in_arithmetic_progression64(2 * n as u64, 1, 0, p2 - 1).unwrap();
1101
1102 let p = p0 * p1 * p2 * p3;
1103 let plan = Plan::try_new(n, p, [p0, p1, p2, p3]).unwrap();
1104
1105 let standard = &*(0..n)
1106 .map(|_| rand::random::<u64>() % p)
1107 .collect::<Box<[_]>>();
1108 let ntt = &mut *alloc::vec![0u64; plan.ntt_domain_len()];
1109 let roundtrip = &mut *alloc::vec![0u64; n];
1110
1111 let p_div = Div64::new(p);
1112 let mul = |a, b| mul_mod_u64(p_div, a, b);
1113
1114 let n_inv_mod_p = modular_inv_u64(p_div, n as u64);
1115 plan.fwd(ntt, standard, FwdMode::Generic);
1116 plan.inv(roundtrip, ntt, InvMode::Replace);
1117 for x in roundtrip.iter_mut() {
1118 *x = mul(*x, n_inv_mod_p);
1119 }
1120
1121 assert_eq!(roundtrip, standard);
1122 }
1123
1124 #[test]
1125 fn test_product_u32x2_u64x1() {
1126 let n = 256;
1127
1128 let p0 = largest_prime_in_arithmetic_progression64(2 * n as u64, 1, 0, 1u64 << 33).unwrap();
1129 let p1 = largest_prime_in_arithmetic_progression64(2 * n as u64, 1, 0, 1u64 << 15).unwrap();
1130 let p2 = largest_prime_in_arithmetic_progression64(2 * n as u64, 1, 0, p1 - 1).unwrap();
1131
1132 let p = p0 * p1 * p2;
1133 let plan = Plan::try_new(n, p, [p0, p1, p2]).unwrap();
1134
1135 let standard = &*(0..n)
1136 .map(|_| rand::random::<u64>() % p)
1137 .collect::<Box<[_]>>();
1138 let ntt = &mut *alloc::vec![0u64; plan.ntt_domain_len()];
1139 let roundtrip = &mut *alloc::vec![0u64; n];
1140
1141 let p_div = Div64::new(p);
1142 let mul = |a, b| mul_mod_u64(p_div, a, b);
1143
1144 let n_inv_mod_p = modular_inv_u64(p_div, n as u64);
1145 plan.fwd(ntt, standard, FwdMode::Generic);
1146 plan.inv(roundtrip, ntt, InvMode::Replace);
1147 for x in roundtrip.iter_mut() {
1148 *x = mul(*x, n_inv_mod_p);
1149 }
1150
1151 assert_eq!(roundtrip, standard);
1152 }
1153
1154 #[test]
1155 fn test_plan_failure_zero() {
1156 let n = 256;
1157 let p0 = largest_prime_in_arithmetic_progression64(2 * n as u64, 1, 0, 1u64 << 33).unwrap();
1158 assert!(Plan::try_new(n, 0, [p0, 0]).is_none());
1159 }
1160
1161 #[test]
1162 fn test_plan_failure_dup() {
1163 let n = 256;
1164 let p0 = largest_prime_in_arithmetic_progression64(2 * n as u64, 1, 0, 1u64 << 33).unwrap();
1165 let p1 = largest_prime_in_arithmetic_progression64(2 * n as u64, 1, 0, 1u64 << 15).unwrap();
1166 assert!(Plan::try_new(n, p0 * p1 * p1, [p1, p0, p1]).is_none());
1167 }
1168}