Skip to main content

shadowforge_lib/adapters/
reconstruction.rs

1//! Adapter implementing the [`Reconstructor`] port for K-of-N shard
2//! reassembly with full verification chain.
3
4use crate::domain::correction::decode_shards;
5use crate::domain::errors::ReconstructionError;
6use crate::domain::ports::{ExtractTechnique, Reconstructor};
7use crate::domain::reconstruction::{
8    arrange_shards, count_present, deserialize_shard, validate_shard_count,
9};
10use crate::domain::types::{CoverMedia, Payload};
11
12/// Concrete [`Reconstructor`] implementation.
13///
14/// Reconstruction chain:
15/// 1. Extract shard data from each stego cover
16/// 2. Arrange shards by index into slots
17/// 3. Validate minimum shard count (K of N)
18/// 4. RS-decode to recover original payload
19pub struct ReconstructorImpl {
20    /// Number of data shards (K).
21    data_shards: u8,
22    /// Number of parity shards (M).
23    parity_shards: u8,
24    /// HMAC key for shard verification.
25    hmac_key: Vec<u8>,
26    /// Original payload length for RS trim.
27    original_len: usize,
28}
29
30impl ReconstructorImpl {
31    /// Create a new reconstructor with the given shard parameters.
32    #[must_use]
33    pub const fn new(
34        data_shards: u8,
35        parity_shards: u8,
36        hmac_key: Vec<u8>,
37        original_len: usize,
38    ) -> Self {
39        Self {
40            data_shards,
41            parity_shards,
42            hmac_key,
43            original_len,
44        }
45    }
46}
47
48impl Reconstructor for ReconstructorImpl {
49    fn reconstruct(
50        &self,
51        covers: Vec<CoverMedia>,
52        extractor: &dyn ExtractTechnique,
53        progress_cb: &dyn Fn(usize, usize),
54    ) -> Result<Payload, ReconstructionError> {
55        let total = covers.len();
56        let total_shards = self.data_shards.strict_add(self.parity_shards);
57
58        // Step 1: Extract shard data from each cover
59        let mut shards = Vec::with_capacity(total);
60        for (i, cover) in covers.into_iter().enumerate() {
61            match extractor.extract(&cover) {
62                Ok(payload) => {
63                    // Deserialize from the embedded binary format
64                    if let Some(shard) = deserialize_shard(payload.as_bytes()) {
65                        shards.push(shard);
66                    } else {
67                        tracing::warn!(
68                            cover_index = i,
69                            "could not deserialize shard, treating as missing"
70                        );
71                    }
72                }
73                Err(e) => {
74                    // Treat extraction failure as missing shard
75                    // (within parity budget, reconstruction may still succeed)
76                    tracing::warn!(cover_index = i, error = %e, "extraction failed, treating as missing shard");
77                }
78            }
79            progress_cb(i.strict_add(1), total);
80        }
81
82        // Step 2: Arrange by index
83        let slots = arrange_shards(shards, total_shards);
84        let present = count_present(&slots);
85
86        // Step 3: Validate minimum count
87        validate_shard_count(present, usize::from(self.data_shards))?;
88
89        // Step 4: RS-decode
90        let recovered = decode_shards(
91            &slots,
92            self.data_shards,
93            self.parity_shards,
94            &self.hmac_key,
95            self.original_len,
96        )
97        .map_err(|source| ReconstructionError::CorrectionFailed { source })?;
98
99        Ok(Payload::from_bytes(recovered.to_vec()))
100    }
101}
102
103#[cfg(test)]
104mod tests {
105    use super::*;
106    use crate::domain::correction::encode_shards;
107    use crate::domain::errors::StegoError;
108    use crate::domain::ports::EmbedTechnique;
109    use crate::domain::reconstruction::serialize_shard;
110    use crate::domain::types::{Capacity, CoverMedia, CoverMediaKind, StegoTechnique};
111    use bytes::Bytes;
112    use std::cell::Cell;
113
114    type TestResult = Result<(), Box<dyn std::error::Error>>;
115
116    /// Mock embedder: prepends a 4-byte length header then payload.
117    struct MockEmbedder;
118
119    impl EmbedTechnique for MockEmbedder {
120        fn technique(&self) -> StegoTechnique {
121            StegoTechnique::LsbImage
122        }
123
124        fn capacity(&self, cover: &CoverMedia) -> Result<Capacity, StegoError> {
125            Ok(Capacity {
126                bytes: cover.data.len() as u64,
127                technique: StegoTechnique::LsbImage,
128            })
129        }
130
131        fn embed(&self, cover: CoverMedia, payload: &Payload) -> Result<CoverMedia, StegoError> {
132            let mut data = cover.data.to_vec();
133            #[expect(clippy::cast_possible_truncation, reason = "test data < 4 GiB")]
134            let len = payload.len() as u32;
135            data.extend_from_slice(&len.to_le_bytes());
136            data.extend_from_slice(payload.as_bytes());
137            Ok(CoverMedia {
138                kind: cover.kind,
139                data: Bytes::from(data),
140                metadata: cover.metadata,
141            })
142        }
143    }
144
145    /// Mock extractor: reads length-prefixed payload after cover prefix.
146    struct MockExtractor {
147        cover_prefix_len: usize,
148    }
149
150    impl ExtractTechnique for MockExtractor {
151        fn technique(&self) -> StegoTechnique {
152            StegoTechnique::LsbImage
153        }
154
155        fn extract(&self, stego: &CoverMedia) -> Result<Payload, StegoError> {
156            let data = &stego.data;
157            if data.len() <= self.cover_prefix_len + 4 {
158                return Err(StegoError::NoPayloadFound);
159            }
160            let offset = self.cover_prefix_len;
161            let len_bytes: [u8; 4] = data
162                .get(offset..offset + 4)
163                .ok_or(StegoError::NoPayloadFound)?
164                .try_into()
165                .map_err(|_| StegoError::NoPayloadFound)?;
166            let len = u32::from_le_bytes(len_bytes) as usize;
167            let start = offset + 4;
168            let payload_data = data
169                .get(start..start + len)
170                .ok_or(StegoError::NoPayloadFound)?;
171            Ok(Payload::from_bytes(payload_data.to_vec()))
172        }
173    }
174
175    fn make_cover(size: usize) -> CoverMedia {
176        CoverMedia {
177            kind: CoverMediaKind::PngImage,
178            data: Bytes::from(vec![0u8; size]),
179            metadata: std::collections::HashMap::new(),
180        }
181    }
182
183    /// Helper: encode payload into shards, embed each in a cover.
184    fn distribute_and_get_covers(
185        payload: &[u8],
186        data_shards: u8,
187        parity_shards: u8,
188        hmac_key: &[u8],
189        cover_size: usize,
190    ) -> Result<Vec<CoverMedia>, Box<dyn std::error::Error>> {
191        let shards = encode_shards(payload, data_shards, parity_shards, hmac_key)?;
192        let embedder = MockEmbedder;
193        let covers = shards
194            .iter()
195            .map(|shard| {
196                let cover = make_cover(cover_size);
197                let serialized = serialize_shard(shard);
198                let shard_payload = Payload::from_bytes(serialized);
199                embedder.embed(cover, &shard_payload)
200            })
201            .collect::<Result<Vec<_>, _>>()?;
202        Ok(covers)
203    }
204
205    #[test]
206    fn full_recovery_all_shards_present() -> TestResult {
207        let original = b"hello reconstruction world!";
208        let hmac_key = b"test-hmac-key";
209        let covers = distribute_and_get_covers(original, 3, 2, hmac_key, 128)?;
210        assert_eq!(covers.len(), 5);
211
212        let reconstructor = ReconstructorImpl::new(3, 2, hmac_key.to_vec(), original.len());
213        let extractor = MockExtractor {
214            cover_prefix_len: 128,
215        };
216        let progress_calls = Cell::new(0usize);
217        let result = reconstructor.reconstruct(covers, &extractor, &|_done, _total| {
218            progress_calls.set(progress_calls.get().strict_add(1));
219        })?;
220
221        assert_eq!(result.as_bytes(), original);
222        assert_eq!(progress_calls.get(), 5);
223        Ok(())
224    }
225
226    #[test]
227    fn partial_recovery_minimum_shards() -> TestResult {
228        let original = b"partial recovery test payload";
229        let hmac_key = b"test-hmac-key";
230        let mut covers = distribute_and_get_covers(original, 3, 2, hmac_key, 128)?;
231        assert_eq!(covers.len(), 5);
232
233        // Drop 2 parity shards (keep exactly data_shards = 3)
234        covers.remove(4);
235        covers.remove(3);
236
237        let reconstructor = ReconstructorImpl::new(3, 2, hmac_key.to_vec(), original.len());
238        let extractor = MockExtractor {
239            cover_prefix_len: 128,
240        };
241        let result = reconstructor.reconstruct(covers, &extractor, &|_, _| {})?;
242
243        assert_eq!(result.as_bytes(), original);
244        Ok(())
245    }
246
247    #[test]
248    fn insufficient_shards_returns_error() -> TestResult {
249        let original = b"not enough shards";
250        let hmac_key = b"test-hmac-key";
251        let mut covers = distribute_and_get_covers(original, 3, 2, hmac_key, 128)?;
252
253        // Drop 3 shards (only 2 remain, but need 3)
254        covers.remove(4);
255        covers.remove(3);
256        covers.remove(2);
257
258        let reconstructor = ReconstructorImpl::new(3, 2, hmac_key.to_vec(), original.len());
259        let extractor = MockExtractor {
260            cover_prefix_len: 128,
261        };
262        let result = reconstructor.reconstruct(covers, &extractor, &|_, _| {});
263
264        assert!(result.is_err());
265        Ok(())
266    }
267
268    #[test]
269    fn progress_callback_called_correctly() -> TestResult {
270        let original = b"track progress";
271        let hmac_key = b"test-hmac-key";
272        let covers = distribute_and_get_covers(original, 2, 1, hmac_key, 64)?;
273        let total_covers = covers.len();
274
275        let reconstructor = ReconstructorImpl::new(2, 1, hmac_key.to_vec(), original.len());
276        let extractor = MockExtractor {
277            cover_prefix_len: 64,
278        };
279
280        let progress_log = std::cell::RefCell::new(Vec::new());
281        let result = reconstructor.reconstruct(covers, &extractor, &|done, total| {
282            progress_log.borrow_mut().push((done, total));
283        })?;
284
285        assert_eq!(result.as_bytes(), original);
286        let log = progress_log.borrow();
287        assert_eq!(log.len(), total_covers);
288        for (i, &(done, total)) in log.iter().enumerate() {
289            assert_eq!(done, i + 1);
290            assert_eq!(total, total_covers);
291        }
292        Ok(())
293    }
294}