tfhe_ntt/
native_binary128.rs

1pub(crate) use crate::native64::{mul_mod32, mul_mod64};
2use aligned_vec::avec;
3
4pub struct Plan32(
5    crate::prime32::Plan,
6    crate::prime32::Plan,
7    crate::prime32::Plan,
8    crate::prime32::Plan,
9    crate::prime32::Plan,
10);
11
12#[inline(always)]
13fn reconstruct_32bit_01234_v2(
14    mod_p0: u32,
15    mod_p1: u32,
16    mod_p2: u32,
17    mod_p3: u32,
18    mod_p4: u32,
19) -> u128 {
20    use crate::primes32::*;
21
22    let mod_p12 = {
23        let v1 = mod_p1;
24        let v2 = mul_mod32(P2, P1_INV_MOD_P2, 2 * P2 + mod_p2 - v1);
25        v1 as u64 + (v2 as u64 * P1 as u64)
26    };
27    let mod_p34 = {
28        let v3 = mod_p3;
29        let v4 = mul_mod32(P4, P3_INV_MOD_P4, 2 * P4 + mod_p4 - v3);
30        v3 as u64 + (v4 as u64 * P3 as u64)
31    };
32
33    let v0 = mod_p0 as u64;
34    let v12 = mul_mod64(
35        P12.wrapping_neg(),
36        2 * P12 + mod_p12 - v0,
37        P0_INV_MOD_P12,
38        P0_INV_MOD_P12_SHOUP,
39    );
40    let v34 = mul_mod64(
41        P34.wrapping_neg(),
42        2 * P34 + mod_p34 - (v0 + mul_mod64(P34.wrapping_neg(), v12, P0 as u64, P0_MOD_P34_SHOUP)),
43        P012_INV_MOD_P34,
44        P012_INV_MOD_P34_SHOUP,
45    );
46
47    let sign = v34 > (P34 / 2);
48
49    const _0: u128 = P0 as u128;
50    const _012: u128 = _0.wrapping_mul(P12 as u128);
51    const _01234: u128 = _012.wrapping_mul(P34 as u128);
52
53    let pos = (v0 as u128)
54        .wrapping_add((v12 as u128).wrapping_mul(_0))
55        .wrapping_add((v34 as u128).wrapping_mul(_012));
56    let neg = pos.wrapping_sub(_01234);
57
58    if sign {
59        neg
60    } else {
61        pos
62    }
63}
64
65impl Plan32 {
66    /// Returns a negacyclic NTT plan for the given polynomial size, or `None` if no
67    /// suitable roots of unity can be found for the wanted parameters.
68    pub fn try_new(n: usize) -> Option<Self> {
69        use crate::{prime32::Plan, primes32::*};
70        Some(Self(
71            Plan::try_new(n, P0)?,
72            Plan::try_new(n, P1)?,
73            Plan::try_new(n, P2)?,
74            Plan::try_new(n, P3)?,
75            Plan::try_new(n, P4)?,
76        ))
77    }
78
79    /// Returns the polynomial size of the negacyclic NTT plan.
80    #[inline]
81    pub fn ntt_size(&self) -> usize {
82        self.0.ntt_size()
83    }
84
85    pub fn fwd(
86        &self,
87        value: &[u128],
88        mod_p0: &mut [u32],
89        mod_p1: &mut [u32],
90        mod_p2: &mut [u32],
91        mod_p3: &mut [u32],
92        mod_p4: &mut [u32],
93    ) {
94        for (value, mod_p0, mod_p1, mod_p2, mod_p3, mod_p4) in crate::izip!(
95            value,
96            &mut *mod_p0,
97            &mut *mod_p1,
98            &mut *mod_p2,
99            &mut *mod_p3,
100            &mut *mod_p4,
101        ) {
102            *mod_p0 = (value % crate::primes32::P0 as u128) as u32;
103            *mod_p1 = (value % crate::primes32::P1 as u128) as u32;
104            *mod_p2 = (value % crate::primes32::P2 as u128) as u32;
105            *mod_p3 = (value % crate::primes32::P3 as u128) as u32;
106            *mod_p4 = (value % crate::primes32::P4 as u128) as u32;
107        }
108        self.0.fwd(mod_p0);
109        self.1.fwd(mod_p1);
110        self.2.fwd(mod_p2);
111        self.3.fwd(mod_p3);
112        self.4.fwd(mod_p4);
113    }
114
115    pub fn fwd_binary(
116        &self,
117        value: &[u128],
118        mod_p0: &mut [u32],
119        mod_p1: &mut [u32],
120        mod_p2: &mut [u32],
121        mod_p3: &mut [u32],
122        mod_p4: &mut [u32],
123    ) {
124        for (value, mod_p0, mod_p1, mod_p2, mod_p3, mod_p4) in crate::izip!(
125            value,
126            &mut *mod_p0,
127            &mut *mod_p1,
128            &mut *mod_p2,
129            &mut *mod_p3,
130            &mut *mod_p4,
131        ) {
132            *mod_p0 = *value as u32;
133            *mod_p1 = *value as u32;
134            *mod_p2 = *value as u32;
135            *mod_p3 = *value as u32;
136            *mod_p4 = *value as u32;
137        }
138        self.0.fwd(mod_p0);
139        self.1.fwd(mod_p1);
140        self.2.fwd(mod_p2);
141        self.3.fwd(mod_p3);
142        self.4.fwd(mod_p4);
143    }
144
145    pub fn inv(
146        &self,
147        value: &mut [u128],
148        mod_p0: &mut [u32],
149        mod_p1: &mut [u32],
150        mod_p2: &mut [u32],
151        mod_p3: &mut [u32],
152        mod_p4: &mut [u32],
153    ) {
154        self.0.inv(mod_p0);
155        self.1.inv(mod_p1);
156        self.2.inv(mod_p2);
157        self.3.inv(mod_p3);
158        self.4.inv(mod_p4);
159
160        for (value, &mod_p0, &mod_p1, &mod_p2, &mod_p3, &mod_p4) in
161            crate::izip!(value, &*mod_p0, &*mod_p1, &*mod_p2, &*mod_p3, &*mod_p4)
162        {
163            *value = reconstruct_32bit_01234_v2(mod_p0, mod_p1, mod_p2, mod_p3, mod_p4);
164        }
165    }
166
167    /// Computes the negacyclic polynomial product of `lhs` and `rhs`, and stores the result in
168    /// `prod`.
169    pub fn negacyclic_polymul(&self, prod: &mut [u128], lhs: &[u128], rhs: &[u128]) {
170        let n = prod.len();
171        assert_eq!(n, lhs.len());
172        assert_eq!(n, rhs.len());
173
174        let mut lhs0 = avec![0; n];
175        let mut lhs1 = avec![0; n];
176        let mut lhs2 = avec![0; n];
177        let mut lhs3 = avec![0; n];
178        let mut lhs4 = avec![0; n];
179
180        let mut rhs0 = avec![0; n];
181        let mut rhs1 = avec![0; n];
182        let mut rhs2 = avec![0; n];
183        let mut rhs3 = avec![0; n];
184        let mut rhs4 = avec![0; n];
185
186        self.fwd(lhs, &mut lhs0, &mut lhs1, &mut lhs2, &mut lhs3, &mut lhs4);
187        self.fwd_binary(rhs, &mut rhs0, &mut rhs1, &mut rhs2, &mut rhs3, &mut rhs4);
188
189        self.0.mul_assign_normalize(&mut lhs0, &rhs0);
190        self.1.mul_assign_normalize(&mut lhs1, &rhs1);
191        self.2.mul_assign_normalize(&mut lhs2, &rhs2);
192        self.3.mul_assign_normalize(&mut lhs3, &rhs3);
193        self.4.mul_assign_normalize(&mut lhs4, &rhs4);
194
195        self.inv(prod, &mut lhs0, &mut lhs1, &mut lhs2, &mut lhs3, &mut lhs4);
196    }
197}
198
199#[cfg(test)]
200mod tests {
201    use super::*;
202    use crate::native128::tests::negacyclic_convolution;
203    use alloc::{vec, vec::Vec};
204    use rand::random;
205
206    extern crate alloc;
207
208    #[test]
209    fn reconstruct_32bit() {
210        for n in [32, 64, 256, 1024, 2048] {
211            let plan = Plan32::try_new(n).unwrap();
212
213            let lhs = (0..n).map(|_| random::<u128>()).collect::<Vec<_>>();
214            let rhs = (0..n).map(|_| random::<u128>() % 2).collect::<Vec<_>>();
215            let negacyclic_convolution = negacyclic_convolution(n, &lhs, &rhs);
216
217            let mut prod = vec![0; n];
218            plan.negacyclic_polymul(&mut prod, &lhs, &rhs);
219            assert_eq!(prod, negacyclic_convolution);
220        }
221    }
222}