tfhe_ntt/
product.rs

1use crate::{
2    fastdiv::{Div32, Div64},
3    izip, prime32, prime64,
4};
5
6// for no_std environments
7extern crate alloc;
8type Box<T> = alloc::boxed::Box<T>;
9
10#[derive(Copy, Clone, Debug, PartialEq, Eq)]
11pub enum FwdMode {
12    Generic,
13    Bounded(u64),
14}
15
16#[derive(Copy, Clone, Debug, PartialEq, Eq)]
17pub enum InvMode {
18    Replace,
19    Accumulate,
20}
21
22fn modular_inv_u32(modulus: Div32, n: u32) -> u32 {
23    let modulus_div = modulus;
24    let modulus = modulus.divisor();
25
26    let mut old_r = Div32::rem(n, modulus_div);
27    let mut r = modulus;
28
29    let mut old_s = 1u32;
30    let mut s = 0u32;
31
32    while r != 0 {
33        let q = old_r / r;
34        (old_r, r) = (r, old_r - q * r);
35        (old_s, s) = (
36            s,
37            sub_mod_u32(modulus, old_s, mul_mod_u32(modulus_div, q, s)),
38        );
39    }
40
41    old_s
42}
43
44fn modular_inv_u64(modulus: Div64, n: u64) -> u64 {
45    let modulus_div = modulus;
46    let modulus = modulus.divisor();
47
48    let mut old_r = Div64::rem(n, modulus_div);
49    let mut r = modulus;
50
51    let mut old_s = 1u64;
52    let mut s = 0u64;
53
54    while r != 0 {
55        let q = old_r / r;
56        (old_r, r) = (r, old_r - q * r);
57        (old_s, s) = (
58            s,
59            sub_mod_u64(modulus, old_s, mul_mod_u64(modulus_div, q, s)),
60        );
61    }
62
63    old_s
64}
65
66#[inline]
67fn sub_mod_u64(modulus: u64, a: u64, b: u64) -> u64 {
68    if a >= b {
69        a - b
70    } else {
71        a.wrapping_sub(b).wrapping_add(modulus)
72    }
73}
74
75#[inline]
76fn sub_mod_u32(modulus: u32, a: u32, b: u32) -> u32 {
77    if a >= b {
78        a - b
79    } else {
80        a.wrapping_sub(b).wrapping_add(modulus)
81    }
82}
83
84#[inline]
85fn add_mod_u64(modulus: u64, a: u64, b: u64) -> u64 {
86    let (sum, overflow) = a.overflowing_add(b);
87    if sum >= modulus || overflow {
88        sum.wrapping_sub(modulus)
89    } else {
90        sum
91    }
92}
93
94#[inline]
95fn add_mod_u64_less_than_2_63(modulus: u64, a: u64, b: u64) -> u64 {
96    debug_assert!(modulus < 1 << 63);
97
98    let sum = a + b;
99    if sum >= modulus {
100        sum - modulus
101    } else {
102        sum
103    }
104}
105
106#[inline]
107fn add_mod_u32(modulus: u32, a: u32, b: u32) -> u32 {
108    let (sum, overflow) = a.overflowing_add(b);
109    if sum >= modulus || overflow {
110        sum.wrapping_sub(modulus)
111    } else {
112        sum
113    }
114}
115
116#[inline]
117fn mul_mod_u64(modulus: Div64, a: u64, b: u64) -> u64 {
118    Div64::rem_u128(a as u128 * b as u128, modulus)
119}
120
121#[inline]
122fn mul_mod_u32(modulus: Div32, a: u32, b: u32) -> u32 {
123    Div32::rem_u64(a as u64 * b as u64, modulus)
124}
125
126#[inline]
127fn shoup_mul_mod_u32(modulus: u32, a: u32, b: u32, b_shoup: u32) -> u32 {
128    debug_assert!(modulus < 1 << 31);
129    let q = ((a as u64 * b_shoup as u64) >> 32) as u32;
130    let mut r = u32::wrapping_sub(b.wrapping_mul(a), q.wrapping_mul(modulus));
131    if r >= modulus {
132        r -= modulus
133    }
134    r
135}
136
137/// Negacyclic NTT plan for 64bit product of distinct primes.
138#[derive(Clone, Debug)]
139pub struct Plan {
140    polynomial_size: usize,
141    modulus: u64,
142    modular_inverses: Box<[u64]>,
143    plan_32: Box<[prime32::Plan]>,
144    plan_64: Box<[prime64::Plan]>,
145    div_32: Box<[Div32]>,
146    div_64: Box<[Div64]>,
147}
148
149impl Plan {
150    /// Returns a negacyclic NTT plan for the given polynomial size and modulus (product of the
151    /// given distinct primes), or `None` if no suitable roots of unity can be found for the
152    /// wanted parameters.
153    pub fn try_new(
154        polynomial_size: usize,
155        modulus: u64,
156        factors: impl AsRef<[u64]>,
157    ) -> Option<Self> {
158        fn try_new_impl(polynomial_size: usize, modulus: u64, primes: &mut [u64]) -> Option<Plan> {
159            if polynomial_size % 2 != 0 {
160                return None;
161            }
162
163            // check for zeros/duplicates
164            primes.sort_unstable();
165
166            let mut prev = 0;
167            for &factor in &*primes {
168                if factor == prev {
169                    return None;
170                }
171                prev = factor;
172            }
173
174            let start = primes.partition_point(|&modulus| modulus == 1);
175            let primes = &primes[start..];
176
177            if primes
178                .iter()
179                .try_fold(1u64, |prod, &modulus| prod.checked_mul(modulus))
180                != Some(modulus)
181            {
182                return None;
183            };
184
185            let mid = primes.partition_point(|&modulus| modulus < (1u64 << 32));
186            let (primes_32, primes_64) = primes.split_at(mid);
187
188            let plan_32 = primes_32
189                .iter()
190                .map(|&modulus| prime32::Plan::try_new(polynomial_size, modulus as u32))
191                .collect::<Option<Box<[_]>>>()?;
192
193            let plan_64 = primes_64
194                .iter()
195                .map(|&modulus| prime64::Plan::try_new(polynomial_size, modulus))
196                .collect::<Option<Box<[_]>>>()?;
197
198            let div_32 = plan_32
199                .iter()
200                .map(prime32::Plan::p_div)
201                .collect::<Box<[_]>>();
202            let div_64 = plan_64
203                .iter()
204                .map(prime64::Plan::p_div)
205                .collect::<Box<[_]>>();
206
207            let len = primes.len();
208
209            let mut modular_inverses = alloc::vec![0u64; (len * (len - 1)) / 2].into_boxed_slice();
210            let mut offset = 0;
211            for (j, pj) in plan_32.iter().map(prime32::Plan::p_div).enumerate() {
212                for (inv, &pi) in modular_inverses[offset..][..j]
213                    .iter_mut()
214                    .zip(&primes_32[..j])
215                {
216                    *inv = modular_inv_u32(pj, pi as u32) as u64;
217                }
218                offset += j;
219            }
220
221            let count_32 = plan_32.len();
222            for (j, pj) in plan_64.iter().map(prime64::Plan::p_div).enumerate() {
223                let j = j + count_32;
224
225                for (inv, &pi) in modular_inverses[offset..][..j].iter_mut().zip(&primes[..j]) {
226                    *inv = modular_inv_u64(pj, pi);
227                }
228                offset += j;
229            }
230
231            Some(Plan {
232                polynomial_size,
233                modulus,
234                modular_inverses,
235                plan_32,
236                plan_64,
237                div_32,
238                div_64,
239            })
240        }
241
242        try_new_impl(
243            polynomial_size,
244            modulus,
245            &mut factors.as_ref().iter().copied().collect::<Box<[_]>>(),
246        )
247    }
248
249    /// Returns the polynomial size of the negacyclic NTT plan.
250    #[inline]
251    pub fn ntt_size(&self) -> usize {
252        self.polynomial_size
253    }
254
255    /// Returns the modulus of the negacyclic NTT plan.
256    #[inline]
257    pub fn modulus(&self) -> u64 {
258        self.modulus
259    }
260
261    fn ntt_domain_len_u32(&self) -> usize {
262        (self.polynomial_size / 2) * self.plan_32.len()
263    }
264    fn ntt_domain_len_u64(&self) -> usize {
265        self.polynomial_size * self.plan_64.len()
266    }
267
268    pub fn ntt_domain_len(&self) -> usize {
269        self.ntt_domain_len_u32() + self.ntt_domain_len_u64()
270    }
271
272    #[track_caller]
273    pub fn fwd(&self, ntt: &mut [u64], standard: &[u64], mode: FwdMode) {
274        assert_eq!(standard.len(), self.ntt_size());
275        assert_eq!(ntt.len(), self.ntt_domain_len());
276
277        let (ntt_32, ntt_64) = ntt.split_at_mut(self.ntt_domain_len_u32());
278        let ntt_32: &mut [u32] = bytemuck::cast_slice_mut(ntt_32);
279
280        // optimize common cases(?): u64x1, u32x1
281        if self.plan_32.is_empty() && self.plan_64.len() == 1 {
282            ntt_64.copy_from_slice(standard);
283            self.plan_64[0].fwd(ntt_64);
284            return;
285        }
286        if self.plan_32.len() == 1 && self.plan_64.is_empty() {
287            for (ntt, &standard) in ntt_32.iter_mut().zip(standard) {
288                *ntt = standard as u32;
289            }
290            self.plan_32[0].fwd(ntt_32);
291            return;
292        }
293
294        if self.plan_32.len() == 2 && self.plan_64.is_empty() {
295            let (ntt0, ntt1) = ntt_32.split_at_mut(self.ntt_size());
296            let p0_div = self.plan_32[0].p_div();
297            let p1_div = self.plan_32[1].p_div();
298            let p0 = self.plan_32[0].modulus();
299            let p1 = self.plan_32[1].modulus();
300            let p = self.modulus();
301            let p_u32 = p as u32;
302
303            match mode {
304                FwdMode::Bounded(bound) if bound < p0 as u64 && bound < p1 as u64 => {
305                    for ((ntt0, ntt1), &standard) in
306                        ntt0.iter_mut().zip(ntt1.iter_mut()).zip(standard)
307                    {
308                        let positive = standard < p / 2;
309                        let standard = standard as u32;
310                        let complement = p_u32.wrapping_sub(standard);
311                        *ntt0 = if positive {
312                            standard
313                        } else {
314                            p0.wrapping_sub(complement)
315                        };
316                        *ntt1 = if positive {
317                            standard
318                        } else {
319                            p1.wrapping_sub(complement)
320                        };
321                    }
322                }
323                _ => {
324                    for ((ntt0, ntt1), &standard) in
325                        ntt0.iter_mut().zip(ntt1.iter_mut()).zip(standard)
326                    {
327                        *ntt0 = Div32::rem_u64(standard, p0_div);
328                        *ntt1 = Div32::rem_u64(standard, p1_div);
329                    }
330                }
331            }
332
333            self.plan_32[0].fwd(ntt0);
334            self.plan_32[1].fwd(ntt1);
335
336            return;
337        }
338
339        for (ntt, plan) in ntt_32.chunks_exact_mut(self.ntt_size()).zip(&*self.plan_32) {
340            let modulus = plan.p_div();
341
342            for (ntt, &standard) in ntt.iter_mut().zip(standard) {
343                *ntt = Div32::rem_u64(standard, modulus);
344            }
345
346            plan.fwd(ntt);
347        }
348
349        for (ntt, plan) in ntt_64.chunks_exact_mut(self.ntt_size()).zip(&*self.plan_64) {
350            let modulus = plan.p_div();
351            for (ntt, &standard) in ntt.iter_mut().zip(standard) {
352                *ntt = Div64::rem(standard, modulus);
353            }
354
355            plan.fwd(ntt);
356        }
357    }
358
359    #[track_caller]
360    pub fn inv(&self, standard: &mut [u64], ntt: &mut [u64], mode: InvMode) {
361        assert_eq!(standard.len(), self.ntt_size());
362        assert_eq!(ntt.len(), self.ntt_domain_len());
363
364        let (ntt_32, ntt_64) = ntt.split_at_mut(self.ntt_domain_len_u32());
365        let ntt_32: &mut [u32] = bytemuck::cast_slice_mut(ntt_32);
366
367        for (ntt, plan) in ntt_32.chunks_exact_mut(self.ntt_size()).zip(&*self.plan_32) {
368            plan.inv(ntt);
369        }
370        for (ntt, plan) in ntt_64.chunks_exact_mut(self.ntt_size()).zip(&*self.plan_64) {
371            plan.inv(ntt);
372        }
373
374        let ntt_32 = &*ntt_32;
375        let ntt_64 = &*ntt_64;
376
377        // optimize common cases(?): u64x1, u32x1, u32x2
378        if self.plan_32.is_empty() && self.plan_64.is_empty() {
379            match mode {
380                InvMode::Replace => standard.fill(0),
381                InvMode::Accumulate => {}
382            }
383            return;
384        }
385
386        if self.plan_32.is_empty() && self.plan_64.len() == 1 {
387            match mode {
388                InvMode::Replace => standard.copy_from_slice(ntt_64),
389                InvMode::Accumulate => {
390                    let p = self.plan_64[0].modulus();
391
392                    for (standard, &ntt) in standard.iter_mut().zip(ntt_64) {
393                        *standard = add_mod_u64(p, *standard, ntt);
394                    }
395                }
396            }
397            return;
398        }
399        if self.plan_32.len() == 1 && self.plan_64.is_empty() {
400            match mode {
401                InvMode::Replace => {
402                    for (standard, &ntt) in standard.iter_mut().zip(ntt_32) {
403                        *standard = ntt as u64;
404                    }
405                }
406                InvMode::Accumulate => {
407                    let p = self.plan_32[0].modulus();
408
409                    for (standard, &ntt) in standard.iter_mut().zip(ntt_32) {
410                        *standard = add_mod_u32(p, *standard as u32, ntt) as u64;
411                    }
412                }
413            }
414            return;
415        }
416
417        // implements the algorithms from "the art of computer programming (Donald E. Knuth)" 4.3.2
418        // for finding solutions of the chinese remainder theorem
419        if self.plan_32.len() == 2 && self.plan_64.is_empty() {
420            let (ntt0, ntt1) = ntt_32.split_at(self.ntt_size());
421            let p0 = self.plan_32[0].modulus();
422            let p1 = self.plan_32[1].modulus();
423            let p = self.modulus();
424            let p1_div = self.plan_32[1].p_div();
425
426            let inv = self.modular_inverses[0] as u32;
427
428            if p1 < 1 << 31 {
429                let inv_shoup = Div32::div_u64((inv as u64) << 32, p1_div) as u32;
430                match mode {
431                    InvMode::Replace => {
432                        for (standard, &ntt0, &ntt1) in izip!(standard.iter_mut(), ntt0, ntt1) {
433                            let u0 = ntt0;
434                            let u1 = ntt1;
435
436                            let v0 = u0;
437
438                            let diff = sub_mod_u32(p1, u1, v0);
439                            let v1 = shoup_mul_mod_u32(p1, diff, inv, inv_shoup);
440
441                            *standard = v0 as u64 + (v1 as u64 * p0 as u64);
442                        }
443                    }
444                    // we optimize this path in particular because it corresponds to a possibly hot
445                    // loop in tfhe-rs (ntt pbs with modulus = product of two u32 primes < 2^31)
446                    InvMode::Accumulate => {
447                        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
448                        {
449                            #[cfg(feature = "nightly")]
450                            if let Some(simd) = pulp::x86::V4::try_new() {
451                                struct Impl<'a> {
452                                    simd: pulp::x86::V4,
453                                    standard: &'a mut [u64],
454                                    ntt0: &'a [u32],
455                                    ntt1: &'a [u32],
456                                    p: u64,
457                                    p0: u32,
458                                    p1: u32,
459                                    inv: u32,
460                                    inv_shoup: u32,
461                                }
462
463                                impl pulp::NullaryFnOnce for Impl<'_> {
464                                    type Output = ();
465
466                                    #[inline(always)]
467                                    fn call(self) -> Self::Output {
468                                        let Self {
469                                            simd,
470                                            standard,
471                                            ntt0,
472                                            ntt1,
473                                            p,
474                                            p0,
475                                            p1,
476                                            inv,
477                                            inv_shoup,
478                                        } = self;
479
480                                        {
481                                            let standard = pulp::as_arrays_mut::<8, _>(standard).0;
482                                            let ntt0 = pulp::as_arrays::<8, _>(ntt0).0;
483                                            let ntt1 = pulp::as_arrays::<8, _>(ntt1).0;
484
485                                            let standard: &mut [pulp::u64x8] =
486                                                bytemuck::cast_slice_mut(standard);
487                                            let ntt0: &[pulp::u32x8] = bytemuck::cast_slice(ntt0);
488                                            let ntt1: &[pulp::u32x8] = bytemuck::cast_slice(ntt1);
489
490                                            let p1_u32 = simd.splat_u32x8(p1);
491                                            let p1_u64 = simd.convert_u32x8_to_u64x8(p1_u32);
492                                            let p0 =
493                                                simd.convert_u32x8_to_u64x8(simd.splat_u32x8(p0));
494                                            let p = simd.splat_u64x8(p);
495                                            let inv =
496                                                simd.convert_u32x8_to_u64x8(simd.splat_u32x8(inv));
497                                            let inv_shoup = simd.convert_u32x8_to_u64x8(
498                                                simd.splat_u32x8(inv_shoup),
499                                            );
500
501                                            for (standard, &ntt0, &ntt1) in
502                                                izip!(standard.iter_mut(), ntt0, ntt1)
503                                            {
504                                                let u0 = ntt0;
505                                                let u1 = ntt1;
506
507                                                let v0 = u0;
508
509                                                let diff = simd.wrapping_sub_u32x8(u1, v0);
510                                                let diff = simd.min_u32x8(
511                                                    diff,
512                                                    simd.wrapping_add_u32x8(diff, p1_u32),
513                                                );
514                                                let diff = simd.convert_u32x8_to_u64x8(diff);
515
516                                                let v1: pulp::u64x8 = {
517                                                    // shoup mul mod
518                                                    let a = diff;
519                                                    let b = inv;
520                                                    let b_shoup = inv_shoup;
521                                                    let modulus = p1_u64;
522
523                                                    let q =
524                                                        pulp::cast(simd.avx512f._mm512_mul_epu32(
525                                                            pulp::cast(a),
526                                                            pulp::cast(b_shoup),
527                                                        ));
528                                                    let q = simd.shr_const_u64x8::<32>(q);
529
530                                                    let ab =
531                                                        pulp::cast(simd.avx512f._mm512_mul_epu32(
532                                                            pulp::cast(a),
533                                                            pulp::cast(b),
534                                                        ));
535
536                                                    let qmod =
537                                                        pulp::cast(simd.avx512f._mm512_mul_epu32(
538                                                            pulp::cast(q),
539                                                            pulp::cast(modulus),
540                                                        ));
541
542                                                    let r = simd.wrapping_sub_u32x16(ab, qmod);
543                                                    let r = simd.and_u32x16(
544                                                        r,
545                                                        pulp::u32x16(
546                                                            !0, 0, !0, 0, !0, 0, !0, 0, !0, 0, !0,
547                                                            0, !0, 0, !0, 0,
548                                                        ),
549                                                    );
550
551                                                    let r = simd.min_u32x16(
552                                                        r,
553                                                        simd.wrapping_sub_u32x16(
554                                                            r,
555                                                            pulp::cast(modulus),
556                                                        ),
557                                                    );
558                                                    pulp::cast(r)
559                                                };
560
561                                                let v0 = simd.convert_u32x8_to_u64x8(v0);
562                                                let v = simd.wrapping_add_u64x8(
563                                                    v0,
564                                                    pulp::cast(simd.avx512f._mm512_mul_epu32(
565                                                        pulp::cast(v1),
566                                                        pulp::cast(p0),
567                                                    )),
568                                                );
569                                                let sum = simd.wrapping_add_u64x8(*standard, v);
570                                                let smaller_than_p = simd.cmp_lt_u64x8(sum, p);
571                                                *standard = simd.select_u64x8(
572                                                    smaller_than_p,
573                                                    sum,
574                                                    simd.wrapping_sub_u64x8(sum, p),
575                                                );
576                                            }
577                                        }
578                                    }
579                                }
580
581                                simd.vectorize(Impl {
582                                    simd,
583                                    standard,
584                                    ntt0,
585                                    ntt1,
586                                    p,
587                                    p0,
588                                    p1,
589                                    inv,
590                                    inv_shoup,
591                                });
592
593                                return;
594                            }
595
596                            if let Some(simd) = pulp::x86::V3::try_new() {
597                                struct Impl<'a> {
598                                    simd: pulp::x86::V3,
599                                    standard: &'a mut [u64],
600                                    ntt0: &'a [u32],
601                                    ntt1: &'a [u32],
602                                    p: u64,
603                                    p0: u32,
604                                    p1: u32,
605                                    inv: u32,
606                                    inv_shoup: u32,
607                                }
608
609                                impl pulp::NullaryFnOnce for Impl<'_> {
610                                    type Output = ();
611
612                                    #[inline(always)]
613                                    fn call(self) -> Self::Output {
614                                        let Self {
615                                            simd,
616                                            standard,
617                                            ntt0,
618                                            ntt1,
619                                            p,
620                                            p0,
621                                            p1,
622                                            inv,
623                                            inv_shoup,
624                                        } = self;
625
626                                        {
627                                            let standard = pulp::as_arrays_mut::<4, _>(standard).0;
628                                            let ntt0 = pulp::as_arrays::<4, _>(ntt0).0;
629                                            let ntt1 = pulp::as_arrays::<4, _>(ntt1).0;
630
631                                            let standard: &mut [pulp::u64x4] =
632                                                bytemuck::cast_slice_mut(standard);
633                                            let ntt0: &[pulp::u32x4] = bytemuck::cast_slice(ntt0);
634                                            let ntt1: &[pulp::u32x4] = bytemuck::cast_slice(ntt1);
635
636                                            let p1_u32 = simd.splat_u32x4(p1);
637                                            let p1_u64 = simd.convert_u32x4_to_u64x4(p1_u32);
638                                            let p0 =
639                                                simd.convert_u32x4_to_u64x4(simd.splat_u32x4(p0));
640                                            let p = simd.splat_u64x4(p);
641                                            let inv =
642                                                simd.convert_u32x4_to_u64x4(simd.splat_u32x4(inv));
643                                            let inv_shoup = simd.convert_u32x4_to_u64x4(
644                                                simd.splat_u32x4(inv_shoup),
645                                            );
646
647                                            for (standard, &ntt0, &ntt1) in
648                                                izip!(standard.iter_mut(), ntt0, ntt1)
649                                            {
650                                                let u0 = ntt0;
651                                                let u1 = ntt1;
652
653                                                let v0 = u0;
654
655                                                let diff = simd.wrapping_sub_u32x4(u1, v0);
656                                                let diff = simd.min_u32x4(
657                                                    diff,
658                                                    simd.wrapping_add_u32x4(diff, p1_u32),
659                                                );
660                                                let diff = simd.convert_u32x4_to_u64x4(diff);
661
662                                                let v1: pulp::u64x4 = {
663                                                    // shoup mul mod
664                                                    let a = diff;
665                                                    let b = inv;
666                                                    let b_shoup = inv_shoup;
667                                                    let modulus = p1_u64;
668
669                                                    let q = pulp::cast(simd.avx2._mm256_mul_epu32(
670                                                        pulp::cast(a),
671                                                        pulp::cast(b_shoup),
672                                                    ));
673                                                    let q = simd.shr_const_u64x4::<32>(q);
674
675                                                    let ab =
676                                                        pulp::cast(simd.avx2._mm256_mul_epu32(
677                                                            pulp::cast(a),
678                                                            pulp::cast(b),
679                                                        ));
680
681                                                    let qmod =
682                                                        pulp::cast(simd.avx2._mm256_mul_epu32(
683                                                            pulp::cast(q),
684                                                            pulp::cast(modulus),
685                                                        ));
686
687                                                    let r = simd.wrapping_sub_u32x8(ab, qmod);
688                                                    let r = simd.and_u32x8(
689                                                        r,
690                                                        pulp::u32x8(!0, 0, !0, 0, !0, 0, !0, 0),
691                                                    );
692
693                                                    let r = simd.min_u32x8(
694                                                        r,
695                                                        simd.wrapping_sub_u32x8(
696                                                            r,
697                                                            pulp::cast(modulus),
698                                                        ),
699                                                    );
700                                                    pulp::cast(r)
701                                                };
702
703                                                let v0 = simd.convert_u32x4_to_u64x4(v0);
704                                                let v = simd.wrapping_add_u64x4(
705                                                    v0,
706                                                    pulp::cast(simd.avx2._mm256_mul_epu32(
707                                                        pulp::cast(v1),
708                                                        pulp::cast(p0),
709                                                    )),
710                                                );
711                                                let sum = simd.wrapping_add_u64x4(*standard, v);
712                                                let smaller_than_p = simd.cmp_lt_u64x4(sum, p);
713                                                *standard = simd.select_u64x4(
714                                                    smaller_than_p,
715                                                    sum,
716                                                    simd.wrapping_sub_u64x4(sum, p),
717                                                );
718                                            }
719                                        }
720                                    }
721                                }
722
723                                simd.vectorize(Impl {
724                                    simd,
725                                    standard,
726                                    ntt0,
727                                    ntt1,
728                                    p,
729                                    p0,
730                                    p1,
731                                    inv,
732                                    inv_shoup,
733                                });
734
735                                return;
736                            }
737                        }
738
739                        for (standard, &ntt0, &ntt1) in izip!(standard.iter_mut(), ntt0, ntt1) {
740                            let u0 = ntt0;
741                            let u1 = ntt1;
742
743                            let v0 = u0;
744
745                            let diff = sub_mod_u32(p1, u1, v0);
746                            let v1 = shoup_mul_mod_u32(p1, diff, inv, inv_shoup);
747
748                            *standard = add_mod_u64_less_than_2_63(
749                                p,
750                                *standard,
751                                v0 as u64 + (v1 as u64 * p0 as u64),
752                            );
753                        }
754                    }
755                }
756            } else {
757                match mode {
758                    InvMode::Replace => {
759                        for (standard, &ntt0, &ntt1) in izip!(standard.iter_mut(), ntt0, ntt1) {
760                            let u0 = ntt0;
761                            let u1 = ntt1;
762
763                            let v0 = u0;
764
765                            let diff = sub_mod_u32(p1, u1, v0);
766                            let v1 = mul_mod_u32(p1_div, diff, inv);
767
768                            *standard = v0 as u64 + (v1 as u64 * p0 as u64);
769                        }
770                    }
771                    InvMode::Accumulate => {
772                        for (standard, &ntt0, &ntt1) in izip!(standard.iter_mut(), ntt0, ntt1) {
773                            let u0 = ntt0;
774                            let u1 = ntt1;
775
776                            let v0 = u0;
777
778                            let diff = sub_mod_u32(p1, u1, v0);
779                            let v1 = mul_mod_u32(p1_div, diff, inv);
780
781                            *standard =
782                                add_mod_u64(p, *standard, v0 as u64 + (v1 as u64 * p0 as u64));
783                        }
784                    }
785                }
786            }
787
788            return;
789        }
790
791        let u_32 = &mut *alloc::vec![0u32; self.plan_32.len()];
792        let v_32 = &mut *alloc::vec![0u32; self.plan_32.len()];
793        let u_64 = &mut *alloc::vec![0u64; self.plan_64.len()];
794        let v_64 = &mut *alloc::vec![0u64; self.plan_64.len()];
795
796        let div_32 = &*self.div_32;
797        let div_64 = &*self.div_64;
798
799        let p = self.modulus();
800
801        let count_32 = self.plan_32.len();
802
803        let modular_inverses = &*self.modular_inverses;
804
805        for (idx, standard) in standard.iter_mut().enumerate() {
806            let ntt_32 = ntt_32.get(idx..).unwrap_or(&[]);
807            let ntt_64 = ntt_64.get(idx..).unwrap_or(&[]);
808
809            let ntt_32 = ntt_32.iter().step_by(self.ntt_size()).copied();
810            let ntt_64 = ntt_64.iter().step_by(self.ntt_size()).copied();
811
812            u_32.iter_mut()
813                .zip(ntt_32)
814                .for_each(|(dst, src)| *dst = src);
815            u_64.iter_mut()
816                .zip(ntt_64)
817                .for_each(|(dst, src)| *dst = src);
818
819            let u_32 = &*u_32;
820            let u_64 = &*u_64;
821
822            let mut offset = 0;
823
824            for (j, (&uj, &div_j)) in u_32.iter().zip(div_32).enumerate() {
825                let pj = div_j.divisor();
826                let mut x = uj;
827                {
828                    let v = &v_32[..j];
829
830                    for (&vj, &inv) in v.iter().zip(&modular_inverses[offset..][..j]) {
831                        let diff = sub_mod_u32(pj, x, vj);
832                        x = mul_mod_u32(div_j, diff, inv as u32);
833                    }
834                    offset += j;
835                }
836                v_32[j] = x;
837            }
838
839            for (j, (&uj, &div_j)) in u_64.iter().zip(div_64).enumerate() {
840                let pj = div_j.divisor();
841                let mut x = uj;
842                {
843                    let v = &*v_32;
844
845                    for (&vj, &inv) in v.iter().zip(&modular_inverses[offset..][..count_32]) {
846                        let diff = sub_mod_u64(pj, x, vj as u64);
847                        x = mul_mod_u64(div_j, diff, inv);
848                    }
849                    offset += count_32;
850                }
851                {
852                    let v = &v_64[..j];
853
854                    for (&vj, &inv) in v.iter().zip(&modular_inverses[offset..][..j]) {
855                        let diff = sub_mod_u64(pj, x, vj);
856                        x = mul_mod_u64(div_j, diff, inv);
857                    }
858                    offset += j;
859                }
860                v_64[j] = x;
861            }
862
863            let mut acc = 0u64;
864            for (&v, &p) in v_64.iter().zip(div_64).rev() {
865                let p = p.divisor();
866                acc *= p;
867                acc += v;
868            }
869            for (&v, &p) in v_32.iter().zip(div_32).rev() {
870                let p = p.divisor();
871                acc *= p as u64;
872                acc += v as u64;
873            }
874
875            match mode {
876                InvMode::Replace => *standard = acc,
877                InvMode::Accumulate => *standard = add_mod_u64(p, *standard, acc),
878            }
879        }
880    }
881
882    /// Computes the elementwise product of `lhs` and `rhs`, multiplied by the inverse of the
883    /// polynomial modulo the NTT modulus, and stores the result in `lhs`.
884    #[track_caller]
885    pub fn mul_assign_normalize(&self, lhs: &mut [u64], rhs: &[u64]) {
886        assert_eq!(lhs.len(), self.ntt_domain_len());
887        assert_eq!(rhs.len(), self.ntt_domain_len());
888
889        let (lhs_32, lhs_64) = lhs.split_at_mut(self.ntt_domain_len_u32());
890        let (rhs_32, rhs_64) = rhs.split_at(self.ntt_domain_len_u32());
891
892        let lhs_32: &mut [u32] = bytemuck::cast_slice_mut(lhs_32);
893        let rhs_32: &[u32] = bytemuck::cast_slice(rhs_32);
894
895        let size = self.ntt_size();
896
897        for ((lhs, rhs), plan) in lhs_32
898            .chunks_exact_mut(size)
899            .zip(rhs_32.chunks_exact(size))
900            .zip(&*self.plan_32)
901        {
902            plan.mul_assign_normalize(lhs, rhs);
903        }
904
905        for ((lhs, rhs), plan) in lhs_64
906            .chunks_exact_mut(size)
907            .zip(rhs_64.chunks_exact(size))
908            .zip(&*self.plan_64)
909        {
910            plan.mul_assign_normalize(lhs, rhs);
911        }
912    }
913
914    /// Multiplies the values by the inverse of the polynomial modulo the NTT modulus, and stores
915    /// the result in `values`.
916    #[track_caller]
917    pub fn normalize(&self, values: &mut [u64]) {
918        assert_eq!(values.len(), self.ntt_domain_len());
919
920        let (values_32, values_64) = values.split_at_mut(self.ntt_domain_len_u32());
921        let values_32: &mut [u32] = bytemuck::cast_slice_mut(values_32);
922
923        let size = self.ntt_size();
924
925        for (values, plan) in values_32.chunks_exact_mut(size).zip(&*self.plan_32) {
926            plan.normalize(values);
927        }
928        for (values, plan) in values_64.chunks_exact_mut(size).zip(&*self.plan_64) {
929            plan.normalize(values);
930        }
931    }
932
933    /// Computes the elementwise product of `lhs` and `rhs` and accumulates the result to `acc`.
934    #[track_caller]
935    pub fn mul_accumulate(&self, acc: &mut [u64], lhs: &[u64], rhs: &[u64]) {
936        assert_eq!(lhs.len(), self.ntt_domain_len());
937        assert_eq!(rhs.len(), self.ntt_domain_len());
938
939        let (acc_32, acc_64) = acc.split_at_mut(self.ntt_domain_len_u32());
940        let (lhs_32, lhs_64) = lhs.split_at(self.ntt_domain_len_u32());
941        let (rhs_32, rhs_64) = rhs.split_at(self.ntt_domain_len_u32());
942
943        let acc_32: &mut [u32] = bytemuck::cast_slice_mut(acc_32);
944        let lhs_32: &[u32] = bytemuck::cast_slice(lhs_32);
945        let rhs_32: &[u32] = bytemuck::cast_slice(rhs_32);
946
947        let size = self.ntt_size();
948
949        for (((acc, lhs), rhs), plan) in acc_32
950            .chunks_exact_mut(size)
951            .zip(lhs_32.chunks_exact(size))
952            .zip(rhs_32.chunks_exact(size))
953            .zip(&*self.plan_32)
954        {
955            plan.mul_accumulate(acc, lhs, rhs);
956        }
957
958        for (((acc, lhs), rhs), plan) in acc_64
959            .chunks_exact_mut(size)
960            .zip(lhs_64.chunks_exact(size))
961            .zip(rhs_64.chunks_exact(size))
962            .zip(&*self.plan_64)
963        {
964            plan.mul_accumulate(acc, lhs, rhs);
965        }
966    }
967}
968
969#[cfg(test)]
970mod tests {
971    use super::*;
972    use crate::prime::largest_prime_in_arithmetic_progression64;
973
974    extern crate alloc;
975
976    #[test]
977    fn test_product_u64x1() {
978        let n = 256;
979
980        let p = largest_prime_in_arithmetic_progression64(2 * n as u64, 1, 0, u64::MAX).unwrap();
981        let plan = Plan::try_new(n, p, [p]).unwrap();
982
983        let standard = &*(0..n)
984            .map(|_| rand::random::<u64>() % p)
985            .collect::<Box<[_]>>();
986        let ntt = &mut *alloc::vec![0u64; plan.ntt_domain_len()];
987        let roundtrip = &mut *alloc::vec![0u64; n];
988
989        let p_div = Div64::new(p);
990        let mul = |a, b| mul_mod_u64(p_div, a, b);
991
992        let n_inv_mod_p = modular_inv_u64(p_div, n as u64);
993        plan.fwd(ntt, standard, FwdMode::Generic);
994        plan.inv(roundtrip, ntt, InvMode::Replace);
995        for x in roundtrip.iter_mut() {
996            *x = mul(*x, n_inv_mod_p);
997        }
998
999        assert_eq!(roundtrip, standard);
1000    }
1001
1002    #[test]
1003    fn test_product_u32x1() {
1004        let n = 256;
1005
1006        let p =
1007            largest_prime_in_arithmetic_progression64(2 * n as u64, 1, 0, u32::MAX as u64).unwrap();
1008        let plan = Plan::try_new(n, p, [p]).unwrap();
1009
1010        let standard = &*(0..n)
1011            .map(|_| rand::random::<u64>() % p)
1012            .collect::<Box<[_]>>();
1013        let ntt = &mut *alloc::vec![0u64; plan.ntt_domain_len()];
1014        let roundtrip = &mut *alloc::vec![0u64; n];
1015
1016        let p_div = Div64::new(p);
1017        let mul = |a, b| mul_mod_u64(p_div, a, b);
1018
1019        let n_inv_mod_p = modular_inv_u64(p_div, n as u64);
1020        plan.fwd(ntt, standard, FwdMode::Generic);
1021        plan.inv(roundtrip, ntt, InvMode::Replace);
1022        for x in roundtrip.iter_mut() {
1023            *x = mul(*x, n_inv_mod_p);
1024        }
1025
1026        assert_eq!(roundtrip, standard);
1027    }
1028
1029    #[test]
1030    fn test_product_u32x2() {
1031        let n = 256;
1032
1033        let p0 =
1034            largest_prime_in_arithmetic_progression64(2 * n as u64, 1, 0, u32::MAX as u64).unwrap();
1035        let p1 = largest_prime_in_arithmetic_progression64(2 * n as u64, 1, 0, p0 - 1).unwrap();
1036
1037        let p = p0 * p1;
1038        let plan = Plan::try_new(n, p, [p0, p1]).unwrap();
1039
1040        let standard = &*(0..n)
1041            .map(|_| rand::random::<u64>() % p)
1042            .collect::<Box<[_]>>();
1043        for inv_mode in [InvMode::Replace, InvMode::Accumulate] {
1044            let ntt = &mut *alloc::vec![0u64; plan.ntt_domain_len()];
1045            let roundtrip = &mut *alloc::vec![0u64; n];
1046
1047            let p_div = Div64::new(p);
1048            let mul = |a, b| mul_mod_u64(p_div, a, b);
1049
1050            let n_inv_mod_p = modular_inv_u64(p_div, n as u64);
1051            plan.fwd(ntt, standard, FwdMode::Generic);
1052            plan.inv(roundtrip, ntt, inv_mode);
1053            for x in roundtrip.iter_mut() {
1054                *x = mul(*x, n_inv_mod_p);
1055            }
1056
1057            assert_eq!(roundtrip, standard);
1058        }
1059    }
1060
1061    #[test]
1062    fn test_product_u30x2() {
1063        let n = 256;
1064
1065        let p0 = largest_prime_in_arithmetic_progression64(2 * n as u64, 1, 0, 1 << 30).unwrap();
1066        let p1 = largest_prime_in_arithmetic_progression64(2 * n as u64, 1, 0, p0 - 1).unwrap();
1067
1068        let p = p0 * p1;
1069        let plan = Plan::try_new(n, p, [p0, p1]).unwrap();
1070
1071        let standard = &*(0..n)
1072            .map(|_| rand::random::<u64>() % p)
1073            .collect::<Box<[_]>>();
1074        for inv_mode in [InvMode::Replace, InvMode::Accumulate] {
1075            let ntt = &mut *alloc::vec![0u64; plan.ntt_domain_len()];
1076            let roundtrip = &mut *alloc::vec![0u64; n];
1077
1078            let p_div = Div64::new(p);
1079            let mul = |a, b| mul_mod_u64(p_div, a, b);
1080
1081            let n_inv_mod_p = modular_inv_u64(p_div, n as u64);
1082            plan.fwd(ntt, standard, FwdMode::Generic);
1083            plan.inv(roundtrip, ntt, inv_mode);
1084            for x in roundtrip.iter_mut() {
1085                *x = mul(*x, n_inv_mod_p);
1086            }
1087
1088            assert_eq!(roundtrip, standard);
1089        }
1090    }
1091
1092    #[test]
1093    fn test_product_u32x4() {
1094        let n = 256;
1095
1096        let p0 =
1097            largest_prime_in_arithmetic_progression64(2 * n as u64, 1, 0, u16::MAX as u64).unwrap();
1098        let p1 = largest_prime_in_arithmetic_progression64(2 * n as u64, 1, 0, p0 - 1).unwrap();
1099        let p2 = largest_prime_in_arithmetic_progression64(2 * n as u64, 1, 0, p1 - 1).unwrap();
1100        let p3 = largest_prime_in_arithmetic_progression64(2 * n as u64, 1, 0, p2 - 1).unwrap();
1101
1102        let p = p0 * p1 * p2 * p3;
1103        let plan = Plan::try_new(n, p, [p0, p1, p2, p3]).unwrap();
1104
1105        let standard = &*(0..n)
1106            .map(|_| rand::random::<u64>() % p)
1107            .collect::<Box<[_]>>();
1108        let ntt = &mut *alloc::vec![0u64; plan.ntt_domain_len()];
1109        let roundtrip = &mut *alloc::vec![0u64; n];
1110
1111        let p_div = Div64::new(p);
1112        let mul = |a, b| mul_mod_u64(p_div, a, b);
1113
1114        let n_inv_mod_p = modular_inv_u64(p_div, n as u64);
1115        plan.fwd(ntt, standard, FwdMode::Generic);
1116        plan.inv(roundtrip, ntt, InvMode::Replace);
1117        for x in roundtrip.iter_mut() {
1118            *x = mul(*x, n_inv_mod_p);
1119        }
1120
1121        assert_eq!(roundtrip, standard);
1122    }
1123
1124    #[test]
1125    fn test_product_u32x2_u64x1() {
1126        let n = 256;
1127
1128        let p0 = largest_prime_in_arithmetic_progression64(2 * n as u64, 1, 0, 1u64 << 33).unwrap();
1129        let p1 = largest_prime_in_arithmetic_progression64(2 * n as u64, 1, 0, 1u64 << 15).unwrap();
1130        let p2 = largest_prime_in_arithmetic_progression64(2 * n as u64, 1, 0, p1 - 1).unwrap();
1131
1132        let p = p0 * p1 * p2;
1133        let plan = Plan::try_new(n, p, [p0, p1, p2]).unwrap();
1134
1135        let standard = &*(0..n)
1136            .map(|_| rand::random::<u64>() % p)
1137            .collect::<Box<[_]>>();
1138        let ntt = &mut *alloc::vec![0u64; plan.ntt_domain_len()];
1139        let roundtrip = &mut *alloc::vec![0u64; n];
1140
1141        let p_div = Div64::new(p);
1142        let mul = |a, b| mul_mod_u64(p_div, a, b);
1143
1144        let n_inv_mod_p = modular_inv_u64(p_div, n as u64);
1145        plan.fwd(ntt, standard, FwdMode::Generic);
1146        plan.inv(roundtrip, ntt, InvMode::Replace);
1147        for x in roundtrip.iter_mut() {
1148            *x = mul(*x, n_inv_mod_p);
1149        }
1150
1151        assert_eq!(roundtrip, standard);
1152    }
1153
1154    #[test]
1155    fn test_plan_failure_zero() {
1156        let n = 256;
1157        let p0 = largest_prime_in_arithmetic_progression64(2 * n as u64, 1, 0, 1u64 << 33).unwrap();
1158        assert!(Plan::try_new(n, 0, [p0, 0]).is_none());
1159    }
1160
1161    #[test]
1162    fn test_plan_failure_dup() {
1163        let n = 256;
1164        let p0 = largest_prime_in_arithmetic_progression64(2 * n as u64, 1, 0, 1u64 << 33).unwrap();
1165        let p1 = largest_prime_in_arithmetic_progression64(2 * n as u64, 1, 0, 1u64 << 15).unwrap();
1166        assert!(Plan::try_new(n, p0 * p1 * p1, [p1, p0, p1]).is_none());
1167    }
1168}