Skip to main content

shadowforge_lib/adapters/
reconstruction.rs

1//! Adapter implementing the [`Reconstructor`] port for K-of-N shard
2//! reassembly with full verification chain.
3
4use crate::domain::errors::ReconstructionError;
5use crate::domain::ports::{ErrorCorrector, ExtractTechnique, Reconstructor};
6use crate::domain::reconstruction::{
7    arrange_shards, count_present, deserialize_shard, validate_shard_count,
8};
9use crate::domain::types::{CoverMedia, Payload};
10
11/// Concrete [`Reconstructor`] implementation.
12///
13/// Reconstruction chain:
14/// 1. Extract shard data from each stego cover
15/// 2. Arrange shards by index into slots
16/// 3. Validate minimum shard count (K of N)
17/// 4. RS-decode to recover original payload
18pub struct ReconstructorImpl {
19    /// Number of data shards (K).
20    data_shards: u8,
21    /// Number of parity shards (M).
22    parity_shards: u8,
23    /// Original payload length for RS trim.
24    original_len: usize,
25    /// Error corrector for K-of-N decoding.
26    corrector: Box<dyn ErrorCorrector>,
27}
28
29impl ReconstructorImpl {
30    /// Create a new reconstructor with the given shard parameters and error corrector.
31    #[must_use]
32    pub fn new(
33        data_shards: u8,
34        parity_shards: u8,
35        original_len: usize,
36        corrector: Box<dyn ErrorCorrector>,
37    ) -> Self {
38        Self {
39            data_shards,
40            parity_shards,
41            original_len,
42            corrector,
43        }
44    }
45}
46
47impl Reconstructor for ReconstructorImpl {
48    fn reconstruct(
49        &self,
50        covers: Vec<CoverMedia>,
51        extractor: &dyn ExtractTechnique,
52        progress_cb: &dyn Fn(usize, usize),
53    ) -> Result<Payload, ReconstructionError> {
54        let total = covers.len();
55        let total_shards = self.data_shards.strict_add(self.parity_shards);
56
57        // Step 1: Extract shard data from each cover
58        let mut shards = Vec::with_capacity(total);
59        for (i, cover) in covers.into_iter().enumerate() {
60            match extractor.extract(&cover) {
61                Ok(payload) => {
62                    // Deserialize from the embedded binary format
63                    if let Some(shard) = deserialize_shard(payload.as_bytes()) {
64                        shards.push(shard);
65                    } else {
66                        tracing::warn!(
67                            cover_index = i,
68                            "could not deserialize shard, treating as missing"
69                        );
70                    }
71                }
72                Err(e) => {
73                    // Treat extraction failure as missing shard
74                    // (within parity budget, reconstruction may still succeed)
75                    tracing::warn!(cover_index = i, error = %e, "extraction failed, treating as missing shard");
76                }
77            }
78            progress_cb(i.strict_add(1), total);
79        }
80
81        // Step 2: Arrange by index
82        let slots = arrange_shards(shards, total_shards);
83        let present = count_present(&slots);
84
85        // Step 3: Validate minimum count
86        validate_shard_count(present, usize::from(self.data_shards))?;
87
88        // Step 4: RS-decode via the ErrorCorrector port.
89        let recovered = self.corrector
90            .decode(&slots, self.data_shards, self.parity_shards)
91            .map_err(|source| ReconstructionError::CorrectionFailed { source })?;
92
93        let payload_bytes = if self.original_len > 0 {
94            recovered
95                .get(..self.original_len)
96                .map_or_else(|| recovered.to_vec(), ToOwned::to_owned)
97        } else {
98            recovered.to_vec()
99        };
100
101        Ok(Payload::from_bytes(payload_bytes))
102    }
103}
104
105#[cfg(test)]
106mod tests {
107    use super::*;
108    use crate::domain::errors::StegoError;
109    use crate::domain::ports::EmbedTechnique;
110    use crate::domain::ports::ErrorCorrector;
111    use crate::domain::reconstruction::serialize_shard;
112    use crate::domain::types::{Capacity, CoverMedia, CoverMediaKind, StegoTechnique};
113    use bytes::Bytes;
114    use std::cell::Cell;
115
116    type TestResult = Result<(), Box<dyn std::error::Error>>;
117
118    /// Mock embedder: prepends a 4-byte length header then payload.
119    struct MockEmbedder;
120
121    impl EmbedTechnique for MockEmbedder {
122        fn technique(&self) -> StegoTechnique {
123            StegoTechnique::LsbImage
124        }
125
126        fn capacity(&self, cover: &CoverMedia) -> Result<Capacity, StegoError> {
127            Ok(Capacity {
128                bytes: cover.data.len() as u64,
129                technique: StegoTechnique::LsbImage,
130            })
131        }
132
133        fn embed(&self, cover: CoverMedia, payload: &Payload) -> Result<CoverMedia, StegoError> {
134            let mut data = cover.data.to_vec();
135            #[expect(clippy::cast_possible_truncation, reason = "test data < 4 GiB")]
136            let len = payload.len() as u32;
137            data.extend_from_slice(&len.to_le_bytes());
138            data.extend_from_slice(payload.as_bytes());
139            Ok(CoverMedia {
140                kind: cover.kind,
141                data: Bytes::from(data),
142                metadata: cover.metadata,
143            })
144        }
145    }
146
147    /// Mock extractor: reads length-prefixed payload after cover prefix.
148    struct MockExtractor {
149        cover_prefix_len: usize,
150    }
151
152    impl ExtractTechnique for MockExtractor {
153        fn technique(&self) -> StegoTechnique {
154            StegoTechnique::LsbImage
155        }
156
157        fn extract(&self, stego: &CoverMedia) -> Result<Payload, StegoError> {
158            let data = &stego.data;
159            if data.len() <= self.cover_prefix_len + 4 {
160                return Err(StegoError::NoPayloadFound);
161            }
162            let offset = self.cover_prefix_len;
163            let len_bytes: [u8; 4] = data
164                .get(offset..offset + 4)
165                .ok_or(StegoError::NoPayloadFound)?
166                .try_into()
167                .map_err(|_| StegoError::NoPayloadFound)?;
168            let len = u32::from_le_bytes(len_bytes) as usize;
169            let start = offset + 4;
170            let payload_data = data
171                .get(start..start + len)
172                .ok_or(StegoError::NoPayloadFound)?;
173            Ok(Payload::from_bytes(payload_data.to_vec()))
174        }
175    }
176
177    fn make_cover(size: usize) -> CoverMedia {
178        CoverMedia {
179            kind: CoverMediaKind::PngImage,
180            data: Bytes::from(vec![0u8; size]),
181            metadata: std::collections::HashMap::new(),
182        }
183    }
184
185    /// Helper: encode payload into shards, embed each in a cover.
186    fn distribute_and_get_covers(
187        payload: &[u8],
188        data_shards: u8,
189        parity_shards: u8,
190        hmac_key: &[u8],
191        cover_size: usize,
192    ) -> Result<Vec<CoverMedia>, Box<dyn std::error::Error>> {
193        let corrector = crate::adapters::correction::RsErrorCorrector::new(hmac_key.to_vec());
194        let shards = corrector.encode(payload, data_shards, parity_shards)?;
195        let embedder = MockEmbedder;
196        let covers = shards
197            .iter()
198            .map(|shard| {
199                let cover = make_cover(cover_size);
200                let serialized = serialize_shard(shard);
201                let shard_payload = Payload::from_bytes(serialized);
202                embedder.embed(cover, &shard_payload)
203            })
204            .collect::<Result<Vec<_>, _>>()?;
205        Ok(covers)
206    }
207
208    #[test]
209    fn full_recovery_all_shards_present() -> TestResult {
210        let original = b"hello reconstruction world!";
211        let hmac_key = b"test-hmac-key";
212        let covers = distribute_and_get_covers(original, 3, 2, hmac_key, 128)?;
213        assert_eq!(covers.len(), 5);
214
215        let corrector: Box<dyn ErrorCorrector> = Box::new(
216            crate::adapters::correction::RsErrorCorrector::new(hmac_key.to_vec())
217        );
218        let reconstructor = ReconstructorImpl::new(3, 2, original.len(), corrector);
219        let extractor = MockExtractor {
220            cover_prefix_len: 128,
221        };
222        let progress_calls = Cell::new(0usize);
223        let result = reconstructor.reconstruct(covers, &extractor, &|_done, _total| {
224            progress_calls.set(progress_calls.get().strict_add(1));
225        })?;
226
227        assert_eq!(result.as_bytes(), original);
228        assert_eq!(progress_calls.get(), 5);
229        Ok(())
230    }
231
232    #[test]
233    fn partial_recovery_minimum_shards() -> TestResult {
234        let original = b"partial recovery test payload";
235        let hmac_key = b"test-hmac-key";
236        let mut covers = distribute_and_get_covers(original, 3, 2, hmac_key, 128)?;
237        assert_eq!(covers.len(), 5);
238
239        // Drop 2 parity shards (keep exactly data_shards = 3)
240        covers.remove(4);
241        covers.remove(3);
242
243        let corrector: Box<dyn ErrorCorrector> = Box::new(
244            crate::adapters::correction::RsErrorCorrector::new(hmac_key.to_vec())
245        );
246        let reconstructor = ReconstructorImpl::new(3, 2, original.len(), corrector);
247        let extractor = MockExtractor {
248            cover_prefix_len: 128,
249        };
250        let result = reconstructor.reconstruct(covers, &extractor, &|_, _| {})?;
251
252        assert_eq!(result.as_bytes(), original);
253        Ok(())
254    }
255
256    #[test]
257    fn insufficient_shards_returns_error() -> TestResult {
258        let original = b"not enough shards";
259        let hmac_key = b"test-hmac-key";
260        let mut covers = distribute_and_get_covers(original, 3, 2, hmac_key, 128)?;
261
262        // Drop 3 shards (only 2 remain, but need 3)
263        covers.remove(4);
264        covers.remove(3);
265        covers.remove(2);
266
267        let corrector: Box<dyn ErrorCorrector> = Box::new(
268            crate::adapters::correction::RsErrorCorrector::new(hmac_key.to_vec())
269        );
270        let reconstructor = ReconstructorImpl::new(3, 2, original.len(), corrector);
271        let extractor = MockExtractor {
272            cover_prefix_len: 128,
273        };
274        let result = reconstructor.reconstruct(covers, &extractor, &|_, _| {});
275
276        assert!(result.is_err());
277        Ok(())
278    }
279
280    #[test]
281    fn progress_callback_called_correctly() -> TestResult {
282        let original = b"track progress";
283        let hmac_key = b"test-hmac-key";
284        let covers = distribute_and_get_covers(original, 2, 1, hmac_key, 64)?;
285        let total_covers = covers.len();
286
287        let corrector: Box<dyn ErrorCorrector> = Box::new(
288            crate::adapters::correction::RsErrorCorrector::new(hmac_key.to_vec())
289        );
290        let reconstructor = ReconstructorImpl::new(2, 1, original.len(), corrector);
291        let extractor = MockExtractor {
292            cover_prefix_len: 64,
293        };
294
295        let progress_log = std::cell::RefCell::new(Vec::new());
296        let result = reconstructor.reconstruct(covers, &extractor, &|done, total| {
297            progress_log.borrow_mut().push((done, total));
298        })?;
299
300        assert_eq!(result.as_bytes(), original);
301        let log = progress_log.borrow();
302        assert_eq!(log.len(), total_covers);
303        for (i, &(done, total)) in log.iter().enumerate() {
304            assert_eq!(done, i + 1);
305            assert_eq!(total, total_covers);
306        }
307        Ok(())
308    }
309}