sunscreen_math/poly/
mod.rs

1use std::ops::{Add, Index, IndexMut, Mul, Neg, Sub};
2
3use serde::{Deserialize, Serialize};
4use sunscreen_math_macros::refify_binary_op;
5
6use crate::{One, Zero, ring::Ring};
7
8#[derive(Debug, Clone, Eq, Serialize, Deserialize)]
9/// A polynomial over the ring `T`.
10///
11/// # Remarks
12/// The coefficient at index `i` corresponds to the `x^i` term. E.g.
13/// index 0 is the constant coefficient term, 1 is x, ... n is x^n.
14pub struct Polynomial<R>
15where
16    R: Ring,
17{
18    /// The coefficients of the polynomial.
19    pub coeffs: Vec<R>,
20}
21
22impl<R> PartialEq for Polynomial<R>
23where
24    R: Ring,
25{
26    /// Computes polynomial equality.
27    ///
28    /// # Remarks
29    /// Variable time
30    fn eq(&self, other: &Self) -> bool {
31        // Need to handle zero polynomial specially, as calling degree will panic.
32        let lhs_is_zero = self.vartime_is_zero();
33        let rhs_is_zero = other.vartime_is_zero();
34
35        if lhs_is_zero || rhs_is_zero {
36            return lhs_is_zero && rhs_is_zero;
37        }
38
39        let lhs_degree = self.vartime_degree();
40        let rhs_degree = other.vartime_degree();
41
42        if lhs_degree != rhs_degree {
43            return false;
44        }
45
46        for i in 0..lhs_degree {
47            if self.coeffs[i] != other.coeffs[i] {
48                return false;
49            }
50        }
51
52        true
53    }
54}
55
56impl<R> Polynomial<R>
57where
58    R: Ring,
59{
60    /// Construct a new polynomial with the given coefficients.
61    pub fn new(coeffs: &[R]) -> Self {
62        Self {
63            coeffs: coeffs.to_owned(),
64        }
65    }
66
67    /// Evaluate the polynomial at x.
68    pub fn evaluate(&self, x: &R) -> R {
69        let mut eval = R::zero();
70        let mut cur_pow = R::one();
71
72        for i in &self.coeffs {
73            eval = eval + i.clone() * &cur_pow;
74            cur_pow = cur_pow.clone() * x;
75        }
76
77        eval
78    }
79
80    /// Returns the degree of the polynomial
81    ///
82    /// # Remarks
83    /// Runtime varies depending on the number of leading zeros.
84    ///
85    /// # Panics
86    /// The degree of the zero polynomial is undefined, and thus this function will
87    /// panic.
88    pub fn vartime_degree(&self) -> usize {
89        for (i, coeff) in self.coeffs.iter().rev().enumerate() {
90            if !coeff.vartime_is_zero() {
91                return self.coeffs.len() - i - 1;
92            }
93        }
94
95        panic!("Zero polynomial has undefined degree.");
96    }
97
98    /// Computes the quotient and remainder of `self / rhs`.
99    ///
100    /// # Remarks
101    /// Runtime is variable except for very restricted use cases.
102    /// This function will be constant time so long as:
103    /// * Neither the numerator nor denominator have leading zeros.
104    /// * The numerator is always of higher degree than the denominator
105    /// * The numerator's and denominator's degrees are fixed across invocations
106    /// * The inner type R supports constant time subtraction and multiplication.
107    ///
108    /// In order for polynomial division to work in a ring, `rhs` has restrictions.
109    /// Specifically, the highest order non-zero coefficient in `rhs` must be 1 so as to avoid
110    /// inverse operations. While multiplicative inverses are not guaranteed to exist
111    /// for a [`Ring`] element, the inverse of one will always be one.
112    ///
113    /// # Panics
114    /// If the divisor's leading non-zero coefficient isn't one.
115    ///
116    /// If rhs is the zero polynomial.
117    pub fn vartime_div_rem_restricted_rhs(&self, rhs: &Self) -> (Self, Self) {
118        let mut rem = self.clone();
119
120        if self.vartime_is_zero() {
121            return (Self::zero(), Self::zero());
122        }
123
124        let lhs_degree = self.vartime_degree();
125
126        // Will panic if rhs is `Self::zero()`
127        let rhs_degree = rhs.vartime_degree();
128
129        // If the denominator is higher degree than the numerator, then we're done.
130        if lhs_degree < rhs_degree {
131            return (Self::zero(), rem);
132        }
133
134        let iter_count = lhs_degree - rhs_degree + 1;
135        let mut q = Polynomial {
136            coeffs: vec![R::zero(); iter_count],
137        };
138
139        for i in 0..iter_count {
140            // Normally, we would compute the scale factor as coeff_i(rem) * coeff_i(rhs)^-1,
141            // but inverse isn't defined for rings. Since we leverage the fact that the
142            // leading coefficient is always 1, we don't have this problem.
143            let scale = rem.coeffs[lhs_degree - i].clone();
144
145            for j in 0..=rhs_degree {
146                let lhs_index = lhs_degree - i - j;
147                let rhs_index = rhs_degree - j;
148
149                rem.coeffs[lhs_index] =
150                    rem.coeffs[lhs_index].clone() - rhs.coeffs[rhs_index].clone() * &scale;
151            }
152
153            q.coeffs[iter_count - i - 1] = scale;
154        }
155
156        (q, rem)
157    }
158}
159
160impl<T> Index<usize> for Polynomial<T>
161where
162    T: Ring,
163{
164    type Output = T;
165
166    fn index(&self, index: usize) -> &Self::Output {
167        &self.coeffs[index]
168    }
169}
170
171impl<T> IndexMut<usize> for Polynomial<T>
172where
173    T: Ring,
174{
175    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
176        &mut self.coeffs[index]
177    }
178}
179
180#[refify_binary_op]
181impl<T> Add<&Polynomial<T>> for &Polynomial<T>
182where
183    T: Ring,
184{
185    type Output = Polynomial<T>;
186
187    fn add(self, rhs: &Polynomial<T>) -> Self::Output {
188        let out_len = usize::max(self.coeffs.len(), rhs.coeffs.len());
189
190        let mut out_coeffs = Vec::with_capacity(out_len);
191        let len = usize::max(self.coeffs.len(), rhs.coeffs.len());
192
193        for i in 0..len {
194            let a = self.coeffs.get(i).unwrap_or(&T::zero()).clone();
195            let b = rhs.coeffs.get(i).unwrap_or(&T::zero()).clone();
196
197            out_coeffs.push(a + b);
198        }
199
200        Polynomial { coeffs: out_coeffs }
201    }
202}
203
204#[refify_binary_op]
205impl<T> Sub<&Polynomial<T>> for &Polynomial<T>
206where
207    T: Ring,
208{
209    type Output = Polynomial<T>;
210
211    fn sub(self, rhs: &Polynomial<T>) -> Self::Output {
212        let out_len = usize::max(self.coeffs.len(), rhs.coeffs.len());
213
214        let mut out_coeffs = Vec::with_capacity(out_len);
215        let len = usize::max(self.coeffs.len(), rhs.coeffs.len());
216
217        for i in 0..len {
218            let a = self.coeffs.get(i).unwrap_or(&T::zero()).clone();
219            let b = rhs.coeffs.get(i).unwrap_or(&T::zero()).clone();
220
221            out_coeffs.push(a - b);
222        }
223
224        Polynomial { coeffs: out_coeffs }
225    }
226}
227
228#[refify_binary_op]
229impl<T> Mul<&Polynomial<T>> for &Polynomial<T>
230where
231    T: Ring,
232{
233    type Output = Polynomial<T>;
234
235    fn mul(self, rhs: &Polynomial<T>) -> Self::Output {
236        // TODO: Fix vartime
237        if self.coeffs.is_empty() || rhs.coeffs.is_empty() {
238            return Self::Output::zero();
239        }
240
241        let mut out_coeffs = vec![T::zero(); (self.coeffs.len() - 1) + (rhs.coeffs.len() - 1) + 1];
242
243        for i in 0..self.coeffs.len() {
244            for j in 0..rhs.coeffs.len() {
245                let a = self.coeffs.get(i).unwrap_or(&T::zero()).clone();
246                let b = rhs.coeffs.get(j).unwrap_or(&T::zero()).clone();
247
248                out_coeffs[i + j] = a * b + &out_coeffs[i + j];
249            }
250        }
251
252        Polynomial { coeffs: out_coeffs }
253    }
254}
255
256#[refify_binary_op]
257impl<T> Mul<&T> for &Polynomial<T>
258where
259    T: Ring,
260{
261    type Output = Polynomial<T>;
262
263    fn mul(self, rhs: &T) -> Self::Output {
264        Self::Output {
265            coeffs: self
266                .coeffs
267                .iter()
268                .map(|x| x.clone() * rhs)
269                .collect::<Vec<_>>(),
270        }
271    }
272}
273
274impl<T> Zero for Polynomial<T>
275where
276    T: Ring,
277{
278    #[inline(always)]
279    fn zero() -> Self {
280        Self { coeffs: vec![] }
281    }
282
283    fn vartime_is_zero(&self) -> bool {
284        self.coeffs.iter().all(|x| x.vartime_is_zero())
285    }
286}
287
288impl<T> One for Polynomial<T>
289where
290    T: Ring,
291{
292    #[inline(always)]
293    fn one() -> Self {
294        Self {
295            coeffs: vec![T::one()],
296        }
297    }
298}
299
300impl<T> Neg for Polynomial<T>
301where
302    T: Ring,
303{
304    type Output = Polynomial<T>;
305
306    fn neg(self) -> Self::Output {
307        Self {
308            coeffs: self.coeffs.iter().map(|x| -x.clone()).collect::<Vec<_>>(),
309        }
310    }
311}
312
313impl<T> Ring for Polynomial<T> where T: Ring {}
314
315#[cfg(test)]
316mod tests {
317    use rand::{distr::Uniform, prelude::Distribution, rng};
318    use sunscreen_math_macros::BarrettConfig;
319
320    use crate::{
321        self as sunscreen_math, One, Zero,
322        poly::Polynomial,
323        ring::{BarrettBackend, Zq},
324    };
325
326    #[test]
327    fn can_add_polynomials() {
328        #[derive(BarrettConfig)]
329        #[barrett_config(modulus = "5", num_limbs = 1)]
330        struct Cfg;
331
332        type R = Zq<1, BarrettBackend<1, Cfg>>;
333        type TestPoly = Polynomial<Zq<1, BarrettBackend<1, Cfg>>>;
334
335        let a = TestPoly::new(&[R::from(1), R::from(2), R::from(3)]);
336
337        let b = TestPoly::new(&[R::from(4), R::from(1)]);
338
339        let expected = TestPoly::new(&[R::zero(), R::from(3), R::from(3)]);
340
341        assert_eq!(&a + &b, expected);
342        assert_eq!(b + a, expected);
343    }
344    #[test]
345    fn can_sub_polynomials() {
346        #[derive(BarrettConfig)]
347        #[barrett_config(modulus = "5", num_limbs = 1)]
348        struct Cfg;
349
350        type R = Zq<1, BarrettBackend<1, Cfg>>;
351        type TestPoly = Polynomial<Zq<1, BarrettBackend<1, Cfg>>>;
352
353        let a = TestPoly::new(&[R::from(1), R::from(2), R::from(3)]);
354
355        let b = TestPoly::new(&[R::from(4), R::from(1)]);
356
357        let expected_1 = TestPoly::new(&[R::from(2), R::from(1), R::from(3)]);
358
359        assert_eq!(&a - &b, expected_1);
360
361        let expected_2 = TestPoly::new(&[R::from(3), R::from(4), R::from(2)]);
362
363        assert_eq!(b - a, expected_2);
364    }
365
366    #[test]
367    fn can_mul_polynomials() {
368        #[derive(BarrettConfig)]
369        #[barrett_config(modulus = "5", num_limbs = 1)]
370        struct Cfg;
371
372        type R = Zq<1, BarrettBackend<1, Cfg>>;
373        type TestPoly = Polynomial<Zq<1, BarrettBackend<1, Cfg>>>;
374
375        let a = TestPoly::new(&[R::from(1), R::from(2), R::from(3)]);
376
377        let b = TestPoly::new(&[R::from(4), R::from(1)]);
378
379        let expected = TestPoly::new(&[R::from(4), R::from(4), R::from(4), R::from(3)]);
380
381        assert_eq!(a * b, expected);
382    }
383
384    #[test]
385    fn can_get_poly_degree_constant_coeff() {
386        #[derive(BarrettConfig)]
387        #[barrett_config(modulus = "5", num_limbs = 1)]
388        struct Cfg;
389
390        type R = Zq<1, BarrettBackend<1, Cfg>>;
391        type TestPoly = Polynomial<Zq<1, BarrettBackend<1, Cfg>>>;
392
393        let x = TestPoly {
394            coeffs: vec![R::from(1)],
395        };
396
397        assert_eq!(x.vartime_degree(), 0);
398    }
399
400    #[test]
401    fn can_get_poly_degree() {
402        #[derive(BarrettConfig)]
403        #[barrett_config(modulus = "5", num_limbs = 1)]
404        struct Cfg;
405
406        type R = Zq<1, BarrettBackend<1, Cfg>>;
407        type TestPoly = Polynomial<Zq<1, BarrettBackend<1, Cfg>>>;
408
409        let x = TestPoly {
410            coeffs: vec![R::from(0), R::from(1), R::from(2), R::from(3)],
411        };
412
413        assert_eq!(x.vartime_degree(), 3);
414    }
415
416    #[test]
417    fn can_get_poly_degree_padded_zeros() {
418        #[derive(BarrettConfig)]
419        #[barrett_config(modulus = "5", num_limbs = 1)]
420        struct Cfg;
421
422        type R = Zq<1, BarrettBackend<1, Cfg>>;
423        type TestPoly = Polynomial<Zq<1, BarrettBackend<1, Cfg>>>;
424
425        let x = TestPoly {
426            coeffs: vec![R::from(0), R::from(1), R::from(2), R::from(3), R::from(0)],
427        };
428
429        assert_eq!(x.vartime_degree(), 3);
430    }
431
432    #[test]
433    #[should_panic]
434    fn zero_poly_degree_should_panic() {
435        #[derive(BarrettConfig)]
436        #[barrett_config(modulus = "5", num_limbs = 1)]
437        struct Cfg;
438
439        type TestPoly = Polynomial<Zq<1, BarrettBackend<1, Cfg>>>;
440
441        let x = TestPoly::zero();
442
443        x.vartime_degree();
444    }
445
446    #[test]
447    #[should_panic]
448    fn zero_poly_padded_zeros_degree_should_panic() {
449        #[derive(BarrettConfig)]
450        #[barrett_config(modulus = "5", num_limbs = 1)]
451        struct Cfg;
452
453        type R = Zq<1, BarrettBackend<1, Cfg>>;
454        type TestPoly = Polynomial<Zq<1, BarrettBackend<1, Cfg>>>;
455
456        let x = TestPoly {
457            coeffs: vec![R::zero(); 3],
458        };
459
460        x.vartime_degree();
461    }
462
463    #[test]
464    fn polynomial_equality() {
465        #[derive(BarrettConfig)]
466        #[barrett_config(modulus = "6", num_limbs = 1)]
467        struct Cfg;
468
469        type R = Zq<1, BarrettBackend<1, Cfg>>;
470        type TestPoly = Polynomial<Zq<1, BarrettBackend<1, Cfg>>>;
471
472        assert_eq!(TestPoly::zero(), TestPoly::zero());
473
474        let a = TestPoly {
475            coeffs: vec![R::from(0), R::from(1), R::from(2)],
476        };
477
478        let b = TestPoly {
479            coeffs: vec![R::from(1), R::from(2), R::from(3)],
480        };
481
482        let c = TestPoly {
483            coeffs: vec![R::from(0), R::from(1), R::from(2), R::from(3)],
484        };
485
486        assert_eq!(a, a);
487        assert_ne!(a, b);
488        assert_ne!(a, c);
489    }
490
491    #[test]
492    fn polynomial_equality_padded() {
493        #[derive(BarrettConfig)]
494        #[barrett_config(modulus = "6", num_limbs = 1)]
495        struct Cfg;
496
497        type R = Zq<1, BarrettBackend<1, Cfg>>;
498        type TestPoly = Polynomial<Zq<1, BarrettBackend<1, Cfg>>>;
499
500        assert_eq!(
501            TestPoly::zero(),
502            TestPoly {
503                coeffs: vec![R::zero()]
504            }
505        );
506
507        let a = TestPoly {
508            coeffs: vec![R::from(0), R::from(1), R::from(2), R::from(0)],
509        };
510
511        let b = TestPoly {
512            coeffs: vec![R::from(0), R::from(1), R::from(2), R::from(0), R::from(0)],
513        };
514
515        let c = TestPoly {
516            coeffs: vec![R::from(0), R::from(1), R::from(2), R::from(3), R::from(0)],
517        };
518
519        assert_eq!(a, a);
520        assert_eq!(a, b);
521        assert_ne!(a, c);
522    }
523
524    #[test]
525    fn can_div_rem_basic_polynomial() {
526        #[derive(BarrettConfig)]
527        #[barrett_config(modulus = "6", num_limbs = 1)]
528        struct Cfg;
529
530        type R = Zq<1, BarrettBackend<1, Cfg>>;
531        type TestPoly = Polynomial<Zq<1, BarrettBackend<1, Cfg>>>;
532
533        let a = TestPoly {
534            coeffs: vec![
535                R::from(1),
536                R::from(2),
537                R::from(0),
538                R::from(4),
539                R::from(2),
540                R::from(3),
541            ],
542        };
543
544        let b = TestPoly {
545            coeffs: vec![R::from(1), R::from(1), R::from(1)],
546        };
547
548        let (q, rem) = a.vartime_div_rem_restricted_rhs(&b);
549
550        let actual = q * b + rem;
551
552        assert_eq!(actual, a);
553    }
554
555    #[test]
556    fn can_div_rem_basic_padded_polynomial() {
557        #[derive(BarrettConfig)]
558        #[barrett_config(modulus = "6", num_limbs = 1)]
559        struct Cfg;
560
561        type R = Zq<1, BarrettBackend<1, Cfg>>;
562        type TestPoly = Polynomial<Zq<1, BarrettBackend<1, Cfg>>>;
563
564        let a = TestPoly {
565            coeffs: vec![
566                R::from(1),
567                R::from(2),
568                R::from(0),
569                R::from(4),
570                R::from(2),
571                R::from(3),
572                R::from(0),
573            ],
574        };
575
576        let b = TestPoly {
577            coeffs: vec![R::from(1), R::from(1), R::from(1), R::from(0)],
578        };
579
580        let (q, rem) = a.vartime_div_rem_restricted_rhs(&b);
581
582        let actual = q * b + rem;
583
584        assert_eq!(actual, a);
585    }
586
587    #[test]
588    fn can_div_rem_random_polynomials() {
589        #[derive(BarrettConfig)]
590        #[barrett_config(modulus = "1234", num_limbs = 1)]
591        struct Cfg;
592
593        type R = Zq<1, BarrettBackend<1, Cfg>>;
594        type TestPoly = Polynomial<Zq<1, BarrettBackend<1, Cfg>>>;
595
596        fn test_case() {
597            let target_den_degree = Uniform::try_from(2..50).unwrap().sample(&mut rng());
598            let target_num_degree = Uniform::try_from(1..200).unwrap().sample(&mut rng());
599
600            let mut num = TestPoly { coeffs: vec![] };
601
602            let mut den = num.clone();
603
604            for _ in 0..target_den_degree {
605                let coeff = Uniform::try_from(0..1234u64).unwrap().sample(&mut rng());
606                den.coeffs.push(R::from(coeff));
607            }
608
609            // Leading coefficient in denominator must be a 1.
610            den.coeffs.push(R::one());
611
612            for _ in 0..=target_num_degree {
613                let coeff = Uniform::try_from(0..1234u64).unwrap().sample(&mut rng());
614                num.coeffs.push(R::from(coeff));
615            }
616
617            let (q, rem) = num.vartime_div_rem_restricted_rhs(&den);
618
619            assert_eq!(q * &den + &rem, num);
620            assert!(rem.vartime_degree() < den.vartime_degree());
621        }
622
623        for _ in 0..100 {
624            test_case();
625        }
626    }
627}