sapling_crypto_ce/
interpolation.rs

1use bellman::pairing::Engine;
2
3use bellman::pairing::ff::{
4    Field,
5    PrimeField
6};
7
8/// Perform a Lagrange interpolation for a set of points
9/// It's O(n^2) operations, so use with caution
10pub fn interpolate<E: Engine>(
11    points: &[(E::Fr, E::Fr)]
12) -> Option<Vec<E::Fr>> {
13    let max_degree_plus_one = points.len();
14    assert!(max_degree_plus_one >= 2, "should interpolate for degree >= 1");
15    let mut coeffs = vec![E::Fr::zero(); max_degree_plus_one];
16    // external iterator
17    for (k, p_k) in points.iter().enumerate() {
18        let (x_k, y_k) = p_k;
19        // coeffs from 0 to max_degree - 1
20        let mut contribution = vec![E::Fr::zero(); max_degree_plus_one];
21        let mut demoninator = E::Fr::one();
22        let mut max_contribution_degree = 0;
23        // internal iterator
24        for (j, p_j) in points.iter().enumerate() {
25            let (x_j, _) = p_j;
26            if j == k {
27                continue;
28            }
29
30            let mut diff = *x_k;
31            diff.sub_assign(&x_j);
32            demoninator.mul_assign(&diff);
33
34            if max_contribution_degree == 0 {
35                max_contribution_degree = 1;
36                contribution.get_mut(0).expect("must have enough coefficients").sub_assign(&x_j);
37                contribution.get_mut(1).expect("must have enough coefficients").add_assign(&E::Fr::one());
38            } else {
39                let mul_by_minus_x_j: Vec<E::Fr> = contribution.iter().map(|el| {
40                    let mut tmp = *el;
41                    tmp.mul_assign(&x_j);
42                    tmp.negate();
43
44                    tmp
45                }).collect();
46
47                contribution.insert(0, E::Fr::zero());
48                contribution.truncate(max_degree_plus_one);
49
50                assert_eq!(mul_by_minus_x_j.len(), max_degree_plus_one);
51                for (i, c) in contribution.iter_mut().enumerate() {
52                    let other = mul_by_minus_x_j.get(i).expect("should have enough elements");
53                    c.add_assign(&other);
54                }
55            }
56        }
57
58        demoninator = demoninator.inverse().expect("denominator must be non-zero");
59        for (i, this_contribution) in contribution.into_iter().enumerate() {
60            let c = coeffs.get_mut(i).expect("should have enough coefficients");
61            let mut tmp = this_contribution;
62            tmp.mul_assign(&demoninator);
63            tmp.mul_assign(&y_k);
64            c.add_assign(&tmp);
65        }
66
67    }
68
69    Some(coeffs)
70}
71
72pub fn evaluate_at_x<E: Engine>(
73    coeffs: &[E::Fr],
74    x: &E::Fr
75) -> E::Fr {
76    let mut res = E::Fr::zero();
77    let mut pow = E::Fr::one();
78    for c in coeffs.iter() {
79        let mut tmp = c.clone();
80        tmp.mul_assign(&pow);
81        res.add_assign(&tmp);
82
83        pow.mul_assign(&x);
84    }
85
86    res
87}
88
89#[test]
90fn test_interpolation_1(){
91    use bellman::pairing::bn256::{Bn256, Fr};
92    let points = vec![(Fr::zero(), Fr::one()), (Fr::one(), Fr::from_str("2").unwrap())];
93    let interpolation_res = interpolate::<Bn256>(&points[..]).expect("must interpolate a linear func");
94    assert_eq!(interpolation_res.len(), 2);
95    for (i, c) in interpolation_res.iter().enumerate() {
96        println!("Coeff {} = {}", i, c);
97    }
98
99    for (i, p) in points.iter().enumerate() {
100        let (x, y) = p;
101        let val = evaluate_at_x::<Bn256>(&interpolation_res[..], &x);
102        assert_eq!(*y, val);
103        println!("Eval at {} = {}, original value = {}", x, val, y);
104    }
105}
106
107#[test]
108fn test_interpolation_powers_of_2(){
109    use bellman::pairing::bn256::{Bn256, Fr};
110    const MAX_POWER: u32 = Fr::CAPACITY;
111
112    let mut points: Vec<(Fr, Fr)> = vec![];
113    let mut power = Fr::one();
114    let two = Fr::from_str("2").unwrap();
115    for i in 0..MAX_POWER {
116        let x = Fr::from_str(&i.to_string()).unwrap();
117        let y = power.clone();
118        points.push((x,y));
119
120        power.mul_assign(&two);
121    }
122    let interpolation_res = interpolate::<Bn256>(&points[..]).expect("must interpolate");
123    assert_eq!(*interpolation_res.get(0).unwrap(), Fr::one());
124    assert_eq!(interpolation_res.len(), points.len(), "array sized must match");
125    assert_eq!(interpolation_res.len(), MAX_POWER as usize, "array size must be equal to the max power");
126
127    for (i, p) in points.iter().enumerate() {
128        let (x, y) = p;
129        let val = evaluate_at_x::<Bn256>(&interpolation_res[..], &x);
130        // println!("Eval at {} = {}, original value = {}", x, val, y);
131        // assert!(*y == val, format!("must assert equality for x = {}", x) );
132        assert_eq!(*y, val);
133
134    }
135}
136
137