Skip to main content

vaea_ntt/
rns.rs

1// Copyright (C) 2024-2026 Vaea SAS
2// SPDX-License-Identifier: AGPL-3.0-or-later
3//
4// This file is part of VaeaNTT.
5//
6// VaeaNTT is free software: you can redistribute it and/or modify it under
7// the terms of the GNU Affero General Public License as published by the
8// Free Software Foundation, either version 3 of the License, or (at your
9// option) any later version.
10//
11// VaeaNTT is distributed in the hope that it will be useful, but WITHOUT
12// ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13// FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General Public
14// License for more details.
15//
16// You should have received a copy of the GNU Affero General Public License
17// along with VaeaNTT. If not, see <https://www.gnu.org/licenses/>.
18
19//! # Residue Number System (RNS) — Multi-Moduli Decomposition
20//!
21//! RNS allows working with large integers by decomposing them into residues
22//! modulo several small coprime moduli. Each component can be processed
23//! independently, which is perfect for parallelism and avoids
24//! multi-precision arithmetic.
25//!
26//! For CKKS, the product Q = q₁·q₂·…·q_L defines the precision level.
27//! Rescaling removes one modulus per level.
28
29use crate::ntt64::arith::Ntt64Arith;
30use crate::ntt64::context::Ntt64Context;
31use crate::poly::Poly64;
32use alloc::vec::Vec;
33
34// ---------------------------------------------------------------------------
35// RnsContext — RNS context
36// ---------------------------------------------------------------------------
37
38/// RNS context: a set of coprime moduli.
39///
40/// Precomputes modular arithmetic and NTT contexts for each modulus,
41/// enabling efficient component-wise polynomial operations.
42pub struct RnsContext {
43    /// The moduli q₁, q₂, …, q_L.
44    pub moduli: Vec<u64>,
45    /// Modular arithmetic contexts for each modulus (Barrett, Montgomery).
46    pub ariths: Vec<Ntt64Arith>,
47    /// NTT contexts for each modulus.
48    pub ntt_ctxs: Vec<Ntt64Context>,
49    /// Polynomial degree N.
50    pub poly_degree: usize,
51}
52
53impl RnsContext {
54    /// Creates an RNS context with the given moduli.
55    ///
56    /// Precomputes all modular arithmetic and NTT contexts.
57    /// Each modulus must be NTT-friendly for the given polynomial degree.
58    ///
59    /// # Panics
60    /// - If `poly_degree` is not a power of 2
61    /// - If `moduli` is empty
62    /// - If any modulus is not NTT-friendly for the given degree
63    pub fn new(poly_degree: usize, moduli: Vec<u64>) -> Self {
64        assert!(
65            poly_degree.is_power_of_two(),
66            "poly_degree must be a power of 2"
67        );
68        assert!(!moduli.is_empty(), "at least one modulus is required");
69
70        let ariths: Vec<Ntt64Arith> = moduli.iter().map(|&q| Ntt64Arith::new(q)).collect();
71
72        let ntt_ctxs: Vec<Ntt64Context> = ariths
73            .iter()
74            .map(|arith| Ntt64Context::new(poly_degree, arith.clone()))
75            .collect();
76
77        Self {
78            moduli,
79            ariths,
80            ntt_ctxs,
81            poly_degree,
82        }
83    }
84
85    /// Number of moduli (= total number of levels).
86    #[inline]
87    pub fn num_moduli(&self) -> usize {
88        self.moduli.len()
89    }
90}
91
92// ---------------------------------------------------------------------------
93// RnsPoly — polynomial in RNS representation
94// ---------------------------------------------------------------------------
95
96/// Polynomial in RNS representation: one component per modulus.
97///
98/// Each component `components[i]` is a polynomial in Z_{q_i}\[X\]/(X^N+1),
99/// stored in NTT domain by default for performance.
100///
101/// The `level` indicates the number of active moduli. CKKS rescaling reduces
102/// the level by removing the last modulus.
103#[derive(Clone, Debug)]
104pub struct RnsPoly {
105    /// `components[i]` = polynomial modulo `moduli[i]`.
106    pub components: Vec<Poly64>,
107    /// Current level (number of active moduli).
108    pub level: usize,
109}
110
111impl RnsPoly {
112    /// Encodes a signed-integer polynomial into RNS representation.
113    ///
114    /// For each modulus q_i:
115    /// 1. Reduces each coefficient mod q_i (handles negatives)
116    /// 2. Converts to NTT domain
117    ///
118    /// # Arguments
119    /// * `coeffs` — polynomial coefficients in Z (signed, coefficient domain)
120    /// * `ctx` — RNS context
121    pub fn from_coefficients(coeffs: &[i64], ctx: &RnsContext) -> Self {
122        let n = ctx.poly_degree;
123        assert!(
124            coeffs.len() <= n,
125            "too many coefficients: {} > {}",
126            coeffs.len(),
127            n
128        );
129
130        let level = ctx.num_moduli();
131        let mut components = Vec::with_capacity(level);
132
133        for i in 0..level {
134            let q = ctx.moduli[i];
135            let q_i64 = q as i64;
136
137            let mut poly = Poly64::new_zero(n);
138            for (j, &c) in coeffs.iter().enumerate() {
139                let r = c % q_i64;
140                poly.data[j] = if r < 0 { (r + q_i64) as u64 } else { r as u64 };
141            }
142
143            poly.forward_ntt(&ctx.ntt_ctxs[i]);
144            components.push(poly);
145        }
146
147        Self { components, level }
148    }
149
150    /// Component-wise addition in RNS.
151    ///
152    /// Both polynomials must have the same level.
153    pub fn add(&self, other: &RnsPoly, ctx: &RnsContext) -> RnsPoly {
154        assert_eq!(
155            self.level, other.level,
156            "levels must match: {} != {}",
157            self.level, other.level
158        );
159
160        let mut result = self.clone();
161        for i in 0..self.level {
162            result.components[i].add_assign(&other.components[i], &ctx.ariths[i]);
163        }
164        result
165    }
166
167    /// Component-wise subtraction in RNS.
168    pub fn sub(&self, other: &RnsPoly, ctx: &RnsContext) -> RnsPoly {
169        assert_eq!(self.level, other.level, "levels must match");
170
171        let mut result = self.clone();
172        for i in 0..self.level {
173            result.components[i].sub_assign(&other.components[i], &ctx.ariths[i]);
174        }
175        result
176    }
177
178    /// Component-wise multiplication in RNS (NTT domain).
179    ///
180    /// All components must be in NTT domain.
181    pub fn mul(&self, other: &RnsPoly, ctx: &RnsContext) -> RnsPoly {
182        assert_eq!(self.level, other.level, "levels must match");
183
184        let mut result = self.clone();
185        for i in 0..self.level {
186            result.components[i].mul_assign(&other.components[i], &ctx.ariths[i]);
187        }
188        result
189    }
190
191    /// Drops the last modulus (CKKS rescaling).
192    ///
193    /// After this operation, the level decreases by 1 and the last component
194    /// is removed. The scale factor Δ is implicitly divided by q_L.
195    ///
196    /// # Panics
197    /// Panics if the level is already 1.
198    pub fn drop_last_modulus(&mut self) {
199        assert!(self.level > 1, "cannot reduce level below 1");
200        self.components.pop();
201        self.level -= 1;
202    }
203
204    /// Converts all components from coefficient domain to NTT domain.
205    ///
206    /// Components already in NTT domain are skipped.
207    pub fn forward_all(&mut self, ctx: &RnsContext) {
208        for i in 0..self.level {
209            if !self.components[i].is_ntt {
210                self.components[i].forward_ntt(&ctx.ntt_ctxs[i]);
211            }
212        }
213    }
214
215    /// Converts all components from NTT domain to coefficient domain.
216    ///
217    /// Components already in coefficient domain are skipped.
218    pub fn inverse_all(&mut self, ctx: &RnsContext) {
219        for i in 0..self.level {
220            if self.components[i].is_ntt {
221                self.components[i].inverse_ntt(&ctx.ntt_ctxs[i]);
222            }
223        }
224    }
225}
226
227// ---------------------------------------------------------------------------
228// Tests
229// ---------------------------------------------------------------------------
230
231#[cfg(test)]
232mod tests {
233    use super::*;
234    use crate::ntt64::prime::is_prime;
235    use alloc::vec;
236    use alloc::vec::Vec;
237
238    const TEST_N: usize = 256;
239    const TEST_Q1: u64 = 7681; // 15·512+1
240    const TEST_Q2: u64 = 12289; // 24·512+1
241
242    fn test_rns_ctx() -> RnsContext {
243        RnsContext::new(TEST_N, vec![TEST_Q1, TEST_Q2])
244    }
245
246    #[test]
247    fn test_rns_encode_decode() {
248        let ctx = test_rns_ctx();
249        let coeffs = vec![5i64, -3, 0, 7];
250        let mut rns_poly = RnsPoly::from_coefficients(&coeffs, &ctx);
251
252        rns_poly.inverse_all(&ctx);
253
254        assert_eq!(rns_poly.components[0].data[0], 5);
255        assert_eq!(rns_poly.components[0].data[1], TEST_Q1 - 3);
256        assert_eq!(rns_poly.components[0].data[2], 0);
257        assert_eq!(rns_poly.components[0].data[3], 7);
258
259        assert_eq!(rns_poly.components[1].data[0], 5);
260        assert_eq!(rns_poly.components[1].data[1], TEST_Q2 - 3);
261        assert_eq!(rns_poly.components[1].data[2], 0);
262        assert_eq!(rns_poly.components[1].data[3], 7);
263    }
264
265    #[test]
266    fn test_rns_add_mul_distributivity() {
267        let ctx = test_rns_ctx();
268
269        let a_coeffs: Vec<i64> = (0..TEST_N as i64).map(|i| i % 100).collect();
270        let b_coeffs: Vec<i64> = (0..TEST_N as i64).map(|i| (i * 3 + 7) % 100).collect();
271        let c_coeffs: Vec<i64> = (0..TEST_N as i64).map(|i| (i * 2 + 1) % 50).collect();
272
273        let a = RnsPoly::from_coefficients(&a_coeffs, &ctx);
274        let b = RnsPoly::from_coefficients(&b_coeffs, &ctx);
275        let c = RnsPoly::from_coefficients(&c_coeffs, &ctx);
276
277        // (a + b) * c
278        let ab = a.add(&b, &ctx);
279        let mut lhs = ab.mul(&c, &ctx);
280
281        // a*c + b*c
282        let ac = a.mul(&c, &ctx);
283        let bc = b.mul(&c, &ctx);
284        let mut rhs = ac.add(&bc, &ctx);
285
286        lhs.inverse_all(&ctx);
287        rhs.inverse_all(&ctx);
288
289        for i in 0..ctx.num_moduli() {
290            for j in 0..TEST_N {
291                assert_eq!(
292                    lhs.components[i].data[j], rhs.components[i].data[j],
293                    "(a+b)*c != a*c+b*c — modulus {}, coeff {}",
294                    ctx.moduli[i], j
295                );
296            }
297        }
298    }
299
300    #[test]
301    fn test_rns_drop_last_modulus() {
302        let ctx = test_rns_ctx();
303        let coeffs = vec![1i64, 2, 3];
304        let mut poly = RnsPoly::from_coefficients(&coeffs, &ctx);
305
306        assert_eq!(poly.level, 2);
307        assert_eq!(poly.components.len(), 2);
308
309        poly.drop_last_modulus();
310
311        assert_eq!(poly.level, 1);
312        assert_eq!(poly.components.len(), 1);
313    }
314
315    #[test]
316    #[should_panic(expected = "cannot reduce")]
317    fn test_rns_drop_last_modulus_panics_at_level_1() {
318        let ctx = RnsContext::new(TEST_N, vec![TEST_Q1]);
319        let coeffs = vec![1i64];
320        let mut poly = RnsPoly::from_coefficients(&coeffs, &ctx);
321        poly.drop_last_modulus();
322    }
323
324    #[test]
325    fn test_rns_sub() {
326        let ctx = test_rns_ctx();
327        let coeffs: Vec<i64> = (0..TEST_N as i64).map(|i| i % 1000 - 500).collect();
328        let a = RnsPoly::from_coefficients(&coeffs, &ctx);
329
330        let mut zero = a.sub(&a, &ctx);
331        zero.inverse_all(&ctx);
332
333        for i in 0..ctx.num_moduli() {
334            for j in 0..TEST_N {
335                assert_eq!(
336                    zero.components[i].data[j], 0,
337                    "a - a != 0 — modulus {}, coeff {}",
338                    ctx.moduli[i], j
339                );
340            }
341        }
342    }
343
344    #[test]
345    fn test_ntt_friendly_primes_are_valid() {
346        assert!(is_prime(TEST_Q1), "q1 = {TEST_Q1} should be prime");
347        assert!(is_prime(TEST_Q2), "q2 = {TEST_Q2} should be prime");
348
349        let two_n = 2 * TEST_N as u64;
350        assert_eq!((TEST_Q1 - 1) % two_n, 0);
351        assert_eq!((TEST_Q2 - 1) % two_n, 0);
352    }
353}