shadowforge_lib/domain/correction/
mod.rs1use bytes::Bytes;
6use hmac::{Hmac, Mac};
7use reed_solomon_erasure::galois_8::ReedSolomon;
8use sha2::Sha256;
9
10use crate::domain::errors::CorrectionError;
11use crate::domain::types::Shard;
12
13type HmacSha256 = Hmac<Sha256>;
14
15fn compute_hmac_tag(
19 hmac_key: &[u8],
20 index: u8,
21 total: u8,
22 data: &[u8],
23) -> Result<[u8; 32], CorrectionError> {
24 let mut mac =
25 HmacSha256::new_from_slice(hmac_key).map_err(|_| CorrectionError::InvalidParameters {
26 reason: "invalid HMAC key length".into(),
27 })?;
28 mac.update(&[index]);
29 mac.update(&[total]);
30 mac.update(data);
31 Ok(mac.finalize().into_bytes().into())
32}
33
34fn verify_hmac_tag(hmac_key: &[u8], shard: &Shard) -> Result<bool, CorrectionError> {
38 use subtle::ConstantTimeEq;
39
40 let expected = compute_hmac_tag(hmac_key, shard.index, shard.total, &shard.data)?;
41 Ok(expected.ct_eq(&shard.hmac_tag).into())
42}
43
44pub fn encode_shards(
50 data: &[u8],
51 data_shards: u8,
52 parity_shards: u8,
53 hmac_key: &[u8],
54) -> Result<Vec<Shard>, CorrectionError> {
55 if data_shards == 0 {
56 return Err(CorrectionError::InvalidParameters {
57 reason: "data_shards must be > 0".into(),
58 });
59 }
60 if parity_shards == 0 {
61 return Err(CorrectionError::InvalidParameters {
62 reason: "parity_shards must be > 0".into(),
63 });
64 }
65
66 let total_shards = data_shards.strict_add(parity_shards);
67 let shard_size = (data
68 .len()
69 .strict_add(usize::from(data_shards).strict_sub(1)))
70 / usize::from(data_shards);
71
72 let total_size = shard_size.strict_mul(usize::from(data_shards));
74 let mut padded = vec![0u8; total_size];
75 padded
76 .get_mut(..data.len())
77 .ok_or_else(|| CorrectionError::InvalidParameters {
78 reason: "data length exceeds padded buffer".into(),
79 })?
80 .copy_from_slice(data);
81
82 let rs =
84 ReedSolomon::new(usize::from(data_shards), usize::from(parity_shards)).map_err(|e| {
85 CorrectionError::ReedSolomonError {
86 reason: e.to_string(),
87 }
88 })?;
89
90 let mut chunks: Vec<Vec<u8>> = padded.chunks(shard_size).map(<[u8]>::to_vec).collect();
92
93 chunks.resize(usize::from(total_shards), vec![0u8; shard_size]);
95
96 rs.encode(&mut chunks)
98 .map_err(|e| CorrectionError::ReedSolomonError {
99 reason: e.to_string(),
100 })?;
101
102 let shards = chunks
104 .into_iter()
105 .enumerate()
106 .map(|(i, data)| {
107 #[expect(clippy::cast_possible_truncation, reason = "total_shards is u8")]
108 let index = i as u8;
109 let hmac_tag = compute_hmac_tag(hmac_key, index, total_shards, &data)?;
110 Ok(Shard {
111 index,
112 total: total_shards,
113 data,
114 hmac_tag,
115 })
116 })
117 .collect::<Result<Vec<Shard>, CorrectionError>>()?;
118
119 Ok(shards)
120}
121
122pub fn decode_shards(
132 shards: &[Option<Shard>],
133 data_shards: u8,
134 parity_shards: u8,
135 hmac_key: &[u8],
136 original_len: usize,
137) -> Result<Bytes, CorrectionError> {
138 let total_shards = data_shards.strict_add(parity_shards);
139
140 if shards.len() != usize::from(total_shards) {
141 return Err(CorrectionError::InvalidParameters {
142 reason: format!("expected {} shards, got {}", total_shards, shards.len()),
143 });
144 }
145
146 for shard in shards.iter().flatten() {
148 if !verify_hmac_tag(hmac_key, shard)? {
149 return Err(CorrectionError::HmacMismatch { index: shard.index });
150 }
151 }
152
153 let valid_count = shards.iter().filter(|s| s.is_some()).count();
155 if valid_count < usize::from(data_shards) {
156 return Err(CorrectionError::InsufficientShards {
157 needed: usize::from(data_shards),
158 available: valid_count,
159 });
160 }
161
162 let rs =
164 ReedSolomon::new(usize::from(data_shards), usize::from(parity_shards)).map_err(|e| {
165 CorrectionError::ReedSolomonError {
166 reason: e.to_string(),
167 }
168 })?;
169
170 let mut chunks: Vec<Option<Vec<u8>>> = shards
172 .iter()
173 .map(|opt| opt.as_ref().map(|s| s.data.clone()))
174 .collect();
175
176 rs.reconstruct(&mut chunks)
178 .map_err(|e| CorrectionError::ReedSolomonError {
179 reason: e.to_string(),
180 })?;
181
182 let mut recovered = Vec::new();
184 for chunk in chunks.iter().take(usize::from(data_shards)).flatten() {
185 recovered.extend_from_slice(chunk);
186 }
187
188 recovered.truncate(original_len);
190
191 Ok(Bytes::from(recovered))
192}
193
194#[cfg(test)]
195mod tests {
196 use super::*;
197
198 type TestResult = Result<(), Box<dyn std::error::Error>>;
199
200 const HMAC_KEY: &[u8] = b"test_hmac_key_32_bytes_long_!!!";
201
202 #[test]
203 fn test_encode_decode_roundtrip() -> TestResult {
204 let data = b"The quick brown fox jumps over the lazy dog";
205 let data_shards = 10;
206 let parity_shards = 5;
207
208 let shards = encode_shards(data, data_shards, parity_shards, HMAC_KEY)?;
209 assert_eq!(shards.len(), 15);
210
211 let opt_shards: Vec<Option<Shard>> = shards.into_iter().map(Some).collect();
213
214 let recovered = decode_shards(
215 &opt_shards,
216 data_shards,
217 parity_shards,
218 HMAC_KEY,
219 data.len(),
220 )?;
221
222 assert_eq!(recovered.as_ref(), data);
223 Ok(())
224 }
225
226 #[test]
227 fn test_decode_with_missing_shards() -> TestResult {
228 let data = b"The quick brown fox jumps over the lazy dog";
229 let data_shards = 10;
230 let parity_shards = 5;
231
232 let shards = encode_shards(data, data_shards, parity_shards, HMAC_KEY)?;
233
234 let mut opt_shards: Vec<Option<Shard>> = shards.into_iter().map(Some).collect();
236 *opt_shards.get_mut(0).ok_or("out of bounds")? = None;
237 *opt_shards.get_mut(3).ok_or("out of bounds")? = None;
238 *opt_shards.get_mut(7).ok_or("out of bounds")? = None;
239 *opt_shards.get_mut(10).ok_or("out of bounds")? = None;
240 *opt_shards.get_mut(13).ok_or("out of bounds")? = None;
241
242 let recovered = decode_shards(
243 &opt_shards,
244 data_shards,
245 parity_shards,
246 HMAC_KEY,
247 data.len(),
248 )?;
249
250 assert_eq!(recovered.as_ref(), data);
251 Ok(())
252 }
253
254 #[test]
255 fn test_decode_insufficient_shards() -> TestResult {
256 let data = b"test data";
257 let data_shards = 10;
258 let parity_shards = 5;
259
260 let shards = encode_shards(data, data_shards, parity_shards, HMAC_KEY)?;
261
262 let mut opt_shards: Vec<Option<Shard>> = shards.into_iter().map(Some).collect();
264 for i in 0..6 {
265 *opt_shards.get_mut(i).ok_or("out of bounds")? = None;
266 }
267
268 let result = decode_shards(
269 &opt_shards,
270 data_shards,
271 parity_shards,
272 HMAC_KEY,
273 data.len(),
274 );
275 assert!(matches!(
276 result,
277 Err(CorrectionError::InsufficientShards { .. })
278 ));
279 Ok(())
280 }
281
282 #[test]
283 fn test_decode_hmac_mismatch() -> TestResult {
284 let data = b"test data";
285 let data_shards = 10;
286 let parity_shards = 5;
287
288 let mut shards = encode_shards(data, data_shards, parity_shards, HMAC_KEY)?;
289
290 let shard = shards.get_mut(0).ok_or("missing shard 0")?;
292 *shard.data.first_mut().ok_or("empty shard data")? ^= 0xFF;
293
294 let opt_shards: Vec<Option<Shard>> = shards.into_iter().map(Some).collect();
295
296 let result = decode_shards(
297 &opt_shards,
298 data_shards,
299 parity_shards,
300 HMAC_KEY,
301 data.len(),
302 );
303 assert!(matches!(result, Err(CorrectionError::HmacMismatch { .. })));
304 Ok(())
305 }
306
307 #[test]
308 fn test_encode_zero_data_shards() {
309 let data = b"test";
310 let result = encode_shards(data, 0, 5, HMAC_KEY);
311 assert!(matches!(
312 result,
313 Err(CorrectionError::InvalidParameters { .. })
314 ));
315 }
316
317 #[test]
318 fn test_encode_zero_parity_shards() {
319 let data = b"test";
320 let result = encode_shards(data, 10, 0, HMAC_KEY);
321 assert!(matches!(
322 result,
323 Err(CorrectionError::InvalidParameters { .. })
324 ));
325 }
326}