1mod tables;
29use std::{
30 collections::{BTreeMap, HashSet},
31 fmt::Debug,
32};
33use rand::distr::StandardUniform;
34use rand::Rng;
35use tables::{EXP, LOG};
36
37#[derive(Debug, Clone)]
40pub struct ShamirSS;
41
42impl ShamirSS {
43 pub fn split(n: i32, k: i32, secret: Vec<u8>) -> Result<BTreeMap<i32, Vec<u8>>, String> {
51 if k <= 1 { return Err("Threshold k must be greater than 1".to_string()); }
52 if n < k { return Err("Total shares n must be greater than or equal to k".to_string()); }
53 if n > 255 { return Err("Total shares n cannot exceed 255".to_string()); }
54 if secret.is_empty() { return Err("Secret cannot be empty".to_string()); }
55
56 let seclen = secret.len();
57 let mut values: Vec<Vec<u8>> = vec![vec![0u8; seclen]; n as usize];
58 let degree = k - 1;
59
60 for (i, &byte) in secret.iter().enumerate() {
61 let p = GFC256::generate(degree, byte);
62 for x in 1..=n {
63 values[(x - 1) as usize][i] = GFC256::eval(&p, x as u8);
64 }
65 }
66
67 let mut parts = BTreeMap::new();
68 for i in 1..=n {
69 parts.insert(i, values[(i - 1) as usize].clone());
70 }
71
72 Ok(parts)
73 }
74
75 pub fn join(parts: BTreeMap<i32, Vec<u8>>) -> Result<Vec<u8>, String> {
77 if parts.is_empty() {
78 return Err("No parts provided".to_string());
79 }
80
81 let lengths: HashSet<usize> = parts.values().map(|v| v.len()).collect();
82 if lengths.len() != 1 {
83 return Err("Varying lengths of part values".to_string());
84 }
85
86 let secret_len = *lengths.iter().next().unwrap();
87 let mut secret = vec![0u8; secret_len];
88
89 for i in 0..secret_len {
90 let points: Vec<Vec<u8>> = parts.iter()
91 .map(|(&idx, data)| vec![idx as u8, data[i]])
92 .collect();
93
94 secret[i] = GFC256::interpolate(points);
95 }
96
97 Ok(secret)
98 }
99}
100
101struct GFC256;
103
104impl GFC256 {
105 #[inline]
106 fn add(a: u8, b: u8) -> u8 { a ^ b }
107
108 #[inline]
109 fn sub(a: u8, b: u8) -> u8 { a ^ b }
110
111 fn mul(a: u8, b: u8) -> u8 {
112 if a == 0 || b == 0 { return 0; }
113 let log_sum = LOG[a as usize] as usize + LOG[b as usize] as usize;
114 EXP[log_sum % 255]
115 }
116
117 fn div(a: u8, b: u8) -> u8 {
118 if b == 0 { panic!("Division by zero in GF(256)"); }
119 if a == 0 { return 0; }
120 let log_diff = (LOG[a as usize] as i32 - LOG[b as usize] as i32 + 255) % 255;
121 EXP[log_diff as usize]
122 }
123
124 fn eval(p: &[u8], x: u8) -> u8 {
125 let mut result = 0u8;
126 for &coeff in p.iter().rev() {
127 result = Self::add(Self::mul(result, x), coeff);
128 }
129 result
130 }
131
132 fn generate(degree: i32, secret_byte: u8) -> Vec<u8> {
133 let mut rng = rand::rng();
134 let mut p = vec![0u8; (degree + 1) as usize];
135 p[0] = secret_byte;
136 for i in p.iter_mut().take(degree as usize + 1).skip(1) {
137 *i = rng.sample(StandardUniform);
138 }
139 while p[degree as usize] == 0 {
141 p[degree as usize] = rng.sample(StandardUniform);
142 }
143 p
144 }
145
146 fn interpolate(points: Vec<Vec<u8>>) -> u8 {
147 let mut y = 0u8;
148 let len = points.len();
149 for i in 0..len {
150 let mut li = 1u8;
151 for j in 0..len {
152 if i != j {
153 let num = points[j][0];
154 let den = Self::sub(points[i][0], points[j][0]);
155 li = Self::mul(li, Self::div(num, den));
156 }
157 }
158 y = Self::add(y, Self::mul(li, points[i][1]));
159 }
160 y
161 }
162}
163
164
165#[cfg(test)]
166mod tests {
167 use super::*;
168
169 #[test]
170 fn it_works() {
171 let secret = b"Hello Shamir Shared Secret!!!!!";
172 let numparts = 5;
173 let miniumparts = 3;
174
175
176
177 let keys = ShamirSS::split(numparts, miniumparts, secret.to_vec());
178 assert!(keys.is_ok());
179 let keys = keys.unwrap();
180 let mut parts:BTreeMap<i32,Vec<u8>>=BTreeMap::new();
181 for (key, value) in &keys {
182 if *key <= miniumparts {
184 parts.insert(*key, value.clone());
185 }
186 }
187 let nshared = ShamirSS::join(parts);
188 assert!(nshared.is_ok());
189 let shared = nshared.unwrap();
190 assert_eq!(shared, secret.to_vec());
191
192 }
193}