1use crate::ntt64::arith::Ntt64Arith;
30use crate::ntt64::context::Ntt64Context;
31use crate::poly::Poly64;
32use alloc::vec::Vec;
33
34pub struct RnsContext {
43 pub moduli: Vec<u64>,
45 pub ariths: Vec<Ntt64Arith>,
47 pub ntt_ctxs: Vec<Ntt64Context>,
49 pub poly_degree: usize,
51}
52
53impl RnsContext {
54 pub fn new(poly_degree: usize, moduli: Vec<u64>) -> Self {
64 assert!(
65 poly_degree.is_power_of_two(),
66 "poly_degree must be a power of 2"
67 );
68 assert!(!moduli.is_empty(), "at least one modulus is required");
69
70 let ariths: Vec<Ntt64Arith> = moduli.iter().map(|&q| Ntt64Arith::new(q)).collect();
71
72 let ntt_ctxs: Vec<Ntt64Context> = ariths
73 .iter()
74 .map(|arith| Ntt64Context::new(poly_degree, arith.clone()))
75 .collect();
76
77 Self {
78 moduli,
79 ariths,
80 ntt_ctxs,
81 poly_degree,
82 }
83 }
84
85 #[inline]
87 pub fn num_moduli(&self) -> usize {
88 self.moduli.len()
89 }
90}
91
92#[derive(Clone, Debug)]
104pub struct RnsPoly {
105 pub components: Vec<Poly64>,
107 pub level: usize,
109}
110
111impl RnsPoly {
112 pub fn from_coefficients(coeffs: &[i64], ctx: &RnsContext) -> Self {
122 let n = ctx.poly_degree;
123 assert!(
124 coeffs.len() <= n,
125 "too many coefficients: {} > {}",
126 coeffs.len(),
127 n
128 );
129
130 let level = ctx.num_moduli();
131 let mut components = Vec::with_capacity(level);
132
133 for i in 0..level {
134 let q = ctx.moduli[i];
135 let q_i64 = q as i64;
136
137 let mut poly = Poly64::new_zero(n);
138 for (j, &c) in coeffs.iter().enumerate() {
139 let r = c % q_i64;
140 poly.data[j] = if r < 0 { (r + q_i64) as u64 } else { r as u64 };
141 }
142
143 poly.forward_ntt(&ctx.ntt_ctxs[i]);
144 components.push(poly);
145 }
146
147 Self { components, level }
148 }
149
150 pub fn add(&self, other: &RnsPoly, ctx: &RnsContext) -> RnsPoly {
154 assert_eq!(
155 self.level, other.level,
156 "levels must match: {} != {}",
157 self.level, other.level
158 );
159
160 let mut result = self.clone();
161 for i in 0..self.level {
162 result.components[i].add_assign(&other.components[i], &ctx.ariths[i]);
163 }
164 result
165 }
166
167 pub fn sub(&self, other: &RnsPoly, ctx: &RnsContext) -> RnsPoly {
169 assert_eq!(self.level, other.level, "levels must match");
170
171 let mut result = self.clone();
172 for i in 0..self.level {
173 result.components[i].sub_assign(&other.components[i], &ctx.ariths[i]);
174 }
175 result
176 }
177
178 pub fn mul(&self, other: &RnsPoly, ctx: &RnsContext) -> RnsPoly {
182 assert_eq!(self.level, other.level, "levels must match");
183
184 let mut result = self.clone();
185 for i in 0..self.level {
186 result.components[i].mul_assign(&other.components[i], &ctx.ariths[i]);
187 }
188 result
189 }
190
191 pub fn drop_last_modulus(&mut self) {
199 assert!(self.level > 1, "cannot reduce level below 1");
200 self.components.pop();
201 self.level -= 1;
202 }
203
204 pub fn forward_all(&mut self, ctx: &RnsContext) {
208 for i in 0..self.level {
209 if !self.components[i].is_ntt {
210 self.components[i].forward_ntt(&ctx.ntt_ctxs[i]);
211 }
212 }
213 }
214
215 pub fn inverse_all(&mut self, ctx: &RnsContext) {
219 for i in 0..self.level {
220 if self.components[i].is_ntt {
221 self.components[i].inverse_ntt(&ctx.ntt_ctxs[i]);
222 }
223 }
224 }
225}
226
227#[cfg(test)]
232mod tests {
233 use super::*;
234 use crate::ntt64::prime::is_prime;
235 use alloc::vec;
236 use alloc::vec::Vec;
237
238 const TEST_N: usize = 256;
239 const TEST_Q1: u64 = 7681; const TEST_Q2: u64 = 12289; fn test_rns_ctx() -> RnsContext {
243 RnsContext::new(TEST_N, vec![TEST_Q1, TEST_Q2])
244 }
245
246 #[test]
247 fn test_rns_encode_decode() {
248 let ctx = test_rns_ctx();
249 let coeffs = vec![5i64, -3, 0, 7];
250 let mut rns_poly = RnsPoly::from_coefficients(&coeffs, &ctx);
251
252 rns_poly.inverse_all(&ctx);
253
254 assert_eq!(rns_poly.components[0].data[0], 5);
255 assert_eq!(rns_poly.components[0].data[1], TEST_Q1 - 3);
256 assert_eq!(rns_poly.components[0].data[2], 0);
257 assert_eq!(rns_poly.components[0].data[3], 7);
258
259 assert_eq!(rns_poly.components[1].data[0], 5);
260 assert_eq!(rns_poly.components[1].data[1], TEST_Q2 - 3);
261 assert_eq!(rns_poly.components[1].data[2], 0);
262 assert_eq!(rns_poly.components[1].data[3], 7);
263 }
264
265 #[test]
266 fn test_rns_add_mul_distributivity() {
267 let ctx = test_rns_ctx();
268
269 let a_coeffs: Vec<i64> = (0..TEST_N as i64).map(|i| i % 100).collect();
270 let b_coeffs: Vec<i64> = (0..TEST_N as i64).map(|i| (i * 3 + 7) % 100).collect();
271 let c_coeffs: Vec<i64> = (0..TEST_N as i64).map(|i| (i * 2 + 1) % 50).collect();
272
273 let a = RnsPoly::from_coefficients(&a_coeffs, &ctx);
274 let b = RnsPoly::from_coefficients(&b_coeffs, &ctx);
275 let c = RnsPoly::from_coefficients(&c_coeffs, &ctx);
276
277 let ab = a.add(&b, &ctx);
279 let mut lhs = ab.mul(&c, &ctx);
280
281 let ac = a.mul(&c, &ctx);
283 let bc = b.mul(&c, &ctx);
284 let mut rhs = ac.add(&bc, &ctx);
285
286 lhs.inverse_all(&ctx);
287 rhs.inverse_all(&ctx);
288
289 for i in 0..ctx.num_moduli() {
290 for j in 0..TEST_N {
291 assert_eq!(
292 lhs.components[i].data[j], rhs.components[i].data[j],
293 "(a+b)*c != a*c+b*c — modulus {}, coeff {}",
294 ctx.moduli[i], j
295 );
296 }
297 }
298 }
299
300 #[test]
301 fn test_rns_drop_last_modulus() {
302 let ctx = test_rns_ctx();
303 let coeffs = vec![1i64, 2, 3];
304 let mut poly = RnsPoly::from_coefficients(&coeffs, &ctx);
305
306 assert_eq!(poly.level, 2);
307 assert_eq!(poly.components.len(), 2);
308
309 poly.drop_last_modulus();
310
311 assert_eq!(poly.level, 1);
312 assert_eq!(poly.components.len(), 1);
313 }
314
315 #[test]
316 #[should_panic(expected = "cannot reduce")]
317 fn test_rns_drop_last_modulus_panics_at_level_1() {
318 let ctx = RnsContext::new(TEST_N, vec![TEST_Q1]);
319 let coeffs = vec![1i64];
320 let mut poly = RnsPoly::from_coefficients(&coeffs, &ctx);
321 poly.drop_last_modulus();
322 }
323
324 #[test]
325 fn test_rns_sub() {
326 let ctx = test_rns_ctx();
327 let coeffs: Vec<i64> = (0..TEST_N as i64).map(|i| i % 1000 - 500).collect();
328 let a = RnsPoly::from_coefficients(&coeffs, &ctx);
329
330 let mut zero = a.sub(&a, &ctx);
331 zero.inverse_all(&ctx);
332
333 for i in 0..ctx.num_moduli() {
334 for j in 0..TEST_N {
335 assert_eq!(
336 zero.components[i].data[j], 0,
337 "a - a != 0 — modulus {}, coeff {}",
338 ctx.moduli[i], j
339 );
340 }
341 }
342 }
343
344 #[test]
345 fn test_ntt_friendly_primes_are_valid() {
346 assert!(is_prime(TEST_Q1), "q1 = {TEST_Q1} should be prime");
347 assert!(is_prime(TEST_Q2), "q2 = {TEST_Q2} should be prime");
348
349 let two_n = 2 * TEST_N as u64;
350 assert_eq!((TEST_Q1 - 1) % two_n, 0);
351 assert_eq!((TEST_Q2 - 1) % two_n, 0);
352 }
353}