1use num_bigint::{BigUint, RandBigInt};
7use num_traits::identities::{One, Zero};
8use thiserror::Error;
9
10pub mod extensions;
11
12#[derive(Error, Debug)]
14pub enum SssError {
15 #[error("threshold k must be be in this range: 0 < k ≤ n")]
16 InvalidThreshold,
17 #[error("not enough shares to reconstruct secret (need {threshold}, got {share_count})")]
18 NotEnoughShares {
19 threshold: usize,
21 share_count: usize,
23 },
24
25 #[error("duplicate share indices found")]
27 DuplicateShares,
28}
29
30#[derive(Clone, Debug, PartialEq)]
32pub struct Share {
33 pub index: u32,
35 pub value: BigUint,
37}
38
39#[derive(Debug)]
41pub struct Scheme {
42 prime_modulus: BigUint,
44 threshold: usize,
46 total_shares: usize,
48}
49
50impl Scheme {
51 pub fn new(
72 threshold: usize,
73 total_shares: usize,
74 prime_modulus: BigUint,
75 ) -> Result<Self, SssError> {
76 if threshold == 0 || threshold > total_shares {
77 return Err(SssError::InvalidThreshold);
78 }
79
80 Ok(Scheme {
81 prime_modulus,
82 threshold,
83 total_shares,
84 })
85 }
86
87 pub fn split_secret(&self, secret: &BigUint) -> Vec<Share> {
107 let secret = secret % &self.prime_modulus;
108 let coefficients = self.create_polynomial(&secret);
109
110 (1..=self.total_shares)
111 .map(|x| Share {
112 index: x as u32,
113 value: self.evaluate_polynomial(&coefficients, x as u32),
114 })
115 .collect()
116 }
117
118 pub fn reconstruct_secret(&self, shares: &[Share]) -> Result<BigUint, SssError> {
141 if shares.len() < self.threshold {
143 return Err(SssError::NotEnoughShares {
144 threshold: self.threshold,
145 share_count: shares.len(),
146 });
147 }
148
149 let mut seen_indices = std::collections::HashSet::new();
151 for share in shares {
152 if !seen_indices.insert(share.index) {
153 return Err(SssError::DuplicateShares);
154 }
155 }
156
157 let shares = &shares[0..self.threshold];
159
160 let x = BigUint::from(0u32);
162 let mut secret = BigUint::from(0u32);
163
164 for i in 0..shares.len() {
165 let basis = self.lagrange_basis(shares, i, &x);
166 let term = (basis * &shares[i].value) % &self.prime_modulus;
167 secret = (secret + term) % &self.prime_modulus;
168 }
169
170 Ok(secret)
171 }
172
173 pub(crate) fn create_polynomial(&self, secret: &BigUint) -> Vec<BigUint> {
178 let mut rng = rand::thread_rng();
179 let mut coefficients = Vec::with_capacity(self.threshold);
180
181 coefficients.push(secret.clone());
183
184 for _ in 1..self.threshold {
186 let coeff = rng.gen_biguint_range(&BigUint::from(0u32), &self.prime_modulus);
188 coefficients.push(coeff);
189 }
190
191 coefficients
192 }
193
194 pub(crate) fn evaluate_polynomial(&self, coefficients: &[BigUint], x: u32) -> BigUint {
197 let mut result = BigUint::from(0u32); let x_big = BigUint::from(x);
199 let mut x_power = BigUint::from(1u32); for coeff in coefficients {
202 let term = coeff * &x_power;
203 result += term;
204 x_power *= &x_big;
205 }
206
207 result % &self.prime_modulus
208 }
209
210 fn mod_inverse(number: &BigUint, modulus: &BigUint) -> Option<BigUint> {
212 let mut s = BigUint::zero();
213 let mut old_s = BigUint::one();
214 let mut t = BigUint::one();
215 let mut old_t = BigUint::zero();
216 let mut r = modulus.clone();
217 let mut old_r = number.clone();
218
219 while !r.is_zero() {
220 let quotient = &old_r / &r;
221
222 let temp_r = r.clone();
224 r = old_r - "ient * &r;
225 old_r = temp_r;
226
227 let temp_s = s.clone();
229 s = if quotient.clone() * &s <= old_s {
230 old_s - quotient.clone() * &s
231 } else {
232 modulus - ((quotient.clone() * &s - &old_s) % modulus)
233 };
234 old_s = temp_s;
235
236 let temp_t = t.clone();
238 t = if quotient.clone() * &t <= old_t {
239 old_t - quotient * &t
240 } else {
241 modulus - ((quotient * &t - &old_t) % modulus)
242 };
243 old_t = temp_t;
244 }
245
246 if old_r > BigUint::one() {
247 return None; }
249
250 Some(old_s % modulus)
251 }
252
253 fn lagrange_basis(&self, shares: &[Share], i: usize, x: &BigUint) -> BigUint {
256 let mut numerator = BigUint::from(1u32);
257 let mut denominator = BigUint::from(1u32);
258 let x_i = BigUint::from(shares[i].index);
259
260 for j in 0..shares.len() {
261 if i != j {
262 let x_j = BigUint::from(shares[j].index);
263
264 let term = self.mod_sub(x, &x_j);
266 numerator = (numerator * term) % &self.prime_modulus;
267
268 let diff = self.mod_sub(&x_i, &x_j);
270 denominator = (denominator * diff) % &self.prime_modulus;
271 }
272 }
273
274 let denominator_inv = Self::mod_inverse(&denominator, &self.prime_modulus)
276 .expect("shares should have unique indices");
277
278 (numerator * denominator_inv) % &self.prime_modulus
279 }
280
281 fn mod_sub(&self, a: &BigUint, b: &BigUint) -> BigUint {
283 if a >= b {
284 (a - b) % &self.prime_modulus
285 } else {
286 let mut result = &self.prime_modulus - b;
287 result += a;
288 result % &self.prime_modulus
289 }
290 }
291}
292
293#[cfg(test)]
294mod tests {
295 use super::*;
296 use std::collections::HashSet;
297
298 #[test]
299 fn test_polynomial_evaluation() {
300 let scheme = Scheme::new(3, 5, BigUint::from(17u32)).unwrap();
301
302 let coefficients: Vec<BigUint> = vec![3u32, 2u32, 1u32]
304 .into_iter()
305 .map(BigUint::from)
306 .collect();
307
308 assert_eq!(
309 scheme.evaluate_polynomial(&coefficients, 1),
310 BigUint::from(6u32)
311 );
312
313 assert_eq!(
314 scheme.evaluate_polynomial(&coefficients, 2),
315 BigUint::from(11u32)
316 );
317 }
318
319 #[test]
320 fn test_split_secret() {
321 let scheme = Scheme::new(3, 5, BigUint::from(17u32)).unwrap();
322
323 let shares = scheme.split_secret(&BigUint::from(10u32));
324
325 assert_eq!(shares.len(), 5); for share in shares {
329 assert!(share.value < BigUint::from(17u32));
330 }
331 }
332
333 #[test]
334 fn test_polynomial_randomness() {
335 let scheme = Scheme::new(3, 5, BigUint::from(17u32)).unwrap();
336
337 let secret = BigUint::from(10u32);
338 let poly1 = scheme.create_polynomial(&secret);
339 let poly2 = scheme.create_polynomial(&secret);
340
341 assert_eq!(poly1[0], poly2[0]);
343
344 assert_ne!(poly1[1..], poly2[1..]);
346
347 for coeff in poly1.iter().chain(poly2.iter()) {
349 assert!(coeff < &scheme.prime_modulus);
350 }
351 }
352
353 #[test]
354 fn test_secret_reconstruction() {
355 let prime = BigUint::from(17u32);
356 let scheme = Scheme::new(3, 5, prime).unwrap();
357
358 let secret = BigUint::from(10u32);
359 let shares = scheme.split_secret(&secret);
360
361 let result = scheme.reconstruct_secret(&shares[0..3]).unwrap();
363 assert_eq!(result, secret);
364
365 let result = scheme.reconstruct_secret(&shares[0..4]).unwrap();
367 assert_eq!(result, secret);
368 }
369
370 #[test]
371 fn test_reconstruction_errors() {
372 let prime = BigUint::from(17u32);
373 let scheme = Scheme::new(3, 5, prime).unwrap();
374
375 let secret = BigUint::from(10u32);
376 let shares = scheme.split_secret(&secret);
377
378 assert!(matches!(
380 scheme.reconstruct_secret(&shares[0..2]),
381 Err(SssError::NotEnoughShares { .. })
382 ));
383
384 let mut duplicate_shares = shares[0..3].to_vec();
386 duplicate_shares[1] = duplicate_shares[0].clone();
387 assert!(matches!(
388 scheme.reconstruct_secret(&duplicate_shares),
389 Err(SssError::DuplicateShares)
390 ));
391 }
392
393 #[test]
394 fn test_mod_sub() {
395 let prime = BigUint::from(17u32);
396 let scheme = Scheme::new(3, 5, prime.clone()).unwrap();
397
398 assert_eq!(
400 scheme.mod_sub(&BigUint::from(10u32), &BigUint::from(3u32)),
401 BigUint::from(7u32)
402 ); assert_eq!(
404 scheme.mod_sub(&BigUint::from(3u32), &BigUint::from(10u32)),
405 BigUint::from(10u32)
406 ); }
408
409 #[test]
410 fn test_reconstruction_with_different_share_combinations() {
411 let prime = BigUint::from(257u32); let scheme = Scheme::new(3, 5, prime).unwrap();
413 let secret = BigUint::from(123u32);
414 let shares = scheme.split_secret(&secret);
415
416 let mut seen_secrets = HashSet::new();
418
419 let combinations = vec![
421 vec![0, 1, 2],
422 vec![0, 1, 3],
423 vec![0, 1, 4],
424 vec![0, 2, 3],
425 vec![0, 2, 4],
426 vec![0, 3, 4],
427 vec![1, 2, 3],
428 vec![1, 2, 4],
429 vec![1, 3, 4],
430 vec![2, 3, 4],
431 ];
432
433 for combo in combinations {
434 let share_subset: Vec<Share> = combo.iter().map(|&i| shares[i].clone()).collect();
435
436 let reconstructed = scheme.reconstruct_secret(&share_subset).unwrap();
437 seen_secrets.insert(reconstructed);
438 }
439
440 assert_eq!(seen_secrets.len(), 1);
442 assert_eq!(seen_secrets.into_iter().next().unwrap(), secret);
443 }
444
445 #[test]
446 fn test_insufficient_shares_reveal_nothing() {
447 let prime = BigUint::from(257u32);
448 let scheme = Scheme::new(3, 5, prime.clone()).unwrap();
449 let secret = BigUint::from(123u32);
450 let shares = scheme.split_secret(&secret);
451
452 let pairs = vec![
454 vec![0, 1],
455 vec![0, 2],
456 vec![0, 3],
457 vec![0, 4],
458 vec![1, 2],
459 vec![1, 3],
460 vec![1, 4],
461 vec![2, 3],
462 vec![2, 4],
463 vec![3, 4],
464 ];
465
466 for pair in pairs {
468 let share_pair: Vec<Share> = pair.iter().map(|&i| shares[i].clone()).collect();
469
470 assert!(matches!(
472 scheme.reconstruct_secret(&share_pair),
473 Err(SssError::NotEnoughShares { .. })
474 ));
475 }
476 }
477
478 #[test]
479 fn test_different_threshold_combinations() {
480 let prime = BigUint::from(257u32);
481
482 let configs = vec![
484 (2, 3), (3, 5), (5, 8), (7, 10), ];
489
490 for (k, n) in configs {
491 let scheme = Scheme::new(k, n, prime.clone()).unwrap();
492 let secret = BigUint::from(123u32);
493 let shares = scheme.split_secret(&secret);
494
495 assert_eq!(shares.len(), n);
496
497 let reconstructed = scheme.reconstruct_secret(&shares[0..k]).unwrap();
499 assert_eq!(reconstructed, secret);
500
501 assert!(matches!(
503 scheme.reconstruct_secret(&shares[0..k - 1]),
504 Err(SssError::NotEnoughShares { .. })
505 ));
506 }
507 }
508
509 #[test]
510 fn test_edge_case_secrets() {
511 let prime = BigUint::from(17u32);
512 let scheme = Scheme::new(3, 5, prime.clone()).unwrap();
513
514 let secret = BigUint::from(0u32);
516 let shares = scheme.split_secret(&secret);
517 let reconstructed = scheme.reconstruct_secret(&shares[0..3]).unwrap();
518 assert_eq!(reconstructed, secret);
519
520 let secret = prime.clone() - BigUint::from(1u32);
522 let shares = scheme.split_secret(&secret);
523 let reconstructed = scheme.reconstruct_secret(&shares[0..3]).unwrap();
524 assert_eq!(reconstructed, secret);
525
526 let secret = prime.clone() * BigUint::from(2u32);
528 let shares = scheme.split_secret(&secret);
529 let reconstructed = scheme.reconstruct_secret(&shares[0..3]).unwrap();
530 assert_eq!(reconstructed, BigUint::from(0u32)); }
532
533 #[test]
534 fn test_all_shares_reconstruction() {
535 let prime = BigUint::from(17u32);
536 let scheme = Scheme::new(3, 5, prime).unwrap();
537 let secret = BigUint::from(10u32);
538 let shares = scheme.split_secret(&secret);
539
540 let result = scheme.reconstruct_secret(&shares).unwrap();
542 assert_eq!(result, secret);
543 }
544
545 #[test]
546 fn test_invalid_parameters() {
547 let prime = BigUint::from(17u32);
548
549 assert!(matches!(
551 Scheme::new(0, 5, prime.clone()),
552 Err(SssError::InvalidThreshold)
553 ));
554
555 assert!(matches!(
557 Scheme::new(6, 5, prime.clone()),
558 Err(SssError::InvalidThreshold)
559 ));
560
561 assert!(Scheme::new(5, 5, prime.clone()).is_ok());
563 }
564
565 #[test]
566 fn test_minimum_viable_scheme() {
567 let prime = BigUint::from(17u32);
568 let scheme = Scheme::new(2, 2, prime).unwrap();
570 let secret = BigUint::from(10u32);
571 let shares = scheme.split_secret(&secret);
572
573 assert_eq!(shares.len(), 2);
574 let reconstructed = scheme.reconstruct_secret(&shares).unwrap();
575 assert_eq!(reconstructed, secret);
576 }
577
578 #[test]
579 fn test_large_numbers() {
580 let prime = BigUint::parse_bytes(
581 b"115792089237316195423570985008687907853269984665640564039457584007913129639747",
582 10,
583 )
584 .unwrap();
585 let scheme = Scheme::new(3, 5, prime.clone()).unwrap();
586
587 let secret = prime.clone() - BigUint::from(1u32);
588 let shares = scheme.split_secret(&secret);
589 let reconstructed = scheme.reconstruct_secret(&shares[0..3]).unwrap();
590 assert_eq!(reconstructed, secret);
591 }
592
593 #[test]
594 fn test_share_index_range() {
595 let prime = BigUint::from(17u32);
596 let scheme = Scheme::new(3, 5, prime).unwrap();
597 let secret = BigUint::from(10u32);
598 let shares = scheme.split_secret(&secret);
599
600 for (i, share) in shares.iter().enumerate() {
602 assert_eq!(share.index as usize, i + 1);
603 }
604 }
605
606 #[test]
607 fn test_shuffled_shares() {
608 use rand::seq::SliceRandom;
609 let mut rng = rand::thread_rng();
610
611 let prime = BigUint::from(17u32);
612 let scheme = Scheme::new(3, 5, prime).unwrap();
613 let secret = BigUint::from(10u32);
614 let mut shares = scheme.split_secret(&secret);
615
616 shares.shuffle(&mut rng);
618
619 let reconstructed = scheme.reconstruct_secret(&shares[0..3]).unwrap();
621 assert_eq!(reconstructed, secret);
622 }
623}