1pub(crate) use crate::native64::{mul_mod32, mul_mod64};
2use aligned_vec::avec;
3
4#[derive(Clone, Debug)]
6pub struct Plan32(
7 crate::prime32::Plan,
8 crate::prime32::Plan,
9 crate::prime32::Plan,
10 crate::prime32::Plan,
11 crate::prime32::Plan,
12 crate::prime32::Plan,
13 crate::prime32::Plan,
14 crate::prime32::Plan,
15 crate::prime32::Plan,
16 crate::prime32::Plan,
17);
18
19#[inline(always)]
20fn reconstruct_32bit_0123456789_v2(
21 mod_p0: u32,
22 mod_p1: u32,
23 mod_p2: u32,
24 mod_p3: u32,
25 mod_p4: u32,
26 mod_p5: u32,
27 mod_p6: u32,
28 mod_p7: u32,
29 mod_p8: u32,
30 mod_p9: u32,
31) -> u128 {
32 use crate::primes32::*;
33
34 let mod_p01 = {
35 let v0 = mod_p0;
36 let v1 = mul_mod32(P1, P0_INV_MOD_P1, 2 * P1 + mod_p1 - v0);
37 v0 as u64 + (v1 as u64 * P0 as u64)
38 };
39 let mod_p23 = {
40 let v2 = mod_p2;
41 let v3 = mul_mod32(P3, P2_INV_MOD_P3, 2 * P3 + mod_p3 - v2);
42 v2 as u64 + (v3 as u64 * P2 as u64)
43 };
44 let mod_p45 = {
45 let v4 = mod_p4;
46 let v5 = mul_mod32(P5, P4_INV_MOD_P5, 2 * P5 + mod_p5 - v4);
47 v4 as u64 + (v5 as u64 * P4 as u64)
48 };
49 let mod_p67 = {
50 let v6 = mod_p6;
51 let v7 = mul_mod32(P7, P6_INV_MOD_P7, 2 * P7 + mod_p7 - v6);
52 v6 as u64 + (v7 as u64 * P6 as u64)
53 };
54 let mod_p89 = {
55 let v8 = mod_p8;
56 let v9 = mul_mod32(P9, P8_INV_MOD_P9, 2 * P9 + mod_p9 - v8);
57 v8 as u64 + (v9 as u64 * P8 as u64)
58 };
59
60 let v01 = mod_p01;
61 let v23 = mul_mod64(
62 P23.wrapping_neg(),
63 2 * P23 + mod_p23 - v01,
64 P01_INV_MOD_P23,
65 P01_INV_MOD_P23_SHOUP,
66 );
67 let v45 = mul_mod64(
68 P45.wrapping_neg(),
69 2 * P45 + mod_p45 - (v01 + mul_mod64(P45.wrapping_neg(), v23, P01, P01_MOD_P45_SHOUP)),
70 P0123_INV_MOD_P45,
71 P0123_INV_MOD_P45_SHOUP,
72 );
73 let v67 = mul_mod64(
74 P67.wrapping_neg(),
75 2 * P67 + mod_p67
76 - (v01
77 + mul_mod64(
78 P67.wrapping_neg(),
79 v23 + mul_mod64(P67.wrapping_neg(), v45, P23, P23_MOD_P67_SHOUP),
80 P01,
81 P01_MOD_P67_SHOUP,
82 )),
83 P012345_INV_MOD_P67,
84 P012345_INV_MOD_P67_SHOUP,
85 );
86 let v89 = mul_mod64(
87 P89.wrapping_neg(),
88 2 * P89 + mod_p89
89 - (v01
90 + mul_mod64(
91 P89.wrapping_neg(),
92 v23 + mul_mod64(
93 P89.wrapping_neg(),
94 v45 + mul_mod64(P89.wrapping_neg(), v67, P45, P45_MOD_P89_SHOUP),
95 P23,
96 P23_MOD_P89_SHOUP,
97 ),
98 P01,
99 P01_MOD_P89_SHOUP,
100 )),
101 P01234567_INV_MOD_P89,
102 P01234567_INV_MOD_P89_SHOUP,
103 );
104
105 let sign = v89 > (P89 / 2);
106 let pos = (v01 as u128)
107 .wrapping_add(u128::wrapping_mul(v23 as u128, P01 as u128))
108 .wrapping_add(u128::wrapping_mul(v45 as u128, P0123))
109 .wrapping_add(u128::wrapping_mul(v67 as u128, P012345))
110 .wrapping_add(u128::wrapping_mul(v89 as u128, P01234567));
111 let neg = pos.wrapping_sub(P0123456789);
112
113 if sign {
114 neg
115 } else {
116 pos
117 }
118}
119
120impl Plan32 {
121 pub fn try_new(n: usize) -> Option<Self> {
124 use crate::{prime32::Plan, primes32::*};
125 Some(Self(
126 Plan::try_new(n, P0)?,
127 Plan::try_new(n, P1)?,
128 Plan::try_new(n, P2)?,
129 Plan::try_new(n, P3)?,
130 Plan::try_new(n, P4)?,
131 Plan::try_new(n, P5)?,
132 Plan::try_new(n, P6)?,
133 Plan::try_new(n, P7)?,
134 Plan::try_new(n, P8)?,
135 Plan::try_new(n, P9)?,
136 ))
137 }
138
139 #[inline]
141 pub fn ntt_size(&self) -> usize {
142 self.0.ntt_size()
143 }
144
145 #[inline]
146 pub fn ntt_0(&self) -> &crate::prime32::Plan {
147 &self.0
148 }
149 #[inline]
150 pub fn ntt_1(&self) -> &crate::prime32::Plan {
151 &self.1
152 }
153 #[inline]
154 pub fn ntt_2(&self) -> &crate::prime32::Plan {
155 &self.2
156 }
157 #[inline]
158 pub fn ntt_3(&self) -> &crate::prime32::Plan {
159 &self.3
160 }
161 #[inline]
162 pub fn ntt_4(&self) -> &crate::prime32::Plan {
163 &self.4
164 }
165 #[inline]
166 pub fn ntt_5(&self) -> &crate::prime32::Plan {
167 &self.5
168 }
169 #[inline]
170 pub fn ntt_6(&self) -> &crate::prime32::Plan {
171 &self.6
172 }
173 #[inline]
174 pub fn ntt_7(&self) -> &crate::prime32::Plan {
175 &self.7
176 }
177 #[inline]
178 pub fn ntt_8(&self) -> &crate::prime32::Plan {
179 &self.8
180 }
181 #[inline]
182 pub fn ntt_9(&self) -> &crate::prime32::Plan {
183 &self.9
184 }
185
186 pub fn fwd(
187 &self,
188 value: &[u128],
189 mod_p0: &mut [u32],
190 mod_p1: &mut [u32],
191 mod_p2: &mut [u32],
192 mod_p3: &mut [u32],
193 mod_p4: &mut [u32],
194 mod_p5: &mut [u32],
195 mod_p6: &mut [u32],
196 mod_p7: &mut [u32],
197 mod_p8: &mut [u32],
198 mod_p9: &mut [u32],
199 ) {
200 for (
201 value,
202 mod_p0,
203 mod_p1,
204 mod_p2,
205 mod_p3,
206 mod_p4,
207 mod_p5,
208 mod_p6,
209 mod_p7,
210 mod_p8,
211 mod_p9,
212 ) in crate::izip!(
213 value,
214 &mut *mod_p0,
215 &mut *mod_p1,
216 &mut *mod_p2,
217 &mut *mod_p3,
218 &mut *mod_p4,
219 &mut *mod_p5,
220 &mut *mod_p6,
221 &mut *mod_p7,
222 &mut *mod_p8,
223 &mut *mod_p9,
224 ) {
225 *mod_p0 = (value % crate::primes32::P0 as u128) as u32;
226 *mod_p1 = (value % crate::primes32::P1 as u128) as u32;
227 *mod_p2 = (value % crate::primes32::P2 as u128) as u32;
228 *mod_p3 = (value % crate::primes32::P3 as u128) as u32;
229 *mod_p4 = (value % crate::primes32::P4 as u128) as u32;
230 *mod_p5 = (value % crate::primes32::P5 as u128) as u32;
231 *mod_p6 = (value % crate::primes32::P6 as u128) as u32;
232 *mod_p7 = (value % crate::primes32::P7 as u128) as u32;
233 *mod_p8 = (value % crate::primes32::P8 as u128) as u32;
234 *mod_p9 = (value % crate::primes32::P9 as u128) as u32;
235 }
236 self.0.fwd(mod_p0);
237 self.1.fwd(mod_p1);
238 self.2.fwd(mod_p2);
239 self.3.fwd(mod_p3);
240 self.4.fwd(mod_p4);
241 self.5.fwd(mod_p5);
242 self.6.fwd(mod_p6);
243 self.7.fwd(mod_p7);
244 self.8.fwd(mod_p8);
245 self.9.fwd(mod_p9);
246 }
247
248 pub fn inv(
249 &self,
250 value: &mut [u128],
251 mod_p0: &mut [u32],
252 mod_p1: &mut [u32],
253 mod_p2: &mut [u32],
254 mod_p3: &mut [u32],
255 mod_p4: &mut [u32],
256 mod_p5: &mut [u32],
257 mod_p6: &mut [u32],
258 mod_p7: &mut [u32],
259 mod_p8: &mut [u32],
260 mod_p9: &mut [u32],
261 ) {
262 self.0.inv(mod_p0);
263 self.1.inv(mod_p1);
264 self.2.inv(mod_p2);
265 self.3.inv(mod_p3);
266 self.4.inv(mod_p4);
267 self.5.inv(mod_p5);
268 self.6.inv(mod_p6);
269 self.7.inv(mod_p7);
270 self.8.inv(mod_p8);
271 self.9.inv(mod_p9);
272
273 for (
274 value,
275 &mod_p0,
276 &mod_p1,
277 &mod_p2,
278 &mod_p3,
279 &mod_p4,
280 &mod_p5,
281 &mod_p6,
282 &mod_p7,
283 &mod_p8,
284 &mod_p9,
285 ) in crate::izip!(
286 value, &*mod_p0, &*mod_p1, &*mod_p2, &*mod_p3, &*mod_p4, &*mod_p5, &*mod_p6, &*mod_p7,
287 &*mod_p8, &*mod_p9,
288 ) {
289 *value = reconstruct_32bit_0123456789_v2(
290 mod_p0, mod_p1, mod_p2, mod_p3, mod_p4, mod_p5, mod_p6, mod_p7, mod_p8, mod_p9,
291 );
292 }
293 }
294
295 pub fn negacyclic_polymul(&self, prod: &mut [u128], lhs: &[u128], rhs: &[u128]) {
298 let n = prod.len();
299 assert_eq!(n, lhs.len());
300 assert_eq!(n, rhs.len());
301
302 let mut lhs0 = avec![0; n];
303 let mut lhs1 = avec![0; n];
304 let mut lhs2 = avec![0; n];
305 let mut lhs3 = avec![0; n];
306 let mut lhs4 = avec![0; n];
307 let mut lhs5 = avec![0; n];
308 let mut lhs6 = avec![0; n];
309 let mut lhs7 = avec![0; n];
310 let mut lhs8 = avec![0; n];
311 let mut lhs9 = avec![0; n];
312
313 let mut rhs0 = avec![0; n];
314 let mut rhs1 = avec![0; n];
315 let mut rhs2 = avec![0; n];
316 let mut rhs3 = avec![0; n];
317 let mut rhs4 = avec![0; n];
318 let mut rhs5 = avec![0; n];
319 let mut rhs6 = avec![0; n];
320 let mut rhs7 = avec![0; n];
321 let mut rhs8 = avec![0; n];
322 let mut rhs9 = avec![0; n];
323
324 self.fwd(
325 lhs, &mut lhs0, &mut lhs1, &mut lhs2, &mut lhs3, &mut lhs4, &mut lhs5, &mut lhs6,
326 &mut lhs7, &mut lhs8, &mut lhs9,
327 );
328 self.fwd(
329 rhs, &mut rhs0, &mut rhs1, &mut rhs2, &mut rhs3, &mut rhs4, &mut rhs5, &mut rhs6,
330 &mut rhs7, &mut rhs8, &mut rhs9,
331 );
332
333 self.0.mul_assign_normalize(&mut lhs0, &rhs0);
334 self.1.mul_assign_normalize(&mut lhs1, &rhs1);
335 self.2.mul_assign_normalize(&mut lhs2, &rhs2);
336 self.3.mul_assign_normalize(&mut lhs3, &rhs3);
337 self.4.mul_assign_normalize(&mut lhs4, &rhs4);
338 self.5.mul_assign_normalize(&mut lhs5, &rhs5);
339 self.6.mul_assign_normalize(&mut lhs6, &rhs6);
340 self.7.mul_assign_normalize(&mut lhs7, &rhs7);
341 self.8.mul_assign_normalize(&mut lhs8, &rhs8);
342 self.9.mul_assign_normalize(&mut lhs9, &rhs9);
343
344 self.inv(
345 prod, &mut lhs0, &mut lhs1, &mut lhs2, &mut lhs3, &mut lhs4, &mut lhs5, &mut lhs6,
346 &mut lhs7, &mut lhs8, &mut lhs9,
347 );
348 }
349}
350
351#[cfg(test)]
352pub mod tests {
353 use super::*;
354 use alloc::{vec, vec::Vec};
355 use rand::random;
356
357 extern crate alloc;
358
359 pub fn negacyclic_convolution(n: usize, lhs: &[u128], rhs: &[u128]) -> Vec<u128> {
360 let mut full_convolution = vec![0u128; 2 * n];
361 let mut negacyclic_convolution = vec![0u128; n];
362 for i in 0..n {
363 for j in 0..n {
364 full_convolution[i + j] =
365 full_convolution[i + j].wrapping_add(lhs[i].wrapping_mul(rhs[j]));
366 }
367 }
368 for i in 0..n {
369 negacyclic_convolution[i] = full_convolution[i].wrapping_sub(full_convolution[i + n]);
370 }
371 negacyclic_convolution
372 }
373
374 pub fn random_lhs_rhs_with_negacyclic_convolution(
375 n: usize,
376 ) -> (Vec<u128>, Vec<u128>, Vec<u128>) {
377 let mut lhs = vec![0u128; n];
378 let mut rhs = vec![0u128; n];
379
380 for x in &mut lhs {
381 *x = random();
382 }
383 for x in &mut rhs {
384 *x = random();
385 }
386
387 let lhs = lhs;
388 let rhs = rhs;
389
390 let negacyclic_convolution = negacyclic_convolution(n, &lhs, &rhs);
391 (lhs, rhs, negacyclic_convolution)
392 }
393
394 #[test]
395 fn reconstruct_32bit() {
396 for n in [32, 64, 256, 1024, 2048] {
397 let value = (0..n).map(|_| random::<u128>()).collect::<Vec<_>>();
398 let mut value_roundtrip = vec![0; n];
399 let mut mod_p0 = vec![0; n];
400 let mut mod_p1 = vec![0; n];
401 let mut mod_p2 = vec![0; n];
402 let mut mod_p3 = vec![0; n];
403 let mut mod_p4 = vec![0; n];
404 let mut mod_p5 = vec![0; n];
405 let mut mod_p6 = vec![0; n];
406 let mut mod_p7 = vec![0; n];
407 let mut mod_p8 = vec![0; n];
408 let mut mod_p9 = vec![0; n];
409
410 let plan = Plan32::try_new(n).unwrap();
411 plan.fwd(
412 &value,
413 &mut mod_p0,
414 &mut mod_p1,
415 &mut mod_p2,
416 &mut mod_p3,
417 &mut mod_p4,
418 &mut mod_p5,
419 &mut mod_p6,
420 &mut mod_p7,
421 &mut mod_p8,
422 &mut mod_p9,
423 );
424 plan.inv(
425 &mut value_roundtrip,
426 &mut mod_p0,
427 &mut mod_p1,
428 &mut mod_p2,
429 &mut mod_p3,
430 &mut mod_p4,
431 &mut mod_p5,
432 &mut mod_p6,
433 &mut mod_p7,
434 &mut mod_p8,
435 &mut mod_p9,
436 );
437 for (&value, &value_roundtrip) in crate::izip!(&value, &value_roundtrip) {
438 assert_eq!(value_roundtrip, value.wrapping_mul(n as u128));
439 }
440
441 let (lhs, rhs, negacyclic_convolution) = random_lhs_rhs_with_negacyclic_convolution(n);
442
443 let mut prod = vec![0; n];
444 plan.negacyclic_polymul(&mut prod, &lhs, &rhs);
445 assert_eq!(prod, negacyclic_convolution);
446 }
447 }
448}