1use crate::group::{Curve, Element, Point, Scalar};
2use rand_core::RngCore;
3use serde::{Deserialize, Serialize};
4use std::{collections::BTreeMap, fmt};
5use thiserror::Error;
6
7pub type PrivatePoly<C> = Poly<<C as Curve>::Scalar>;
8pub type PublicPoly<C> = Poly<<C as Curve>::Point>;
9
10pub type Idx = u32;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct Eval<A> {
14 pub value: A,
15 pub index: Idx,
16}
17
18impl<A: fmt::Display> fmt::Display for Eval<A> {
19 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
20 write!(f, "{{ idx: {}, value: {} }}", self.index, self.value)
21 }
22}
23
24#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
28pub struct Poly<C>(Vec<C>);
29
30impl<C> Poly<C> {
31 pub fn degree(&self) -> usize {
33 self.0.len() - 1
36 }
37
38 #[cfg(test)]
39 fn len(&self) -> usize {
41 self.0.len()
42 }
43}
44
45impl<C: Element> Poly<C> {
46 pub fn new_from<R: RngCore>(degree: usize, rng: &mut R) -> Self {
50 let coeffs: Vec<C> = (0..=degree).map(|_| C::rand(rng)).collect();
51 Self::from(coeffs)
52 }
53
54 pub fn get(&self, i: Idx) -> C {
57 self.0[i as usize].clone()
58 }
59
60 pub fn set(&mut self, index: usize, value: C) {
63 self.0[index] = value;
64 }
65
66 pub fn new(degree: usize) -> Self {
71 use rand::prelude::*;
72 Self::new_from(degree, &mut thread_rng())
73 }
74
75 pub fn zero() -> Self {
80 Self::from(vec![C::zero()])
81 }
82
83 fn is_zero(&self) -> bool {
84 self.0.is_empty() || self.0.iter().all(|coeff| coeff == &C::zero())
85 }
86
87 pub fn add(&mut self, other: &Self) {
89 if self.0.len() < other.0.len() {
91 self.0.resize(other.0.len(), C::zero())
92 }
93
94 self.0.iter_mut().zip(&other.0).for_each(|(a, b)| a.add(b))
95 }
96}
97
98#[derive(Debug, Error)]
99pub enum PolyError {
100 #[error("Invalid recovery: only has {0}/{1} shares")]
101 InvalidRecovery(usize, usize),
102 #[error("Could not invert scalar")]
103 NoInverse,
104}
105
106impl<C> Poly<C>
107where
108 C: Element,
109 C::RHS: Scalar<RHS = C::RHS>,
110{
111 pub fn eval(&self, i: Idx) -> Eval<C> {
113 let mut xi = C::RHS::new();
114 xi.set_int((i + 1).into());
118
119 let res = self.0.iter().rev().fold(C::zero(), |mut sum, coeff| {
120 sum.mul(&xi);
121 sum.add(coeff);
122 sum
123 });
124
125 Eval {
126 value: res,
127 index: i,
128 }
129 }
130
131 pub fn recover(t: usize, shares: Vec<Eval<C>>) -> Result<C, PolyError> {
134 let xs = Self::share_map(t, shares)?;
135
136 let mut acc = C::zero();
139 for (i, xi) in &xs {
140 let mut yi = xi.1.clone();
141 let mut num = C::RHS::one();
142 let mut den = C::RHS::one();
143
144 for (j, xj) in &xs {
145 if i == j {
146 continue;
147 }
148
149 num.mul(&xj.0);
151
152 let mut tmp = xj.0.clone();
154 tmp.sub(&xi.0);
155 den.mul(&tmp);
156 }
157
158 let inv = den.inverse().ok_or(PolyError::NoInverse)?;
159 num.mul(&inv);
160 yi.mul(&num);
161 acc.add(&yi);
162 }
163
164 Ok(acc)
165 }
166
167 pub fn full_recover(t: usize, shares: Vec<Eval<C>>) -> Result<Self, PolyError> {
169 let xs = Self::share_map(t, shares)?;
170
171 let res = xs
174 .iter()
175 .map(|(i, share)| (share, Poly::<C::RHS>::lagrange_basis(*i, &xs)))
177 .map(|(share, basis)| {
179 let linear_coeffs = basis
181 .0
182 .into_iter()
183 .map(move |c| {
184 let mut s = share.1.clone();
187 s.mul(&c);
188 s
189 })
190 .collect::<Vec<_>>();
191
192 Self::from(linear_coeffs)
193 })
194 .fold(Self::zero(), |mut acc, poly| {
195 acc.add(&poly);
196 acc
197 });
198
199 Ok(res)
200 }
201
202 fn share_map(
203 t: usize,
204 mut shares: Vec<Eval<C>>,
205 ) -> Result<BTreeMap<Idx, (C::RHS, C)>, PolyError> {
206 if shares.len() < t {
207 return Err(PolyError::InvalidRecovery(shares.len(), t));
208 }
209
210 shares.sort_by(|a, b| a.index.cmp(&b.index));
213
214 let xs = shares
216 .into_iter()
217 .take(t)
218 .fold(BTreeMap::new(), |mut m, sh| {
219 let mut xi = C::RHS::new();
220 xi.set_int((sh.index + 1).into());
221 m.insert(sh.index, (xi, sh.value));
222 m
223 });
224
225 debug_assert_eq!(xs.len(), t);
226
227 Ok(xs)
228 }
229
230 pub fn public_key(&self) -> &C {
233 &self.0[0]
234 }
235}
236
237impl<C: Element> From<Vec<C>> for Poly<C> {
238 fn from(c: Vec<C>) -> Self {
239 Self(c)
240 }
241}
242
243impl<C: Element> From<Poly<C>> for Vec<C> {
244 fn from(poly: Poly<C>) -> Self {
245 poly.0
246 }
247}
248
249impl<X: Scalar<RHS = X>> Poly<X> {
250 fn mul(&mut self, other: &Self) {
258 if self.is_zero() || other.is_zero() {
259 *self = Self::zero();
260 return;
261 }
262
263 let d3 = self.degree() + other.degree();
264
265 let mut coeffs = (0..=d3).map(|_| X::zero()).collect::<Vec<X>>();
267
268 for (i, c1) in self.0.iter().enumerate() {
269 for (j, c2) in other.0.iter().enumerate() {
270 let mut tmp = X::one();
272 tmp.mul(c1);
273 tmp.mul(c2);
274 coeffs[i + j].add(&tmp);
275 }
276 }
277
278 self.0 = coeffs;
279 }
280
281 fn new_neg_constant(mut c: X) -> Poly<X> {
283 c.negate();
284 Poly::from(vec![c, X::one()])
285 }
286
287 fn lagrange_basis<E: Element<RHS = X>>(i: Idx, xs: &BTreeMap<Idx, (X, E)>) -> Poly<X> {
289 let mut basis = Poly::<X>::from(vec![X::one()]);
290
291 let mut acc = X::one();
293
294 let xi = xs.get(&i).unwrap().clone().0;
296 for (idx, sc) in xs.iter() {
297 if *idx == i {
298 continue;
299 }
300
301 let minus_sc = Poly::<X>::new_neg_constant(sc.0.clone());
303 basis.mul(&minus_sc);
304
305 let mut den = X::zero();
307 den.add(&xi);
308 den.sub(&sc.0);
309
310 den = den.inverse().unwrap();
312
313 acc.mul(&den);
315 }
316
317 basis.mul(&Poly::from(vec![acc]));
319 basis
320 }
321
322 pub fn commit<P: Point<RHS = X>>(&self) -> Poly<P> {
328 let commits = self
329 .0
330 .iter()
331 .map(|c| {
332 let mut commitment = P::one();
333 commitment.mul(c);
334 commitment
335 })
336 .collect::<Vec<P>>();
337
338 Poly::<P>::from(commits)
339 }
340}
341
342impl<C: fmt::Display> fmt::Display for Poly<C> {
343 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
344 let s = self
345 .0
346 .iter()
347 .enumerate()
348 .map(|(i, c)| format!("{}: {}", i, c))
349 .collect::<Vec<String>>()
350 .join(", ");
351 write!(f, "[deg: {}, coeffs: [{}]]", self.degree(), s)
352 }
353}
354
355#[cfg(feature = "bls12_381")]
356#[cfg(test)]
357pub mod tests {
358 use super::*;
359 use crate::curve::bls12381::Scalar as Sc;
360 use crate::curve::bls12381::G1;
361 use rand::prelude::*;
362
363 #[test]
364 fn poly_degree() {
365 let s = 5;
366 let p = Poly::<Sc>::new(s);
367 assert_eq!(p.0.len(), s + 1);
368 assert_eq!(p.degree(), s);
369 }
370
371 #[test]
372 fn add_zero() {
373 let p1 = Poly::<Sc>::new(3);
374 let p2 = Poly::<Sc>::zero();
375 let mut res = p1.clone();
376 res.add(&p2);
377 assert_eq!(res, p1);
378
379 let p1 = Poly::<Sc>::zero();
380 let p2 = Poly::<Sc>::new(3);
381 let mut res = p1;
382 res.add(&p2);
383 assert_eq!(res, p2);
384 }
385
386 #[test]
387 fn mul_by_zero() {
388 let p1 = Poly::<Sc>::new(3);
389 let p2 = Poly::<Sc>::zero();
390 let mut res = p1;
391 res.mul(&p2);
392 assert_eq!(res, Poly::<Sc>::zero());
393
394 let p1 = Poly::<Sc>::zero();
395 let p2 = Poly::<Sc>::new(3);
396 let mut res = p1;
397 res.mul(&p2);
398 assert_eq!(res, Poly::<Sc>::zero());
399 }
400
401 use proptest::prelude::*;
402
403 proptest! {
404
405 #[test]
408 fn addition(deg1 in 0..100usize, deg2 in 0..100usize) {
409 dbg!(deg1, deg2);
410 let p1 = Poly::<Sc>::new(deg1);
411 let p2 = Poly::<Sc>::new(deg2);
412 let mut res = p1.clone();
413 res.add(&p2);
414
415 let (larger, smaller) = if p1.degree() > p2.degree() {
416 (&p1, &p2)
417 } else {
418 (&p2, &p1)
419 };
420
421 for i in 0..larger.len() {
422 if i < smaller.len() {
423 let mut coeff_sum = p1.0[i];
424 coeff_sum.add(&p2.0[i]);
425 assert_eq!(res.0[i], coeff_sum);
426 } else {
427 assert_eq!(res.0[i], larger.0[i]);
429 }
430 }
431
432 assert_eq!(res.degree(), larger.degree());
434 }
435
436
437 #[test]
438 fn interpolation(degree in 0..10usize, num_evals in 0..10usize) {
439 let poly = Poly::<Sc>::new(degree);
440 let expected = poly.0[0];
441
442 let shares = (0..num_evals)
443 .map(|i| poly.eval(i as Idx))
444 .collect::<Vec<_>>();
445
446 let recovered_poly = Poly::<Sc>::full_recover(num_evals, shares.clone()).unwrap();
447 let computed = recovered_poly.0[0];
448
449 let recovered_constant = Poly::<Sc>::recover(num_evals, shares).unwrap();
450
451 if num_evals > degree {
453 assert_eq!(expected, computed);
454 assert_eq!(expected, recovered_constant);
455 } else {
456 assert_ne!(expected, computed);
459 assert_ne!(expected, recovered_constant);
460 }
461 }
462
463 #[test]
464 fn eval(d in 0..100usize, idx in 0..(100 as Idx)) {
465 let mut x = Sc::new();
466 x.set_int(idx as u64 + 1);
467
468 let p1 = Poly::<Sc>::new(d);
469 let evaluation = p1.eval(idx).value;
470
471 let coeffs = p1.0;
473 let mut sum = coeffs[0];
474 for (i, coeff) in coeffs.into_iter().enumerate().take(d + 1).skip(1) {
475 let xi = pow(x, i);
476 let mut var = coeff;
477 var.mul(&xi);
478 sum.add(&var);
479 }
480
481 assert_eq!(sum, evaluation);
482
483 fn pow(base: Sc, pow: usize) -> Sc {
485 let mut res = Sc::one();
486 for _ in 0..pow {
487 res.mul(&base)
488 }
489 res
490 }
491 }
492
493 }
494
495 #[test]
496 fn interpolation_insufficient_shares() {
497 let degree = 4;
498 let threshold = degree + 1;
499 let poly = Poly::<Sc>::new(degree);
500
501 let shares = (0..threshold - 1)
503 .map(|i| poly.eval(i as Idx))
504 .collect::<Vec<_>>();
505
506 Poly::<Sc>::recover(threshold, shares.clone()).unwrap_err();
507 Poly::<Sc>::full_recover(threshold, shares).unwrap_err();
508 }
509
510 #[test]
511 fn benchy() {
512 use std::time::SystemTime;
513 let degree = 49;
514 let threshold = degree + 1;
515 let poly = Poly::<Sc>::new(degree);
516 let shares = (0..threshold)
517 .map(|i| poly.eval(i as Idx))
518 .collect::<Vec<Eval<Sc>>>();
519 let now = SystemTime::now();
520 Poly::<Sc>::recover(threshold as usize, shares).unwrap();
521 match now.elapsed() {
522 Ok(e) => println!("single recover: time elapsed {:?}", e),
523 Err(e) => panic!("{}", e),
524 }
525 let shares = (0..threshold)
526 .map(|i| poly.eval(i as Idx))
527 .collect::<Vec<Eval<Sc>>>();
528
529 let now = SystemTime::now();
530 Poly::<Sc>::full_recover(threshold as usize, shares).unwrap();
531 match now.elapsed() {
532 Ok(e) => println!("full_recover: time elapsed {:?}", e),
533 Err(e) => panic!("{}", e),
534 }
535 }
536
537 #[test]
538 fn mul() {
539 let d = 1;
540 let p1 = Poly::<Sc>::new(d);
541 let p2 = Poly::<Sc>::new(d);
542 let mut p3 = p1.clone();
543 p3.mul(&p2);
544 assert_eq!(p3.degree(), d + d);
545 let mut l1 = p1.0[0];
552 l1.mul(&p2.0[0]);
553
554 let mut l21 = p1.0[0];
556 l21.mul(&p2.0[1]);
557
558 let mut l22 = p1.0[1];
560 l22.mul(&p2.0[0]);
561 let mut l2 = Sc::new();
562 l2.add(&l21);
563 l2.add(&l22);
564 let mut l3 = p1.0[1];
565 l3.mul(&p2.0[1]);
566
567 let mut total = Sc::new();
568 total.add(&l1);
569 total.add(&l2);
570 total.add(&l3);
571 let res = p3.eval(0);
572 assert_eq!(total, res.value);
573 }
574
575 #[test]
576 fn new_neg_constant() {
577 let mut constant = Sc::rand(&mut thread_rng());
578 let p = Poly::<Sc>::new_neg_constant(constant);
579
580 constant.negate();
581 let v = vec![constant, Sc::one()];
582 let res = Poly::from(v);
583
584 assert_eq!(res, p);
585 }
586
587 #[test]
588 fn commit() {
589 let secret = Poly::<Sc>::new(5);
590
591 let coeffs = secret.0.clone();
592 let commitment = coeffs
593 .iter()
594 .map(|coeff| {
595 let mut p = G1::one();
596 p.mul(coeff);
597 p
598 })
599 .collect::<Vec<_>>();
600 let commitment = Poly::from(commitment);
601
602 assert_eq!(commitment, secret.commit::<G1>());
603 }
604}