1use crate::ntt64::arith::{mod_mul_barrett, Ntt64Arith};
28use crate::ntt64::context::{ntt_forward, ntt_inverse, Ntt64Context};
29use alloc::vec;
30use alloc::vec::Vec;
31#[cfg(feature = "rand")]
32use rand::Rng;
33#[cfg(feature = "rand")]
34use rand_distr::{Distribution, Normal};
35
36#[derive(Clone, Debug)]
45pub struct Poly64 {
46 pub data: Vec<u64>,
48 pub is_ntt: bool,
50}
51
52impl Poly64 {
53 #[inline]
59 pub fn new_zero(n: usize) -> Self {
60 Self {
61 data: vec![0u64; n],
62 is_ntt: false,
63 }
64 }
65
66 #[cfg(feature = "rand")]
70 pub fn new_random(n: usize, arith: &Ntt64Arith) -> Self {
71 let mut rng = rand::thread_rng();
72 let q = arith.modulus;
73 let data: Vec<u64> = (0..n).map(|_| rng.gen_range(0..q)).collect();
74 Self {
75 data,
76 is_ntt: false,
77 }
78 }
79
80 #[cfg(feature = "rand")]
87 pub fn new_ternary(n: usize, arith: &Ntt64Arith) -> Self {
88 let mut rng = rand::thread_rng();
89 let q = arith.modulus;
90 let data: Vec<u64> = (0..n)
91 .map(|_| match rng.gen_range(0u32..3) {
92 0 => 0,
93 1 => 1,
94 _ => q - 1,
95 })
96 .collect();
97 Self {
98 data,
99 is_ntt: false,
100 }
101 }
102
103 #[cfg(feature = "rand")]
110 pub fn new_gaussian(n: usize, sigma: f64, arith: &Ntt64Arith) -> Self {
111 let mut rng = rand::thread_rng();
112 let q = arith.modulus;
113 let normal = Normal::new(0.0, sigma).expect("sigma must be > 0");
114 let data: Vec<u64> = (0..n)
115 .map(|_| {
116 let sample: f64 = normal.sample(&mut rng);
117 let rounded = sample.round() as i64;
118 if rounded >= 0 {
119 (rounded as u64) % q
120 } else {
121 let abs_val = (-rounded) as u64;
122 let r = abs_val % q;
123 if r == 0 {
124 0
125 } else {
126 q - r
127 }
128 }
129 })
130 .collect();
131 Self {
132 data,
133 is_ntt: false,
134 }
135 }
136
137 pub fn forward_ntt(&mut self, ntt_ctx: &Ntt64Context) {
146 assert!(!self.is_ntt, "polynomial is already in NTT domain");
147 ntt_forward(&mut self.data, ntt_ctx);
148 self.is_ntt = true;
149 }
150
151 pub fn inverse_ntt(&mut self, ntt_ctx: &Ntt64Context) {
156 assert!(self.is_ntt, "polynomial is not in NTT domain");
157 ntt_inverse(&mut self.data, ntt_ctx);
158 self.is_ntt = false;
159 }
160
161 pub fn add_assign(&mut self, other: &Poly64, arith: &Ntt64Arith) {
172 assert_eq!(
173 self.is_ntt, other.is_ntt,
174 "polynomials must be in the same domain"
175 );
176 assert_eq!(
177 self.data.len(),
178 other.data.len(),
179 "polynomials must have the same size"
180 );
181 let q = arith.modulus;
182 for (a, &b) in self.data.iter_mut().zip(other.data.iter()) {
183 let sum = *a + b;
184 let (sub, borrow) = sum.overflowing_sub(q);
186 *a = if borrow { sum } else { sub };
187 }
188 }
189
190 pub fn sub_assign(&mut self, other: &Poly64, arith: &Ntt64Arith) {
197 assert_eq!(
198 self.is_ntt, other.is_ntt,
199 "polynomials must be in the same domain"
200 );
201 assert_eq!(
202 self.data.len(),
203 other.data.len(),
204 "polynomials must have the same size"
205 );
206 let q = arith.modulus;
207 for (a, &b) in self.data.iter_mut().zip(other.data.iter()) {
208 let (sub, borrow) = (*a).overflowing_sub(b);
209 *a = if borrow { sub.wrapping_add(q) } else { sub };
210 }
211 }
212
213 pub fn mul_assign(&mut self, other: &Poly64, arith: &Ntt64Arith) {
221 assert!(
222 self.is_ntt && other.is_ntt,
223 "both polynomials must be in NTT domain for multiplication"
224 );
225 assert_eq!(
226 self.data.len(),
227 other.data.len(),
228 "polynomials must have the same size"
229 );
230 for (a, &b) in self.data.iter_mut().zip(other.data.iter()) {
231 *a = mod_mul_barrett(*a, b, arith);
232 }
233 }
234
235 pub fn scalar_mul(&mut self, scalar: u64, arith: &Ntt64Arith) {
237 for a in self.data.iter_mut() {
238 *a = mod_mul_barrett(*a, scalar, arith);
239 }
240 }
241
242 pub fn negate(&mut self, arith: &Ntt64Arith) {
244 let q = arith.modulus;
245 for a in self.data.iter_mut() {
246 *a = if *a == 0 { 0 } else { q - *a };
250 }
251 }
252
253 #[inline]
259 pub fn len(&self) -> usize {
260 self.data.len()
261 }
262
263 #[inline]
265 pub fn is_empty(&self) -> bool {
266 self.data.is_empty()
267 }
268}
269
270#[cfg(test)]
278#[allow(unused_variables, clippy::needless_range_loop, dead_code)]
279fn naive_poly_mul(a: &[u64], b: &[u64], q: u64) -> Vec<u64> {
280 let n = a.len();
281 assert_eq!(n, b.len());
282 let mut result = vec![0u64; n];
283
284 for i in 0..n {
285 for j in 0..n {
286 let prod = (a[i] as u128) * (b[j] as u128);
287 let idx = i + j;
288 if idx < n {
289 let val = (result[idx] as u128 + prod) % (q as u128);
290 result[idx] = val as u64;
291 } else {
292 let wrapped_idx = idx - n;
293 let val = (result[wrapped_idx] as u128 + (q as u128) - (prod % (q as u128)))
294 % (q as u128);
295 result[wrapped_idx] = val as u64;
296 }
297 }
298 }
299 result
300}
301
302#[cfg(test)]
307#[allow(unused_variables, clippy::needless_range_loop, dead_code)]
308mod tests {
309 use super::*;
310 use crate::ntt64::arith::Ntt64Arith;
311 use crate::ntt64::context::Ntt64Context;
312
313 const TEST_Q: u64 = 7681;
315 const TEST_N: usize = 256;
316
317 fn test_arith() -> Ntt64Arith {
318 Ntt64Arith::new(TEST_Q)
319 }
320
321 fn test_ntt_ctx() -> Ntt64Context {
322 Ntt64Context::new(TEST_N, test_arith())
323 }
324
325 #[test]
326 fn test_poly_add_sub() {
327 let arith = test_arith();
328 let a = Poly64::new_random(TEST_N, &arith);
329 let b = Poly64::new_random(TEST_N, &arith);
330
331 let mut c = a.clone();
332 c.add_assign(&b, &arith);
333 c.sub_assign(&b, &arith);
334
335 for i in 0..TEST_N {
336 assert_eq!(c.data[i], a.data[i], "add/sub roundtrip fails at index {i}");
337 }
338 }
339
340 #[test]
341 fn test_poly_add_commutative() {
342 let arith = test_arith();
343 let a = Poly64::new_random(TEST_N, &arith);
344 let b = Poly64::new_random(TEST_N, &arith);
345
346 let mut ab = a.clone();
347 ab.add_assign(&b, &arith);
348
349 let mut ba = b.clone();
350 ba.add_assign(&a, &arith);
351
352 for i in 0..TEST_N {
353 assert_eq!(ab.data[i], ba.data[i], "add not commutative at index {i}");
354 }
355 }
356
357 #[test]
358 fn test_poly_negate() {
359 let arith = test_arith();
360 let a = Poly64::new_random(TEST_N, &arith);
361
362 let mut neg_a = a.clone();
363 neg_a.negate(&arith);
364
365 let mut sum = a.clone();
366 sum.add_assign(&neg_a, &arith);
367
368 for i in 0..TEST_N {
369 assert_eq!(sum.data[i], 0, "a + (-a) != 0 at index {i}");
370 }
371 }
372
373 #[test]
374 fn test_poly_scalar_mul() {
375 let arith = test_arith();
376 let a = Poly64::new_random(TEST_N, &arith);
377
378 let mut doubled = a.clone();
379 doubled.scalar_mul(2, &arith);
380
381 let mut sum = a.clone();
382 sum.add_assign(&a, &arith);
383
384 for i in 0..TEST_N {
385 assert_eq!(doubled.data[i], sum.data[i], "2*a != a+a at index {i}");
386 }
387 }
388
389 #[test]
390 fn test_poly_mul_ntt() {
391 let arith = test_arith();
392 let ntt_ctx = test_ntt_ctx();
393
394 let mut a = Poly64::new_zero(TEST_N);
395 a.data[0] = 1;
396 a.data[1] = 1;
397
398 let mut b = Poly64::new_zero(TEST_N);
399 b.data[0] = 1;
400 b.data[2] = 1;
401
402 let expected = naive_poly_mul(&a.data, &b.data, TEST_Q);
403
404 a.forward_ntt(&ntt_ctx);
405 b.forward_ntt(&ntt_ctx);
406 a.mul_assign(&b, &arith);
407 a.inverse_ntt(&ntt_ctx);
408
409 for i in 0..TEST_N {
410 assert_eq!(a.data[i], expected[i], "NTT mul != naive at index {i}");
411 }
412 }
413
414 #[test]
415 fn test_poly_mul_random_ntt() {
416 let arith = test_arith();
417 let ntt_ctx = test_ntt_ctx();
418
419 let a_orig = Poly64::new_random(TEST_N, &arith);
420 let b_orig = Poly64::new_random(TEST_N, &arith);
421
422 let expected = naive_poly_mul(&a_orig.data, &b_orig.data, TEST_Q);
423
424 let mut a = a_orig.clone();
425 let mut b = b_orig.clone();
426 a.forward_ntt(&ntt_ctx);
427 b.forward_ntt(&ntt_ctx);
428 a.mul_assign(&b, &arith);
429 a.inverse_ntt(&ntt_ctx);
430
431 for i in 0..TEST_N {
432 assert_eq!(a.data[i], expected[i], "NTT mul != naive at index {i}");
433 }
434 }
435
436 #[test]
437 fn test_ternary_distribution() {
438 let arith = test_arith();
439 let poly = Poly64::new_ternary(1024, &arith);
440
441 for (i, &coeff) in poly.data.iter().enumerate() {
442 assert!(
443 coeff == 0 || coeff == 1 || coeff == TEST_Q - 1,
444 "invalid ternary coefficient at index {i}: {coeff}"
445 );
446 }
447
448 let count_zero = poly.data.iter().filter(|&&c| c == 0).count();
449 let count_one = poly.data.iter().filter(|&&c| c == 1).count();
450 let count_neg = poly.data.iter().filter(|&&c| c == TEST_Q - 1).count();
451
452 assert!(count_zero > 0);
453 assert!(count_one > 0);
454 assert!(count_neg > 0);
455 }
456
457 #[test]
458 fn test_gaussian_distribution() {
459 let arith = test_arith();
460 let sigma = 3.2;
461 let n = 8192;
462 let poly = Poly64::new_gaussian(n, sigma, &arith);
463
464 let q = TEST_Q as f64;
465 let half_q = q / 2.0;
466 let centered: Vec<f64> = poly
467 .data
468 .iter()
469 .map(|&c| {
470 let c = c as f64;
471 if c > half_q {
472 c - q
473 } else {
474 c
475 }
476 })
477 .collect();
478
479 let mean = centered.iter().sum::<f64>() / n as f64;
480 assert!(mean.abs() < 0.5, "mean too far from 0: {mean}");
481
482 let variance = centered.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n as f64;
483 let std_dev = variance.sqrt();
484 assert!(
485 (std_dev - sigma).abs() < 1.0,
486 "stddev too far from {sigma}: {std_dev}"
487 );
488 }
489
490 #[test]
491 fn test_ntt_roundtrip() {
492 let arith = test_arith();
493 let ntt_ctx = test_ntt_ctx();
494 let original = Poly64::new_random(TEST_N, &arith);
495
496 let mut poly = original.clone();
497 poly.forward_ntt(&ntt_ctx);
498 assert!(poly.is_ntt);
499 poly.inverse_ntt(&ntt_ctx);
500 assert!(!poly.is_ntt);
501
502 for i in 0..TEST_N {
503 assert_eq!(
504 poly.data[i], original.data[i],
505 "NTT roundtrip fails at index {i}"
506 );
507 }
508 }
509
510 #[test]
511 fn test_new_zero() {
512 let poly = Poly64::new_zero(64);
513 assert_eq!(poly.len(), 64);
514 assert!(!poly.is_ntt);
515 for &c in &poly.data {
516 assert_eq!(c, 0);
517 }
518 }
519
520 #[test]
521 #[should_panic(expected = "already in NTT domain")]
522 fn test_double_forward_ntt_panics() {
523 let arith = test_arith();
524 let ntt_ctx = test_ntt_ctx();
525 let mut poly = Poly64::new_random(TEST_N, &arith);
526 poly.forward_ntt(&ntt_ctx);
527 poly.forward_ntt(&ntt_ctx);
528 }
529
530 #[test]
531 #[should_panic(expected = "not in NTT domain")]
532 fn test_inverse_ntt_without_forward_panics() {
533 let arith = test_arith();
534 let ntt_ctx = test_ntt_ctx();
535 let mut poly = Poly64::new_random(TEST_N, &arith);
536 poly.inverse_ntt(&ntt_ctx);
537 }
538}