shadowforge_lib/adapters/
reconstruction.rs1use crate::domain::errors::ReconstructionError;
5use crate::domain::ports::{ErrorCorrector, ExtractTechnique, Reconstructor};
6use crate::domain::reconstruction::{
7 arrange_shards, count_present, deserialize_shard, validate_shard_count,
8};
9use crate::domain::types::{CoverMedia, Payload};
10
11pub struct ReconstructorImpl {
19 data_shards: u8,
21 parity_shards: u8,
23 original_len: usize,
25 corrector: Box<dyn ErrorCorrector>,
27}
28
29impl ReconstructorImpl {
30 #[must_use]
32 pub fn new(
33 data_shards: u8,
34 parity_shards: u8,
35 original_len: usize,
36 corrector: Box<dyn ErrorCorrector>,
37 ) -> Self {
38 Self {
39 data_shards,
40 parity_shards,
41 original_len,
42 corrector,
43 }
44 }
45}
46
47impl Reconstructor for ReconstructorImpl {
48 fn reconstruct(
49 &self,
50 covers: Vec<CoverMedia>,
51 extractor: &dyn ExtractTechnique,
52 progress_cb: &dyn Fn(usize, usize),
53 ) -> Result<Payload, ReconstructionError> {
54 let total = covers.len();
55 let total_shards = self.data_shards.strict_add(self.parity_shards);
56
57 let mut shards = Vec::with_capacity(total);
59 for (i, cover) in covers.into_iter().enumerate() {
60 match extractor.extract(&cover) {
61 Ok(payload) => {
62 if let Some(shard) = deserialize_shard(payload.as_bytes()) {
64 shards.push(shard);
65 } else {
66 tracing::warn!(
67 cover_index = i,
68 "could not deserialize shard, treating as missing"
69 );
70 }
71 }
72 Err(e) => {
73 tracing::warn!(cover_index = i, error = %e, "extraction failed, treating as missing shard");
76 }
77 }
78 progress_cb(i.strict_add(1), total);
79 }
80
81 let slots = arrange_shards(shards, total_shards);
83 let present = count_present(&slots);
84
85 validate_shard_count(present, usize::from(self.data_shards))?;
87
88 let recovered = self
90 .corrector
91 .decode(&slots, self.data_shards, self.parity_shards)
92 .map_err(|source| ReconstructionError::CorrectionFailed { source })?;
93
94 let payload_bytes = if self.original_len > 0 {
95 recovered
96 .get(..self.original_len)
97 .map_or_else(|| recovered.to_vec(), ToOwned::to_owned)
98 } else {
99 recovered.to_vec()
100 };
101
102 Ok(Payload::from_bytes(payload_bytes))
103 }
104}
105
106#[cfg(test)]
107mod tests {
108 use super::*;
109 use crate::domain::errors::StegoError;
110 use crate::domain::ports::EmbedTechnique;
111 use crate::domain::ports::ErrorCorrector;
112 use crate::domain::reconstruction::serialize_shard;
113 use crate::domain::types::{Capacity, CoverMedia, CoverMediaKind, StegoTechnique};
114 use bytes::Bytes;
115 use std::cell::Cell;
116
117 type TestResult = Result<(), Box<dyn std::error::Error>>;
118
119 struct MockEmbedder;
121
122 impl EmbedTechnique for MockEmbedder {
123 fn technique(&self) -> StegoTechnique {
124 StegoTechnique::LsbImage
125 }
126
127 fn capacity(&self, cover: &CoverMedia) -> Result<Capacity, StegoError> {
128 Ok(Capacity {
129 bytes: cover.data.len() as u64,
130 technique: StegoTechnique::LsbImage,
131 })
132 }
133
134 fn embed(&self, cover: CoverMedia, payload: &Payload) -> Result<CoverMedia, StegoError> {
135 let mut data = cover.data.to_vec();
136 #[expect(clippy::cast_possible_truncation, reason = "test data < 4 GiB")]
137 let len = payload.len() as u32;
138 data.extend_from_slice(&len.to_le_bytes());
139 data.extend_from_slice(payload.as_bytes());
140 Ok(CoverMedia {
141 kind: cover.kind,
142 data: Bytes::from(data),
143 metadata: cover.metadata,
144 })
145 }
146 }
147
148 struct MockExtractor {
150 cover_prefix_len: usize,
151 }
152
153 impl ExtractTechnique for MockExtractor {
154 fn technique(&self) -> StegoTechnique {
155 StegoTechnique::LsbImage
156 }
157
158 fn extract(&self, stego: &CoverMedia) -> Result<Payload, StegoError> {
159 let data = &stego.data;
160 if data.len() <= self.cover_prefix_len + 4 {
161 return Err(StegoError::NoPayloadFound);
162 }
163 let offset = self.cover_prefix_len;
164 let len_bytes: [u8; 4] = data
165 .get(offset..offset + 4)
166 .ok_or(StegoError::NoPayloadFound)?
167 .try_into()
168 .map_err(|_| StegoError::NoPayloadFound)?;
169 let len = u32::from_le_bytes(len_bytes) as usize;
170 let start = offset + 4;
171 let payload_data = data
172 .get(start..start + len)
173 .ok_or(StegoError::NoPayloadFound)?;
174 Ok(Payload::from_bytes(payload_data.to_vec()))
175 }
176 }
177
178 fn make_cover(size: usize) -> CoverMedia {
179 CoverMedia {
180 kind: CoverMediaKind::PngImage,
181 data: Bytes::from(vec![0u8; size]),
182 metadata: std::collections::HashMap::new(),
183 }
184 }
185
186 fn distribute_and_get_covers(
188 payload: &[u8],
189 data_shards: u8,
190 parity_shards: u8,
191 hmac_key: &[u8],
192 cover_size: usize,
193 ) -> Result<Vec<CoverMedia>, Box<dyn std::error::Error>> {
194 let corrector = crate::adapters::correction::RsErrorCorrector::new(hmac_key.to_vec());
195 let shards = corrector.encode(payload, data_shards, parity_shards)?;
196 let embedder = MockEmbedder;
197 let covers = shards
198 .iter()
199 .map(|shard| {
200 let cover = make_cover(cover_size);
201 let serialized = serialize_shard(shard);
202 let shard_payload = Payload::from_bytes(serialized);
203 embedder.embed(cover, &shard_payload)
204 })
205 .collect::<Result<Vec<_>, _>>()?;
206 Ok(covers)
207 }
208
209 #[test]
210 fn full_recovery_all_shards_present() -> TestResult {
211 let original = b"hello reconstruction world!";
212 let hmac_key = b"test-hmac-key";
213 let covers = distribute_and_get_covers(original, 3, 2, hmac_key, 128)?;
214 assert_eq!(covers.len(), 5);
215
216 let corrector: Box<dyn ErrorCorrector> = Box::new(
217 crate::adapters::correction::RsErrorCorrector::new(hmac_key.to_vec()),
218 );
219 let reconstructor = ReconstructorImpl::new(3, 2, original.len(), corrector);
220 let extractor = MockExtractor {
221 cover_prefix_len: 128,
222 };
223 let progress_calls = Cell::new(0usize);
224 let result = reconstructor.reconstruct(covers, &extractor, &|_done, _total| {
225 progress_calls.set(progress_calls.get().strict_add(1));
226 })?;
227
228 assert_eq!(result.as_bytes(), original);
229 assert_eq!(progress_calls.get(), 5);
230 Ok(())
231 }
232
233 #[test]
234 fn partial_recovery_minimum_shards() -> TestResult {
235 let original = b"partial recovery test payload";
236 let hmac_key = b"test-hmac-key";
237 let mut covers = distribute_and_get_covers(original, 3, 2, hmac_key, 128)?;
238 assert_eq!(covers.len(), 5);
239
240 covers.remove(4);
242 covers.remove(3);
243
244 let corrector: Box<dyn ErrorCorrector> = Box::new(
245 crate::adapters::correction::RsErrorCorrector::new(hmac_key.to_vec()),
246 );
247 let reconstructor = ReconstructorImpl::new(3, 2, original.len(), corrector);
248 let extractor = MockExtractor {
249 cover_prefix_len: 128,
250 };
251 let result = reconstructor.reconstruct(covers, &extractor, &|_, _| {})?;
252
253 assert_eq!(result.as_bytes(), original);
254 Ok(())
255 }
256
257 #[test]
258 fn insufficient_shards_returns_error() -> TestResult {
259 let original = b"not enough shards";
260 let hmac_key = b"test-hmac-key";
261 let mut covers = distribute_and_get_covers(original, 3, 2, hmac_key, 128)?;
262
263 covers.remove(4);
265 covers.remove(3);
266 covers.remove(2);
267
268 let corrector: Box<dyn ErrorCorrector> = Box::new(
269 crate::adapters::correction::RsErrorCorrector::new(hmac_key.to_vec()),
270 );
271 let reconstructor = ReconstructorImpl::new(3, 2, original.len(), corrector);
272 let extractor = MockExtractor {
273 cover_prefix_len: 128,
274 };
275 let result = reconstructor.reconstruct(covers, &extractor, &|_, _| {});
276
277 assert!(result.is_err());
278 Ok(())
279 }
280
281 #[test]
282 fn progress_callback_called_correctly() -> TestResult {
283 let original = b"track progress";
284 let hmac_key = b"test-hmac-key";
285 let covers = distribute_and_get_covers(original, 2, 1, hmac_key, 64)?;
286 let total_covers = covers.len();
287
288 let corrector: Box<dyn ErrorCorrector> = Box::new(
289 crate::adapters::correction::RsErrorCorrector::new(hmac_key.to_vec()),
290 );
291 let reconstructor = ReconstructorImpl::new(2, 1, original.len(), corrector);
292 let extractor = MockExtractor {
293 cover_prefix_len: 64,
294 };
295
296 let progress_log = std::cell::RefCell::new(Vec::new());
297 let result = reconstructor.reconstruct(covers, &extractor, &|done, total| {
298 progress_log.borrow_mut().push((done, total));
299 })?;
300
301 assert_eq!(result.as_bytes(), original);
302 let log = progress_log.borrow();
303 assert_eq!(log.len(), total_covers);
304 for (i, &(done, total)) in log.iter().enumerate() {
305 assert_eq!(done, i + 1);
306 assert_eq!(total, total_covers);
307 }
308 Ok(())
309 }
310}