1use std::ops::{Add, Index, IndexMut, Mul, Neg, Sub};
2
3use serde::{Deserialize, Serialize};
4use sunscreen_math_macros::refify_binary_op;
5
6use crate::{One, Zero, ring::Ring};
7
8#[derive(Debug, Clone, Eq, Serialize, Deserialize)]
9pub struct Polynomial<R>
15where
16 R: Ring,
17{
18 pub coeffs: Vec<R>,
20}
21
22impl<R> PartialEq for Polynomial<R>
23where
24 R: Ring,
25{
26 fn eq(&self, other: &Self) -> bool {
31 let lhs_is_zero = self.vartime_is_zero();
33 let rhs_is_zero = other.vartime_is_zero();
34
35 if lhs_is_zero || rhs_is_zero {
36 return lhs_is_zero && rhs_is_zero;
37 }
38
39 let lhs_degree = self.vartime_degree();
40 let rhs_degree = other.vartime_degree();
41
42 if lhs_degree != rhs_degree {
43 return false;
44 }
45
46 for i in 0..lhs_degree {
47 if self.coeffs[i] != other.coeffs[i] {
48 return false;
49 }
50 }
51
52 true
53 }
54}
55
56impl<R> Polynomial<R>
57where
58 R: Ring,
59{
60 pub fn new(coeffs: &[R]) -> Self {
62 Self {
63 coeffs: coeffs.to_owned(),
64 }
65 }
66
67 pub fn evaluate(&self, x: &R) -> R {
69 let mut eval = R::zero();
70 let mut cur_pow = R::one();
71
72 for i in &self.coeffs {
73 eval = eval + i.clone() * &cur_pow;
74 cur_pow = cur_pow.clone() * x;
75 }
76
77 eval
78 }
79
80 pub fn vartime_degree(&self) -> usize {
89 for (i, coeff) in self.coeffs.iter().rev().enumerate() {
90 if !coeff.vartime_is_zero() {
91 return self.coeffs.len() - i - 1;
92 }
93 }
94
95 panic!("Zero polynomial has undefined degree.");
96 }
97
98 pub fn vartime_div_rem_restricted_rhs(&self, rhs: &Self) -> (Self, Self) {
118 let mut rem = self.clone();
119
120 if self.vartime_is_zero() {
121 return (Self::zero(), Self::zero());
122 }
123
124 let lhs_degree = self.vartime_degree();
125
126 let rhs_degree = rhs.vartime_degree();
128
129 if lhs_degree < rhs_degree {
131 return (Self::zero(), rem);
132 }
133
134 let iter_count = lhs_degree - rhs_degree + 1;
135 let mut q = Polynomial {
136 coeffs: vec![R::zero(); iter_count],
137 };
138
139 for i in 0..iter_count {
140 let scale = rem.coeffs[lhs_degree - i].clone();
144
145 for j in 0..=rhs_degree {
146 let lhs_index = lhs_degree - i - j;
147 let rhs_index = rhs_degree - j;
148
149 rem.coeffs[lhs_index] =
150 rem.coeffs[lhs_index].clone() - rhs.coeffs[rhs_index].clone() * &scale;
151 }
152
153 q.coeffs[iter_count - i - 1] = scale;
154 }
155
156 (q, rem)
157 }
158}
159
160impl<T> Index<usize> for Polynomial<T>
161where
162 T: Ring,
163{
164 type Output = T;
165
166 fn index(&self, index: usize) -> &Self::Output {
167 &self.coeffs[index]
168 }
169}
170
171impl<T> IndexMut<usize> for Polynomial<T>
172where
173 T: Ring,
174{
175 fn index_mut(&mut self, index: usize) -> &mut Self::Output {
176 &mut self.coeffs[index]
177 }
178}
179
180#[refify_binary_op]
181impl<T> Add<&Polynomial<T>> for &Polynomial<T>
182where
183 T: Ring,
184{
185 type Output = Polynomial<T>;
186
187 fn add(self, rhs: &Polynomial<T>) -> Self::Output {
188 let out_len = usize::max(self.coeffs.len(), rhs.coeffs.len());
189
190 let mut out_coeffs = Vec::with_capacity(out_len);
191 let len = usize::max(self.coeffs.len(), rhs.coeffs.len());
192
193 for i in 0..len {
194 let a = self.coeffs.get(i).unwrap_or(&T::zero()).clone();
195 let b = rhs.coeffs.get(i).unwrap_or(&T::zero()).clone();
196
197 out_coeffs.push(a + b);
198 }
199
200 Polynomial { coeffs: out_coeffs }
201 }
202}
203
204#[refify_binary_op]
205impl<T> Sub<&Polynomial<T>> for &Polynomial<T>
206where
207 T: Ring,
208{
209 type Output = Polynomial<T>;
210
211 fn sub(self, rhs: &Polynomial<T>) -> Self::Output {
212 let out_len = usize::max(self.coeffs.len(), rhs.coeffs.len());
213
214 let mut out_coeffs = Vec::with_capacity(out_len);
215 let len = usize::max(self.coeffs.len(), rhs.coeffs.len());
216
217 for i in 0..len {
218 let a = self.coeffs.get(i).unwrap_or(&T::zero()).clone();
219 let b = rhs.coeffs.get(i).unwrap_or(&T::zero()).clone();
220
221 out_coeffs.push(a - b);
222 }
223
224 Polynomial { coeffs: out_coeffs }
225 }
226}
227
228#[refify_binary_op]
229impl<T> Mul<&Polynomial<T>> for &Polynomial<T>
230where
231 T: Ring,
232{
233 type Output = Polynomial<T>;
234
235 fn mul(self, rhs: &Polynomial<T>) -> Self::Output {
236 if self.coeffs.is_empty() || rhs.coeffs.is_empty() {
238 return Self::Output::zero();
239 }
240
241 let mut out_coeffs = vec![T::zero(); (self.coeffs.len() - 1) + (rhs.coeffs.len() - 1) + 1];
242
243 for i in 0..self.coeffs.len() {
244 for j in 0..rhs.coeffs.len() {
245 let a = self.coeffs.get(i).unwrap_or(&T::zero()).clone();
246 let b = rhs.coeffs.get(j).unwrap_or(&T::zero()).clone();
247
248 out_coeffs[i + j] = a * b + &out_coeffs[i + j];
249 }
250 }
251
252 Polynomial { coeffs: out_coeffs }
253 }
254}
255
256#[refify_binary_op]
257impl<T> Mul<&T> for &Polynomial<T>
258where
259 T: Ring,
260{
261 type Output = Polynomial<T>;
262
263 fn mul(self, rhs: &T) -> Self::Output {
264 Self::Output {
265 coeffs: self
266 .coeffs
267 .iter()
268 .map(|x| x.clone() * rhs)
269 .collect::<Vec<_>>(),
270 }
271 }
272}
273
274impl<T> Zero for Polynomial<T>
275where
276 T: Ring,
277{
278 #[inline(always)]
279 fn zero() -> Self {
280 Self { coeffs: vec![] }
281 }
282
283 fn vartime_is_zero(&self) -> bool {
284 self.coeffs.iter().all(|x| x.vartime_is_zero())
285 }
286}
287
288impl<T> One for Polynomial<T>
289where
290 T: Ring,
291{
292 #[inline(always)]
293 fn one() -> Self {
294 Self {
295 coeffs: vec![T::one()],
296 }
297 }
298}
299
300impl<T> Neg for Polynomial<T>
301where
302 T: Ring,
303{
304 type Output = Polynomial<T>;
305
306 fn neg(self) -> Self::Output {
307 Self {
308 coeffs: self.coeffs.iter().map(|x| -x.clone()).collect::<Vec<_>>(),
309 }
310 }
311}
312
313impl<T> Ring for Polynomial<T> where T: Ring {}
314
315#[cfg(test)]
316mod tests {
317 use rand::{distr::Uniform, prelude::Distribution, rng};
318 use sunscreen_math_macros::BarrettConfig;
319
320 use crate::{
321 self as sunscreen_math, One, Zero,
322 poly::Polynomial,
323 ring::{BarrettBackend, Zq},
324 };
325
326 #[test]
327 fn can_add_polynomials() {
328 #[derive(BarrettConfig)]
329 #[barrett_config(modulus = "5", num_limbs = 1)]
330 struct Cfg;
331
332 type R = Zq<1, BarrettBackend<1, Cfg>>;
333 type TestPoly = Polynomial<Zq<1, BarrettBackend<1, Cfg>>>;
334
335 let a = TestPoly::new(&[R::from(1), R::from(2), R::from(3)]);
336
337 let b = TestPoly::new(&[R::from(4), R::from(1)]);
338
339 let expected = TestPoly::new(&[R::zero(), R::from(3), R::from(3)]);
340
341 assert_eq!(&a + &b, expected);
342 assert_eq!(b + a, expected);
343 }
344 #[test]
345 fn can_sub_polynomials() {
346 #[derive(BarrettConfig)]
347 #[barrett_config(modulus = "5", num_limbs = 1)]
348 struct Cfg;
349
350 type R = Zq<1, BarrettBackend<1, Cfg>>;
351 type TestPoly = Polynomial<Zq<1, BarrettBackend<1, Cfg>>>;
352
353 let a = TestPoly::new(&[R::from(1), R::from(2), R::from(3)]);
354
355 let b = TestPoly::new(&[R::from(4), R::from(1)]);
356
357 let expected_1 = TestPoly::new(&[R::from(2), R::from(1), R::from(3)]);
358
359 assert_eq!(&a - &b, expected_1);
360
361 let expected_2 = TestPoly::new(&[R::from(3), R::from(4), R::from(2)]);
362
363 assert_eq!(b - a, expected_2);
364 }
365
366 #[test]
367 fn can_mul_polynomials() {
368 #[derive(BarrettConfig)]
369 #[barrett_config(modulus = "5", num_limbs = 1)]
370 struct Cfg;
371
372 type R = Zq<1, BarrettBackend<1, Cfg>>;
373 type TestPoly = Polynomial<Zq<1, BarrettBackend<1, Cfg>>>;
374
375 let a = TestPoly::new(&[R::from(1), R::from(2), R::from(3)]);
376
377 let b = TestPoly::new(&[R::from(4), R::from(1)]);
378
379 let expected = TestPoly::new(&[R::from(4), R::from(4), R::from(4), R::from(3)]);
380
381 assert_eq!(a * b, expected);
382 }
383
384 #[test]
385 fn can_get_poly_degree_constant_coeff() {
386 #[derive(BarrettConfig)]
387 #[barrett_config(modulus = "5", num_limbs = 1)]
388 struct Cfg;
389
390 type R = Zq<1, BarrettBackend<1, Cfg>>;
391 type TestPoly = Polynomial<Zq<1, BarrettBackend<1, Cfg>>>;
392
393 let x = TestPoly {
394 coeffs: vec![R::from(1)],
395 };
396
397 assert_eq!(x.vartime_degree(), 0);
398 }
399
400 #[test]
401 fn can_get_poly_degree() {
402 #[derive(BarrettConfig)]
403 #[barrett_config(modulus = "5", num_limbs = 1)]
404 struct Cfg;
405
406 type R = Zq<1, BarrettBackend<1, Cfg>>;
407 type TestPoly = Polynomial<Zq<1, BarrettBackend<1, Cfg>>>;
408
409 let x = TestPoly {
410 coeffs: vec![R::from(0), R::from(1), R::from(2), R::from(3)],
411 };
412
413 assert_eq!(x.vartime_degree(), 3);
414 }
415
416 #[test]
417 fn can_get_poly_degree_padded_zeros() {
418 #[derive(BarrettConfig)]
419 #[barrett_config(modulus = "5", num_limbs = 1)]
420 struct Cfg;
421
422 type R = Zq<1, BarrettBackend<1, Cfg>>;
423 type TestPoly = Polynomial<Zq<1, BarrettBackend<1, Cfg>>>;
424
425 let x = TestPoly {
426 coeffs: vec![R::from(0), R::from(1), R::from(2), R::from(3), R::from(0)],
427 };
428
429 assert_eq!(x.vartime_degree(), 3);
430 }
431
432 #[test]
433 #[should_panic]
434 fn zero_poly_degree_should_panic() {
435 #[derive(BarrettConfig)]
436 #[barrett_config(modulus = "5", num_limbs = 1)]
437 struct Cfg;
438
439 type TestPoly = Polynomial<Zq<1, BarrettBackend<1, Cfg>>>;
440
441 let x = TestPoly::zero();
442
443 x.vartime_degree();
444 }
445
446 #[test]
447 #[should_panic]
448 fn zero_poly_padded_zeros_degree_should_panic() {
449 #[derive(BarrettConfig)]
450 #[barrett_config(modulus = "5", num_limbs = 1)]
451 struct Cfg;
452
453 type R = Zq<1, BarrettBackend<1, Cfg>>;
454 type TestPoly = Polynomial<Zq<1, BarrettBackend<1, Cfg>>>;
455
456 let x = TestPoly {
457 coeffs: vec![R::zero(); 3],
458 };
459
460 x.vartime_degree();
461 }
462
463 #[test]
464 fn polynomial_equality() {
465 #[derive(BarrettConfig)]
466 #[barrett_config(modulus = "6", num_limbs = 1)]
467 struct Cfg;
468
469 type R = Zq<1, BarrettBackend<1, Cfg>>;
470 type TestPoly = Polynomial<Zq<1, BarrettBackend<1, Cfg>>>;
471
472 assert_eq!(TestPoly::zero(), TestPoly::zero());
473
474 let a = TestPoly {
475 coeffs: vec![R::from(0), R::from(1), R::from(2)],
476 };
477
478 let b = TestPoly {
479 coeffs: vec![R::from(1), R::from(2), R::from(3)],
480 };
481
482 let c = TestPoly {
483 coeffs: vec![R::from(0), R::from(1), R::from(2), R::from(3)],
484 };
485
486 assert_eq!(a, a);
487 assert_ne!(a, b);
488 assert_ne!(a, c);
489 }
490
491 #[test]
492 fn polynomial_equality_padded() {
493 #[derive(BarrettConfig)]
494 #[barrett_config(modulus = "6", num_limbs = 1)]
495 struct Cfg;
496
497 type R = Zq<1, BarrettBackend<1, Cfg>>;
498 type TestPoly = Polynomial<Zq<1, BarrettBackend<1, Cfg>>>;
499
500 assert_eq!(
501 TestPoly::zero(),
502 TestPoly {
503 coeffs: vec![R::zero()]
504 }
505 );
506
507 let a = TestPoly {
508 coeffs: vec![R::from(0), R::from(1), R::from(2), R::from(0)],
509 };
510
511 let b = TestPoly {
512 coeffs: vec![R::from(0), R::from(1), R::from(2), R::from(0), R::from(0)],
513 };
514
515 let c = TestPoly {
516 coeffs: vec![R::from(0), R::from(1), R::from(2), R::from(3), R::from(0)],
517 };
518
519 assert_eq!(a, a);
520 assert_eq!(a, b);
521 assert_ne!(a, c);
522 }
523
524 #[test]
525 fn can_div_rem_basic_polynomial() {
526 #[derive(BarrettConfig)]
527 #[barrett_config(modulus = "6", num_limbs = 1)]
528 struct Cfg;
529
530 type R = Zq<1, BarrettBackend<1, Cfg>>;
531 type TestPoly = Polynomial<Zq<1, BarrettBackend<1, Cfg>>>;
532
533 let a = TestPoly {
534 coeffs: vec![
535 R::from(1),
536 R::from(2),
537 R::from(0),
538 R::from(4),
539 R::from(2),
540 R::from(3),
541 ],
542 };
543
544 let b = TestPoly {
545 coeffs: vec![R::from(1), R::from(1), R::from(1)],
546 };
547
548 let (q, rem) = a.vartime_div_rem_restricted_rhs(&b);
549
550 let actual = q * b + rem;
551
552 assert_eq!(actual, a);
553 }
554
555 #[test]
556 fn can_div_rem_basic_padded_polynomial() {
557 #[derive(BarrettConfig)]
558 #[barrett_config(modulus = "6", num_limbs = 1)]
559 struct Cfg;
560
561 type R = Zq<1, BarrettBackend<1, Cfg>>;
562 type TestPoly = Polynomial<Zq<1, BarrettBackend<1, Cfg>>>;
563
564 let a = TestPoly {
565 coeffs: vec![
566 R::from(1),
567 R::from(2),
568 R::from(0),
569 R::from(4),
570 R::from(2),
571 R::from(3),
572 R::from(0),
573 ],
574 };
575
576 let b = TestPoly {
577 coeffs: vec![R::from(1), R::from(1), R::from(1), R::from(0)],
578 };
579
580 let (q, rem) = a.vartime_div_rem_restricted_rhs(&b);
581
582 let actual = q * b + rem;
583
584 assert_eq!(actual, a);
585 }
586
587 #[test]
588 fn can_div_rem_random_polynomials() {
589 #[derive(BarrettConfig)]
590 #[barrett_config(modulus = "1234", num_limbs = 1)]
591 struct Cfg;
592
593 type R = Zq<1, BarrettBackend<1, Cfg>>;
594 type TestPoly = Polynomial<Zq<1, BarrettBackend<1, Cfg>>>;
595
596 fn test_case() {
597 let target_den_degree = Uniform::try_from(2..50).unwrap().sample(&mut rng());
598 let target_num_degree = Uniform::try_from(1..200).unwrap().sample(&mut rng());
599
600 let mut num = TestPoly { coeffs: vec![] };
601
602 let mut den = num.clone();
603
604 for _ in 0..target_den_degree {
605 let coeff = Uniform::try_from(0..1234u64).unwrap().sample(&mut rng());
606 den.coeffs.push(R::from(coeff));
607 }
608
609 den.coeffs.push(R::one());
611
612 for _ in 0..=target_num_degree {
613 let coeff = Uniform::try_from(0..1234u64).unwrap().sample(&mut rng());
614 num.coeffs.push(R::from(coeff));
615 }
616
617 let (q, rem) = num.vartime_div_rem_restricted_rhs(&den);
618
619 assert_eq!(q * &den + &rem, num);
620 assert!(rem.vartime_degree() < den.vartime_degree());
621 }
622
623 for _ in 0..100 {
624 test_case();
625 }
626 }
627}