shadowforge_lib/adapters/
reconstruction.rs1use crate::domain::errors::ReconstructionError;
5use crate::domain::ports::{ErrorCorrector, ExtractTechnique, Reconstructor};
6use crate::domain::reconstruction::{
7 arrange_shards, count_present, deserialize_shard, validate_shard_count,
8};
9use crate::domain::types::{CoverMedia, Payload};
10
11pub struct ReconstructorImpl {
19 data_shards: u8,
21 parity_shards: u8,
23 original_len: usize,
25 corrector: Box<dyn ErrorCorrector>,
27}
28
29impl ReconstructorImpl {
30 #[must_use]
32 pub fn new(
33 data_shards: u8,
34 parity_shards: u8,
35 original_len: usize,
36 corrector: Box<dyn ErrorCorrector>,
37 ) -> Self {
38 Self {
39 data_shards,
40 parity_shards,
41 original_len,
42 corrector,
43 }
44 }
45}
46
47impl Reconstructor for ReconstructorImpl {
48 fn reconstruct(
49 &self,
50 covers: Vec<CoverMedia>,
51 extractor: &dyn ExtractTechnique,
52 progress_cb: &dyn Fn(usize, usize),
53 ) -> Result<Payload, ReconstructionError> {
54 let total = covers.len();
55 let total_shards = self.data_shards.strict_add(self.parity_shards);
56
57 let mut shards = Vec::with_capacity(total);
59 for (i, cover) in covers.into_iter().enumerate() {
60 match extractor.extract(&cover) {
61 Ok(payload) => {
62 if let Some(shard) = deserialize_shard(payload.as_bytes()) {
64 shards.push(shard);
65 } else {
66 tracing::warn!(
67 cover_index = i,
68 "could not deserialize shard, treating as missing"
69 );
70 }
71 }
72 Err(e) => {
73 tracing::warn!(cover_index = i, error = %e, "extraction failed, treating as missing shard");
76 }
77 }
78 progress_cb(i.strict_add(1), total);
79 }
80
81 let slots = arrange_shards(shards, total_shards);
83 let present = count_present(&slots);
84
85 validate_shard_count(present, usize::from(self.data_shards))?;
87
88 let recovered = self.corrector
90 .decode(&slots, self.data_shards, self.parity_shards)
91 .map_err(|source| ReconstructionError::CorrectionFailed { source })?;
92
93 let payload_bytes = if self.original_len > 0 {
94 recovered
95 .get(..self.original_len)
96 .map_or_else(|| recovered.to_vec(), ToOwned::to_owned)
97 } else {
98 recovered.to_vec()
99 };
100
101 Ok(Payload::from_bytes(payload_bytes))
102 }
103}
104
105#[cfg(test)]
106mod tests {
107 use super::*;
108 use crate::domain::errors::StegoError;
109 use crate::domain::ports::EmbedTechnique;
110 use crate::domain::ports::ErrorCorrector;
111 use crate::domain::reconstruction::serialize_shard;
112 use crate::domain::types::{Capacity, CoverMedia, CoverMediaKind, StegoTechnique};
113 use bytes::Bytes;
114 use std::cell::Cell;
115
116 type TestResult = Result<(), Box<dyn std::error::Error>>;
117
118 struct MockEmbedder;
120
121 impl EmbedTechnique for MockEmbedder {
122 fn technique(&self) -> StegoTechnique {
123 StegoTechnique::LsbImage
124 }
125
126 fn capacity(&self, cover: &CoverMedia) -> Result<Capacity, StegoError> {
127 Ok(Capacity {
128 bytes: cover.data.len() as u64,
129 technique: StegoTechnique::LsbImage,
130 })
131 }
132
133 fn embed(&self, cover: CoverMedia, payload: &Payload) -> Result<CoverMedia, StegoError> {
134 let mut data = cover.data.to_vec();
135 #[expect(clippy::cast_possible_truncation, reason = "test data < 4 GiB")]
136 let len = payload.len() as u32;
137 data.extend_from_slice(&len.to_le_bytes());
138 data.extend_from_slice(payload.as_bytes());
139 Ok(CoverMedia {
140 kind: cover.kind,
141 data: Bytes::from(data),
142 metadata: cover.metadata,
143 })
144 }
145 }
146
147 struct MockExtractor {
149 cover_prefix_len: usize,
150 }
151
152 impl ExtractTechnique for MockExtractor {
153 fn technique(&self) -> StegoTechnique {
154 StegoTechnique::LsbImage
155 }
156
157 fn extract(&self, stego: &CoverMedia) -> Result<Payload, StegoError> {
158 let data = &stego.data;
159 if data.len() <= self.cover_prefix_len + 4 {
160 return Err(StegoError::NoPayloadFound);
161 }
162 let offset = self.cover_prefix_len;
163 let len_bytes: [u8; 4] = data
164 .get(offset..offset + 4)
165 .ok_or(StegoError::NoPayloadFound)?
166 .try_into()
167 .map_err(|_| StegoError::NoPayloadFound)?;
168 let len = u32::from_le_bytes(len_bytes) as usize;
169 let start = offset + 4;
170 let payload_data = data
171 .get(start..start + len)
172 .ok_or(StegoError::NoPayloadFound)?;
173 Ok(Payload::from_bytes(payload_data.to_vec()))
174 }
175 }
176
177 fn make_cover(size: usize) -> CoverMedia {
178 CoverMedia {
179 kind: CoverMediaKind::PngImage,
180 data: Bytes::from(vec![0u8; size]),
181 metadata: std::collections::HashMap::new(),
182 }
183 }
184
185 fn distribute_and_get_covers(
187 payload: &[u8],
188 data_shards: u8,
189 parity_shards: u8,
190 hmac_key: &[u8],
191 cover_size: usize,
192 ) -> Result<Vec<CoverMedia>, Box<dyn std::error::Error>> {
193 let corrector = crate::adapters::correction::RsErrorCorrector::new(hmac_key.to_vec());
194 let shards = corrector.encode(payload, data_shards, parity_shards)?;
195 let embedder = MockEmbedder;
196 let covers = shards
197 .iter()
198 .map(|shard| {
199 let cover = make_cover(cover_size);
200 let serialized = serialize_shard(shard);
201 let shard_payload = Payload::from_bytes(serialized);
202 embedder.embed(cover, &shard_payload)
203 })
204 .collect::<Result<Vec<_>, _>>()?;
205 Ok(covers)
206 }
207
208 #[test]
209 fn full_recovery_all_shards_present() -> TestResult {
210 let original = b"hello reconstruction world!";
211 let hmac_key = b"test-hmac-key";
212 let covers = distribute_and_get_covers(original, 3, 2, hmac_key, 128)?;
213 assert_eq!(covers.len(), 5);
214
215 let corrector: Box<dyn ErrorCorrector> = Box::new(
216 crate::adapters::correction::RsErrorCorrector::new(hmac_key.to_vec())
217 );
218 let reconstructor = ReconstructorImpl::new(3, 2, original.len(), corrector);
219 let extractor = MockExtractor {
220 cover_prefix_len: 128,
221 };
222 let progress_calls = Cell::new(0usize);
223 let result = reconstructor.reconstruct(covers, &extractor, &|_done, _total| {
224 progress_calls.set(progress_calls.get().strict_add(1));
225 })?;
226
227 assert_eq!(result.as_bytes(), original);
228 assert_eq!(progress_calls.get(), 5);
229 Ok(())
230 }
231
232 #[test]
233 fn partial_recovery_minimum_shards() -> TestResult {
234 let original = b"partial recovery test payload";
235 let hmac_key = b"test-hmac-key";
236 let mut covers = distribute_and_get_covers(original, 3, 2, hmac_key, 128)?;
237 assert_eq!(covers.len(), 5);
238
239 covers.remove(4);
241 covers.remove(3);
242
243 let corrector: Box<dyn ErrorCorrector> = Box::new(
244 crate::adapters::correction::RsErrorCorrector::new(hmac_key.to_vec())
245 );
246 let reconstructor = ReconstructorImpl::new(3, 2, original.len(), corrector);
247 let extractor = MockExtractor {
248 cover_prefix_len: 128,
249 };
250 let result = reconstructor.reconstruct(covers, &extractor, &|_, _| {})?;
251
252 assert_eq!(result.as_bytes(), original);
253 Ok(())
254 }
255
256 #[test]
257 fn insufficient_shards_returns_error() -> TestResult {
258 let original = b"not enough shards";
259 let hmac_key = b"test-hmac-key";
260 let mut covers = distribute_and_get_covers(original, 3, 2, hmac_key, 128)?;
261
262 covers.remove(4);
264 covers.remove(3);
265 covers.remove(2);
266
267 let corrector: Box<dyn ErrorCorrector> = Box::new(
268 crate::adapters::correction::RsErrorCorrector::new(hmac_key.to_vec())
269 );
270 let reconstructor = ReconstructorImpl::new(3, 2, original.len(), corrector);
271 let extractor = MockExtractor {
272 cover_prefix_len: 128,
273 };
274 let result = reconstructor.reconstruct(covers, &extractor, &|_, _| {});
275
276 assert!(result.is_err());
277 Ok(())
278 }
279
280 #[test]
281 fn progress_callback_called_correctly() -> TestResult {
282 let original = b"track progress";
283 let hmac_key = b"test-hmac-key";
284 let covers = distribute_and_get_covers(original, 2, 1, hmac_key, 64)?;
285 let total_covers = covers.len();
286
287 let corrector: Box<dyn ErrorCorrector> = Box::new(
288 crate::adapters::correction::RsErrorCorrector::new(hmac_key.to_vec())
289 );
290 let reconstructor = ReconstructorImpl::new(2, 1, original.len(), corrector);
291 let extractor = MockExtractor {
292 cover_prefix_len: 64,
293 };
294
295 let progress_log = std::cell::RefCell::new(Vec::new());
296 let result = reconstructor.reconstruct(covers, &extractor, &|done, total| {
297 progress_log.borrow_mut().push((done, total));
298 })?;
299
300 assert_eq!(result.as_bytes(), original);
301 let log = progress_log.borrow();
302 assert_eq!(log.len(), total_covers);
303 for (i, &(done, total)) in log.iter().enumerate() {
304 assert_eq!(done, i + 1);
305 assert_eq!(total, total_covers);
306 }
307 Ok(())
308 }
309}