shamir_rs/
lib.rs

1/// Implementation of Shamir's Secret Sharing scheme.
2/// 
3/// This scheme allows splitting a secret into n shares where any k shares can
4/// reconstruct the secret, but k-1 shares reveal no information about the secret.
5/// Based on Adi Shamir's paper "How to Share a Secret" (Communications of the ACM, 1979).
6use num_bigint::{BigUint, RandBigInt};
7use num_traits::identities::{One, Zero};
8use thiserror::Error;
9
10pub mod extensions;
11
12/// These are errors that can occur during secret sharing operations
13#[derive(Error, Debug)]
14pub enum SssError {
15    #[error("threshold k must be be in this range: 0 < k ≤ n")]
16    InvalidThreshold,
17    #[error("not enough shares to reconstruct secret (need {threshold}, got {share_count})")]
18    NotEnoughShares {
19        /// Required number of shares (k)
20        threshold: usize,
21        /// Actual number of shares provided
22        share_count: usize,
23    },
24    
25    /// Duplicate share indices found during reconstruction
26    #[error("duplicate share indices found")]
27    DuplicateShares,
28}
29
30/// A single share of a split secret, representing a point on the polynomial
31#[derive(Clone, Debug, PartialEq)]
32pub struct Share {
33    /// The x-coordinate of the polynomial point (share index)
34    pub index: u32,
35    /// The y-coordinate of the polynomial point (share value)
36    pub value: BigUint,
37}
38
39/// The main struct implementing Shamir's Secret Sharing scheme
40#[derive(Debug)]
41pub struct Scheme {
42    /// Prime modulus defining the finite field ℤ/pℤ
43    prime_modulus: BigUint,
44    /// Minimum number of shares needed to reconstruct (k)
45    threshold: usize,
46    /// Total number of shares to generate (n)
47    total_shares: usize,
48}
49
50impl Scheme {
51    /// Creates a new Shamir's Secret Sharing scheme with the specified parameters.
52    /// 
53    /// # Arguments
54    /// * `threshold` - Minimum number of shares needed to reconstruct the secret (k)
55    /// * `total_shares` - Total number of shares to generate (n)
56    /// * `prime_modulus` - Prime number defining the finite field. Must be larger
57    ///                     than both the secret and total_shares.
58    /// 
59    /// # Returns
60    /// * `Ok(Scheme)` - If parameters are valid
61    /// * `Err(SssError::InvalidThreshold)` - If threshold is 0 or greater than total_shares
62    /// 
63    /// # Example
64    /// ```
65    /// use num_bigint::BigUint;
66    /// use shamir_rs::Scheme;
67    /// 
68    /// let prime = BigUint::from(257u32);
69    /// let scheme = Scheme::new(3, 5, prime).unwrap();
70    /// ```
71    pub fn new(
72        threshold: usize,
73        total_shares: usize,
74        prime_modulus: BigUint,
75    ) -> Result<Self, SssError> {
76        if threshold == 0 || threshold > total_shares {
77            return Err(SssError::InvalidThreshold);
78        }
79
80        Ok(Scheme {
81            prime_modulus,
82            threshold,
83            total_shares,
84        })
85    }
86
87    /// Splits a secret into n shares where k shares are required to reconstruct.
88    /// 
89    /// # Arguments
90    /// * `secret` - The secret to split. Must be less than prime_modulus.
91    /// 
92    /// # Returns
93    /// A vector of n shares. Each share is a point on a random polynomial of
94    /// degree k-1 where the constant term is the secret.
95    /// 
96    /// # Example
97    /// ```
98    /// # use num_bigint::BigUint;
99    /// # use shamir_rs::Scheme;
100    /// # let prime = BigUint::from(257u32);
101    /// # let scheme = Scheme::new(3, 5, prime).unwrap();
102    /// let secret = BigUint::from(123u32);
103    /// let shares = scheme.split_secret(&secret);
104    /// assert_eq!(shares.len(), 5);
105    /// ```
106    pub fn split_secret(&self, secret: &BigUint) -> Vec<Share> {
107        let secret = secret % &self.prime_modulus;
108        let coefficients = self.create_polynomial(&secret);
109
110        (1..=self.total_shares)
111            .map(|x| Share {
112                index: x as u32,
113                value: self.evaluate_polynomial(&coefficients, x as u32),
114            })
115            .collect()
116    }
117
118    /// Reconstructs a secret from k or more shares using Lagrange interpolation.
119    /// 
120    /// # Arguments
121    /// * `shares` - Slice of shares to use for reconstruction. Must contain at
122    ///             least k shares with unique indices.
123    /// 
124    /// # Returns
125    /// * `Ok(BigUint)` - The reconstructed secret
126    /// * `Err(SssError::NotEnoughShares)` - If fewer than k shares provided
127    /// * `Err(SssError::DuplicateShares)` - If shares contain duplicate indices
128    /// 
129    /// # Example
130    /// ```
131    /// # use num_bigint::BigUint;
132    /// # use shamir_rs::Scheme;
133    /// # let prime = BigUint::from(257u32);
134    /// # let scheme = Scheme::new(3, 5, prime).unwrap();
135    /// # let secret = BigUint::from(123u32);
136    /// # let shares = scheme.split_secret(&secret);
137    /// let reconstructed = scheme.reconstruct_secret(&shares[0..3]).unwrap();
138    /// assert_eq!(reconstructed, secret);
139    /// ```
140    pub fn reconstruct_secret(&self, shares: &[Share]) -> Result<BigUint, SssError> {
141        // Check if we have enough shares
142        if shares.len() < self.threshold {
143            return Err(SssError::NotEnoughShares {
144                threshold: self.threshold,
145                share_count: shares.len(),
146            });
147        }
148
149        // Check for duplicate indices
150        let mut seen_indices = std::collections::HashSet::new();
151        for share in shares {
152            if !seen_indices.insert(share.index) {
153                return Err(SssError::DuplicateShares);
154            }
155        }
156
157        // We only need threshold number of shares
158        let shares = &shares[0..self.threshold];
159
160        // Evaluate at x = 0 to get the secret
161        let x = BigUint::from(0u32);
162        let mut secret = BigUint::from(0u32);
163
164        for i in 0..shares.len() {
165            let basis = self.lagrange_basis(shares, i, &x);
166            let term = (basis * &shares[i].value) % &self.prime_modulus;
167            secret = (secret + term) % &self.prime_modulus;
168        }
169
170        Ok(secret)
171    }
172
173    /// Creates a random polynomial of degree k-1 where:
174    /// - a₀ is the secret
175    /// - all other coefficients are random
176    /// - all arithmetic is done modulo prime_modulus
177    pub(crate) fn create_polynomial(&self, secret: &BigUint) -> Vec<BigUint> {
178        let mut rng = rand::thread_rng();
179        let mut coefficients = Vec::with_capacity(self.threshold);
180
181        // a₀ = secret
182        coefficients.push(secret.clone());
183
184        // Generate random coefficients a₁ through aₖ₋₁
185        for _ in 1..self.threshold {
186            // Generate a random number in range [0, prime_modulus)
187            let coeff = rng.gen_biguint_range(&BigUint::from(0u32), &self.prime_modulus);
188            coefficients.push(coeff);
189        }
190
191        coefficients
192    }
193
194    /// Evaluates polynomial at point x
195    /// polynomial is represented by its coefficients [a₀, a₁, ..., aₖ₋₁]
196    pub(crate) fn evaluate_polynomial(&self, coefficients: &[BigUint], x: u32) -> BigUint {
197        let mut result = BigUint::from(0u32); // Start with 0
198        let x_big = BigUint::from(x);
199        let mut x_power = BigUint::from(1u32); // Start with 1
200
201        for coeff in coefficients {
202            let term = coeff * &x_power;
203            result += term;
204            x_power *= &x_big;
205        }
206
207        result % &self.prime_modulus
208    }
209
210    /// Calculates a modular multiplicative inverse using Extended Euclidean Algorithm
211    fn mod_inverse(number: &BigUint, modulus: &BigUint) -> Option<BigUint> {
212        let mut s = BigUint::zero();
213        let mut old_s = BigUint::one();
214        let mut t = BigUint::one();
215        let mut old_t = BigUint::zero();
216        let mut r = modulus.clone();
217        let mut old_r = number.clone();
218
219        while !r.is_zero() {
220            let quotient = &old_r / &r;
221
222            // Update r
223            let temp_r = r.clone();
224            r = old_r - &quotient * &r;
225            old_r = temp_r;
226
227            // Update s
228            let temp_s = s.clone();
229            s = if quotient.clone() * &s <= old_s {
230                old_s - quotient.clone() * &s
231            } else {
232                modulus - ((quotient.clone() * &s - &old_s) % modulus)
233            };
234            old_s = temp_s;
235
236            // Update t
237            let temp_t = t.clone();
238            t = if quotient.clone() * &t <= old_t {
239                old_t - quotient * &t
240            } else {
241                modulus - ((quotient * &t - &old_t) % modulus)
242            };
243            old_t = temp_t;
244        }
245
246        if old_r > BigUint::one() {
247            return None; // number and modulus aren't coprime
248        }
249
250        Some(old_s % modulus)
251    }
252
253    /// Calculates the Lagrange basis polynomial li(x) for each share:
254    /// li(x) = ∏(j≠i) (x - xj)/(xi - xj)
255    fn lagrange_basis(&self, shares: &[Share], i: usize, x: &BigUint) -> BigUint {
256        let mut numerator = BigUint::from(1u32);
257        let mut denominator = BigUint::from(1u32);
258        let x_i = BigUint::from(shares[i].index);
259
260        for j in 0..shares.len() {
261            if i != j {
262                let x_j = BigUint::from(shares[j].index);
263
264                // Compute (x - x_j) mod p
265                let term = self.mod_sub(x, &x_j);
266                numerator = (numerator * term) % &self.prime_modulus;
267
268                // Compute (x_i - x_j) mod p
269                let diff = self.mod_sub(&x_i, &x_j);
270                denominator = (denominator * diff) % &self.prime_modulus;
271            }
272        }
273
274        // Calculate modular multiplicative inverse of denominator
275        let denominator_inv = Self::mod_inverse(&denominator, &self.prime_modulus)
276            .expect("shares should have unique indices");
277
278        (numerator * denominator_inv) % &self.prime_modulus
279    }
280
281    /// Helper function: Perform modular subtraction (a - b) mod p
282    fn mod_sub(&self, a: &BigUint, b: &BigUint) -> BigUint {
283        if a >= b {
284            (a - b) % &self.prime_modulus
285        } else {
286            let mut result = &self.prime_modulus - b;
287            result += a;
288            result % &self.prime_modulus
289        }
290    }
291}
292
293#[cfg(test)]
294mod tests {
295    use super::*;
296    use std::collections::HashSet;
297
298    #[test]
299    fn test_polynomial_evaluation() {
300        let scheme = Scheme::new(3, 5, BigUint::from(17u32)).unwrap();
301
302        // Test polynomial x² + 2x + 3 mod 17
303        let coefficients: Vec<BigUint> = vec![3u32, 2u32, 1u32]
304            .into_iter()
305            .map(BigUint::from)
306            .collect();
307
308        assert_eq!(
309            scheme.evaluate_polynomial(&coefficients, 1),
310            BigUint::from(6u32)
311        );
312
313        assert_eq!(
314            scheme.evaluate_polynomial(&coefficients, 2),
315            BigUint::from(11u32)
316        );
317    }
318
319    #[test]
320    fn test_split_secret() {
321        let scheme = Scheme::new(3, 5, BigUint::from(17u32)).unwrap();
322
323        let shares = scheme.split_secret(&BigUint::from(10u32));
324
325        assert_eq!(shares.len(), 5); // Should get 5 shares
326
327        // All share values should be < 17 (prime modulus)
328        for share in shares {
329            assert!(share.value < BigUint::from(17u32));
330        }
331    }
332
333    #[test]
334    fn test_polynomial_randomness() {
335        let scheme = Scheme::new(3, 5, BigUint::from(17u32)).unwrap();
336
337        let secret = BigUint::from(10u32);
338        let poly1 = scheme.create_polynomial(&secret);
339        let poly2 = scheme.create_polynomial(&secret);
340
341        // First coefficient (secret) should be the same
342        assert_eq!(poly1[0], poly2[0]);
343
344        // Other coefficients should be different (with very high probability)
345        assert_ne!(poly1[1..], poly2[1..]);
346
347        // All coefficients should be less than prime modulus
348        for coeff in poly1.iter().chain(poly2.iter()) {
349            assert!(coeff < &scheme.prime_modulus);
350        }
351    }
352
353    #[test]
354    fn test_secret_reconstruction() {
355        let prime = BigUint::from(17u32);
356        let scheme = Scheme::new(3, 5, prime).unwrap();
357
358        let secret = BigUint::from(10u32);
359        let shares = scheme.split_secret(&secret);
360
361        // Try reconstructing with exactly k shares
362        let result = scheme.reconstruct_secret(&shares[0..3]).unwrap();
363        assert_eq!(result, secret);
364
365        // Try reconstructing with more than k shares
366        let result = scheme.reconstruct_secret(&shares[0..4]).unwrap();
367        assert_eq!(result, secret);
368    }
369
370    #[test]
371    fn test_reconstruction_errors() {
372        let prime = BigUint::from(17u32);
373        let scheme = Scheme::new(3, 5, prime).unwrap();
374
375        let secret = BigUint::from(10u32);
376        let shares = scheme.split_secret(&secret);
377
378        // Try with not enough shares
379        assert!(matches!(
380            scheme.reconstruct_secret(&shares[0..2]),
381            Err(SssError::NotEnoughShares { .. })
382        ));
383
384        // Try with duplicate shares
385        let mut duplicate_shares = shares[0..3].to_vec();
386        duplicate_shares[1] = duplicate_shares[0].clone();
387        assert!(matches!(
388            scheme.reconstruct_secret(&duplicate_shares),
389            Err(SssError::DuplicateShares)
390        ));
391    }
392
393    #[test]
394    fn test_mod_sub() {
395        let prime = BigUint::from(17u32);
396        let scheme = Scheme::new(3, 5, prime.clone()).unwrap();
397
398        // Test cases for modular subtraction
399        assert_eq!(
400            scheme.mod_sub(&BigUint::from(10u32), &BigUint::from(3u32)),
401            BigUint::from(7u32)
402        ); // 10 - 3 = 7
403        assert_eq!(
404            scheme.mod_sub(&BigUint::from(3u32), &BigUint::from(10u32)),
405            BigUint::from(10u32)
406        ); // (17 - 10 + 3) mod 17 = 10
407    }
408
409    #[test]
410    fn test_reconstruction_with_different_share_combinations() {
411        let prime = BigUint::from(257u32); // Larger prime for better security testing
412        let scheme = Scheme::new(3, 5, prime).unwrap();
413        let secret = BigUint::from(123u32);
414        let shares = scheme.split_secret(&secret);
415
416        // Test ALL possible combinations of k shares
417        let mut seen_secrets = HashSet::new();
418
419        // Try every possible combination of 3 shares from the 5 shares
420        let combinations = vec![
421            vec![0, 1, 2],
422            vec![0, 1, 3],
423            vec![0, 1, 4],
424            vec![0, 2, 3],
425            vec![0, 2, 4],
426            vec![0, 3, 4],
427            vec![1, 2, 3],
428            vec![1, 2, 4],
429            vec![1, 3, 4],
430            vec![2, 3, 4],
431        ];
432
433        for combo in combinations {
434            let share_subset: Vec<Share> = combo.iter().map(|&i| shares[i].clone()).collect();
435
436            let reconstructed = scheme.reconstruct_secret(&share_subset).unwrap();
437            seen_secrets.insert(reconstructed);
438        }
439
440        // All reconstructions should yield the same secret
441        assert_eq!(seen_secrets.len(), 1);
442        assert_eq!(seen_secrets.into_iter().next().unwrap(), secret);
443    }
444
445    #[test]
446    fn test_insufficient_shares_reveal_nothing() {
447        let prime = BigUint::from(257u32);
448        let scheme = Scheme::new(3, 5, prime.clone()).unwrap();
449        let secret = BigUint::from(123u32);
450        let shares = scheme.split_secret(&secret);
451
452        // Take all possible pairs of shares (k-1 shares)
453        let pairs = vec![
454            vec![0, 1],
455            vec![0, 2],
456            vec![0, 3],
457            vec![0, 4],
458            vec![1, 2],
459            vec![1, 3],
460            vec![1, 4],
461            vec![2, 3],
462            vec![2, 4],
463            vec![3, 4],
464        ];
465
466        // For each pair, verify that all possible values in our field could be the secret
467        for pair in pairs {
468            let share_pair: Vec<Share> = pair.iter().map(|&i| shares[i].clone()).collect();
469
470            // Verify we can't reconstruct with k-1 shares
471            assert!(matches!(
472                scheme.reconstruct_secret(&share_pair),
473                Err(SssError::NotEnoughShares { .. })
474            ));
475        }
476    }
477
478    #[test]
479    fn test_different_threshold_combinations() {
480        let prime = BigUint::from(257u32);
481
482        // Test different (k,n) combinations
483        let configs = vec![
484            (2, 3),  // minimal case
485            (3, 5),  // our standard test case
486            (5, 8),  // larger case
487            (7, 10), // even larger case
488        ];
489
490        for (k, n) in configs {
491            let scheme = Scheme::new(k, n, prime.clone()).unwrap();
492            let secret = BigUint::from(123u32);
493            let shares = scheme.split_secret(&secret);
494
495            assert_eq!(shares.len(), n);
496
497            // Should succeed with k shares
498            let reconstructed = scheme.reconstruct_secret(&shares[0..k]).unwrap();
499            assert_eq!(reconstructed, secret);
500
501            // Should fail with k-1 shares
502            assert!(matches!(
503                scheme.reconstruct_secret(&shares[0..k - 1]),
504                Err(SssError::NotEnoughShares { .. })
505            ));
506        }
507    }
508
509    #[test]
510    fn test_edge_case_secrets() {
511        let prime = BigUint::from(17u32);
512        let scheme = Scheme::new(3, 5, prime.clone()).unwrap();
513
514        // Test secret = 0
515        let secret = BigUint::from(0u32);
516        let shares = scheme.split_secret(&secret);
517        let reconstructed = scheme.reconstruct_secret(&shares[0..3]).unwrap();
518        assert_eq!(reconstructed, secret);
519
520        // Test secret = p-1 (largest possible)
521        let secret = prime.clone() - BigUint::from(1u32);
522        let shares = scheme.split_secret(&secret);
523        let reconstructed = scheme.reconstruct_secret(&shares[0..3]).unwrap();
524        assert_eq!(reconstructed, secret);
525
526        // Test secret > p (should work with modulo)
527        let secret = prime.clone() * BigUint::from(2u32);
528        let shares = scheme.split_secret(&secret);
529        let reconstructed = scheme.reconstruct_secret(&shares[0..3]).unwrap();
530        assert_eq!(reconstructed, BigUint::from(0u32)); // since secret mod p = 0
531    }
532
533    #[test]
534    fn test_all_shares_reconstruction() {
535        let prime = BigUint::from(17u32);
536        let scheme = Scheme::new(3, 5, prime).unwrap();
537        let secret = BigUint::from(10u32);
538        let shares = scheme.split_secret(&secret);
539
540        // Try reconstructing with all shares
541        let result = scheme.reconstruct_secret(&shares).unwrap();
542        assert_eq!(result, secret);
543    }
544
545    #[test]
546    fn test_invalid_parameters() {
547        let prime = BigUint::from(17u32);
548
549        // Test k = 0
550        assert!(matches!(
551            Scheme::new(0, 5, prime.clone()),
552            Err(SssError::InvalidThreshold)
553        ));
554
555        // Test k > n
556        assert!(matches!(
557            Scheme::new(6, 5, prime.clone()),
558            Err(SssError::InvalidThreshold)
559        ));
560
561        // Test k = n (should work)
562        assert!(Scheme::new(5, 5, prime.clone()).is_ok());
563    }
564
565    #[test]
566    fn test_minimum_viable_scheme() {
567        let prime = BigUint::from(17u32);
568        // Test smallest possible scheme: k=2, n=2
569        let scheme = Scheme::new(2, 2, prime).unwrap();
570        let secret = BigUint::from(10u32);
571        let shares = scheme.split_secret(&secret);
572
573        assert_eq!(shares.len(), 2);
574        let reconstructed = scheme.reconstruct_secret(&shares).unwrap();
575        assert_eq!(reconstructed, secret);
576    }
577
578    #[test]
579    fn test_large_numbers() {
580        let prime = BigUint::parse_bytes(
581            b"115792089237316195423570985008687907853269984665640564039457584007913129639747",
582            10,
583        )
584        .unwrap();
585        let scheme = Scheme::new(3, 5, prime.clone()).unwrap();
586
587        let secret = prime.clone() - BigUint::from(1u32);
588        let shares = scheme.split_secret(&secret);
589        let reconstructed = scheme.reconstruct_secret(&shares[0..3]).unwrap();
590        assert_eq!(reconstructed, secret);
591    }
592
593    #[test]
594    fn test_share_index_range() {
595        let prime = BigUint::from(17u32);
596        let scheme = Scheme::new(3, 5, prime).unwrap();
597        let secret = BigUint::from(10u32);
598        let shares = scheme.split_secret(&secret);
599
600        // verify share indices are 1 through n
601        for (i, share) in shares.iter().enumerate() {
602            assert_eq!(share.index as usize, i + 1);
603        }
604    }
605
606    #[test]
607    fn test_shuffled_shares() {
608        use rand::seq::SliceRandom;
609        let mut rng = rand::thread_rng();
610
611        let prime = BigUint::from(17u32);
612        let scheme = Scheme::new(3, 5, prime).unwrap();
613        let secret = BigUint::from(10u32);
614        let mut shares = scheme.split_secret(&secret);
615
616        // shuffle the shares
617        shares.shuffle(&mut rng);
618
619        // Should still reconstruct correctly
620        let reconstructed = scheme.reconstruct_secret(&shares[0..3]).unwrap();
621        assert_eq!(reconstructed, secret);
622    }
623}