Skip to main content

shadowforge_lib/domain/correction/
mod.rs

1//! Reed-Solomon K-of-N erasure coding with HMAC integrity.
2//!
3//! All functions are pure — no I/O, no file system, no network.
4
5use bytes::Bytes;
6use hmac::{Hmac, Mac};
7use reed_solomon_erasure::galois_8::ReedSolomon;
8use sha2::Sha256;
9
10use crate::domain::errors::CorrectionError;
11use crate::domain::types::Shard;
12
13type HmacSha256 = Hmac<Sha256>;
14
15/// Compute HMAC-SHA-256 tag for a shard.
16///
17/// Tag covers: `index || total || data`
18fn compute_hmac_tag(
19    hmac_key: &[u8],
20    index: u8,
21    total: u8,
22    data: &[u8],
23) -> Result<[u8; 32], CorrectionError> {
24    let mut mac =
25        HmacSha256::new_from_slice(hmac_key).map_err(|_| CorrectionError::InvalidParameters {
26            reason: "invalid HMAC key length".into(),
27        })?;
28    mac.update(&[index]);
29    mac.update(&[total]);
30    mac.update(data);
31    Ok(mac.finalize().into_bytes().into())
32}
33
34/// Verify HMAC tag for a shard.
35///
36/// Returns `true` if the tag is valid.
37fn verify_hmac_tag(hmac_key: &[u8], shard: &Shard) -> Result<bool, CorrectionError> {
38    use subtle::ConstantTimeEq;
39
40    let expected = compute_hmac_tag(hmac_key, shard.index, shard.total, &shard.data)?;
41    Ok(expected.ct_eq(&shard.hmac_tag).into())
42}
43
44/// Encode data into Reed-Solomon shards with HMAC tags.
45///
46/// # Errors
47/// Returns [`CorrectionError::InvalidParameters`] if shard counts are invalid,
48/// or [`CorrectionError::ReedSolomonError`] if encoding fails.
49pub fn encode_shards(
50    data: &[u8],
51    data_shards: u8,
52    parity_shards: u8,
53    hmac_key: &[u8],
54) -> Result<Vec<Shard>, CorrectionError> {
55    if data_shards == 0 {
56        return Err(CorrectionError::InvalidParameters {
57            reason: "data_shards must be > 0".into(),
58        });
59    }
60    if parity_shards == 0 {
61        return Err(CorrectionError::InvalidParameters {
62            reason: "parity_shards must be > 0".into(),
63        });
64    }
65
66    let total_shards = data_shards.strict_add(parity_shards);
67    let shard_size = (data
68        .len()
69        .strict_add(usize::from(data_shards).strict_sub(1)))
70        / usize::from(data_shards);
71
72    // Pad data to fit evenly into data_shards
73    let total_size = shard_size.strict_mul(usize::from(data_shards));
74    let mut padded = vec![0u8; total_size];
75    padded
76        .get_mut(..data.len())
77        .ok_or_else(|| CorrectionError::InvalidParameters {
78            reason: "data length exceeds padded buffer".into(),
79        })?
80        .copy_from_slice(data);
81
82    // Create Reed-Solomon encoder
83    let rs =
84        ReedSolomon::new(usize::from(data_shards), usize::from(parity_shards)).map_err(|e| {
85            CorrectionError::ReedSolomonError {
86                reason: e.to_string(),
87            }
88        })?;
89
90    // Split into chunks
91    let mut chunks: Vec<Vec<u8>> = padded.chunks(shard_size).map(<[u8]>::to_vec).collect();
92
93    // Add parity shards
94    chunks.resize(usize::from(total_shards), vec![0u8; shard_size]);
95
96    // Encode
97    rs.encode(&mut chunks)
98        .map_err(|e| CorrectionError::ReedSolomonError {
99            reason: e.to_string(),
100        })?;
101
102    // Create shards with HMAC tags
103    let shards = chunks
104        .into_iter()
105        .enumerate()
106        .map(|(i, data)| {
107            #[expect(clippy::cast_possible_truncation, reason = "total_shards is u8")]
108            let index = i as u8;
109            let hmac_tag = compute_hmac_tag(hmac_key, index, total_shards, &data)?;
110            Ok(Shard {
111                index,
112                total: total_shards,
113                data,
114                hmac_tag,
115            })
116        })
117        .collect::<Result<Vec<Shard>, CorrectionError>>()?;
118
119    Ok(shards)
120}
121
122/// Decode Reed-Solomon shards back to original data.
123///
124/// Accepts partial shard sets (some may be `None`). Requires at least
125/// `data_shards` valid shards with passing HMAC tags.
126///
127/// # Errors
128/// Returns [`CorrectionError::InsufficientShards`] if not enough valid shards,
129/// [`CorrectionError::HmacMismatch`] if HMAC verification fails, or
130/// [`CorrectionError::ReedSolomonError`] if decoding fails.
131pub fn decode_shards(
132    shards: &[Option<Shard>],
133    data_shards: u8,
134    parity_shards: u8,
135    hmac_key: &[u8],
136    original_len: usize,
137) -> Result<Bytes, CorrectionError> {
138    let total_shards = data_shards.strict_add(parity_shards);
139
140    if shards.len() != usize::from(total_shards) {
141        return Err(CorrectionError::InvalidParameters {
142            reason: format!("expected {} shards, got {}", total_shards, shards.len()),
143        });
144    }
145
146    // Verify HMAC tags for all present shards
147    for shard in shards.iter().flatten() {
148        if !verify_hmac_tag(hmac_key, shard)? {
149            return Err(CorrectionError::HmacMismatch { index: shard.index });
150        }
151    }
152
153    // Count valid shards
154    let valid_count = shards.iter().filter(|s| s.is_some()).count();
155    if valid_count < usize::from(data_shards) {
156        return Err(CorrectionError::InsufficientShards {
157            needed: usize::from(data_shards),
158            available: valid_count,
159        });
160    }
161
162    // Create Reed-Solomon decoder
163    let rs =
164        ReedSolomon::new(usize::from(data_shards), usize::from(parity_shards)).map_err(|e| {
165            CorrectionError::ReedSolomonError {
166                reason: e.to_string(),
167            }
168        })?;
169
170    // Convert to Option<Vec<u8>> for RS decoder
171    let mut chunks: Vec<Option<Vec<u8>>> = shards
172        .iter()
173        .map(|opt| opt.as_ref().map(|s| s.data.clone()))
174        .collect();
175
176    // Decode
177    rs.reconstruct(&mut chunks)
178        .map_err(|e| CorrectionError::ReedSolomonError {
179            reason: e.to_string(),
180        })?;
181
182    // Extract data shards
183    let mut recovered = Vec::new();
184    for chunk in chunks.iter().take(usize::from(data_shards)).flatten() {
185        recovered.extend_from_slice(chunk);
186    }
187
188    // Trim to original length
189    recovered.truncate(original_len);
190
191    Ok(Bytes::from(recovered))
192}
193
194#[cfg(test)]
195mod tests {
196    use super::*;
197
198    type TestResult = Result<(), Box<dyn std::error::Error>>;
199
200    const HMAC_KEY: &[u8] = b"test_hmac_key_32_bytes_long_!!!";
201
202    #[test]
203    fn test_encode_decode_roundtrip() -> TestResult {
204        let data = b"The quick brown fox jumps over the lazy dog";
205        let data_shards = 10;
206        let parity_shards = 5;
207
208        let shards = encode_shards(data, data_shards, parity_shards, HMAC_KEY)?;
209        assert_eq!(shards.len(), 15);
210
211        // Convert to Option<Shard>
212        let opt_shards: Vec<Option<Shard>> = shards.into_iter().map(Some).collect();
213
214        let recovered = decode_shards(
215            &opt_shards,
216            data_shards,
217            parity_shards,
218            HMAC_KEY,
219            data.len(),
220        )?;
221
222        assert_eq!(recovered.as_ref(), data);
223        Ok(())
224    }
225
226    #[test]
227    fn test_decode_with_missing_shards() -> TestResult {
228        let data = b"The quick brown fox jumps over the lazy dog";
229        let data_shards = 10;
230        let parity_shards = 5;
231
232        let shards = encode_shards(data, data_shards, parity_shards, HMAC_KEY)?;
233
234        // Drop 5 shards (any 5)
235        let mut opt_shards: Vec<Option<Shard>> = shards.into_iter().map(Some).collect();
236        *opt_shards.get_mut(0).ok_or("out of bounds")? = None;
237        *opt_shards.get_mut(3).ok_or("out of bounds")? = None;
238        *opt_shards.get_mut(7).ok_or("out of bounds")? = None;
239        *opt_shards.get_mut(10).ok_or("out of bounds")? = None;
240        *opt_shards.get_mut(13).ok_or("out of bounds")? = None;
241
242        let recovered = decode_shards(
243            &opt_shards,
244            data_shards,
245            parity_shards,
246            HMAC_KEY,
247            data.len(),
248        )?;
249
250        assert_eq!(recovered.as_ref(), data);
251        Ok(())
252    }
253
254    #[test]
255    fn test_decode_insufficient_shards() -> TestResult {
256        let data = b"test data";
257        let data_shards = 10;
258        let parity_shards = 5;
259
260        let shards = encode_shards(data, data_shards, parity_shards, HMAC_KEY)?;
261
262        // Drop 6 shards (too many)
263        let mut opt_shards: Vec<Option<Shard>> = shards.into_iter().map(Some).collect();
264        for i in 0..6 {
265            *opt_shards.get_mut(i).ok_or("out of bounds")? = None;
266        }
267
268        let result = decode_shards(
269            &opt_shards,
270            data_shards,
271            parity_shards,
272            HMAC_KEY,
273            data.len(),
274        );
275        assert!(matches!(
276            result,
277            Err(CorrectionError::InsufficientShards { .. })
278        ));
279        Ok(())
280    }
281
282    #[test]
283    fn test_decode_hmac_mismatch() -> TestResult {
284        let data = b"test data";
285        let data_shards = 10;
286        let parity_shards = 5;
287
288        let mut shards = encode_shards(data, data_shards, parity_shards, HMAC_KEY)?;
289
290        // Tamper with one shard's data
291        let shard = shards.get_mut(0).ok_or("missing shard 0")?;
292        *shard.data.first_mut().ok_or("empty shard data")? ^= 0xFF;
293
294        let opt_shards: Vec<Option<Shard>> = shards.into_iter().map(Some).collect();
295
296        let result = decode_shards(
297            &opt_shards,
298            data_shards,
299            parity_shards,
300            HMAC_KEY,
301            data.len(),
302        );
303        assert!(matches!(result, Err(CorrectionError::HmacMismatch { .. })));
304        Ok(())
305    }
306
307    #[test]
308    fn test_encode_zero_data_shards() {
309        let data = b"test";
310        let result = encode_shards(data, 0, 5, HMAC_KEY);
311        assert!(matches!(
312            result,
313            Err(CorrectionError::InvalidParameters { .. })
314        ));
315    }
316
317    #[test]
318    fn test_encode_zero_parity_shards() {
319        let data = b"test";
320        let result = encode_shards(data, 10, 0, HMAC_KEY);
321        assert!(matches!(
322            result,
323            Err(CorrectionError::InvalidParameters { .. })
324        ));
325    }
326}