tfhe_ntt/
native_binary32.rs

1use aligned_vec::avec;
2
3#[allow(unused_imports)]
4use pulp::*;
5
6use crate::native32::mul_mod32;
7
8/// Negacyclic NTT plan for multiplying two 32bit polynomials, where the RHS contains binary
9/// coefficients.
10#[derive(Clone, Debug)]
11pub struct Plan32(crate::prime32::Plan, crate::prime32::Plan);
12
13/// Negacyclic NTT plan for multiplying two 32bit polynomials, where the RHS contains binary
14/// coefficients.  
15/// This can be more efficient than [`Plan32`], but requires the AVX512 instruction set.
16#[cfg(all(feature = "nightly", any(target_arch = "x86", target_arch = "x86_64")))]
17#[cfg_attr(docsrs, doc(cfg(feature = "nightly")))]
18#[derive(Clone, Debug)]
19pub struct Plan52(crate::prime64::Plan, crate::V4IFma);
20
21#[inline(always)]
22pub(crate) fn reconstruct_32bit_01(mod_p0: u32, mod_p1: u32) -> u32 {
23    use crate::primes32::*;
24
25    let v0 = mod_p0;
26    let v1 = mul_mod32(P1, P0_INV_MOD_P1, 2 * P1 + mod_p1 - v0);
27
28    let sign = v1 > (P1 / 2);
29
30    const _0: u32 = P0;
31    const _01: u32 = _0.wrapping_mul(P1);
32
33    let pos = v0.wrapping_add(v1.wrapping_mul(_0));
34    let neg = pos.wrapping_sub(_01);
35
36    if sign {
37        neg
38    } else {
39        pos
40    }
41}
42
43#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
44#[inline(always)]
45pub(crate) fn reconstruct_32bit_01_avx2(simd: crate::V3, mod_p0: u32x8, mod_p1: u32x8) -> u32x8 {
46    use crate::{native32::mul_mod32_avx2, primes32::*};
47
48    let p0 = simd.splat_u32x8(P0);
49    let p1 = simd.splat_u32x8(P1);
50    let two_p1 = simd.splat_u32x8(2 * P1);
51    let half_p1 = simd.splat_u32x8(P1 / 2);
52
53    let p0_inv_mod_p1 = simd.splat_u32x8(P0_INV_MOD_P1);
54    let p0_inv_mod_p1_shoup = simd.splat_u32x8(P0_INV_MOD_P1_SHOUP);
55
56    let p01 = simd.splat_u32x8(P0.wrapping_mul(P1));
57
58    let v0 = mod_p0;
59    let v1 = mul_mod32_avx2(
60        simd,
61        p1,
62        simd.wrapping_sub_u32x8(simd.wrapping_add_u32x8(two_p1, mod_p1), v0),
63        p0_inv_mod_p1,
64        p0_inv_mod_p1_shoup,
65    );
66
67    let sign = simd.cmp_gt_u32x8(v1, half_p1);
68    let pos = simd.wrapping_add_u32x8(v0, simd.wrapping_mul_u32x8(v1, p0));
69
70    let neg = simd.wrapping_sub_u32x8(pos, p01);
71
72    simd.select_u32x8(sign, neg, pos)
73}
74
75#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
76#[cfg(feature = "nightly")]
77#[inline(always)]
78fn reconstruct_32bit_01_avx512(simd: crate::V4IFma, mod_p0: u32x16, mod_p1: u32x16) -> u32x16 {
79    use crate::{native32::mul_mod32_avx512, primes32::*};
80
81    let p0 = simd.splat_u32x16(P0);
82    let p1 = simd.splat_u32x16(P1);
83    let two_p1 = simd.splat_u32x16(2 * P1);
84    let half_p1 = simd.splat_u32x16(P1 / 2);
85
86    let p0_inv_mod_p1 = simd.splat_u32x16(P0_INV_MOD_P1);
87    let p0_inv_mod_p1_shoup = simd.splat_u32x16(P0_INV_MOD_P1_SHOUP);
88
89    let p01 = simd.splat_u32x16(P0.wrapping_mul(P1));
90
91    let v0 = mod_p0;
92    let v1 = mul_mod32_avx512(
93        simd,
94        p1,
95        simd.wrapping_sub_u32x16(simd.wrapping_add_u32x16(two_p1, mod_p1), v0),
96        p0_inv_mod_p1,
97        p0_inv_mod_p1_shoup,
98    );
99
100    let sign = simd.cmp_gt_u32x16(v1, half_p1);
101    let pos = simd.wrapping_add_u32x16(v0, simd.wrapping_mul_u32x16(v1, p0));
102
103    let neg = simd.wrapping_sub_u32x16(pos, p01);
104
105    simd.select_u32x16(sign, neg, pos)
106}
107
108#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
109#[cfg(feature = "nightly")]
110#[inline(always)]
111fn reconstruct_52bit_0_avx512(simd: crate::V4IFma, mod_p0: u64x8) -> u32x8 {
112    use crate::primes52::*;
113
114    let p0 = simd.splat_u64x8(P0);
115    let half_p0 = simd.splat_u64x8(P0 / 2);
116
117    let v0 = mod_p0;
118
119    let sign = simd.cmp_gt_u64x8(v0, half_p0);
120
121    let pos = v0;
122    let neg = simd.wrapping_sub_u64x8(pos, p0);
123
124    simd.convert_u64x8_to_u32x8(simd.select_u64x8(sign, neg, pos))
125}
126
127#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
128fn reconstruct_slice_32bit_01_avx2(
129    simd: crate::V3,
130    value: &mut [u32],
131    mod_p0: &[u32],
132    mod_p1: &[u32],
133) {
134    simd.vectorize(
135        #[inline(always)]
136        move || {
137            let value = pulp::as_arrays_mut::<8, _>(value).0;
138            let mod_p0 = pulp::as_arrays::<8, _>(mod_p0).0;
139            let mod_p1 = pulp::as_arrays::<8, _>(mod_p1).0;
140            for (value, &mod_p0, &mod_p1) in crate::izip!(value, mod_p0, mod_p1) {
141                *value = cast(reconstruct_32bit_01_avx2(simd, cast(mod_p0), cast(mod_p1)));
142            }
143        },
144    );
145}
146
147#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
148#[cfg(feature = "nightly")]
149fn reconstruct_slice_32bit_01_avx512(
150    simd: crate::V4IFma,
151    value: &mut [u32],
152    mod_p0: &[u32],
153    mod_p1: &[u32],
154) {
155    simd.vectorize(
156        #[inline(always)]
157        move || {
158            let value = pulp::as_arrays_mut::<16, _>(value).0;
159            let mod_p0 = pulp::as_arrays::<16, _>(mod_p0).0;
160            let mod_p1 = pulp::as_arrays::<16, _>(mod_p1).0;
161            for (value, &mod_p0, &mod_p1) in crate::izip!(value, mod_p0, mod_p1) {
162                *value = cast(reconstruct_32bit_01_avx512(
163                    simd,
164                    cast(mod_p0),
165                    cast(mod_p1),
166                ));
167            }
168        },
169    );
170}
171
172#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
173#[cfg(feature = "nightly")]
174fn reconstruct_slice_52bit_0_avx512(simd: crate::V4IFma, value: &mut [u32], mod_p0: &[u64]) {
175    simd.vectorize(
176        #[inline(always)]
177        move || {
178            let value = pulp::as_arrays_mut::<8, _>(value).0;
179            let mod_p0 = pulp::as_arrays::<8, _>(mod_p0).0;
180            for (value, &mod_p0) in crate::izip!(value, mod_p0) {
181                *value = cast(reconstruct_52bit_0_avx512(simd, cast(mod_p0)));
182            }
183        },
184    );
185}
186
187impl Plan32 {
188    /// Returns a negacyclic NTT plan for the given polynomial size, or `None` if no
189    /// suitable roots of unity can be found for the wanted parameters.
190    pub fn try_new(n: usize) -> Option<Self> {
191        use crate::{prime32::Plan, primes32::*};
192        Some(Self(Plan::try_new(n, P0)?, Plan::try_new(n, P1)?))
193    }
194
195    /// Returns the polynomial size of the negacyclic NTT plan.
196    #[inline]
197    pub fn ntt_size(&self) -> usize {
198        self.0.ntt_size()
199    }
200
201    pub fn fwd(&self, value: &[u32], mod_p0: &mut [u32], mod_p1: &mut [u32]) {
202        for (value, mod_p0, mod_p1) in crate::izip!(value, &mut *mod_p0, &mut *mod_p1) {
203            *mod_p0 = value % crate::primes32::P0;
204            *mod_p1 = value % crate::primes32::P1;
205        }
206        self.0.fwd(mod_p0);
207        self.1.fwd(mod_p1);
208    }
209
210    pub fn fwd_binary(&self, value: &[u32], mod_p0: &mut [u32], mod_p1: &mut [u32]) {
211        for (value, mod_p0, mod_p1) in crate::izip!(value, &mut *mod_p0, &mut *mod_p1) {
212            *mod_p0 = *value;
213            *mod_p1 = *value;
214        }
215        self.0.fwd(mod_p0);
216        self.1.fwd(mod_p1);
217    }
218
219    pub fn inv(&self, value: &mut [u32], mod_p0: &mut [u32], mod_p1: &mut [u32]) {
220        self.0.inv(mod_p0);
221        self.1.inv(mod_p1);
222
223        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
224        {
225            #[cfg(feature = "nightly")]
226            if let Some(simd) = crate::V4IFma::try_new() {
227                reconstruct_slice_32bit_01_avx512(simd, value, mod_p0, mod_p1);
228                return;
229            }
230            if let Some(simd) = crate::V3::try_new() {
231                reconstruct_slice_32bit_01_avx2(simd, value, mod_p0, mod_p1);
232                return;
233            }
234        }
235
236        for (value, &mod_p0, &mod_p1) in crate::izip!(value, &*mod_p0, &*mod_p1) {
237            *value = reconstruct_32bit_01(mod_p0, mod_p1);
238        }
239    }
240
241    /// Computes the negacyclic polynomial product of `lhs` and `rhs`, and stores the result in
242    /// `prod`.
243    pub fn negacyclic_polymul(&self, prod: &mut [u32], lhs: &[u32], rhs_binary: &[u32]) {
244        let n = prod.len();
245        assert_eq!(n, lhs.len());
246        assert_eq!(n, rhs_binary.len());
247
248        let mut lhs0 = avec![0; n];
249        let mut lhs1 = avec![0; n];
250
251        let mut rhs0 = avec![0; n];
252        let mut rhs1 = avec![0; n];
253
254        self.fwd(lhs, &mut lhs0, &mut lhs1);
255        self.fwd_binary(rhs_binary, &mut rhs0, &mut rhs1);
256
257        self.0.mul_assign_normalize(&mut lhs0, &rhs0);
258        self.1.mul_assign_normalize(&mut lhs1, &rhs1);
259
260        self.inv(prod, &mut lhs0, &mut lhs1);
261    }
262}
263
264#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
265#[cfg(feature = "nightly")]
266impl Plan52 {
267    /// Returns a negacyclic NTT plan for the given polynomial size, or `None` if no
268    /// suitable roots of unity can be found for the wanted parameters, or if the AVX512
269    /// instruction set isn't detected.
270    pub fn try_new(n: usize) -> Option<Self> {
271        use crate::{prime64::Plan, primes52::*};
272        let simd = crate::V4IFma::try_new()?;
273        Some(Self(Plan::try_new(n, P0)?, simd))
274    }
275
276    /// Returns the polynomial size of the negacyclic NTT plan.
277    #[inline]
278    pub fn ntt_size(&self) -> usize {
279        self.0.ntt_size()
280    }
281
282    pub fn fwd(&self, value: &[u32], mod_p0: &mut [u64]) {
283        self.1.vectorize(
284            #[inline(always)]
285            || {
286                for (value, mod_p0) in crate::izip!(value, &mut *mod_p0) {
287                    *mod_p0 = *value as u64;
288                }
289            },
290        );
291        self.0.fwd(mod_p0);
292    }
293
294    pub fn fwd_binary(&self, value: &[u32], mod_p0: &mut [u64]) {
295        self.fwd(value, mod_p0);
296    }
297
298    pub fn inv(&self, value: &mut [u32], mod_p0: &mut [u64]) {
299        self.0.inv(mod_p0);
300
301        let simd = self.1;
302        reconstruct_slice_52bit_0_avx512(simd, value, mod_p0);
303    }
304
305    /// Computes the negacyclic polynomial product of `lhs` and `rhs`, and stores the result in
306    /// `prod`.
307    pub fn negacyclic_polymul(&self, prod: &mut [u32], lhs: &[u32], rhs_binary: &[u32]) {
308        let n = prod.len();
309        assert_eq!(n, lhs.len());
310        assert_eq!(n, rhs_binary.len());
311
312        let mut lhs0 = avec![0; n];
313        let mut rhs0 = avec![0; n];
314
315        self.fwd(lhs, &mut lhs0);
316        self.fwd_binary(rhs_binary, &mut rhs0);
317
318        self.0.mul_assign_normalize(&mut lhs0, &rhs0);
319
320        self.inv(prod, &mut lhs0);
321    }
322}
323
324#[cfg(test)]
325mod tests {
326    use super::*;
327    use crate::prime32::tests::negacyclic_convolution;
328    use alloc::{vec, vec::Vec};
329    use rand::random;
330
331    extern crate alloc;
332
333    #[test]
334    fn reconstruct_32bit() {
335        for n in [32, 64, 256, 1024, 2048] {
336            let plan = Plan32::try_new(n).unwrap();
337
338            let lhs = (0..n).map(|_| random::<u32>()).collect::<Vec<_>>();
339            let rhs = (0..n).map(|_| random::<u32>() % 2).collect::<Vec<_>>();
340            let negacyclic_convolution = negacyclic_convolution(n, 0, &lhs, &rhs);
341
342            let mut prod = vec![0; n];
343            plan.negacyclic_polymul(&mut prod, &lhs, &rhs);
344            assert_eq!(prod, negacyclic_convolution);
345        }
346    }
347
348    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
349    #[cfg(feature = "nightly")]
350    #[test]
351    fn reconstruct_52bit() {
352        for n in [32, 64, 256, 1024, 2048] {
353            if let Some(plan) = Plan52::try_new(n) {
354                let lhs = (0..n).map(|_| random::<u32>()).collect::<Vec<_>>();
355                let rhs = (0..n).map(|_| random::<u32>() % 2).collect::<Vec<_>>();
356                let negacyclic_convolution = negacyclic_convolution(n, 0, &lhs, &rhs);
357
358                let mut prod = vec![0; n];
359                plan.negacyclic_polymul(&mut prod, &lhs, &rhs);
360                assert_eq!(prod, negacyclic_convolution);
361            }
362        }
363    }
364}