Skip to main content

shadowforge_lib/adapters/
reconstruction.rs

1//! Adapter implementing the [`Reconstructor`] port for K-of-N shard
2//! reassembly with full verification chain.
3
4use crate::domain::errors::ReconstructionError;
5use crate::domain::ports::{ErrorCorrector, ExtractTechnique, Reconstructor};
6use crate::domain::reconstruction::{
7    arrange_shards, count_present, deserialize_shard, validate_shard_count,
8};
9use crate::domain::types::{CoverMedia, Payload};
10
11/// Concrete [`Reconstructor`] implementation.
12///
13/// Reconstruction chain:
14/// 1. Extract shard data from each stego cover
15/// 2. Arrange shards by index into slots
16/// 3. Validate minimum shard count (K of N)
17/// 4. RS-decode to recover original payload
18pub struct ReconstructorImpl {
19    /// Number of data shards (K).
20    data_shards: u8,
21    /// Number of parity shards (M).
22    parity_shards: u8,
23    /// Original payload length for RS trim.
24    original_len: usize,
25    /// Error corrector for K-of-N decoding.
26    corrector: Box<dyn ErrorCorrector>,
27}
28
29impl ReconstructorImpl {
30    /// Create a new reconstructor with the given shard parameters and error corrector.
31    #[must_use]
32    pub fn new(
33        data_shards: u8,
34        parity_shards: u8,
35        original_len: usize,
36        corrector: Box<dyn ErrorCorrector>,
37    ) -> Self {
38        Self {
39            data_shards,
40            parity_shards,
41            original_len,
42            corrector,
43        }
44    }
45}
46
47impl Reconstructor for ReconstructorImpl {
48    fn reconstruct(
49        &self,
50        covers: Vec<CoverMedia>,
51        extractor: &dyn ExtractTechnique,
52        progress_cb: &dyn Fn(usize, usize),
53    ) -> Result<Payload, ReconstructionError> {
54        let total = covers.len();
55        let total_shards = self.data_shards.strict_add(self.parity_shards);
56
57        // Step 1: Extract shard data from each cover
58        let mut shards = Vec::with_capacity(total);
59        for (i, cover) in covers.into_iter().enumerate() {
60            match extractor.extract(&cover) {
61                Ok(payload) => {
62                    // Deserialize from the embedded binary format
63                    if let Some(shard) = deserialize_shard(payload.as_bytes()) {
64                        shards.push(shard);
65                    } else {
66                        tracing::warn!(
67                            cover_index = i,
68                            "could not deserialize shard, treating as missing"
69                        );
70                    }
71                }
72                Err(e) => {
73                    // Treat extraction failure as missing shard
74                    // (within parity budget, reconstruction may still succeed)
75                    tracing::warn!(cover_index = i, error = %e, "extraction failed, treating as missing shard");
76                }
77            }
78            progress_cb(i.strict_add(1), total);
79        }
80
81        // Step 2: Arrange by index
82        let slots = arrange_shards(shards, total_shards);
83        let present = count_present(&slots);
84
85        // Step 3: Validate minimum count
86        validate_shard_count(present, usize::from(self.data_shards))?;
87
88        // Step 4: RS-decode via the ErrorCorrector port.
89        let recovered = self
90            .corrector
91            .decode(&slots, self.data_shards, self.parity_shards)
92            .map_err(|source| ReconstructionError::CorrectionFailed { source })?;
93
94        let payload_bytes = if self.original_len > 0 {
95            recovered
96                .get(..self.original_len)
97                .map_or_else(|| recovered.to_vec(), ToOwned::to_owned)
98        } else {
99            recovered.to_vec()
100        };
101
102        Ok(Payload::from_bytes(payload_bytes))
103    }
104}
105
106#[cfg(test)]
107mod tests {
108    use super::*;
109    use crate::domain::errors::StegoError;
110    use crate::domain::ports::EmbedTechnique;
111    use crate::domain::ports::ErrorCorrector;
112    use crate::domain::reconstruction::serialize_shard;
113    use crate::domain::types::{Capacity, CoverMedia, CoverMediaKind, StegoTechnique};
114    use bytes::Bytes;
115    use std::cell::Cell;
116
117    type TestResult = Result<(), Box<dyn std::error::Error>>;
118
119    /// Mock embedder: prepends a 4-byte length header then payload.
120    struct MockEmbedder;
121
122    impl EmbedTechnique for MockEmbedder {
123        fn technique(&self) -> StegoTechnique {
124            StegoTechnique::LsbImage
125        }
126
127        fn capacity(&self, cover: &CoverMedia) -> Result<Capacity, StegoError> {
128            Ok(Capacity {
129                bytes: cover.data.len() as u64,
130                technique: StegoTechnique::LsbImage,
131            })
132        }
133
134        fn embed(&self, cover: CoverMedia, payload: &Payload) -> Result<CoverMedia, StegoError> {
135            let mut data = cover.data.to_vec();
136            #[expect(clippy::cast_possible_truncation, reason = "test data < 4 GiB")]
137            let len = payload.len() as u32;
138            data.extend_from_slice(&len.to_le_bytes());
139            data.extend_from_slice(payload.as_bytes());
140            Ok(CoverMedia {
141                kind: cover.kind,
142                data: Bytes::from(data),
143                metadata: cover.metadata,
144            })
145        }
146    }
147
148    /// Mock extractor: reads length-prefixed payload after cover prefix.
149    struct MockExtractor {
150        cover_prefix_len: usize,
151    }
152
153    impl ExtractTechnique for MockExtractor {
154        fn technique(&self) -> StegoTechnique {
155            StegoTechnique::LsbImage
156        }
157
158        fn extract(&self, stego: &CoverMedia) -> Result<Payload, StegoError> {
159            let data = &stego.data;
160            if data.len() <= self.cover_prefix_len + 4 {
161                return Err(StegoError::NoPayloadFound);
162            }
163            let offset = self.cover_prefix_len;
164            let len_bytes: [u8; 4] = data
165                .get(offset..offset + 4)
166                .ok_or(StegoError::NoPayloadFound)?
167                .try_into()
168                .map_err(|_| StegoError::NoPayloadFound)?;
169            let len = u32::from_le_bytes(len_bytes) as usize;
170            let start = offset + 4;
171            let payload_data = data
172                .get(start..start + len)
173                .ok_or(StegoError::NoPayloadFound)?;
174            Ok(Payload::from_bytes(payload_data.to_vec()))
175        }
176    }
177
178    fn make_cover(size: usize) -> CoverMedia {
179        CoverMedia {
180            kind: CoverMediaKind::PngImage,
181            data: Bytes::from(vec![0u8; size]),
182            metadata: std::collections::HashMap::new(),
183        }
184    }
185
186    /// Helper: encode payload into shards, embed each in a cover.
187    fn distribute_and_get_covers(
188        payload: &[u8],
189        data_shards: u8,
190        parity_shards: u8,
191        hmac_key: &[u8],
192        cover_size: usize,
193    ) -> Result<Vec<CoverMedia>, Box<dyn std::error::Error>> {
194        let corrector = crate::adapters::correction::RsErrorCorrector::new(hmac_key.to_vec());
195        let shards = corrector.encode(payload, data_shards, parity_shards)?;
196        let embedder = MockEmbedder;
197        let covers = shards
198            .iter()
199            .map(|shard| {
200                let cover = make_cover(cover_size);
201                let serialized = serialize_shard(shard);
202                let shard_payload = Payload::from_bytes(serialized);
203                embedder.embed(cover, &shard_payload)
204            })
205            .collect::<Result<Vec<_>, _>>()?;
206        Ok(covers)
207    }
208
209    #[test]
210    fn full_recovery_all_shards_present() -> TestResult {
211        let original = b"hello reconstruction world!";
212        let hmac_key = b"test-hmac-key";
213        let covers = distribute_and_get_covers(original, 3, 2, hmac_key, 128)?;
214        assert_eq!(covers.len(), 5);
215
216        let corrector: Box<dyn ErrorCorrector> = Box::new(
217            crate::adapters::correction::RsErrorCorrector::new(hmac_key.to_vec()),
218        );
219        let reconstructor = ReconstructorImpl::new(3, 2, original.len(), corrector);
220        let extractor = MockExtractor {
221            cover_prefix_len: 128,
222        };
223        let progress_calls = Cell::new(0usize);
224        let result = reconstructor.reconstruct(covers, &extractor, &|_done, _total| {
225            progress_calls.set(progress_calls.get().strict_add(1));
226        })?;
227
228        assert_eq!(result.as_bytes(), original);
229        assert_eq!(progress_calls.get(), 5);
230        Ok(())
231    }
232
233    #[test]
234    fn partial_recovery_minimum_shards() -> TestResult {
235        let original = b"partial recovery test payload";
236        let hmac_key = b"test-hmac-key";
237        let mut covers = distribute_and_get_covers(original, 3, 2, hmac_key, 128)?;
238        assert_eq!(covers.len(), 5);
239
240        // Drop 2 parity shards (keep exactly data_shards = 3)
241        covers.remove(4);
242        covers.remove(3);
243
244        let corrector: Box<dyn ErrorCorrector> = Box::new(
245            crate::adapters::correction::RsErrorCorrector::new(hmac_key.to_vec()),
246        );
247        let reconstructor = ReconstructorImpl::new(3, 2, original.len(), corrector);
248        let extractor = MockExtractor {
249            cover_prefix_len: 128,
250        };
251        let result = reconstructor.reconstruct(covers, &extractor, &|_, _| {})?;
252
253        assert_eq!(result.as_bytes(), original);
254        Ok(())
255    }
256
257    #[test]
258    fn insufficient_shards_returns_error() -> TestResult {
259        let original = b"not enough shards";
260        let hmac_key = b"test-hmac-key";
261        let mut covers = distribute_and_get_covers(original, 3, 2, hmac_key, 128)?;
262
263        // Drop 3 shards (only 2 remain, but need 3)
264        covers.remove(4);
265        covers.remove(3);
266        covers.remove(2);
267
268        let corrector: Box<dyn ErrorCorrector> = Box::new(
269            crate::adapters::correction::RsErrorCorrector::new(hmac_key.to_vec()),
270        );
271        let reconstructor = ReconstructorImpl::new(3, 2, original.len(), corrector);
272        let extractor = MockExtractor {
273            cover_prefix_len: 128,
274        };
275        let result = reconstructor.reconstruct(covers, &extractor, &|_, _| {});
276
277        assert!(result.is_err());
278        Ok(())
279    }
280
281    #[test]
282    fn progress_callback_called_correctly() -> TestResult {
283        let original = b"track progress";
284        let hmac_key = b"test-hmac-key";
285        let covers = distribute_and_get_covers(original, 2, 1, hmac_key, 64)?;
286        let total_covers = covers.len();
287
288        let corrector: Box<dyn ErrorCorrector> = Box::new(
289            crate::adapters::correction::RsErrorCorrector::new(hmac_key.to_vec()),
290        );
291        let reconstructor = ReconstructorImpl::new(2, 1, original.len(), corrector);
292        let extractor = MockExtractor {
293            cover_prefix_len: 64,
294        };
295
296        let progress_log = std::cell::RefCell::new(Vec::new());
297        let result = reconstructor.reconstruct(covers, &extractor, &|done, total| {
298            progress_log.borrow_mut().push((done, total));
299        })?;
300
301        assert_eq!(result.as_bytes(), original);
302        let log = progress_log.borrow();
303        assert_eq!(log.len(), total_covers);
304        for (i, &(done, total)) in log.iter().enumerate() {
305            assert_eq!(done, i + 1);
306            assert_eq!(total, total_covers);
307        }
308        Ok(())
309    }
310}