Skip to main content

shadowforge_lib/adapters/
correction.rs

1//! Reed-Solomon error correction adapter.
2
3use bytes::Bytes;
4
5use crate::domain::correction::{decode_shards, encode_shards};
6use crate::domain::errors::CorrectionError;
7use crate::domain::ports::ErrorCorrector;
8use crate::domain::types::Shard;
9
10/// Reed-Solomon error correction adapter with HMAC-tagged shards.
11///
12/// Implements the [`ErrorCorrector`] port using the `reed-solomon-erasure` crate.
13/// Each shard is tagged with HMAC-SHA-256 for integrity verification.
14#[derive(Debug)]
15pub struct RsErrorCorrector {
16    /// HMAC key for shard integrity tags.
17    hmac_key: Vec<u8>,
18}
19
20impl RsErrorCorrector {
21    /// Create a new Reed-Solomon error corrector with the given HMAC key.
22    ///
23    /// # Panics
24    /// Panics if `hmac_key` is empty.
25    #[must_use]
26    pub fn new(hmac_key: Vec<u8>) -> Self {
27        assert!(!hmac_key.is_empty(), "HMAC key must not be empty");
28        Self { hmac_key }
29    }
30}
31
32impl ErrorCorrector for RsErrorCorrector {
33    fn encode(
34        &self,
35        data: &[u8],
36        data_shards: u8,
37        parity_shards: u8,
38    ) -> Result<Vec<Shard>, CorrectionError> {
39        encode_shards(data, data_shards, parity_shards, &self.hmac_key)
40    }
41
42    fn decode(
43        &self,
44        shards: &[Option<Shard>],
45        data_shards: u8,
46        parity_shards: u8,
47    ) -> Result<Bytes, CorrectionError> {
48        // Calculate original data length from the first available shard
49        let first_shard = shards.iter().find_map(|opt| opt.as_ref()).ok_or_else(|| {
50            CorrectionError::InsufficientShards {
51                needed: usize::from(data_shards),
52                available: 0,
53            }
54        })?;
55
56        // Original length is encoded in the first data shard's metadata
57        // For now, we'll reconstruct all data and let the caller handle trimming
58        // In a real implementation, this would be in a shard metadata field
59        let shard_size = first_shard.data.len();
60        let total_data_size = shard_size.strict_mul(usize::from(data_shards));
61
62        decode_shards(
63            shards,
64            data_shards,
65            parity_shards,
66            &self.hmac_key,
67            total_data_size,
68        )
69    }
70}
71
72// ─── Tests ────────────────────────────────────────────────────────────────────
73
74#[cfg(test)]
75mod tests {
76    use super::*;
77
78    type TestResult = Result<(), Box<dyn std::error::Error>>;
79
80    #[test]
81    fn test_rs_error_corrector_roundtrip() -> TestResult {
82        let hmac_key = b"test_key_32_bytes_long_padding!!".to_vec();
83        let corrector = RsErrorCorrector::new(hmac_key);
84
85        let data = b"The quick brown fox jumps over the lazy dog";
86        let shards = corrector.encode(data, 10, 5)?;
87
88        let opt_shards: Vec<Option<Shard>> = shards.into_iter().map(Some).collect();
89        let recovered = corrector.decode(&opt_shards, 10, 5)?;
90
91        // Recovered data may be padded, so check prefix
92        assert!(recovered.starts_with(data));
93        Ok(())
94    }
95
96    #[test]
97    fn test_rs_error_corrector_with_missing_shards() -> TestResult {
98        let hmac_key = b"test_key_32_bytes_long_padding!!".to_vec();
99        let corrector = RsErrorCorrector::new(hmac_key);
100
101        let data = b"The quick brown fox jumps over the lazy dog";
102        let shards = corrector.encode(data, 10, 5)?;
103
104        // Drop 5 shards
105        let mut opt_shards: Vec<Option<Shard>> = shards.into_iter().map(Some).collect();
106        *opt_shards.get_mut(0).ok_or("out of bounds")? = None;
107        *opt_shards.get_mut(3).ok_or("out of bounds")? = None;
108        *opt_shards.get_mut(7).ok_or("out of bounds")? = None;
109        *opt_shards.get_mut(10).ok_or("out of bounds")? = None;
110        *opt_shards.get_mut(13).ok_or("out of bounds")? = None;
111
112        let recovered = corrector.decode(&opt_shards, 10, 5)?;
113        assert!(recovered.starts_with(data));
114        Ok(())
115    }
116
117    #[test]
118    fn test_rs_error_corrector_insufficient_shards() -> TestResult {
119        let hmac_key = b"test_key_32_bytes_long_padding!!".to_vec();
120        let corrector = RsErrorCorrector::new(hmac_key);
121
122        let data = b"test data";
123        let shards = corrector.encode(data, 10, 5)?;
124
125        // Drop 6 shards (too many)
126        let mut opt_shards: Vec<Option<Shard>> = shards.into_iter().map(Some).collect();
127        for i in 0..6 {
128            *opt_shards.get_mut(i).ok_or("out of bounds")? = None;
129        }
130
131        let result = corrector.decode(&opt_shards, 10, 5);
132        assert!(matches!(
133            result,
134            Err(CorrectionError::InsufficientShards { .. })
135        ));
136        Ok(())
137    }
138}