tfhe_ntt/
native64.rs

1use aligned_vec::avec;
2
3#[allow(unused_imports)]
4use pulp::*;
5
6pub(crate) use crate::native32::mul_mod32;
7
8#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
9pub(crate) use crate::native32::mul_mod32_avx2;
10#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
11#[cfg(feature = "nightly")]
12pub(crate) use crate::native32::{mul_mod32_avx512, mul_mod52_avx512};
13
14/// Negacyclic NTT plan for multiplying two 64bit polynomials.
15#[derive(Clone, Debug)]
16pub struct Plan32(
17    crate::prime32::Plan,
18    crate::prime32::Plan,
19    crate::prime32::Plan,
20    crate::prime32::Plan,
21    crate::prime32::Plan,
22);
23
24/// Negacyclic NTT plan for multiplying two 64bit polynomials.  
25/// This can be more efficient than [`Plan32`], but requires the AVX512 instruction set.
26#[cfg(all(feature = "nightly", any(target_arch = "x86", target_arch = "x86_64")))]
27#[cfg_attr(docsrs, doc(cfg(feature = "nightly")))]
28#[derive(Clone, Debug)]
29pub struct Plan52(
30    crate::prime64::Plan,
31    crate::prime64::Plan,
32    crate::prime64::Plan,
33    crate::V4IFma,
34);
35
36#[inline(always)]
37pub(crate) fn mul_mod64(p_neg: u64, a: u64, b: u64, b_shoup: u64) -> u64 {
38    let q = ((a as u128 * b_shoup as u128) >> 64) as u64;
39    let r = a.wrapping_mul(b).wrapping_add(p_neg.wrapping_mul(q));
40    r.min(r.wrapping_add(p_neg))
41}
42
43#[inline(always)]
44#[allow(dead_code)]
45fn reconstruct_32bit_01234(mod_p0: u32, mod_p1: u32, mod_p2: u32, mod_p3: u32, mod_p4: u32) -> u64 {
46    use crate::primes32::*;
47
48    let v0 = mod_p0;
49    let v1 = mul_mod32(P1, P0_INV_MOD_P1, 2 * P1 + mod_p1 - v0);
50    let v2 = mul_mod32(
51        P2,
52        P01_INV_MOD_P2,
53        2 * P2 + mod_p2 - (v0 + mul_mod32(P2, P0, v1)),
54    );
55    let v3 = mul_mod32(
56        P3,
57        P012_INV_MOD_P3,
58        2 * P3 + mod_p3 - (v0 + mul_mod32(P3, P0, v1 + mul_mod32(P3, P1, v2))),
59    );
60    let v4 = mul_mod32(
61        P4,
62        P0123_INV_MOD_P4,
63        2 * P4 + mod_p4
64            - (v0 + mul_mod32(P4, P0, v1 + mul_mod32(P4, P1, v2 + mul_mod32(P4, P2, v3)))),
65    );
66
67    let sign = v4 > (P4 / 2);
68
69    const _0: u64 = P0 as u64;
70    const _01: u64 = _0.wrapping_mul(P1 as u64);
71    const _012: u64 = _01.wrapping_mul(P2 as u64);
72    const _0123: u64 = _012.wrapping_mul(P3 as u64);
73    const _01234: u64 = _0123.wrapping_mul(P4 as u64);
74
75    let pos = (v0 as u64)
76        .wrapping_add((v1 as u64).wrapping_mul(_0))
77        .wrapping_add((v2 as u64).wrapping_mul(_01))
78        .wrapping_add((v3 as u64).wrapping_mul(_012))
79        .wrapping_add((v4 as u64).wrapping_mul(_0123));
80
81    let neg = pos.wrapping_sub(_01234);
82
83    if sign {
84        neg
85    } else {
86        pos
87    }
88}
89
90#[inline(always)]
91fn reconstruct_32bit_01234_v2(
92    mod_p0: u32,
93    mod_p1: u32,
94    mod_p2: u32,
95    mod_p3: u32,
96    mod_p4: u32,
97) -> u64 {
98    use crate::primes32::*;
99
100    let mod_p12 = {
101        let v1 = mod_p1;
102        let v2 = mul_mod32(P2, P1_INV_MOD_P2, 2 * P2 + mod_p2 - v1);
103        v1 as u64 + (v2 as u64 * P1 as u64)
104    };
105    let mod_p34 = {
106        let v3 = mod_p3;
107        let v4 = mul_mod32(P4, P3_INV_MOD_P4, 2 * P4 + mod_p4 - v3);
108        v3 as u64 + (v4 as u64 * P3 as u64)
109    };
110
111    let v0 = mod_p0 as u64;
112    let v12 = mul_mod64(
113        P12.wrapping_neg(),
114        2 * P12 + mod_p12 - v0,
115        P0_INV_MOD_P12,
116        P0_INV_MOD_P12_SHOUP,
117    );
118    let v34 = mul_mod64(
119        P34.wrapping_neg(),
120        2 * P34 + mod_p34 - (v0 + mul_mod64(P34.wrapping_neg(), v12, P0 as u64, P0_MOD_P34_SHOUP)),
121        P012_INV_MOD_P34,
122        P012_INV_MOD_P34_SHOUP,
123    );
124
125    let sign = v34 > (P34 / 2);
126
127    const _0: u64 = P0 as u64;
128    const _012: u64 = _0.wrapping_mul(P12);
129    const _01234: u64 = _012.wrapping_mul(P34);
130
131    let pos = v0
132        .wrapping_add(v12.wrapping_mul(_0))
133        .wrapping_add(v34.wrapping_mul(_012));
134    let neg = pos.wrapping_sub(_01234);
135
136    if sign {
137        neg
138    } else {
139        pos
140    }
141}
142
143#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
144#[inline(always)]
145pub(crate) fn mul_mod32_v2_avx2(
146    simd: crate::V3,
147    p: u64x4,
148    a: u64x4,
149    b: u64x4,
150    b_shoup: u64x4,
151) -> u64x4 {
152    let shoup_q = simd.shr_const_u64x4::<32>(simd.mul_low_32_bits_u64x4(a, b_shoup));
153    let t = simd.and_u64x4(
154        simd.splat_u64x4((1u64 << 32) - 1),
155        simd.wrapping_sub_u64x4(
156            simd.mul_low_32_bits_u64x4(a, b),
157            simd.mul_low_32_bits_u64x4(shoup_q, p),
158        ),
159    );
160    simd.small_mod_u64x4(p, t)
161}
162
163#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
164#[cfg(feature = "nightly")]
165#[inline(always)]
166pub(crate) fn mul_mod32_v2_avx512(
167    simd: crate::V4IFma,
168    p: u64x8,
169    a: u64x8,
170    b: u64x8,
171    b_shoup: u64x8,
172) -> u64x8 {
173    let shoup_q = simd.shr_const_u64x8::<32>(simd.mul_low_32_bits_u64x8(a, b_shoup));
174    let t = simd.and_u64x8(
175        simd.splat_u64x8((1u64 << 32) - 1),
176        simd.wrapping_sub_u64x8(
177            simd.mul_low_32_bits_u64x8(a, b),
178            simd.mul_low_32_bits_u64x8(shoup_q, p),
179        ),
180    );
181    simd.small_mod_u64x8(p, t)
182}
183
184#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
185#[inline(always)]
186pub(crate) fn mul_mod64_avx2(
187    simd: crate::V3,
188    p: u64x4,
189    a: u64x4,
190    b: u64x4,
191    b_shoup: u64x4,
192) -> u64x4 {
193    let q = simd.widening_mul_u64x4(a, b_shoup).1;
194    let r = simd.wrapping_sub_u64x4(
195        simd.widening_mul_u64x4(a, b).0,
196        simd.widening_mul_u64x4(p, q).0,
197    );
198    simd.small_mod_u64x4(p, r)
199}
200
201#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
202#[cfg(feature = "nightly")]
203#[inline(always)]
204pub(crate) fn mul_mod64_avx512(
205    simd: crate::V4IFma,
206    p: u64x8,
207    a: u64x8,
208    b: u64x8,
209    b_shoup: u64x8,
210) -> u64x8 {
211    let q = simd.widening_mul_u64x8(a, b_shoup).1;
212    let r = simd.wrapping_sub_u64x8(simd.wrapping_mul_u64x8(a, b), simd.wrapping_mul_u64x8(p, q));
213    simd.small_mod_u64x8(p, r)
214}
215
216#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
217#[inline(always)]
218fn reconstruct_32bit_01234_v2_avx2(
219    simd: crate::V3,
220    mod_p0: u32x4,
221    mod_p1: u32x4,
222    mod_p2: u32x4,
223    mod_p3: u32x4,
224    mod_p4: u32x4,
225) -> u64x4 {
226    use crate::primes32::*;
227
228    let p0 = simd.splat_u64x4(P0 as u64);
229    let p1 = simd.splat_u64x4(P1 as u64);
230    let p2 = simd.splat_u64x4(P2 as u64);
231    let p3 = simd.splat_u64x4(P3 as u64);
232    let p4 = simd.splat_u64x4(P4 as u64);
233    let p12 = simd.splat_u64x4(P12);
234    let p34 = simd.splat_u64x4(P34);
235    let p012 = simd.splat_u64x4((P0 as u64).wrapping_mul(P12));
236    let p01234 = simd.splat_u64x4((P0 as u64).wrapping_mul(P12).wrapping_mul(P34));
237
238    let two_p2 = simd.splat_u64x4(2 * P2 as u64);
239    let two_p4 = simd.splat_u64x4(2 * P4 as u64);
240    let two_p12 = simd.splat_u64x4(2 * P12);
241    let two_p34 = simd.splat_u64x4(2 * P34);
242    let half_p34 = simd.splat_u64x4(P34 / 2);
243
244    let p0_inv_mod_p12 = simd.splat_u64x4(P0_INV_MOD_P12);
245    let p0_inv_mod_p12_shoup = simd.splat_u64x4(P0_INV_MOD_P12_SHOUP);
246    let p1_inv_mod_p2 = simd.splat_u64x4(P1_INV_MOD_P2 as u64);
247    let p1_inv_mod_p2_shoup = simd.splat_u64x4(P1_INV_MOD_P2_SHOUP as u64);
248    let p3_inv_mod_p4 = simd.splat_u64x4(P3_INV_MOD_P4 as u64);
249    let p3_inv_mod_p4_shoup = simd.splat_u64x4(P3_INV_MOD_P4_SHOUP as u64);
250
251    let p012_inv_mod_p34 = simd.splat_u64x4(P012_INV_MOD_P34);
252    let p012_inv_mod_p34_shoup = simd.splat_u64x4(P012_INV_MOD_P34_SHOUP);
253    let p0_mod_p34_shoup = simd.splat_u64x4(P0_MOD_P34_SHOUP);
254
255    let mod_p0 = simd.convert_u32x4_to_u64x4(mod_p0);
256    let mod_p1 = simd.convert_u32x4_to_u64x4(mod_p1);
257    let mod_p2 = simd.convert_u32x4_to_u64x4(mod_p2);
258    let mod_p3 = simd.convert_u32x4_to_u64x4(mod_p3);
259    let mod_p4 = simd.convert_u32x4_to_u64x4(mod_p4);
260
261    let mod_p12 = {
262        let v1 = mod_p1;
263        let v2 = mul_mod32_v2_avx2(
264            simd,
265            p2,
266            simd.wrapping_sub_u64x4(simd.wrapping_add_u64x4(two_p2, mod_p2), v1),
267            p1_inv_mod_p2,
268            p1_inv_mod_p2_shoup,
269        );
270        simd.wrapping_add_u64x4(v1, simd.mul_low_32_bits_u64x4(v2, p1))
271    };
272    let mod_p34 = {
273        let v3 = mod_p3;
274        let v4 = mul_mod32_v2_avx2(
275            simd,
276            p4,
277            simd.wrapping_sub_u64x4(simd.wrapping_add_u64x4(two_p4, mod_p4), v3),
278            p3_inv_mod_p4,
279            p3_inv_mod_p4_shoup,
280        );
281        simd.wrapping_add_u64x4(v3, simd.mul_low_32_bits_u64x4(v4, p3))
282    };
283
284    let v0 = mod_p0;
285    let v12 = mul_mod64_avx2(
286        simd,
287        p12,
288        simd.wrapping_sub_u64x4(simd.wrapping_add_u64x4(two_p12, mod_p12), v0),
289        p0_inv_mod_p12,
290        p0_inv_mod_p12_shoup,
291    );
292    let v34 = mul_mod64_avx2(
293        simd,
294        p34,
295        simd.wrapping_sub_u64x4(
296            simd.wrapping_add_u64x4(two_p34, mod_p34),
297            simd.wrapping_add_u64x4(v0, mul_mod64_avx2(simd, p34, v12, p0, p0_mod_p34_shoup)),
298        ),
299        p012_inv_mod_p34,
300        p012_inv_mod_p34_shoup,
301    );
302
303    let sign = simd.cmp_gt_u64x4(v34, half_p34);
304    let pos = v0;
305    let pos = simd.wrapping_add_u64x4(
306        pos,
307        simd.wrapping_mul_lhs_with_low_32_bits_of_rhs_u64x4(v12, p0),
308    );
309    let pos = simd.wrapping_add_u64x4(pos, simd.widening_mul_u64x4(v34, p012).0);
310    let neg = simd.wrapping_sub_u64x4(pos, p01234);
311    simd.select_u64x4(sign, neg, pos)
312}
313
314#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
315#[allow(dead_code)]
316#[inline(always)]
317fn reconstruct_32bit_01234_avx2(
318    simd: crate::V3,
319    mod_p0: u32x8,
320    mod_p1: u32x8,
321    mod_p2: u32x8,
322    mod_p3: u32x8,
323    mod_p4: u32x8,
324) -> [u64x4; 2] {
325    use crate::primes32::*;
326
327    let p0 = simd.splat_u32x8(P0);
328    let p1 = simd.splat_u32x8(P1);
329    let p2 = simd.splat_u32x8(P2);
330    let p3 = simd.splat_u32x8(P3);
331    let p4 = simd.splat_u32x8(P4);
332    let two_p1 = simd.splat_u32x8(2 * P1);
333    let two_p2 = simd.splat_u32x8(2 * P2);
334    let two_p3 = simd.splat_u32x8(2 * P3);
335    let two_p4 = simd.splat_u32x8(2 * P4);
336    let half_p4 = simd.splat_u32x8(P4 / 2);
337
338    let p0_inv_mod_p1 = simd.splat_u32x8(P0_INV_MOD_P1);
339    let p0_inv_mod_p1_shoup = simd.splat_u32x8(P0_INV_MOD_P1_SHOUP);
340    let p0_mod_p2_shoup = simd.splat_u32x8(P0_MOD_P2_SHOUP);
341    let p0_mod_p3_shoup = simd.splat_u32x8(P0_MOD_P3_SHOUP);
342    let p1_mod_p3_shoup = simd.splat_u32x8(P1_MOD_P3_SHOUP);
343    let p0_mod_p4_shoup = simd.splat_u32x8(P0_MOD_P4_SHOUP);
344    let p1_mod_p4_shoup = simd.splat_u32x8(P1_MOD_P4_SHOUP);
345    let p2_mod_p4_shoup = simd.splat_u32x8(P2_MOD_P4_SHOUP);
346
347    let p01_inv_mod_p2 = simd.splat_u32x8(P01_INV_MOD_P2);
348    let p01_inv_mod_p2_shoup = simd.splat_u32x8(P01_INV_MOD_P2_SHOUP);
349    let p012_inv_mod_p3 = simd.splat_u32x8(P012_INV_MOD_P3);
350    let p012_inv_mod_p3_shoup = simd.splat_u32x8(P012_INV_MOD_P3_SHOUP);
351    let p0123_inv_mod_p4 = simd.splat_u32x8(P0123_INV_MOD_P4);
352    let p0123_inv_mod_p4_shoup = simd.splat_u32x8(P0123_INV_MOD_P4_SHOUP);
353
354    let p01 = simd.splat_u64x4((P0 as u64).wrapping_mul(P1 as u64));
355    let p012 = simd.splat_u64x4((P0 as u64).wrapping_mul(P1 as u64).wrapping_mul(P2 as u64));
356    let p0123 = simd.splat_u64x4(
357        (P0 as u64)
358            .wrapping_mul(P1 as u64)
359            .wrapping_mul(P2 as u64)
360            .wrapping_mul(P3 as u64),
361    );
362    let p01234 = simd.splat_u64x4(
363        (P0 as u64)
364            .wrapping_mul(P1 as u64)
365            .wrapping_mul(P2 as u64)
366            .wrapping_mul(P3 as u64)
367            .wrapping_mul(P4 as u64),
368    );
369
370    let v0 = mod_p0;
371    let v1 = mul_mod32_avx2(
372        simd,
373        p1,
374        simd.wrapping_sub_u32x8(simd.wrapping_add_u32x8(two_p1, mod_p1), v0),
375        p0_inv_mod_p1,
376        p0_inv_mod_p1_shoup,
377    );
378    let v2 = mul_mod32_avx2(
379        simd,
380        p2,
381        simd.wrapping_sub_u32x8(
382            simd.wrapping_add_u32x8(two_p2, mod_p2),
383            simd.wrapping_add_u32x8(v0, mul_mod32_avx2(simd, p2, v1, p0, p0_mod_p2_shoup)),
384        ),
385        p01_inv_mod_p2,
386        p01_inv_mod_p2_shoup,
387    );
388    let v3 = mul_mod32_avx2(
389        simd,
390        p3,
391        simd.wrapping_sub_u32x8(
392            simd.wrapping_add_u32x8(two_p3, mod_p3),
393            simd.wrapping_add_u32x8(
394                v0,
395                mul_mod32_avx2(
396                    simd,
397                    p3,
398                    simd.wrapping_add_u32x8(v1, mul_mod32_avx2(simd, p3, v2, p1, p1_mod_p3_shoup)),
399                    p0,
400                    p0_mod_p3_shoup,
401                ),
402            ),
403        ),
404        p012_inv_mod_p3,
405        p012_inv_mod_p3_shoup,
406    );
407    let v4 = mul_mod32_avx2(
408        simd,
409        p4,
410        simd.wrapping_sub_u32x8(
411            simd.wrapping_add_u32x8(two_p4, mod_p4),
412            simd.wrapping_add_u32x8(
413                v0,
414                mul_mod32_avx2(
415                    simd,
416                    p4,
417                    simd.wrapping_add_u32x8(
418                        v1,
419                        mul_mod32_avx2(
420                            simd,
421                            p4,
422                            simd.wrapping_add_u32x8(
423                                v2,
424                                mul_mod32_avx2(simd, p4, v3, p2, p2_mod_p4_shoup),
425                            ),
426                            p1,
427                            p1_mod_p4_shoup,
428                        ),
429                    ),
430                    p0,
431                    p0_mod_p4_shoup,
432                ),
433            ),
434        ),
435        p0123_inv_mod_p4,
436        p0123_inv_mod_p4_shoup,
437    );
438
439    let sign = simd.cmp_gt_u32x8(v4, half_p4);
440    let sign: [i32x4; 2] = pulp::cast(sign);
441    // sign extend so that -1i32 becomes -1i64
442    let sign0: m64x4 = unsafe { core::mem::transmute(simd.convert_i32x4_to_i64x4(sign[0])) };
443    let sign1: m64x4 = unsafe { core::mem::transmute(simd.convert_i32x4_to_i64x4(sign[1])) };
444
445    let v0: [u32x4; 2] = pulp::cast(v0);
446    let v1: [u32x4; 2] = pulp::cast(v1);
447    let v2: [u32x4; 2] = pulp::cast(v2);
448    let v3: [u32x4; 2] = pulp::cast(v3);
449    let v4: [u32x4; 2] = pulp::cast(v4);
450    let v00 = simd.convert_u32x4_to_u64x4(v0[0]);
451    let v01 = simd.convert_u32x4_to_u64x4(v0[1]);
452    let v10 = simd.convert_u32x4_to_u64x4(v1[0]);
453    let v11 = simd.convert_u32x4_to_u64x4(v1[1]);
454    let v20 = simd.convert_u32x4_to_u64x4(v2[0]);
455    let v21 = simd.convert_u32x4_to_u64x4(v2[1]);
456    let v30 = simd.convert_u32x4_to_u64x4(v3[0]);
457    let v31 = simd.convert_u32x4_to_u64x4(v3[1]);
458    let v40 = simd.convert_u32x4_to_u64x4(v4[0]);
459    let v41 = simd.convert_u32x4_to_u64x4(v4[1]);
460
461    let pos0 = v00;
462    let pos0 = simd.wrapping_add_u64x4(pos0, simd.mul_low_32_bits_u64x4(pulp::cast(p0), v10));
463    let pos0 = simd.wrapping_add_u64x4(
464        pos0,
465        simd.wrapping_mul_lhs_with_low_32_bits_of_rhs_u64x4(p01, v20),
466    );
467    let pos0 = simd.wrapping_add_u64x4(
468        pos0,
469        simd.wrapping_mul_lhs_with_low_32_bits_of_rhs_u64x4(p012, v30),
470    );
471    let pos0 = simd.wrapping_add_u64x4(
472        pos0,
473        simd.wrapping_mul_lhs_with_low_32_bits_of_rhs_u64x4(p0123, v40),
474    );
475
476    let pos1 = v01;
477    let pos1 = simd.wrapping_add_u64x4(pos1, simd.mul_low_32_bits_u64x4(pulp::cast(p0), v11));
478    let pos1 = simd.wrapping_add_u64x4(
479        pos1,
480        simd.wrapping_mul_lhs_with_low_32_bits_of_rhs_u64x4(p01, v21),
481    );
482    let pos1 = simd.wrapping_add_u64x4(
483        pos1,
484        simd.wrapping_mul_lhs_with_low_32_bits_of_rhs_u64x4(p012, v31),
485    );
486    let pos1 = simd.wrapping_add_u64x4(
487        pos1,
488        simd.wrapping_mul_lhs_with_low_32_bits_of_rhs_u64x4(p0123, v41),
489    );
490
491    let neg0 = simd.wrapping_sub_u64x4(pos0, p01234);
492    let neg1 = simd.wrapping_sub_u64x4(pos1, p01234);
493
494    [
495        simd.select_u64x4(sign0, neg0, pos0),
496        simd.select_u64x4(sign1, neg1, pos1),
497    ]
498}
499
500#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
501#[cfg(feature = "nightly")]
502#[allow(dead_code)]
503#[inline(always)]
504fn reconstruct_32bit_01234_avx512(
505    simd: crate::V4IFma,
506    mod_p0: u32x16,
507    mod_p1: u32x16,
508    mod_p2: u32x16,
509    mod_p3: u32x16,
510    mod_p4: u32x16,
511) -> [u64x8; 2] {
512    use crate::primes32::*;
513
514    let p0 = simd.splat_u32x16(P0);
515    let p1 = simd.splat_u32x16(P1);
516    let p2 = simd.splat_u32x16(P2);
517    let p3 = simd.splat_u32x16(P3);
518    let p4 = simd.splat_u32x16(P4);
519    let two_p1 = simd.splat_u32x16(2 * P1);
520    let two_p2 = simd.splat_u32x16(2 * P2);
521    let two_p3 = simd.splat_u32x16(2 * P3);
522    let two_p4 = simd.splat_u32x16(2 * P4);
523    let half_p4 = simd.splat_u32x16(P4 / 2);
524
525    let p0_inv_mod_p1 = simd.splat_u32x16(P0_INV_MOD_P1);
526    let p0_inv_mod_p1_shoup = simd.splat_u32x16(P0_INV_MOD_P1_SHOUP);
527    let p0_mod_p2_shoup = simd.splat_u32x16(P0_MOD_P2_SHOUP);
528    let p0_mod_p3_shoup = simd.splat_u32x16(P0_MOD_P3_SHOUP);
529    let p1_mod_p3_shoup = simd.splat_u32x16(P1_MOD_P3_SHOUP);
530    let p0_mod_p4_shoup = simd.splat_u32x16(P0_MOD_P4_SHOUP);
531    let p1_mod_p4_shoup = simd.splat_u32x16(P1_MOD_P4_SHOUP);
532    let p2_mod_p4_shoup = simd.splat_u32x16(P2_MOD_P4_SHOUP);
533
534    let p01_inv_mod_p2 = simd.splat_u32x16(P01_INV_MOD_P2);
535    let p01_inv_mod_p2_shoup = simd.splat_u32x16(P01_INV_MOD_P2_SHOUP);
536    let p012_inv_mod_p3 = simd.splat_u32x16(P012_INV_MOD_P3);
537    let p012_inv_mod_p3_shoup = simd.splat_u32x16(P012_INV_MOD_P3_SHOUP);
538    let p0123_inv_mod_p4 = simd.splat_u32x16(P0123_INV_MOD_P4);
539    let p0123_inv_mod_p4_shoup = simd.splat_u32x16(P0123_INV_MOD_P4_SHOUP);
540
541    let p01 = simd.splat_u64x8((P0 as u64).wrapping_mul(P1 as u64));
542    let p012 = simd.splat_u64x8((P0 as u64).wrapping_mul(P1 as u64).wrapping_mul(P2 as u64));
543    let p0123 = simd.splat_u64x8(
544        (P0 as u64)
545            .wrapping_mul(P1 as u64)
546            .wrapping_mul(P2 as u64)
547            .wrapping_mul(P3 as u64),
548    );
549    let p01234 = simd.splat_u64x8(
550        (P0 as u64)
551            .wrapping_mul(P1 as u64)
552            .wrapping_mul(P2 as u64)
553            .wrapping_mul(P3 as u64)
554            .wrapping_mul(P4 as u64),
555    );
556
557    let v0 = mod_p0;
558    let v1 = mul_mod32_avx512(
559        simd,
560        p1,
561        simd.wrapping_sub_u32x16(simd.wrapping_add_u32x16(two_p1, mod_p1), v0),
562        p0_inv_mod_p1,
563        p0_inv_mod_p1_shoup,
564    );
565    let v2 = mul_mod32_avx512(
566        simd,
567        p2,
568        simd.wrapping_sub_u32x16(
569            simd.wrapping_add_u32x16(two_p2, mod_p2),
570            simd.wrapping_add_u32x16(v0, mul_mod32_avx512(simd, p2, v1, p0, p0_mod_p2_shoup)),
571        ),
572        p01_inv_mod_p2,
573        p01_inv_mod_p2_shoup,
574    );
575    let v3 = mul_mod32_avx512(
576        simd,
577        p3,
578        simd.wrapping_sub_u32x16(
579            simd.wrapping_add_u32x16(two_p3, mod_p3),
580            simd.wrapping_add_u32x16(
581                v0,
582                mul_mod32_avx512(
583                    simd,
584                    p3,
585                    simd.wrapping_add_u32x16(
586                        v1,
587                        mul_mod32_avx512(simd, p3, v2, p1, p1_mod_p3_shoup),
588                    ),
589                    p0,
590                    p0_mod_p3_shoup,
591                ),
592            ),
593        ),
594        p012_inv_mod_p3,
595        p012_inv_mod_p3_shoup,
596    );
597    let v4 = mul_mod32_avx512(
598        simd,
599        p4,
600        simd.wrapping_sub_u32x16(
601            simd.wrapping_add_u32x16(two_p4, mod_p4),
602            simd.wrapping_add_u32x16(
603                v0,
604                mul_mod32_avx512(
605                    simd,
606                    p4,
607                    simd.wrapping_add_u32x16(
608                        v1,
609                        mul_mod32_avx512(
610                            simd,
611                            p4,
612                            simd.wrapping_add_u32x16(
613                                v2,
614                                mul_mod32_avx512(simd, p4, v3, p2, p2_mod_p4_shoup),
615                            ),
616                            p1,
617                            p1_mod_p4_shoup,
618                        ),
619                    ),
620                    p0,
621                    p0_mod_p4_shoup,
622                ),
623            ),
624        ),
625        p0123_inv_mod_p4,
626        p0123_inv_mod_p4_shoup,
627    );
628
629    let sign = simd.cmp_gt_u32x16(v4, half_p4).0;
630    let sign0 = b8(sign as u8);
631    let sign1 = b8((sign >> 8) as u8);
632    let v0: [u32x8; 2] = pulp::cast(v0);
633    let v1: [u32x8; 2] = pulp::cast(v1);
634    let v2: [u32x8; 2] = pulp::cast(v2);
635    let v3: [u32x8; 2] = pulp::cast(v3);
636    let v4: [u32x8; 2] = pulp::cast(v4);
637    let v00 = simd.convert_u32x8_to_u64x8(v0[0]);
638    let v01 = simd.convert_u32x8_to_u64x8(v0[1]);
639    let v10 = simd.convert_u32x8_to_u64x8(v1[0]);
640    let v11 = simd.convert_u32x8_to_u64x8(v1[1]);
641    let v20 = simd.convert_u32x8_to_u64x8(v2[0]);
642    let v21 = simd.convert_u32x8_to_u64x8(v2[1]);
643    let v30 = simd.convert_u32x8_to_u64x8(v3[0]);
644    let v31 = simd.convert_u32x8_to_u64x8(v3[1]);
645    let v40 = simd.convert_u32x8_to_u64x8(v4[0]);
646    let v41 = simd.convert_u32x8_to_u64x8(v4[1]);
647
648    let pos0 = v00;
649    let pos0 = simd.wrapping_add_u64x8(pos0, simd.mul_low_32_bits_u64x8(pulp::cast(p0), v10));
650    let pos0 = simd.wrapping_add_u64x8(pos0, simd.wrapping_mul_u64x8(p01, v20));
651    let pos0 = simd.wrapping_add_u64x8(pos0, simd.wrapping_mul_u64x8(p012, v30));
652    let pos0 = simd.wrapping_add_u64x8(pos0, simd.wrapping_mul_u64x8(p0123, v40));
653
654    let pos1 = v01;
655    let pos1 = simd.wrapping_add_u64x8(pos1, simd.mul_low_32_bits_u64x8(pulp::cast(p0), v11));
656    let pos1 = simd.wrapping_add_u64x8(pos1, simd.wrapping_mul_u64x8(p01, v21));
657    let pos1 = simd.wrapping_add_u64x8(pos1, simd.wrapping_mul_u64x8(p012, v31));
658    let pos1 = simd.wrapping_add_u64x8(pos1, simd.wrapping_mul_u64x8(p0123, v41));
659
660    let neg0 = simd.wrapping_sub_u64x8(pos0, p01234);
661    let neg1 = simd.wrapping_sub_u64x8(pos1, p01234);
662
663    [
664        simd.select_u64x8(sign0, neg0, pos0),
665        simd.select_u64x8(sign1, neg1, pos1),
666    ]
667}
668
669#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
670#[cfg(feature = "nightly")]
671#[inline(always)]
672fn reconstruct_32bit_01234_v2_avx512(
673    simd: crate::V4IFma,
674    mod_p0: u32x8,
675    mod_p1: u32x8,
676    mod_p2: u32x8,
677    mod_p3: u32x8,
678    mod_p4: u32x8,
679) -> u64x8 {
680    use crate::primes32::*;
681
682    let p0 = simd.splat_u64x8(P0 as u64);
683    let p1 = simd.splat_u64x8(P1 as u64);
684    let p2 = simd.splat_u64x8(P2 as u64);
685    let p3 = simd.splat_u64x8(P3 as u64);
686    let p4 = simd.splat_u64x8(P4 as u64);
687    let p12 = simd.splat_u64x8(P12);
688    let p34 = simd.splat_u64x8(P34);
689    let p012 = simd.splat_u64x8((P0 as u64).wrapping_mul(P12));
690    let p01234 = simd.splat_u64x8((P0 as u64).wrapping_mul(P12).wrapping_mul(P34));
691
692    let two_p2 = simd.splat_u64x8(2 * P2 as u64);
693    let two_p4 = simd.splat_u64x8(2 * P4 as u64);
694    let two_p12 = simd.splat_u64x8(2 * P12);
695    let two_p34 = simd.splat_u64x8(2 * P34);
696    let half_p34 = simd.splat_u64x8(P34 / 2);
697
698    let p0_inv_mod_p12 = simd.splat_u64x8(P0_INV_MOD_P12);
699    let p0_inv_mod_p12_shoup = simd.splat_u64x8(P0_INV_MOD_P12_SHOUP);
700    let p1_inv_mod_p2 = simd.splat_u64x8(P1_INV_MOD_P2 as u64);
701    let p1_inv_mod_p2_shoup = simd.splat_u64x8(P1_INV_MOD_P2_SHOUP as u64);
702    let p3_inv_mod_p4 = simd.splat_u64x8(P3_INV_MOD_P4 as u64);
703    let p3_inv_mod_p4_shoup = simd.splat_u64x8(P3_INV_MOD_P4_SHOUP as u64);
704
705    let p012_inv_mod_p34 = simd.splat_u64x8(P012_INV_MOD_P34);
706    let p012_inv_mod_p34_shoup = simd.splat_u64x8(P012_INV_MOD_P34_SHOUP);
707    let p0_mod_p34_shoup = simd.splat_u64x8(P0_MOD_P34_SHOUP);
708
709    let mod_p0 = simd.convert_u32x8_to_u64x8(mod_p0);
710    let mod_p1 = simd.convert_u32x8_to_u64x8(mod_p1);
711    let mod_p2 = simd.convert_u32x8_to_u64x8(mod_p2);
712    let mod_p3 = simd.convert_u32x8_to_u64x8(mod_p3);
713    let mod_p4 = simd.convert_u32x8_to_u64x8(mod_p4);
714
715    let mod_p12 = {
716        let v1 = mod_p1;
717        let v2 = mul_mod32_v2_avx512(
718            simd,
719            p2,
720            simd.wrapping_sub_u64x8(simd.wrapping_add_u64x8(two_p2, mod_p2), v1),
721            p1_inv_mod_p2,
722            p1_inv_mod_p2_shoup,
723        );
724        simd.wrapping_add_u64x8(v1, simd.wrapping_mul_u64x8(v2, p1))
725    };
726    let mod_p34 = {
727        let v3 = mod_p3;
728        let v4 = mul_mod32_v2_avx512(
729            simd,
730            p4,
731            simd.wrapping_sub_u64x8(simd.wrapping_add_u64x8(two_p4, mod_p4), v3),
732            p3_inv_mod_p4,
733            p3_inv_mod_p4_shoup,
734        );
735        simd.wrapping_add_u64x8(v3, simd.wrapping_mul_u64x8(v4, p3))
736    };
737
738    let v0 = mod_p0;
739    let v12 = mul_mod64_avx512(
740        simd,
741        p12,
742        simd.wrapping_sub_u64x8(simd.wrapping_add_u64x8(two_p12, mod_p12), v0),
743        p0_inv_mod_p12,
744        p0_inv_mod_p12_shoup,
745    );
746    let v34 = mul_mod64_avx512(
747        simd,
748        p34,
749        simd.wrapping_sub_u64x8(
750            simd.wrapping_add_u64x8(two_p34, mod_p34),
751            simd.wrapping_add_u64x8(v0, mul_mod64_avx512(simd, p34, v12, p0, p0_mod_p34_shoup)),
752        ),
753        p012_inv_mod_p34,
754        p012_inv_mod_p34_shoup,
755    );
756
757    let sign = simd.cmp_gt_u64x8(v34, half_p34);
758    let pos = v0;
759    let pos = simd.wrapping_add_u64x8(pos, simd.wrapping_mul_u64x8(v12, p0));
760    let pos = simd.wrapping_add_u64x8(pos, simd.wrapping_mul_u64x8(v34, p012));
761
762    let neg = simd.wrapping_sub_u64x8(pos, p01234);
763
764    simd.select_u64x8(sign, neg, pos)
765}
766
767#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
768#[cfg(feature = "nightly")]
769#[inline(always)]
770fn reconstruct_52bit_012_avx512(
771    simd: crate::V4IFma,
772    mod_p0: u64x8,
773    mod_p1: u64x8,
774    mod_p2: u64x8,
775) -> u64x8 {
776    use crate::primes52::*;
777
778    let p0 = simd.splat_u64x8(P0);
779    let p1 = simd.splat_u64x8(P1);
780    let p2 = simd.splat_u64x8(P2);
781    let neg_p1 = simd.splat_u64x8(P1.wrapping_neg());
782    let neg_p2 = simd.splat_u64x8(P2.wrapping_neg());
783    let two_p1 = simd.splat_u64x8(2 * P1);
784    let two_p2 = simd.splat_u64x8(2 * P2);
785    let half_p2 = simd.splat_u64x8(P2 / 2);
786
787    let p0_inv_mod_p1 = simd.splat_u64x8(P0_INV_MOD_P1);
788    let p0_inv_mod_p1_shoup = simd.splat_u64x8(P0_INV_MOD_P1_SHOUP);
789    let p0_mod_p2_shoup = simd.splat_u64x8(P0_MOD_P2_SHOUP);
790    let p01_inv_mod_p2 = simd.splat_u64x8(P01_INV_MOD_P2);
791    let p01_inv_mod_p2_shoup = simd.splat_u64x8(P01_INV_MOD_P2_SHOUP);
792
793    let p01 = simd.splat_u64x8(P0.wrapping_mul(P1));
794    let p012 = simd.splat_u64x8(P0.wrapping_mul(P1).wrapping_mul(P2));
795
796    let v0 = mod_p0;
797    let v1 = mul_mod52_avx512(
798        simd,
799        p1,
800        neg_p1,
801        simd.wrapping_sub_u64x8(simd.wrapping_add_u64x8(two_p1, mod_p1), v0),
802        p0_inv_mod_p1,
803        p0_inv_mod_p1_shoup,
804    );
805    let v2 = mul_mod52_avx512(
806        simd,
807        p2,
808        neg_p2,
809        simd.wrapping_sub_u64x8(
810            simd.wrapping_add_u64x8(two_p2, mod_p2),
811            simd.wrapping_add_u64x8(
812                v0,
813                mul_mod52_avx512(simd, p2, neg_p2, v1, p0, p0_mod_p2_shoup),
814            ),
815        ),
816        p01_inv_mod_p2,
817        p01_inv_mod_p2_shoup,
818    );
819
820    let sign = simd.cmp_gt_u64x8(v2, half_p2);
821
822    let pos = simd.wrapping_add_u64x8(
823        simd.wrapping_add_u64x8(v0, simd.wrapping_mul_u64x8(v1, p0)),
824        simd.wrapping_mul_u64x8(v2, p01),
825    );
826    let neg = simd.wrapping_sub_u64x8(pos, p012);
827
828    simd.select_u64x8(sign, neg, pos)
829}
830
831#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
832fn reconstruct_slice_32bit_01234_avx2(
833    simd: crate::V3,
834    value: &mut [u64],
835    mod_p0: &[u32],
836    mod_p1: &[u32],
837    mod_p2: &[u32],
838    mod_p3: &[u32],
839    mod_p4: &[u32],
840) {
841    simd.vectorize(
842        #[inline(always)]
843        move || {
844            let value = pulp::as_arrays_mut::<4, _>(value).0;
845            let mod_p0 = pulp::as_arrays::<4, _>(mod_p0).0;
846            let mod_p1 = pulp::as_arrays::<4, _>(mod_p1).0;
847            let mod_p2 = pulp::as_arrays::<4, _>(mod_p2).0;
848            let mod_p3 = pulp::as_arrays::<4, _>(mod_p3).0;
849            let mod_p4 = pulp::as_arrays::<4, _>(mod_p4).0;
850            for (value, &mod_p0, &mod_p1, &mod_p2, &mod_p3, &mod_p4) in
851                crate::izip!(value, mod_p0, mod_p1, mod_p2, mod_p3, mod_p4)
852            {
853                *value = cast(reconstruct_32bit_01234_v2_avx2(
854                    simd,
855                    cast(mod_p0),
856                    cast(mod_p1),
857                    cast(mod_p2),
858                    cast(mod_p3),
859                    cast(mod_p4),
860                ));
861            }
862        },
863    );
864}
865
866#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
867#[cfg(feature = "nightly")]
868fn reconstruct_slice_32bit_01234_avx512(
869    simd: crate::V4IFma,
870    value: &mut [u64],
871    mod_p0: &[u32],
872    mod_p1: &[u32],
873    mod_p2: &[u32],
874    mod_p3: &[u32],
875    mod_p4: &[u32],
876) {
877    simd.vectorize(
878        #[inline(always)]
879        move || {
880            let value = pulp::as_arrays_mut::<8, _>(value).0;
881            let mod_p0 = pulp::as_arrays::<8, _>(mod_p0).0;
882            let mod_p1 = pulp::as_arrays::<8, _>(mod_p1).0;
883            let mod_p2 = pulp::as_arrays::<8, _>(mod_p2).0;
884            let mod_p3 = pulp::as_arrays::<8, _>(mod_p3).0;
885            let mod_p4 = pulp::as_arrays::<8, _>(mod_p4).0;
886            for (value, &mod_p0, &mod_p1, &mod_p2, &mod_p3, &mod_p4) in
887                crate::izip!(value, mod_p0, mod_p1, mod_p2, mod_p3, mod_p4)
888            {
889                *value = cast(reconstruct_32bit_01234_v2_avx512(
890                    simd,
891                    cast(mod_p0),
892                    cast(mod_p1),
893                    cast(mod_p2),
894                    cast(mod_p3),
895                    cast(mod_p4),
896                ));
897            }
898        },
899    );
900}
901
902#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
903#[cfg(feature = "nightly")]
904fn reconstruct_slice_52bit_012_avx512(
905    simd: crate::V4IFma,
906    value: &mut [u64],
907    mod_p0: &[u64],
908    mod_p1: &[u64],
909    mod_p2: &[u64],
910) {
911    simd.vectorize(
912        #[inline(always)]
913        move || {
914            let value = pulp::as_arrays_mut::<8, _>(value).0;
915            let mod_p0 = pulp::as_arrays::<8, _>(mod_p0).0;
916            let mod_p1 = pulp::as_arrays::<8, _>(mod_p1).0;
917            let mod_p2 = pulp::as_arrays::<8, _>(mod_p2).0;
918            for (value, &mod_p0, &mod_p1, &mod_p2) in crate::izip!(value, mod_p0, mod_p1, mod_p2) {
919                *value = cast(reconstruct_52bit_012_avx512(
920                    simd,
921                    cast(mod_p0),
922                    cast(mod_p1),
923                    cast(mod_p2),
924                ));
925            }
926        },
927    );
928}
929
930impl Plan32 {
931    /// Returns a negacyclic NTT plan for the given polynomial size, or `None` if no
932    /// suitable roots of unity can be found for the wanted parameters.
933    pub fn try_new(n: usize) -> Option<Self> {
934        use crate::{prime32::Plan, primes32::*};
935        Some(Self(
936            Plan::try_new(n, P0)?,
937            Plan::try_new(n, P1)?,
938            Plan::try_new(n, P2)?,
939            Plan::try_new(n, P3)?,
940            Plan::try_new(n, P4)?,
941        ))
942    }
943
944    /// Returns the polynomial size of the negacyclic NTT plan.
945    #[inline]
946    pub fn ntt_size(&self) -> usize {
947        self.0.ntt_size()
948    }
949
950    #[inline]
951    pub fn ntt_0(&self) -> &crate::prime32::Plan {
952        &self.0
953    }
954    #[inline]
955    pub fn ntt_1(&self) -> &crate::prime32::Plan {
956        &self.1
957    }
958    #[inline]
959    pub fn ntt_2(&self) -> &crate::prime32::Plan {
960        &self.2
961    }
962    #[inline]
963    pub fn ntt_3(&self) -> &crate::prime32::Plan {
964        &self.3
965    }
966    #[inline]
967    pub fn ntt_4(&self) -> &crate::prime32::Plan {
968        &self.4
969    }
970
971    pub fn fwd(
972        &self,
973        value: &[u64],
974        mod_p0: &mut [u32],
975        mod_p1: &mut [u32],
976        mod_p2: &mut [u32],
977        mod_p3: &mut [u32],
978        mod_p4: &mut [u32],
979    ) {
980        for (value, mod_p0, mod_p1, mod_p2, mod_p3, mod_p4) in crate::izip!(
981            value,
982            &mut *mod_p0,
983            &mut *mod_p1,
984            &mut *mod_p2,
985            &mut *mod_p3,
986            &mut *mod_p4
987        ) {
988            *mod_p0 = (value % crate::primes32::P0 as u64) as u32;
989            *mod_p1 = (value % crate::primes32::P1 as u64) as u32;
990            *mod_p2 = (value % crate::primes32::P2 as u64) as u32;
991            *mod_p3 = (value % crate::primes32::P3 as u64) as u32;
992            *mod_p4 = (value % crate::primes32::P4 as u64) as u32;
993        }
994        self.0.fwd(mod_p0);
995        self.1.fwd(mod_p1);
996        self.2.fwd(mod_p2);
997        self.3.fwd(mod_p3);
998        self.4.fwd(mod_p4);
999    }
1000
1001    pub fn inv(
1002        &self,
1003        value: &mut [u64],
1004        mod_p0: &mut [u32],
1005        mod_p1: &mut [u32],
1006        mod_p2: &mut [u32],
1007        mod_p3: &mut [u32],
1008        mod_p4: &mut [u32],
1009    ) {
1010        self.0.inv(mod_p0);
1011        self.1.inv(mod_p1);
1012        self.2.inv(mod_p2);
1013        self.3.inv(mod_p3);
1014        self.4.inv(mod_p4);
1015
1016        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1017        {
1018            #[cfg(feature = "nightly")]
1019            if let Some(simd) = crate::V4IFma::try_new() {
1020                reconstruct_slice_32bit_01234_avx512(
1021                    simd, value, mod_p0, mod_p1, mod_p2, mod_p3, mod_p4,
1022                );
1023                return;
1024            }
1025            if let Some(simd) = crate::V3::try_new() {
1026                reconstruct_slice_32bit_01234_avx2(
1027                    simd, value, mod_p0, mod_p1, mod_p2, mod_p3, mod_p4,
1028                );
1029                return;
1030            }
1031        }
1032
1033        for (value, &mod_p0, &mod_p1, &mod_p2, &mod_p3, &mod_p4) in
1034            crate::izip!(value, &*mod_p0, &*mod_p1, &*mod_p2, &*mod_p3, &*mod_p4)
1035        {
1036            *value = reconstruct_32bit_01234_v2(mod_p0, mod_p1, mod_p2, mod_p3, mod_p4);
1037        }
1038    }
1039
1040    /// Computes the negacyclic polynomial product of `lhs` and `rhs`, and stores the result in
1041    /// `prod`.
1042    pub fn negacyclic_polymul(&self, prod: &mut [u64], lhs: &[u64], rhs: &[u64]) {
1043        let n = prod.len();
1044        assert_eq!(n, lhs.len());
1045        assert_eq!(n, rhs.len());
1046
1047        let mut lhs0 = avec![0; n];
1048        let mut lhs1 = avec![0; n];
1049        let mut lhs2 = avec![0; n];
1050        let mut lhs3 = avec![0; n];
1051        let mut lhs4 = avec![0; n];
1052
1053        let mut rhs0 = avec![0; n];
1054        let mut rhs1 = avec![0; n];
1055        let mut rhs2 = avec![0; n];
1056        let mut rhs3 = avec![0; n];
1057        let mut rhs4 = avec![0; n];
1058
1059        self.fwd(lhs, &mut lhs0, &mut lhs1, &mut lhs2, &mut lhs3, &mut lhs4);
1060        self.fwd(rhs, &mut rhs0, &mut rhs1, &mut rhs2, &mut rhs3, &mut rhs4);
1061
1062        self.0.mul_assign_normalize(&mut lhs0, &rhs0);
1063        self.1.mul_assign_normalize(&mut lhs1, &rhs1);
1064        self.2.mul_assign_normalize(&mut lhs2, &rhs2);
1065        self.3.mul_assign_normalize(&mut lhs3, &rhs3);
1066        self.4.mul_assign_normalize(&mut lhs4, &rhs4);
1067
1068        self.inv(prod, &mut lhs0, &mut lhs1, &mut lhs2, &mut lhs3, &mut lhs4);
1069    }
1070}
1071
1072#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1073#[cfg(feature = "nightly")]
1074impl Plan52 {
1075    /// Returns a negacyclic NTT plan for the given polynomial size, or `None` if no
1076    /// suitable roots of unity can be found for the wanted parameters, or if the AVX512
1077    /// instruction set isn't detected.
1078    pub fn try_new(n: usize) -> Option<Self> {
1079        use crate::{prime64::Plan, primes52::*};
1080        let simd = crate::V4IFma::try_new()?;
1081        Some(Self(
1082            Plan::try_new(n, P0)?,
1083            Plan::try_new(n, P1)?,
1084            Plan::try_new(n, P2)?,
1085            simd,
1086        ))
1087    }
1088
1089    /// Returns the polynomial size of the negacyclic NTT plan.
1090    #[inline]
1091    pub fn ntt_size(&self) -> usize {
1092        self.0.ntt_size()
1093    }
1094
1095    #[inline]
1096    pub fn ntt_0(&self) -> &crate::prime64::Plan {
1097        &self.0
1098    }
1099    #[inline]
1100    pub fn ntt_1(&self) -> &crate::prime64::Plan {
1101        &self.1
1102    }
1103    #[inline]
1104    pub fn ntt_2(&self) -> &crate::prime64::Plan {
1105        &self.2
1106    }
1107
1108    pub fn fwd(&self, value: &[u64], mod_p0: &mut [u64], mod_p1: &mut [u64], mod_p2: &mut [u64]) {
1109        use crate::primes52::*;
1110        self.3.vectorize(
1111            #[inline(always)]
1112            || {
1113                for (&value, mod_p0, mod_p1, mod_p2) in
1114                    crate::izip!(value, &mut *mod_p0, &mut *mod_p1, &mut *mod_p2)
1115                {
1116                    *mod_p0 = value % P0;
1117                    *mod_p1 = value % P1;
1118                    *mod_p2 = value % P2;
1119                }
1120            },
1121        );
1122        self.0.fwd(mod_p0);
1123        self.1.fwd(mod_p1);
1124        self.2.fwd(mod_p2);
1125    }
1126
1127    pub fn inv(
1128        &self,
1129        value: &mut [u64],
1130        mod_p0: &mut [u64],
1131        mod_p1: &mut [u64],
1132        mod_p2: &mut [u64],
1133    ) {
1134        self.0.inv(mod_p0);
1135        self.1.inv(mod_p1);
1136        self.2.inv(mod_p2);
1137
1138        reconstruct_slice_52bit_012_avx512(self.3, value, mod_p0, mod_p1, mod_p2);
1139    }
1140
1141    /// Computes the negacyclic polynomial product of `lhs` and `rhs`, and stores the result in
1142    /// `prod`.
1143    pub fn negacyclic_polymul(&self, prod: &mut [u64], lhs: &[u64], rhs: &[u64]) {
1144        let n = prod.len();
1145        assert_eq!(n, lhs.len());
1146        assert_eq!(n, rhs.len());
1147
1148        let mut lhs0 = avec![0; n];
1149        let mut lhs1 = avec![0; n];
1150        let mut lhs2 = avec![0; n];
1151
1152        let mut rhs0 = avec![0; n];
1153        let mut rhs1 = avec![0; n];
1154        let mut rhs2 = avec![0; n];
1155
1156        self.fwd(lhs, &mut lhs0, &mut lhs1, &mut lhs2);
1157        self.fwd(rhs, &mut rhs0, &mut rhs1, &mut rhs2);
1158
1159        self.0.mul_assign_normalize(&mut lhs0, &rhs0);
1160        self.1.mul_assign_normalize(&mut lhs1, &rhs1);
1161        self.2.mul_assign_normalize(&mut lhs2, &rhs2);
1162
1163        self.inv(prod, &mut lhs0, &mut lhs1, &mut lhs2);
1164    }
1165}
1166
1167#[cfg(test)]
1168mod tests {
1169    use super::*;
1170    use crate::prime64::tests::random_lhs_rhs_with_negacyclic_convolution;
1171    use alloc::{vec, vec::Vec};
1172    use rand::random;
1173
1174    extern crate alloc;
1175
1176    #[test]
1177    fn reconstruct_32bit() {
1178        for n in [32, 64, 256, 1024, 2048] {
1179            let value = (0..n).map(|_| random::<u64>()).collect::<Vec<_>>();
1180            let mut value_roundtrip = vec![0; n];
1181            let mut mod_p0 = vec![0; n];
1182            let mut mod_p1 = vec![0; n];
1183            let mut mod_p2 = vec![0; n];
1184            let mut mod_p3 = vec![0; n];
1185            let mut mod_p4 = vec![0; n];
1186
1187            let plan = Plan32::try_new(n).unwrap();
1188            plan.fwd(
1189                &value,
1190                &mut mod_p0,
1191                &mut mod_p1,
1192                &mut mod_p2,
1193                &mut mod_p3,
1194                &mut mod_p4,
1195            );
1196            plan.inv(
1197                &mut value_roundtrip,
1198                &mut mod_p0,
1199                &mut mod_p1,
1200                &mut mod_p2,
1201                &mut mod_p3,
1202                &mut mod_p4,
1203            );
1204            for (&value, &value_roundtrip) in crate::izip!(&value, &value_roundtrip) {
1205                assert_eq!(value_roundtrip, value.wrapping_mul(n as u64));
1206            }
1207
1208            let (lhs, rhs, negacyclic_convolution) =
1209                random_lhs_rhs_with_negacyclic_convolution(n, 0);
1210
1211            let mut prod = vec![0; n];
1212            plan.negacyclic_polymul(&mut prod, &lhs, &rhs);
1213            assert_eq!(prod, negacyclic_convolution);
1214        }
1215    }
1216
1217    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1218    #[cfg(feature = "nightly")]
1219    #[test]
1220    fn reconstruct_52bit() {
1221        for n in [32, 64, 256, 1024, 2048] {
1222            if let Some(plan) = Plan52::try_new(n) {
1223                let value = (0..n).map(|_| random::<u64>()).collect::<Vec<_>>();
1224                let mut value_roundtrip = vec![0; n];
1225                let mut mod_p0 = vec![0; n];
1226                let mut mod_p1 = vec![0; n];
1227                let mut mod_p2 = vec![0; n];
1228
1229                plan.fwd(&value, &mut mod_p0, &mut mod_p1, &mut mod_p2);
1230                plan.inv(&mut value_roundtrip, &mut mod_p0, &mut mod_p1, &mut mod_p2);
1231                for (&value, &value_roundtrip) in crate::izip!(&value, &value_roundtrip) {
1232                    assert_eq!(value_roundtrip, value.wrapping_mul(n as u64));
1233                }
1234
1235                let (lhs, rhs, negacyclic_convolution) =
1236                    random_lhs_rhs_with_negacyclic_convolution(n, 0);
1237
1238                let mut prod = vec![0; n];
1239                plan.negacyclic_polymul(&mut prod, &lhs, &rhs);
1240                assert_eq!(prod, negacyclic_convolution);
1241            }
1242        }
1243    }
1244
1245    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1246    #[test]
1247    fn reconstruct_32bit_avx() {
1248        for n in [16, 32, 64, 256, 1024, 2048] {
1249            use crate::primes32::*;
1250
1251            let mut value = vec![0; n];
1252            let mut value_avx2 = vec![0; n];
1253            #[cfg(feature = "nightly")]
1254            let mut value_avx512 = vec![0; n];
1255            let mod_p0 = (0..n).map(|_| random::<u32>() % P0).collect::<Vec<_>>();
1256            let mod_p1 = (0..n).map(|_| random::<u32>() % P1).collect::<Vec<_>>();
1257            let mod_p2 = (0..n).map(|_| random::<u32>() % P2).collect::<Vec<_>>();
1258            let mod_p3 = (0..n).map(|_| random::<u32>() % P3).collect::<Vec<_>>();
1259            let mod_p4 = (0..n).map(|_| random::<u32>() % P4).collect::<Vec<_>>();
1260
1261            for (value, &mod_p0, &mod_p1, &mod_p2, &mod_p3, &mod_p4) in
1262                crate::izip!(&mut value, &mod_p0, &mod_p1, &mod_p2, &mod_p3, &mod_p4)
1263            {
1264                *value = reconstruct_32bit_01234_v2(mod_p0, mod_p1, mod_p2, mod_p3, mod_p4);
1265            }
1266
1267            if let Some(simd) = crate::V3::try_new() {
1268                reconstruct_slice_32bit_01234_avx2(
1269                    simd,
1270                    &mut value_avx2,
1271                    &mod_p0,
1272                    &mod_p1,
1273                    &mod_p2,
1274                    &mod_p3,
1275                    &mod_p4,
1276                );
1277                assert_eq!(value, value_avx2);
1278            }
1279            #[cfg(feature = "nightly")]
1280            if let Some(simd) = crate::V4IFma::try_new() {
1281                reconstruct_slice_32bit_01234_avx512(
1282                    simd,
1283                    &mut value_avx512,
1284                    &mod_p0,
1285                    &mod_p1,
1286                    &mod_p2,
1287                    &mod_p3,
1288                    &mod_p4,
1289                );
1290                assert_eq!(value, value_avx512);
1291            }
1292        }
1293    }
1294}