Skip to main content

shamir_algorithm/
lib.rs

1/// # Shamir Secret Sharing Library
2///
3/// This library implements Shamir's Secret Sharing algorithm, allowing a secret to be split into multiple shares
4/// such that a minimum number of shares are required to reconstruct the original secret.
5/// The implementation uses Galois Field arithmetic over GF(256) for the polynomial operations.
6///
7/// # Example
8///
9/// ```rust
10/// use shamir_rust::ShamirSS;
11/// use std::collections::BTreeMap;
12///
13/// let secret = b"Hello, world!";
14/// let n = 5;
15/// let k = 3;
16///
17/// let shares = ShamirSS::split(n, k, secret.to_vec()).unwrap();
18///
19/// // Use at least k shares to reconstruct
20/// let mut parts = BTreeMap::new();
21/// for i in 1..=k {
22///     parts.insert(i, shares[&i].clone());
23/// }
24///
25/// let reconstructed = ShamirSS::join(parts).unwrap();
26/// assert_eq!(reconstructed, secret);
27/// ```
28mod tables;
29use std::{
30    collections::{BTreeMap, HashSet},
31    fmt::Debug,
32};
33use rand::distr::StandardUniform;
34use rand::Rng;
35use tables::{EXP, LOG};
36
37/// Shamir Secret Sharing implementation.
38/// This struct provides static methods for splitting and joining secrets using Shamir's algorithm.
39#[derive(Debug, Clone)]
40pub struct ShamirSS;
41
42impl ShamirSS {
43    /// Splits a secret into `n` shares, requiring at least `k` shares to reconstruct the secret.
44    ///
45    /// # Arguments
46    ///
47    /// * `n` - Total shares to generate (up to 255).
48    /// * `k` - Threshold needed to reconstruct (must be > 1).
49    /// * `secret` - The secret data as a byte vector.
50    pub fn split(n: i32, k: i32, secret: Vec<u8>) -> Result<BTreeMap<i32, Vec<u8>>, String> {
51        if k <= 1 { return Err("Threshold k must be greater than 1".to_string()); }
52        if n < k { return Err("Total shares n must be greater than or equal to k".to_string()); }
53        if n > 255 { return Err("Total shares n cannot exceed 255".to_string()); }
54        if secret.is_empty() { return Err("Secret cannot be empty".to_string()); }
55
56        let seclen = secret.len();
57        let mut values: Vec<Vec<u8>> = vec![vec![0u8; seclen]; n as usize];
58        let degree = k - 1;
59
60        for (i, &byte) in secret.iter().enumerate() {
61            let p = GFC256::generate(degree, byte);
62            for x in 1..=n {
63                values[(x - 1) as usize][i] = GFC256::eval(&p, x as u8);
64            }
65        }
66
67        let mut parts = BTreeMap::new();
68        for i in 1..=n {
69            parts.insert(i, values[(i - 1) as usize].clone());
70        }
71
72        Ok(parts)
73    }
74
75    /// Reconstructs the original secret from a set of shares.
76    pub fn join(parts: BTreeMap<i32, Vec<u8>>) -> Result<Vec<u8>, String> {
77        if parts.is_empty() {
78            return Err("No parts provided".to_string());
79        }
80
81        let lengths: HashSet<usize> = parts.values().map(|v| v.len()).collect();
82        if lengths.len() != 1 {
83            return Err("Varying lengths of part values".to_string());
84        }
85
86        let secret_len = *lengths.iter().next().unwrap();
87        let mut secret = vec![0u8; secret_len];
88
89        for i in 0..secret_len {
90            let points: Vec<Vec<u8>> = parts.iter()
91                .map(|(&idx, data)| vec![idx as u8, data[i]])
92                .collect();
93
94            secret[i] = GFC256::interpolate(points);
95        }
96
97        Ok(secret)
98    }
99}
100
101/// Galois Field operations over GF(256).
102struct GFC256;
103
104impl GFC256 {
105    #[inline]
106    fn add(a: u8, b: u8) -> u8 { a ^ b }
107
108    #[inline]
109    fn sub(a: u8, b: u8) -> u8 { a ^ b }
110
111    fn mul(a: u8, b: u8) -> u8 {
112        if a == 0 || b == 0 { return 0; }
113        let log_sum = LOG[a as usize] as usize + LOG[b as usize] as usize;
114        EXP[log_sum % 255]
115    }
116
117    fn div(a: u8, b: u8) -> u8 {
118        if b == 0 { panic!("Division by zero in GF(256)"); }
119        if a == 0 { return 0; }
120        let log_diff = (LOG[a as usize] as i32 - LOG[b as usize] as i32 + 255) % 255;
121        EXP[log_diff as usize]
122    }
123
124    fn eval(p: &[u8], x: u8) -> u8 {
125        let mut result = 0u8;
126        for &coeff in p.iter().rev() {
127            result = Self::add(Self::mul(result, x), coeff);
128        }
129        result
130    }
131
132    fn generate(degree: i32, secret_byte: u8) -> Vec<u8> {
133        let mut rng = rand::rng();
134        let mut p = vec![0u8; (degree + 1) as usize];
135        p[0] = secret_byte;
136        for i in p.iter_mut().take(degree as usize + 1).skip(1) {
137            *i = rng.sample(StandardUniform);
138        }
139        // Ensure the leading coefficient is non-zero to maintain the degree
140        while p[degree as usize] == 0 {
141            p[degree as usize] = rng.sample(StandardUniform);
142        }
143        p
144    }
145
146    fn interpolate(points: Vec<Vec<u8>>) -> u8 {
147        let mut y = 0u8;
148        let len = points.len();
149        for i in 0..len {
150            let mut li = 1u8;
151            for j in 0..len {
152                if i != j {
153                    let num = points[j][0];
154                    let den = Self::sub(points[i][0], points[j][0]);
155                    li = Self::mul(li, Self::div(num, den));
156                }
157            }
158            y = Self::add(y, Self::mul(li, points[i][1]));
159        }
160        y
161    }
162}
163
164
165#[cfg(test)]
166mod tests {
167    use super::*;
168
169    #[test]
170    fn it_works() {
171        let secret = b"Hello Shamir Shared Secret!!!!!";
172        let numparts = 5;
173        let miniumparts = 3;
174
175
176
177        let keys = ShamirSS::split(numparts, miniumparts, secret.to_vec());
178        assert!(keys.is_ok());
179        let keys = keys.unwrap();
180        let mut parts:BTreeMap<i32,Vec<u8>>=BTreeMap::new();
181        for (key, value) in &keys {
182            // Copy only entries with keys less than or equal to 3
183            if *key <= miniumparts {
184                parts.insert(*key, value.clone());
185            }
186        }
187        let nshared = ShamirSS::join(parts);
188        assert!(nshared.is_ok());
189        let shared = nshared.unwrap();
190        assert_eq!(shared, secret.to_vec());
191
192    }
193}