Skip to main content

vaea_ntt/ntt64/
context.rs

1// Copyright (C) 2024-2026 Vaea SAS
2// SPDX-License-Identifier: AGPL-3.0-or-later
3//
4// This file is part of VaeaNTT.
5//
6// VaeaNTT is free software: you can redistribute it and/or modify it under
7// the terms of the GNU Affero General Public License as published by the
8// Free Software Foundation, either version 3 of the License, or (at your
9// option) any later version.
10//
11// VaeaNTT is distributed in the hope that it will be useful, but WITHOUT
12// ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13// FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General Public
14// License for more details.
15//
16// You should have received a copy of the GNU Affero General Public License
17// along with VaeaNTT. If not, see <https://www.gnu.org/licenses/>.
18
19//! # NTT Context — Forward and Inverse Transforms
20//!
21//! High-performance Number Theoretic Transform using the Longa-Naehrig ordering
22//! (SEAL/OpenFHE style) with integrated negacyclic twiddle factors.
23//!
24//! ## Algorithms
25//! - **Forward NTT** — Cooley-Tukey radix-2 DIT (Decimation In Time)
26//! - **Inverse NTT** — Gentleman-Sande radix-2 DIF (Decimation In Frequency)
27//! - **Tiled NTT** — Four-step variant for improved cache locality on large N
28
29use super::arith::{mod_add, mod_inv, mod_mul_barrett, mod_pow, mod_sub, Ntt64Arith};
30use super::prime::find_primitive_root;
31use alloc::vec;
32use alloc::vec::Vec;
33
34// ===========================================================================
35// Bit-reversal utility
36// ===========================================================================
37
38/// Reverses the `bits` least-significant bits of `x`.
39///
40/// Example: `bit_reverse(0b110, 3) = 0b011`
41#[inline]
42fn bit_reverse(x: u32, bits: u32) -> u32 {
43    x.reverse_bits() >> (32 - bits)
44}
45
46// ===========================================================================
47// NTT Context
48// ===========================================================================
49
50/// Precomputed NTT context for a given (N, modulus) pair.
51///
52/// Contains twiddle-factor tables for both forward and inverse NTT,
53/// organized in Longa-Naehrig ordering for negacyclic convolution.
54#[derive(Debug, Clone)]
55pub struct Ntt64Context {
56    /// Polynomial size (power of 2).
57    pub n: usize,
58
59    /// log₂(n).
60    pub log_n: u32,
61
62    /// Modular arithmetic context (Barrett/Montgomery constants).
63    pub arith: Ntt64Arith,
64
65    /// Twiddle factors for forward NTT.
66    ///
67    /// Organized for sequential access in the Cooley-Tukey butterfly:
68    /// `root_powers[m + j]` for layer with half-size `m` and group index `j`.
69    pub root_powers: Vec<u64>,
70
71    /// Inverse twiddle factors for inverse NTT.
72    ///
73    /// Organized for sequential access in the Gentleman-Sande butterfly.
74    pub inv_root_powers: Vec<u64>,
75
76    /// N⁻¹ mod q — normalization factor for the INTT.
77    pub n_inv: u64,
78}
79
80impl Ntt64Context {
81    /// Fallible constructor for an NTT context.
82    ///
83    /// Validates all preconditions and returns an error instead of panicking.
84    ///
85    /// # Arguments
86    /// - `n` — polynomial size, must be a power of 2 (≥ 2)
87    /// - `arith` — precomputed modular arithmetic context; the modulus must be prime
88    ///   and satisfy q ≡ 1 (mod 2N)
89    ///
90    /// # Errors
91    /// - [`crate::NttError::InvalidSize`] if `n` is not a power of 2 ≥ 2
92    /// - [`crate::NttError::NotPrime`] if the modulus is not prime
93    /// - [`crate::NttError::NotNttFriendly`] if `q − 1` is not divisible by `2N`
94    pub fn try_new(n: usize, arith: Ntt64Arith) -> Result<Self, crate::NttError> {
95        if n < 2 || !n.is_power_of_two() {
96            return Err(crate::NttError::InvalidSize(n));
97        }
98        let q = arith.modulus;
99        if !super::prime::is_prime(q) {
100            return Err(crate::NttError::NotPrime(q));
101        }
102        if !(q - 1).is_multiple_of(2 * n as u64) {
103            return Err(crate::NttError::NotNttFriendly { q, n });
104        }
105
106        let log_n = n.trailing_zeros();
107
108        // Find primitive 2N-th root of unity
109        let psi = find_primitive_root(n, q);
110        let psi_inv = mod_inv(psi, &arith);
111        let n_inv = mod_inv(n as u64, &arith);
112
113        // Precompute twiddle factors in Longa-Naehrig ordering:
114        //   root_powers[i] = ψ^{bit_reverse(i, log_n)}  for i in 0..N
115        let mut root_powers = vec![0u64; n];
116        let mut inv_root_powers = vec![0u64; n];
117
118        for i in 0..n {
119            let exp = bit_reverse(i as u32, log_n) as u64;
120            root_powers[i] = mod_pow(psi, exp, &arith);
121            inv_root_powers[i] = mod_pow(psi_inv, exp, &arith);
122        }
123
124        Ok(Self {
125            n,
126            log_n,
127            arith,
128            root_powers,
129            inv_root_powers,
130            n_inv,
131        })
132    }
133
134    /// Creates a new NTT context for polynomial size `n` and the given arithmetic context.
135    ///
136    /// # Arguments
137    /// - `n` — polynomial size, must be a power of 2 (≥ 2)
138    /// - `arith` — precomputed modular arithmetic context; the modulus must satisfy q ≡ 1 (mod 2N)
139    ///
140    /// # Panics
141    /// - If `n` is not a power of 2
142    /// - If the modulus is not prime
143    /// - If q − 1 is not divisible by 2N
144    pub fn new(n: usize, arith: Ntt64Arith) -> Self {
145        Self::try_new(n, arith).expect("Invalid NTT parameters")
146    }
147
148    /// Applies forward NTT in-place.
149    #[inline]
150    pub fn forward(&self, data: &mut [u64]) {
151        ntt_forward(data, self);
152    }
153
154    /// Applies inverse NTT in-place.
155    #[inline]
156    pub fn inverse(&self, data: &mut [u64]) {
157        ntt_inverse(data, self);
158    }
159
160    /// Applies the tiled forward NTT in-place.
161    ///
162    /// Currently delegates to the standard forward NTT.
163    /// A cache-optimized four-step variant is planned for v0.2.
164    #[inline]
165    pub fn forward_tiled(&self, data: &mut [u64]) {
166        // TODO: implement proper four-step NTT with correct negacyclic twiddle decomposition
167        ntt_forward(data, self);
168    }
169
170    /// Pointwise multiplication of two NTT-domain vectors.
171    ///
172    /// `result[i] = a[i] * b[i] mod q`
173    ///
174    /// This is the core operation: in NTT domain, polynomial convolution
175    /// becomes element-wise multiplication.
176    pub fn pointwise_mul(&self, a: &[u64], b: &[u64], result: &mut [u64]) {
177        let n = self.n;
178        assert_eq!(a.len(), n);
179        assert_eq!(b.len(), n);
180        assert_eq!(result.len(), n);
181
182        for i in 0..n {
183            result[i] = mod_mul_barrett(a[i], b[i], &self.arith);
184        }
185    }
186
187    /// Full negacyclic polynomial multiplication: `c = a * b mod (X^N + 1)`.
188    ///
189    /// Performs forward NTT on both inputs, pointwise multiplication,
190    /// and inverse NTT on the result.
191    pub fn negacyclic_mul(&self, a: &[u64], b: &[u64]) -> Vec<u64> {
192        let n = self.n;
193        assert_eq!(a.len(), n);
194        assert_eq!(b.len(), n);
195
196        let mut a_ntt = a.to_vec();
197        let mut b_ntt = b.to_vec();
198        ntt_forward(&mut a_ntt, self);
199        ntt_forward(&mut b_ntt, self);
200
201        let mut c_ntt = vec![0u64; n];
202        self.pointwise_mul(&a_ntt, &b_ntt, &mut c_ntt);
203
204        ntt_inverse(&mut c_ntt, self);
205        c_ntt
206    }
207}
208
209// ===========================================================================
210// Forward NTT (Cooley-Tukey, Decimation In Time)
211// ===========================================================================
212
213/// Forward NTT in-place (negacyclic convolution, Longa-Naehrig ordering).
214///
215/// Transforms N polynomial coefficients in Z_q into their NTT representation.
216///
217/// ## Butterfly
218/// ```text
219/// u' = u + w·v
220/// v' = u − w·v
221/// ```
222///
223/// Layers are traversed from coarsest (gap = N/2) to finest (gap = 1).
224pub fn ntt_forward(data: &mut [u64], ctx: &Ntt64Context) {
225    let n = ctx.n;
226    let q = ctx.arith.modulus;
227    assert_eq!(data.len(), n, "data length ({}) != N ({})", data.len(), n);
228
229    let mut t = n;
230    let mut m = 1;
231
232    for _ in 0..ctx.log_n {
233        t >>= 1;
234        let mut k = 0;
235
236        for i in 0..m {
237            let w = ctx.root_powers[m + i];
238
239            for j in k..(k + t) {
240                let u = data[j];
241                let v = mod_mul_barrett(data[j + t], w, &ctx.arith);
242                data[j] = mod_add(u, v, q);
243                data[j + t] = mod_sub(u, v, q);
244            }
245            k += 2 * t;
246        }
247        m <<= 1;
248    }
249}
250
251// ===========================================================================
252// Inverse NTT (Gentleman-Sande, Decimation In Frequency)
253// ===========================================================================
254
255/// Inverse NTT in-place (negacyclic convolution, Longa-Naehrig ordering).
256///
257/// Transforms an NTT representation of N elements back to polynomial coefficients.
258///
259/// ## Butterfly
260/// ```text
261/// u' = u + v
262/// v' = (u − v) · w_inv
263/// ```
264///
265/// Layers are traversed from finest (gap = 1) to coarsest (gap = N/2).
266/// Each coefficient is multiplied by N⁻¹ mod q at the end.
267pub fn ntt_inverse(data: &mut [u64], ctx: &Ntt64Context) {
268    let n = ctx.n;
269    let q = ctx.arith.modulus;
270    assert_eq!(data.len(), n, "data length ({}) != N ({})", data.len(), n);
271
272    let mut t = 1;
273    let mut m = n;
274
275    for _ in 0..ctx.log_n {
276        m >>= 1;
277        let mut k = 0;
278
279        for i in 0..m {
280            let w_inv = ctx.inv_root_powers[m + i];
281
282            for j in k..(k + t) {
283                let u = data[j];
284                let v = data[j + t];
285                data[j] = mod_add(u, v, q);
286                data[j + t] = mod_mul_barrett(mod_sub(u, v, q), w_inv, &ctx.arith);
287            }
288            k += 2 * t;
289        }
290        t <<= 1;
291    }
292
293    // Normalize by N⁻¹
294    for coeff in data.iter_mut() {
295        *coeff = mod_mul_barrett(*coeff, ctx.n_inv, &ctx.arith);
296    }
297}
298
299// ===========================================================================
300// Four-Step Tiled NTT (cache-friendly)
301// ===========================================================================
302
303/// Four-step tiled forward NTT for improved cache locality.
304///
305/// Views the length-N vector as an N1×N2 matrix (row-major) with
306/// N = N1·N2 and N1, N2 powers of 2 (N1 ≈ √N).
307///
308/// ## Steps
309/// 1. NTT of size N2 on each row (fits in L1 cache)
310/// 2. Multiply by transposition twiddle factors ω^{i·j}
311/// 3. Transpose (N1×N2 → N2×N1)
312/// 4. NTT of size N1 on each row
313/// 5. Transpose back
314///
315/// For small N (≤ 64), delegates to the standard NTT.
316///
317/// NOTE: Currently unused — the negacyclic twiddle decomposition has a known bug.
318/// Kept for future v0.2 implementation.
319#[allow(dead_code)]
320pub fn ntt_forward_tiled(data: &mut [u64], ctx: &Ntt64Context) {
321    let n = ctx.n;
322
323    if n <= 64 {
324        ntt_forward(data, ctx);
325        return;
326    }
327
328    let log_n = ctx.log_n;
329    let log_n1 = log_n / 2;
330    let log_n2 = log_n - log_n1;
331    let n1 = 1usize << log_n1;
332    let n2 = 1usize << log_n2;
333
334    let arith = &ctx.arith;
335
336    // Step 1: NTT of size N2 on each row
337    let sub_ctx2 = Ntt64Context::new(n2, arith.clone());
338    for row in 0..n1 {
339        let start = row * n2;
340        ntt_forward(&mut data[start..start + n2], &sub_ctx2);
341    }
342
343    // Step 2: Multiply by transposition twiddle factors
344    // ω = ψ² (N-th root of unity), twiddle = ω^{i·j}
345    let psi = find_primitive_root(n, arith.modulus);
346    let psi_sq = mod_mul_barrett(psi, psi, arith); // ω = ψ², N-th root
347
348    for i in 0..n1 {
349        for j in 0..n2 {
350            if i == 0 || j == 0 {
351                continue;
352            }
353            let exp = ((i as u128 * j as u128) % n as u128) as u64;
354            let twiddle = mod_pow(psi_sq, exp, arith);
355            let idx = i * n2 + j;
356            data[idx] = mod_mul_barrett(data[idx], twiddle, arith);
357        }
358    }
359
360    // Step 3: Transpose (N1×N2 → N2×N1)
361    let mut transposed = vec![0u64; n];
362    for i in 0..n1 {
363        for j in 0..n2 {
364            transposed[j * n1 + i] = data[i * n2 + j];
365        }
366    }
367    data.copy_from_slice(&transposed);
368
369    // Step 4: NTT of size N1 on each row
370    let sub_ctx1 = Ntt64Context::new(n1, arith.clone());
371    for row in 0..n2 {
372        let start = row * n1;
373        ntt_forward(&mut data[start..start + n1], &sub_ctx1);
374    }
375
376    // Step 5: Transpose back (N2×N1 → N1×N2)
377    for i in 0..n2 {
378        for j in 0..n1 {
379            transposed[j * n2 + i] = data[i * n1 + j];
380        }
381    }
382    data.copy_from_slice(&transposed);
383}
384
385// ===========================================================================
386// Naive polynomial multiplication (test-only)
387// ===========================================================================
388
389/// Naive polynomial multiplication in Z_q[X]/(X^N + 1) — O(N²) complexity.
390///
391/// Used only in tests to verify NTT correctness.
392#[cfg(test)]
393#[allow(unused_variables, clippy::needless_range_loop, dead_code)]
394fn poly_mul_naive(a: &[u64], b: &[u64], q: u64) -> Vec<u64> {
395    let n = a.len();
396    assert_eq!(b.len(), n);
397    let mut result = vec![0u64; n];
398
399    for i in 0..n {
400        for j in 0..n {
401            let idx = i + j;
402            let prod = (a[i] as u128 * b[j] as u128) % q as u128;
403            if idx < n {
404                result[idx] = ((result[idx] as u128 + prod) % q as u128) as u64;
405            } else {
406                let idx = idx - n;
407                result[idx] = ((result[idx] as u128 + q as u128 - prod) % q as u128) as u64;
408            }
409        }
410    }
411    result
412}
413
414// ===========================================================================
415// Tests
416// ===========================================================================
417
418#[cfg(test)]
419#[allow(unused_variables, clippy::needless_range_loop, dead_code)]
420mod tests {
421    use super::super::arith::{PRIME_60_1, PRIME_SEAL};
422    use super::*;
423
424    // --- Primitive root ---
425
426    #[test]
427    fn test_primitive_root_small() {
428        let q: u64 = 17;
429        let n = 8;
430        let psi = find_primitive_root(n, q);
431
432        let arith = Ntt64Arith::new(q);
433        assert_eq!(mod_pow(psi, 2 * n as u64, &arith), 1);
434        assert_eq!(mod_pow(psi, n as u64, &arith), 16);
435    }
436
437    #[test]
438    fn test_primitive_root_seal() {
439        let arith = Ntt64Arith::new(PRIME_SEAL);
440        for &n in &[16, 64, 1024, 4096] {
441            let psi = find_primitive_root(n, PRIME_SEAL);
442            assert_eq!(mod_pow(psi, 2 * n as u64, &arith), 1);
443            assert_eq!(mod_pow(psi, n as u64, &arith), arith.modulus - 1);
444        }
445    }
446
447    // --- NTT roundtrip ---
448
449    #[test]
450    fn test_ntt_roundtrip_small() {
451        let arith = Ntt64Arith::new(PRIME_SEAL);
452        let q = arith.modulus;
453
454        for &n in &[16, 64] {
455            let ctx = Ntt64Context::new(n, arith.clone());
456            let original: Vec<u64> = (0..n).map(|i| (i as u64 * 7 + 3) % q).collect();
457            let mut data = original.clone();
458
459            ntt_forward(&mut data, &ctx);
460            assert_ne!(data, original);
461            ntt_inverse(&mut data, &ctx);
462            assert_eq!(data, original, "NTT roundtrip fails for N={n}");
463        }
464    }
465
466    #[test]
467    fn test_ntt_roundtrip_medium() {
468        let arith = Ntt64Arith::new(PRIME_SEAL);
469        let q = arith.modulus;
470
471        for &n in &[1024, 4096] {
472            let ctx = Ntt64Context::new(n, arith.clone());
473            let original: Vec<u64> = (0..n)
474                .map(|i| ((i as u128 * 123456789 + 987654321) % q as u128) as u64)
475                .collect();
476            let mut data = original.clone();
477
478            ntt_forward(&mut data, &ctx);
479            ntt_inverse(&mut data, &ctx);
480            assert_eq!(data, original, "NTT roundtrip fails for N={n}");
481        }
482    }
483
484    #[test]
485    fn test_ntt_roundtrip_zeros() {
486        let arith = Ntt64Arith::new(PRIME_SEAL);
487        let n = 64;
488        let ctx = Ntt64Context::new(n, arith);
489        let mut data = vec![0u64; n];
490        ntt_forward(&mut data, &ctx);
491        ntt_inverse(&mut data, &ctx);
492        assert_eq!(data, vec![0u64; n]);
493    }
494
495    #[test]
496    fn test_ntt_roundtrip_one() {
497        let arith = Ntt64Arith::new(PRIME_SEAL);
498        let n = 64;
499        let ctx = Ntt64Context::new(n, arith);
500        let mut data = vec![0u64; n];
501        data[0] = 1;
502        let original = data.clone();
503        ntt_forward(&mut data, &ctx);
504        ntt_inverse(&mut data, &ctx);
505        assert_eq!(data, original);
506    }
507
508    // --- Negacyclic convolution ---
509
510    #[test]
511    fn test_ntt_convolution_n16() {
512        let arith = Ntt64Arith::new(PRIME_SEAL);
513        let q = arith.modulus;
514        let n = 16;
515        let ctx = Ntt64Context::new(n, arith);
516
517        let a: Vec<u64> = (0..n).map(|i| (i as u64 + 1) % q).collect();
518        let b: Vec<u64> = (0..n).map(|_| 1u64).collect();
519
520        let expected = poly_mul_naive(&a, &b, q);
521        let result = ctx.negacyclic_mul(&a, &b);
522        assert_eq!(result, expected, "NTT convolution != naive for N=16");
523    }
524
525    #[test]
526    fn test_ntt_convolution_n64() {
527        let arith = Ntt64Arith::new(PRIME_SEAL);
528        let q = arith.modulus;
529        let n = 64;
530        let ctx = Ntt64Context::new(n, arith);
531
532        let a: Vec<u64> = (0..n).map(|i| ((i * i + 3 * i + 7) as u64) % q).collect();
533        let b: Vec<u64> = (0..n).map(|i| ((2 * i + 1) as u64) % q).collect();
534
535        let expected = poly_mul_naive(&a, &b, q);
536        let result = ctx.negacyclic_mul(&a, &b);
537        assert_eq!(result, expected, "NTT convolution != naive for N=64");
538    }
539
540    #[test]
541    fn test_ntt_convolution_identity() {
542        let arith = Ntt64Arith::new(PRIME_SEAL);
543        let q = arith.modulus;
544        let n = 64;
545        let ctx = Ntt64Context::new(n, arith);
546
547        let a: Vec<u64> = (0..n).map(|i| ((i * 17 + 5) as u64) % q).collect();
548        let mut one = vec![0u64; n];
549        one[0] = 1;
550
551        let result = ctx.negacyclic_mul(&a, &one);
552        assert_eq!(result, a, "Multiplying by 1 should give identity");
553    }
554
555    // --- Tiled NTT ---
556
557    #[test]
558    fn test_ntt_tiled_matches_standard_small() {
559        let arith = Ntt64Arith::new(PRIME_SEAL);
560        let q = arith.modulus;
561
562        for &n in &[16, 64] {
563            let ctx = Ntt64Context::new(n, arith.clone());
564            let original: Vec<u64> = (0..n).map(|i| (i as u64 * 13 + 7) % q).collect();
565
566            let mut data_std = original.clone();
567            let mut data_tiled = original.clone();
568
569            ntt_forward(&mut data_std, &ctx);
570            ntt_forward_tiled(&mut data_tiled, &ctx);
571
572            assert_eq!(data_tiled, data_std, "tiled NTT != standard for N={n}");
573        }
574    }
575
576    #[test]
577    fn test_ntt_tiled_roundtrip() {
578        let arith = Ntt64Arith::new(PRIME_SEAL);
579        let q = arith.modulus;
580        let n = 256;
581        let ctx = Ntt64Context::new(n, arith);
582
583        let original: Vec<u64> = (0..n)
584            .map(|i| ((i as u128 * 999999937 + 42) % q as u128) as u64)
585            .collect();
586        let mut data = original.clone();
587
588        ntt_forward(&mut data, &ctx);
589        ntt_inverse(&mut data, &ctx);
590        assert_eq!(data, original, "standard roundtrip fails for N=256");
591    }
592
593    // --- With PRIME_60_1 ---
594
595    #[test]
596    fn test_ntt_with_prime_60_1() {
597        let arith = Ntt64Arith::new(PRIME_60_1);
598        let q = arith.modulus;
599
600        for &n in &[16, 64] {
601            assert_eq!((q - 1) % (2 * n as u64), 0);
602            let ctx = Ntt64Context::new(n, arith.clone());
603            let original: Vec<u64> = (0..n).map(|i| (i as u64 * 31 + 11) % q).collect();
604            let mut data = original.clone();
605
606            ntt_forward(&mut data, &ctx);
607            ntt_inverse(&mut data, &ctx);
608            assert_eq!(
609                data, original,
610                "NTT roundtrip fails for N={n} with PRIME_60_1"
611            );
612        }
613    }
614
615    // --- Bit-reverse ---
616
617    #[test]
618    fn test_bit_reverse() {
619        assert_eq!(bit_reverse(0, 3), 0);
620        assert_eq!(bit_reverse(1, 3), 4);
621        assert_eq!(bit_reverse(2, 3), 2);
622        assert_eq!(bit_reverse(3, 3), 6);
623        assert_eq!(bit_reverse(4, 3), 1);
624        assert_eq!(bit_reverse(5, 3), 5);
625        assert_eq!(bit_reverse(6, 3), 3);
626        assert_eq!(bit_reverse(7, 3), 7);
627        assert_eq!(bit_reverse(0, 1), 0);
628        assert_eq!(bit_reverse(1, 1), 1);
629    }
630
631    // --- Linearity ---
632
633    #[test]
634    fn test_ntt_linearity() {
635        let arith = Ntt64Arith::new(PRIME_SEAL);
636        let q = arith.modulus;
637        let n = 64;
638        let ctx = Ntt64Context::new(n, arith);
639
640        let a: Vec<u64> = (0..n).map(|i| (i as u64 * 3 + 1) % q).collect();
641        let b: Vec<u64> = (0..n).map(|i| (i as u64 * 7 + 2) % q).collect();
642
643        let mut a_ntt = a.clone();
644        let mut b_ntt = b.clone();
645        ntt_forward(&mut a_ntt, &ctx);
646        ntt_forward(&mut b_ntt, &ctx);
647
648        let mut sum: Vec<u64> = (0..n).map(|i| mod_add(a[i], b[i], q)).collect();
649        ntt_forward(&mut sum, &ctx);
650
651        for i in 0..n {
652            let expected = mod_add(a_ntt[i], b_ntt[i], q);
653            assert_eq!(sum[i], expected, "linearity violated at index {i}");
654        }
655    }
656
657    // --- Large N roundtrip ---
658
659    #[test]
660    fn test_ntt_roundtrip_large() {
661        let arith = Ntt64Arith::new(PRIME_SEAL);
662        let q = arith.modulus;
663        let n = 32768;
664
665        assert_eq!((q - 1) % (2 * n as u64), 0);
666        let ctx = Ntt64Context::new(n, arith);
667
668        let original: Vec<u64> = (0..n)
669            .map(|i| ((i as u128 * 314159265 + 271828182) % q as u128) as u64)
670            .collect();
671        let mut data = original.clone();
672
673        ntt_forward(&mut data, &ctx);
674        ntt_inverse(&mut data, &ctx);
675        assert_eq!(data, original, "NTT roundtrip fails for N=32768");
676    }
677
678    // Compile-time check: Ntt64Context must be Send + Sync
679    const _: () = {
680        fn assert_send<T: Send>() {}
681        fn assert_sync<T: Sync>() {}
682        fn check() {
683            assert_send::<super::Ntt64Context>();
684            assert_sync::<super::Ntt64Context>();
685        }
686    };
687}