tfhe_ntt/
native128.rs

1pub(crate) use crate::native64::{mul_mod32, mul_mod64};
2use aligned_vec::avec;
3
4/// Negacyclic NTT plan for multiplying two 128bit polynomials.
5#[derive(Clone, Debug)]
6pub struct Plan32(
7    crate::prime32::Plan,
8    crate::prime32::Plan,
9    crate::prime32::Plan,
10    crate::prime32::Plan,
11    crate::prime32::Plan,
12    crate::prime32::Plan,
13    crate::prime32::Plan,
14    crate::prime32::Plan,
15    crate::prime32::Plan,
16    crate::prime32::Plan,
17);
18
19#[inline(always)]
20fn reconstruct_32bit_0123456789_v2(
21    mod_p0: u32,
22    mod_p1: u32,
23    mod_p2: u32,
24    mod_p3: u32,
25    mod_p4: u32,
26    mod_p5: u32,
27    mod_p6: u32,
28    mod_p7: u32,
29    mod_p8: u32,
30    mod_p9: u32,
31) -> u128 {
32    use crate::primes32::*;
33
34    let mod_p01 = {
35        let v0 = mod_p0;
36        let v1 = mul_mod32(P1, P0_INV_MOD_P1, 2 * P1 + mod_p1 - v0);
37        v0 as u64 + (v1 as u64 * P0 as u64)
38    };
39    let mod_p23 = {
40        let v2 = mod_p2;
41        let v3 = mul_mod32(P3, P2_INV_MOD_P3, 2 * P3 + mod_p3 - v2);
42        v2 as u64 + (v3 as u64 * P2 as u64)
43    };
44    let mod_p45 = {
45        let v4 = mod_p4;
46        let v5 = mul_mod32(P5, P4_INV_MOD_P5, 2 * P5 + mod_p5 - v4);
47        v4 as u64 + (v5 as u64 * P4 as u64)
48    };
49    let mod_p67 = {
50        let v6 = mod_p6;
51        let v7 = mul_mod32(P7, P6_INV_MOD_P7, 2 * P7 + mod_p7 - v6);
52        v6 as u64 + (v7 as u64 * P6 as u64)
53    };
54    let mod_p89 = {
55        let v8 = mod_p8;
56        let v9 = mul_mod32(P9, P8_INV_MOD_P9, 2 * P9 + mod_p9 - v8);
57        v8 as u64 + (v9 as u64 * P8 as u64)
58    };
59
60    let v01 = mod_p01;
61    let v23 = mul_mod64(
62        P23.wrapping_neg(),
63        2 * P23 + mod_p23 - v01,
64        P01_INV_MOD_P23,
65        P01_INV_MOD_P23_SHOUP,
66    );
67    let v45 = mul_mod64(
68        P45.wrapping_neg(),
69        2 * P45 + mod_p45 - (v01 + mul_mod64(P45.wrapping_neg(), v23, P01, P01_MOD_P45_SHOUP)),
70        P0123_INV_MOD_P45,
71        P0123_INV_MOD_P45_SHOUP,
72    );
73    let v67 = mul_mod64(
74        P67.wrapping_neg(),
75        2 * P67 + mod_p67
76            - (v01
77                + mul_mod64(
78                    P67.wrapping_neg(),
79                    v23 + mul_mod64(P67.wrapping_neg(), v45, P23, P23_MOD_P67_SHOUP),
80                    P01,
81                    P01_MOD_P67_SHOUP,
82                )),
83        P012345_INV_MOD_P67,
84        P012345_INV_MOD_P67_SHOUP,
85    );
86    let v89 = mul_mod64(
87        P89.wrapping_neg(),
88        2 * P89 + mod_p89
89            - (v01
90                + mul_mod64(
91                    P89.wrapping_neg(),
92                    v23 + mul_mod64(
93                        P89.wrapping_neg(),
94                        v45 + mul_mod64(P89.wrapping_neg(), v67, P45, P45_MOD_P89_SHOUP),
95                        P23,
96                        P23_MOD_P89_SHOUP,
97                    ),
98                    P01,
99                    P01_MOD_P89_SHOUP,
100                )),
101        P01234567_INV_MOD_P89,
102        P01234567_INV_MOD_P89_SHOUP,
103    );
104
105    let sign = v89 > (P89 / 2);
106    let pos = (v01 as u128)
107        .wrapping_add(u128::wrapping_mul(v23 as u128, P01 as u128))
108        .wrapping_add(u128::wrapping_mul(v45 as u128, P0123))
109        .wrapping_add(u128::wrapping_mul(v67 as u128, P012345))
110        .wrapping_add(u128::wrapping_mul(v89 as u128, P01234567));
111    let neg = pos.wrapping_sub(P0123456789);
112
113    if sign {
114        neg
115    } else {
116        pos
117    }
118}
119
120impl Plan32 {
121    /// Returns a negacyclic NTT plan for the given polynomial size, or `None` if no
122    /// suitable roots of unity can be found for the wanted parameters.
123    pub fn try_new(n: usize) -> Option<Self> {
124        use crate::{prime32::Plan, primes32::*};
125        Some(Self(
126            Plan::try_new(n, P0)?,
127            Plan::try_new(n, P1)?,
128            Plan::try_new(n, P2)?,
129            Plan::try_new(n, P3)?,
130            Plan::try_new(n, P4)?,
131            Plan::try_new(n, P5)?,
132            Plan::try_new(n, P6)?,
133            Plan::try_new(n, P7)?,
134            Plan::try_new(n, P8)?,
135            Plan::try_new(n, P9)?,
136        ))
137    }
138
139    /// Returns the polynomial size of the negacyclic NTT plan.
140    #[inline]
141    pub fn ntt_size(&self) -> usize {
142        self.0.ntt_size()
143    }
144
145    #[inline]
146    pub fn ntt_0(&self) -> &crate::prime32::Plan {
147        &self.0
148    }
149    #[inline]
150    pub fn ntt_1(&self) -> &crate::prime32::Plan {
151        &self.1
152    }
153    #[inline]
154    pub fn ntt_2(&self) -> &crate::prime32::Plan {
155        &self.2
156    }
157    #[inline]
158    pub fn ntt_3(&self) -> &crate::prime32::Plan {
159        &self.3
160    }
161    #[inline]
162    pub fn ntt_4(&self) -> &crate::prime32::Plan {
163        &self.4
164    }
165    #[inline]
166    pub fn ntt_5(&self) -> &crate::prime32::Plan {
167        &self.5
168    }
169    #[inline]
170    pub fn ntt_6(&self) -> &crate::prime32::Plan {
171        &self.6
172    }
173    #[inline]
174    pub fn ntt_7(&self) -> &crate::prime32::Plan {
175        &self.7
176    }
177    #[inline]
178    pub fn ntt_8(&self) -> &crate::prime32::Plan {
179        &self.8
180    }
181    #[inline]
182    pub fn ntt_9(&self) -> &crate::prime32::Plan {
183        &self.9
184    }
185
186    pub fn fwd(
187        &self,
188        value: &[u128],
189        mod_p0: &mut [u32],
190        mod_p1: &mut [u32],
191        mod_p2: &mut [u32],
192        mod_p3: &mut [u32],
193        mod_p4: &mut [u32],
194        mod_p5: &mut [u32],
195        mod_p6: &mut [u32],
196        mod_p7: &mut [u32],
197        mod_p8: &mut [u32],
198        mod_p9: &mut [u32],
199    ) {
200        for (
201            value,
202            mod_p0,
203            mod_p1,
204            mod_p2,
205            mod_p3,
206            mod_p4,
207            mod_p5,
208            mod_p6,
209            mod_p7,
210            mod_p8,
211            mod_p9,
212        ) in crate::izip!(
213            value,
214            &mut *mod_p0,
215            &mut *mod_p1,
216            &mut *mod_p2,
217            &mut *mod_p3,
218            &mut *mod_p4,
219            &mut *mod_p5,
220            &mut *mod_p6,
221            &mut *mod_p7,
222            &mut *mod_p8,
223            &mut *mod_p9,
224        ) {
225            *mod_p0 = (value % crate::primes32::P0 as u128) as u32;
226            *mod_p1 = (value % crate::primes32::P1 as u128) as u32;
227            *mod_p2 = (value % crate::primes32::P2 as u128) as u32;
228            *mod_p3 = (value % crate::primes32::P3 as u128) as u32;
229            *mod_p4 = (value % crate::primes32::P4 as u128) as u32;
230            *mod_p5 = (value % crate::primes32::P5 as u128) as u32;
231            *mod_p6 = (value % crate::primes32::P6 as u128) as u32;
232            *mod_p7 = (value % crate::primes32::P7 as u128) as u32;
233            *mod_p8 = (value % crate::primes32::P8 as u128) as u32;
234            *mod_p9 = (value % crate::primes32::P9 as u128) as u32;
235        }
236        self.0.fwd(mod_p0);
237        self.1.fwd(mod_p1);
238        self.2.fwd(mod_p2);
239        self.3.fwd(mod_p3);
240        self.4.fwd(mod_p4);
241        self.5.fwd(mod_p5);
242        self.6.fwd(mod_p6);
243        self.7.fwd(mod_p7);
244        self.8.fwd(mod_p8);
245        self.9.fwd(mod_p9);
246    }
247
248    pub fn inv(
249        &self,
250        value: &mut [u128],
251        mod_p0: &mut [u32],
252        mod_p1: &mut [u32],
253        mod_p2: &mut [u32],
254        mod_p3: &mut [u32],
255        mod_p4: &mut [u32],
256        mod_p5: &mut [u32],
257        mod_p6: &mut [u32],
258        mod_p7: &mut [u32],
259        mod_p8: &mut [u32],
260        mod_p9: &mut [u32],
261    ) {
262        self.0.inv(mod_p0);
263        self.1.inv(mod_p1);
264        self.2.inv(mod_p2);
265        self.3.inv(mod_p3);
266        self.4.inv(mod_p4);
267        self.5.inv(mod_p5);
268        self.6.inv(mod_p6);
269        self.7.inv(mod_p7);
270        self.8.inv(mod_p8);
271        self.9.inv(mod_p9);
272
273        for (
274            value,
275            &mod_p0,
276            &mod_p1,
277            &mod_p2,
278            &mod_p3,
279            &mod_p4,
280            &mod_p5,
281            &mod_p6,
282            &mod_p7,
283            &mod_p8,
284            &mod_p9,
285        ) in crate::izip!(
286            value, &*mod_p0, &*mod_p1, &*mod_p2, &*mod_p3, &*mod_p4, &*mod_p5, &*mod_p6, &*mod_p7,
287            &*mod_p8, &*mod_p9,
288        ) {
289            *value = reconstruct_32bit_0123456789_v2(
290                mod_p0, mod_p1, mod_p2, mod_p3, mod_p4, mod_p5, mod_p6, mod_p7, mod_p8, mod_p9,
291            );
292        }
293    }
294
295    /// Computes the negacyclic polynomial product of `lhs` and `rhs`, and stores the result in
296    /// `prod`.
297    pub fn negacyclic_polymul(&self, prod: &mut [u128], lhs: &[u128], rhs: &[u128]) {
298        let n = prod.len();
299        assert_eq!(n, lhs.len());
300        assert_eq!(n, rhs.len());
301
302        let mut lhs0 = avec![0; n];
303        let mut lhs1 = avec![0; n];
304        let mut lhs2 = avec![0; n];
305        let mut lhs3 = avec![0; n];
306        let mut lhs4 = avec![0; n];
307        let mut lhs5 = avec![0; n];
308        let mut lhs6 = avec![0; n];
309        let mut lhs7 = avec![0; n];
310        let mut lhs8 = avec![0; n];
311        let mut lhs9 = avec![0; n];
312
313        let mut rhs0 = avec![0; n];
314        let mut rhs1 = avec![0; n];
315        let mut rhs2 = avec![0; n];
316        let mut rhs3 = avec![0; n];
317        let mut rhs4 = avec![0; n];
318        let mut rhs5 = avec![0; n];
319        let mut rhs6 = avec![0; n];
320        let mut rhs7 = avec![0; n];
321        let mut rhs8 = avec![0; n];
322        let mut rhs9 = avec![0; n];
323
324        self.fwd(
325            lhs, &mut lhs0, &mut lhs1, &mut lhs2, &mut lhs3, &mut lhs4, &mut lhs5, &mut lhs6,
326            &mut lhs7, &mut lhs8, &mut lhs9,
327        );
328        self.fwd(
329            rhs, &mut rhs0, &mut rhs1, &mut rhs2, &mut rhs3, &mut rhs4, &mut rhs5, &mut rhs6,
330            &mut rhs7, &mut rhs8, &mut rhs9,
331        );
332
333        self.0.mul_assign_normalize(&mut lhs0, &rhs0);
334        self.1.mul_assign_normalize(&mut lhs1, &rhs1);
335        self.2.mul_assign_normalize(&mut lhs2, &rhs2);
336        self.3.mul_assign_normalize(&mut lhs3, &rhs3);
337        self.4.mul_assign_normalize(&mut lhs4, &rhs4);
338        self.5.mul_assign_normalize(&mut lhs5, &rhs5);
339        self.6.mul_assign_normalize(&mut lhs6, &rhs6);
340        self.7.mul_assign_normalize(&mut lhs7, &rhs7);
341        self.8.mul_assign_normalize(&mut lhs8, &rhs8);
342        self.9.mul_assign_normalize(&mut lhs9, &rhs9);
343
344        self.inv(
345            prod, &mut lhs0, &mut lhs1, &mut lhs2, &mut lhs3, &mut lhs4, &mut lhs5, &mut lhs6,
346            &mut lhs7, &mut lhs8, &mut lhs9,
347        );
348    }
349}
350
351#[cfg(test)]
352pub mod tests {
353    use super::*;
354    use alloc::{vec, vec::Vec};
355    use rand::random;
356
357    extern crate alloc;
358
359    pub fn negacyclic_convolution(n: usize, lhs: &[u128], rhs: &[u128]) -> Vec<u128> {
360        let mut full_convolution = vec![0u128; 2 * n];
361        let mut negacyclic_convolution = vec![0u128; n];
362        for i in 0..n {
363            for j in 0..n {
364                full_convolution[i + j] =
365                    full_convolution[i + j].wrapping_add(lhs[i].wrapping_mul(rhs[j]));
366            }
367        }
368        for i in 0..n {
369            negacyclic_convolution[i] = full_convolution[i].wrapping_sub(full_convolution[i + n]);
370        }
371        negacyclic_convolution
372    }
373
374    pub fn random_lhs_rhs_with_negacyclic_convolution(
375        n: usize,
376    ) -> (Vec<u128>, Vec<u128>, Vec<u128>) {
377        let mut lhs = vec![0u128; n];
378        let mut rhs = vec![0u128; n];
379
380        for x in &mut lhs {
381            *x = random();
382        }
383        for x in &mut rhs {
384            *x = random();
385        }
386
387        let lhs = lhs;
388        let rhs = rhs;
389
390        let negacyclic_convolution = negacyclic_convolution(n, &lhs, &rhs);
391        (lhs, rhs, negacyclic_convolution)
392    }
393
394    #[test]
395    fn reconstruct_32bit() {
396        for n in [32, 64, 256, 1024, 2048] {
397            let value = (0..n).map(|_| random::<u128>()).collect::<Vec<_>>();
398            let mut value_roundtrip = vec![0; n];
399            let mut mod_p0 = vec![0; n];
400            let mut mod_p1 = vec![0; n];
401            let mut mod_p2 = vec![0; n];
402            let mut mod_p3 = vec![0; n];
403            let mut mod_p4 = vec![0; n];
404            let mut mod_p5 = vec![0; n];
405            let mut mod_p6 = vec![0; n];
406            let mut mod_p7 = vec![0; n];
407            let mut mod_p8 = vec![0; n];
408            let mut mod_p9 = vec![0; n];
409
410            let plan = Plan32::try_new(n).unwrap();
411            plan.fwd(
412                &value,
413                &mut mod_p0,
414                &mut mod_p1,
415                &mut mod_p2,
416                &mut mod_p3,
417                &mut mod_p4,
418                &mut mod_p5,
419                &mut mod_p6,
420                &mut mod_p7,
421                &mut mod_p8,
422                &mut mod_p9,
423            );
424            plan.inv(
425                &mut value_roundtrip,
426                &mut mod_p0,
427                &mut mod_p1,
428                &mut mod_p2,
429                &mut mod_p3,
430                &mut mod_p4,
431                &mut mod_p5,
432                &mut mod_p6,
433                &mut mod_p7,
434                &mut mod_p8,
435                &mut mod_p9,
436            );
437            for (&value, &value_roundtrip) in crate::izip!(&value, &value_roundtrip) {
438                assert_eq!(value_roundtrip, value.wrapping_mul(n as u128));
439            }
440
441            let (lhs, rhs, negacyclic_convolution) = random_lhs_rhs_with_negacyclic_convolution(n);
442
443            let mut prod = vec![0; n];
444            plan.negacyclic_polymul(&mut prod, &lhs, &rhs);
445            assert_eq!(prod, negacyclic_convolution);
446        }
447    }
448}