1use super::arith::{mod_add, mod_inv, mod_mul_barrett, mod_pow, mod_sub, Ntt64Arith};
30use super::prime::find_primitive_root;
31use alloc::vec;
32use alloc::vec::Vec;
33
34#[inline]
42fn bit_reverse(x: u32, bits: u32) -> u32 {
43 x.reverse_bits() >> (32 - bits)
44}
45
46#[derive(Debug, Clone)]
55pub struct Ntt64Context {
56 pub n: usize,
58
59 pub log_n: u32,
61
62 pub arith: Ntt64Arith,
64
65 pub root_powers: Vec<u64>,
70
71 pub inv_root_powers: Vec<u64>,
75
76 pub n_inv: u64,
78}
79
80impl Ntt64Context {
81 pub fn try_new(n: usize, arith: Ntt64Arith) -> Result<Self, crate::NttError> {
95 if n < 2 || !n.is_power_of_two() {
96 return Err(crate::NttError::InvalidSize(n));
97 }
98 let q = arith.modulus;
99 if !super::prime::is_prime(q) {
100 return Err(crate::NttError::NotPrime(q));
101 }
102 if !(q - 1).is_multiple_of(2 * n as u64) {
103 return Err(crate::NttError::NotNttFriendly { q, n });
104 }
105
106 let log_n = n.trailing_zeros();
107
108 let psi = find_primitive_root(n, q);
110 let psi_inv = mod_inv(psi, &arith);
111 let n_inv = mod_inv(n as u64, &arith);
112
113 let mut root_powers = vec![0u64; n];
116 let mut inv_root_powers = vec![0u64; n];
117
118 for i in 0..n {
119 let exp = bit_reverse(i as u32, log_n) as u64;
120 root_powers[i] = mod_pow(psi, exp, &arith);
121 inv_root_powers[i] = mod_pow(psi_inv, exp, &arith);
122 }
123
124 Ok(Self {
125 n,
126 log_n,
127 arith,
128 root_powers,
129 inv_root_powers,
130 n_inv,
131 })
132 }
133
134 pub fn new(n: usize, arith: Ntt64Arith) -> Self {
145 Self::try_new(n, arith).expect("Invalid NTT parameters")
146 }
147
148 #[inline]
150 pub fn forward(&self, data: &mut [u64]) {
151 ntt_forward(data, self);
152 }
153
154 #[inline]
156 pub fn inverse(&self, data: &mut [u64]) {
157 ntt_inverse(data, self);
158 }
159
160 #[inline]
165 pub fn forward_tiled(&self, data: &mut [u64]) {
166 ntt_forward(data, self);
168 }
169
170 pub fn pointwise_mul(&self, a: &[u64], b: &[u64], result: &mut [u64]) {
177 let n = self.n;
178 assert_eq!(a.len(), n);
179 assert_eq!(b.len(), n);
180 assert_eq!(result.len(), n);
181
182 for i in 0..n {
183 result[i] = mod_mul_barrett(a[i], b[i], &self.arith);
184 }
185 }
186
187 pub fn negacyclic_mul(&self, a: &[u64], b: &[u64]) -> Vec<u64> {
192 let n = self.n;
193 assert_eq!(a.len(), n);
194 assert_eq!(b.len(), n);
195
196 let mut a_ntt = a.to_vec();
197 let mut b_ntt = b.to_vec();
198 ntt_forward(&mut a_ntt, self);
199 ntt_forward(&mut b_ntt, self);
200
201 let mut c_ntt = vec![0u64; n];
202 self.pointwise_mul(&a_ntt, &b_ntt, &mut c_ntt);
203
204 ntt_inverse(&mut c_ntt, self);
205 c_ntt
206 }
207}
208
209pub fn ntt_forward(data: &mut [u64], ctx: &Ntt64Context) {
225 let n = ctx.n;
226 let q = ctx.arith.modulus;
227 assert_eq!(data.len(), n, "data length ({}) != N ({})", data.len(), n);
228
229 let mut t = n;
230 let mut m = 1;
231
232 for _ in 0..ctx.log_n {
233 t >>= 1;
234 let mut k = 0;
235
236 for i in 0..m {
237 let w = ctx.root_powers[m + i];
238
239 for j in k..(k + t) {
240 let u = data[j];
241 let v = mod_mul_barrett(data[j + t], w, &ctx.arith);
242 data[j] = mod_add(u, v, q);
243 data[j + t] = mod_sub(u, v, q);
244 }
245 k += 2 * t;
246 }
247 m <<= 1;
248 }
249}
250
251pub fn ntt_inverse(data: &mut [u64], ctx: &Ntt64Context) {
268 let n = ctx.n;
269 let q = ctx.arith.modulus;
270 assert_eq!(data.len(), n, "data length ({}) != N ({})", data.len(), n);
271
272 let mut t = 1;
273 let mut m = n;
274
275 for _ in 0..ctx.log_n {
276 m >>= 1;
277 let mut k = 0;
278
279 for i in 0..m {
280 let w_inv = ctx.inv_root_powers[m + i];
281
282 for j in k..(k + t) {
283 let u = data[j];
284 let v = data[j + t];
285 data[j] = mod_add(u, v, q);
286 data[j + t] = mod_mul_barrett(mod_sub(u, v, q), w_inv, &ctx.arith);
287 }
288 k += 2 * t;
289 }
290 t <<= 1;
291 }
292
293 for coeff in data.iter_mut() {
295 *coeff = mod_mul_barrett(*coeff, ctx.n_inv, &ctx.arith);
296 }
297}
298
299#[allow(dead_code)]
320pub fn ntt_forward_tiled(data: &mut [u64], ctx: &Ntt64Context) {
321 let n = ctx.n;
322
323 if n <= 64 {
324 ntt_forward(data, ctx);
325 return;
326 }
327
328 let log_n = ctx.log_n;
329 let log_n1 = log_n / 2;
330 let log_n2 = log_n - log_n1;
331 let n1 = 1usize << log_n1;
332 let n2 = 1usize << log_n2;
333
334 let arith = &ctx.arith;
335
336 let sub_ctx2 = Ntt64Context::new(n2, arith.clone());
338 for row in 0..n1 {
339 let start = row * n2;
340 ntt_forward(&mut data[start..start + n2], &sub_ctx2);
341 }
342
343 let psi = find_primitive_root(n, arith.modulus);
346 let psi_sq = mod_mul_barrett(psi, psi, arith); for i in 0..n1 {
349 for j in 0..n2 {
350 if i == 0 || j == 0 {
351 continue;
352 }
353 let exp = ((i as u128 * j as u128) % n as u128) as u64;
354 let twiddle = mod_pow(psi_sq, exp, arith);
355 let idx = i * n2 + j;
356 data[idx] = mod_mul_barrett(data[idx], twiddle, arith);
357 }
358 }
359
360 let mut transposed = vec![0u64; n];
362 for i in 0..n1 {
363 for j in 0..n2 {
364 transposed[j * n1 + i] = data[i * n2 + j];
365 }
366 }
367 data.copy_from_slice(&transposed);
368
369 let sub_ctx1 = Ntt64Context::new(n1, arith.clone());
371 for row in 0..n2 {
372 let start = row * n1;
373 ntt_forward(&mut data[start..start + n1], &sub_ctx1);
374 }
375
376 for i in 0..n2 {
378 for j in 0..n1 {
379 transposed[j * n2 + i] = data[i * n1 + j];
380 }
381 }
382 data.copy_from_slice(&transposed);
383}
384
385#[cfg(test)]
393#[allow(unused_variables, clippy::needless_range_loop, dead_code)]
394fn poly_mul_naive(a: &[u64], b: &[u64], q: u64) -> Vec<u64> {
395 let n = a.len();
396 assert_eq!(b.len(), n);
397 let mut result = vec![0u64; n];
398
399 for i in 0..n {
400 for j in 0..n {
401 let idx = i + j;
402 let prod = (a[i] as u128 * b[j] as u128) % q as u128;
403 if idx < n {
404 result[idx] = ((result[idx] as u128 + prod) % q as u128) as u64;
405 } else {
406 let idx = idx - n;
407 result[idx] = ((result[idx] as u128 + q as u128 - prod) % q as u128) as u64;
408 }
409 }
410 }
411 result
412}
413
414#[cfg(test)]
419#[allow(unused_variables, clippy::needless_range_loop, dead_code)]
420mod tests {
421 use super::super::arith::{PRIME_60_1, PRIME_SEAL};
422 use super::*;
423
424 #[test]
427 fn test_primitive_root_small() {
428 let q: u64 = 17;
429 let n = 8;
430 let psi = find_primitive_root(n, q);
431
432 let arith = Ntt64Arith::new(q);
433 assert_eq!(mod_pow(psi, 2 * n as u64, &arith), 1);
434 assert_eq!(mod_pow(psi, n as u64, &arith), 16);
435 }
436
437 #[test]
438 fn test_primitive_root_seal() {
439 let arith = Ntt64Arith::new(PRIME_SEAL);
440 for &n in &[16, 64, 1024, 4096] {
441 let psi = find_primitive_root(n, PRIME_SEAL);
442 assert_eq!(mod_pow(psi, 2 * n as u64, &arith), 1);
443 assert_eq!(mod_pow(psi, n as u64, &arith), arith.modulus - 1);
444 }
445 }
446
447 #[test]
450 fn test_ntt_roundtrip_small() {
451 let arith = Ntt64Arith::new(PRIME_SEAL);
452 let q = arith.modulus;
453
454 for &n in &[16, 64] {
455 let ctx = Ntt64Context::new(n, arith.clone());
456 let original: Vec<u64> = (0..n).map(|i| (i as u64 * 7 + 3) % q).collect();
457 let mut data = original.clone();
458
459 ntt_forward(&mut data, &ctx);
460 assert_ne!(data, original);
461 ntt_inverse(&mut data, &ctx);
462 assert_eq!(data, original, "NTT roundtrip fails for N={n}");
463 }
464 }
465
466 #[test]
467 fn test_ntt_roundtrip_medium() {
468 let arith = Ntt64Arith::new(PRIME_SEAL);
469 let q = arith.modulus;
470
471 for &n in &[1024, 4096] {
472 let ctx = Ntt64Context::new(n, arith.clone());
473 let original: Vec<u64> = (0..n)
474 .map(|i| ((i as u128 * 123456789 + 987654321) % q as u128) as u64)
475 .collect();
476 let mut data = original.clone();
477
478 ntt_forward(&mut data, &ctx);
479 ntt_inverse(&mut data, &ctx);
480 assert_eq!(data, original, "NTT roundtrip fails for N={n}");
481 }
482 }
483
484 #[test]
485 fn test_ntt_roundtrip_zeros() {
486 let arith = Ntt64Arith::new(PRIME_SEAL);
487 let n = 64;
488 let ctx = Ntt64Context::new(n, arith);
489 let mut data = vec![0u64; n];
490 ntt_forward(&mut data, &ctx);
491 ntt_inverse(&mut data, &ctx);
492 assert_eq!(data, vec![0u64; n]);
493 }
494
495 #[test]
496 fn test_ntt_roundtrip_one() {
497 let arith = Ntt64Arith::new(PRIME_SEAL);
498 let n = 64;
499 let ctx = Ntt64Context::new(n, arith);
500 let mut data = vec![0u64; n];
501 data[0] = 1;
502 let original = data.clone();
503 ntt_forward(&mut data, &ctx);
504 ntt_inverse(&mut data, &ctx);
505 assert_eq!(data, original);
506 }
507
508 #[test]
511 fn test_ntt_convolution_n16() {
512 let arith = Ntt64Arith::new(PRIME_SEAL);
513 let q = arith.modulus;
514 let n = 16;
515 let ctx = Ntt64Context::new(n, arith);
516
517 let a: Vec<u64> = (0..n).map(|i| (i as u64 + 1) % q).collect();
518 let b: Vec<u64> = (0..n).map(|_| 1u64).collect();
519
520 let expected = poly_mul_naive(&a, &b, q);
521 let result = ctx.negacyclic_mul(&a, &b);
522 assert_eq!(result, expected, "NTT convolution != naive for N=16");
523 }
524
525 #[test]
526 fn test_ntt_convolution_n64() {
527 let arith = Ntt64Arith::new(PRIME_SEAL);
528 let q = arith.modulus;
529 let n = 64;
530 let ctx = Ntt64Context::new(n, arith);
531
532 let a: Vec<u64> = (0..n).map(|i| ((i * i + 3 * i + 7) as u64) % q).collect();
533 let b: Vec<u64> = (0..n).map(|i| ((2 * i + 1) as u64) % q).collect();
534
535 let expected = poly_mul_naive(&a, &b, q);
536 let result = ctx.negacyclic_mul(&a, &b);
537 assert_eq!(result, expected, "NTT convolution != naive for N=64");
538 }
539
540 #[test]
541 fn test_ntt_convolution_identity() {
542 let arith = Ntt64Arith::new(PRIME_SEAL);
543 let q = arith.modulus;
544 let n = 64;
545 let ctx = Ntt64Context::new(n, arith);
546
547 let a: Vec<u64> = (0..n).map(|i| ((i * 17 + 5) as u64) % q).collect();
548 let mut one = vec![0u64; n];
549 one[0] = 1;
550
551 let result = ctx.negacyclic_mul(&a, &one);
552 assert_eq!(result, a, "Multiplying by 1 should give identity");
553 }
554
555 #[test]
558 fn test_ntt_tiled_matches_standard_small() {
559 let arith = Ntt64Arith::new(PRIME_SEAL);
560 let q = arith.modulus;
561
562 for &n in &[16, 64] {
563 let ctx = Ntt64Context::new(n, arith.clone());
564 let original: Vec<u64> = (0..n).map(|i| (i as u64 * 13 + 7) % q).collect();
565
566 let mut data_std = original.clone();
567 let mut data_tiled = original.clone();
568
569 ntt_forward(&mut data_std, &ctx);
570 ntt_forward_tiled(&mut data_tiled, &ctx);
571
572 assert_eq!(data_tiled, data_std, "tiled NTT != standard for N={n}");
573 }
574 }
575
576 #[test]
577 fn test_ntt_tiled_roundtrip() {
578 let arith = Ntt64Arith::new(PRIME_SEAL);
579 let q = arith.modulus;
580 let n = 256;
581 let ctx = Ntt64Context::new(n, arith);
582
583 let original: Vec<u64> = (0..n)
584 .map(|i| ((i as u128 * 999999937 + 42) % q as u128) as u64)
585 .collect();
586 let mut data = original.clone();
587
588 ntt_forward(&mut data, &ctx);
589 ntt_inverse(&mut data, &ctx);
590 assert_eq!(data, original, "standard roundtrip fails for N=256");
591 }
592
593 #[test]
596 fn test_ntt_with_prime_60_1() {
597 let arith = Ntt64Arith::new(PRIME_60_1);
598 let q = arith.modulus;
599
600 for &n in &[16, 64] {
601 assert_eq!((q - 1) % (2 * n as u64), 0);
602 let ctx = Ntt64Context::new(n, arith.clone());
603 let original: Vec<u64> = (0..n).map(|i| (i as u64 * 31 + 11) % q).collect();
604 let mut data = original.clone();
605
606 ntt_forward(&mut data, &ctx);
607 ntt_inverse(&mut data, &ctx);
608 assert_eq!(
609 data, original,
610 "NTT roundtrip fails for N={n} with PRIME_60_1"
611 );
612 }
613 }
614
615 #[test]
618 fn test_bit_reverse() {
619 assert_eq!(bit_reverse(0, 3), 0);
620 assert_eq!(bit_reverse(1, 3), 4);
621 assert_eq!(bit_reverse(2, 3), 2);
622 assert_eq!(bit_reverse(3, 3), 6);
623 assert_eq!(bit_reverse(4, 3), 1);
624 assert_eq!(bit_reverse(5, 3), 5);
625 assert_eq!(bit_reverse(6, 3), 3);
626 assert_eq!(bit_reverse(7, 3), 7);
627 assert_eq!(bit_reverse(0, 1), 0);
628 assert_eq!(bit_reverse(1, 1), 1);
629 }
630
631 #[test]
634 fn test_ntt_linearity() {
635 let arith = Ntt64Arith::new(PRIME_SEAL);
636 let q = arith.modulus;
637 let n = 64;
638 let ctx = Ntt64Context::new(n, arith);
639
640 let a: Vec<u64> = (0..n).map(|i| (i as u64 * 3 + 1) % q).collect();
641 let b: Vec<u64> = (0..n).map(|i| (i as u64 * 7 + 2) % q).collect();
642
643 let mut a_ntt = a.clone();
644 let mut b_ntt = b.clone();
645 ntt_forward(&mut a_ntt, &ctx);
646 ntt_forward(&mut b_ntt, &ctx);
647
648 let mut sum: Vec<u64> = (0..n).map(|i| mod_add(a[i], b[i], q)).collect();
649 ntt_forward(&mut sum, &ctx);
650
651 for i in 0..n {
652 let expected = mod_add(a_ntt[i], b_ntt[i], q);
653 assert_eq!(sum[i], expected, "linearity violated at index {i}");
654 }
655 }
656
657 #[test]
660 fn test_ntt_roundtrip_large() {
661 let arith = Ntt64Arith::new(PRIME_SEAL);
662 let q = arith.modulus;
663 let n = 32768;
664
665 assert_eq!((q - 1) % (2 * n as u64), 0);
666 let ctx = Ntt64Context::new(n, arith);
667
668 let original: Vec<u64> = (0..n)
669 .map(|i| ((i as u128 * 314159265 + 271828182) % q as u128) as u64)
670 .collect();
671 let mut data = original.clone();
672
673 ntt_forward(&mut data, &ctx);
674 ntt_inverse(&mut data, &ctx);
675 assert_eq!(data, original, "NTT roundtrip fails for N=32768");
676 }
677
678 const _: () = {
680 fn assert_send<T: Send>() {}
681 fn assert_sync<T: Sync>() {}
682 fn check() {
683 assert_send::<super::Ntt64Context>();
684 assert_sync::<super::Ntt64Context>();
685 }
686 };
687}