shadowforge_lib/adapters/
reconstruction.rs1use crate::domain::correction::decode_shards;
5use crate::domain::errors::ReconstructionError;
6use crate::domain::ports::{ExtractTechnique, Reconstructor};
7use crate::domain::reconstruction::{
8 arrange_shards, count_present, deserialize_shard, validate_shard_count,
9};
10use crate::domain::types::{CoverMedia, Payload};
11
12pub struct ReconstructorImpl {
20 data_shards: u8,
22 parity_shards: u8,
24 hmac_key: Vec<u8>,
26 original_len: usize,
28}
29
30impl ReconstructorImpl {
31 #[must_use]
33 pub const fn new(
34 data_shards: u8,
35 parity_shards: u8,
36 hmac_key: Vec<u8>,
37 original_len: usize,
38 ) -> Self {
39 Self {
40 data_shards,
41 parity_shards,
42 hmac_key,
43 original_len,
44 }
45 }
46}
47
48impl Reconstructor for ReconstructorImpl {
49 fn reconstruct(
50 &self,
51 covers: Vec<CoverMedia>,
52 extractor: &dyn ExtractTechnique,
53 progress_cb: &dyn Fn(usize, usize),
54 ) -> Result<Payload, ReconstructionError> {
55 let total = covers.len();
56 let total_shards = self.data_shards.strict_add(self.parity_shards);
57
58 let mut shards = Vec::with_capacity(total);
60 for (i, cover) in covers.into_iter().enumerate() {
61 match extractor.extract(&cover) {
62 Ok(payload) => {
63 if let Some(shard) = deserialize_shard(payload.as_bytes()) {
65 shards.push(shard);
66 } else {
67 tracing::warn!(
68 cover_index = i,
69 "could not deserialize shard, treating as missing"
70 );
71 }
72 }
73 Err(e) => {
74 tracing::warn!(cover_index = i, error = %e, "extraction failed, treating as missing shard");
77 }
78 }
79 progress_cb(i.strict_add(1), total);
80 }
81
82 let slots = arrange_shards(shards, total_shards);
84 let present = count_present(&slots);
85
86 validate_shard_count(present, usize::from(self.data_shards))?;
88
89 let recovered = decode_shards(
91 &slots,
92 self.data_shards,
93 self.parity_shards,
94 &self.hmac_key,
95 self.original_len,
96 )
97 .map_err(|source| ReconstructionError::CorrectionFailed { source })?;
98
99 Ok(Payload::from_bytes(recovered.to_vec()))
100 }
101}
102
103#[cfg(test)]
104mod tests {
105 use super::*;
106 use crate::domain::correction::encode_shards;
107 use crate::domain::errors::StegoError;
108 use crate::domain::ports::EmbedTechnique;
109 use crate::domain::reconstruction::serialize_shard;
110 use crate::domain::types::{Capacity, CoverMedia, CoverMediaKind, StegoTechnique};
111 use bytes::Bytes;
112 use std::cell::Cell;
113
114 type TestResult = Result<(), Box<dyn std::error::Error>>;
115
116 struct MockEmbedder;
118
119 impl EmbedTechnique for MockEmbedder {
120 fn technique(&self) -> StegoTechnique {
121 StegoTechnique::LsbImage
122 }
123
124 fn capacity(&self, cover: &CoverMedia) -> Result<Capacity, StegoError> {
125 Ok(Capacity {
126 bytes: cover.data.len() as u64,
127 technique: StegoTechnique::LsbImage,
128 })
129 }
130
131 fn embed(&self, cover: CoverMedia, payload: &Payload) -> Result<CoverMedia, StegoError> {
132 let mut data = cover.data.to_vec();
133 #[expect(clippy::cast_possible_truncation, reason = "test data < 4 GiB")]
134 let len = payload.len() as u32;
135 data.extend_from_slice(&len.to_le_bytes());
136 data.extend_from_slice(payload.as_bytes());
137 Ok(CoverMedia {
138 kind: cover.kind,
139 data: Bytes::from(data),
140 metadata: cover.metadata,
141 })
142 }
143 }
144
145 struct MockExtractor {
147 cover_prefix_len: usize,
148 }
149
150 impl ExtractTechnique for MockExtractor {
151 fn technique(&self) -> StegoTechnique {
152 StegoTechnique::LsbImage
153 }
154
155 fn extract(&self, stego: &CoverMedia) -> Result<Payload, StegoError> {
156 let data = &stego.data;
157 if data.len() <= self.cover_prefix_len + 4 {
158 return Err(StegoError::NoPayloadFound);
159 }
160 let offset = self.cover_prefix_len;
161 let len_bytes: [u8; 4] = data
162 .get(offset..offset + 4)
163 .ok_or(StegoError::NoPayloadFound)?
164 .try_into()
165 .map_err(|_| StegoError::NoPayloadFound)?;
166 let len = u32::from_le_bytes(len_bytes) as usize;
167 let start = offset + 4;
168 let payload_data = data
169 .get(start..start + len)
170 .ok_or(StegoError::NoPayloadFound)?;
171 Ok(Payload::from_bytes(payload_data.to_vec()))
172 }
173 }
174
175 fn make_cover(size: usize) -> CoverMedia {
176 CoverMedia {
177 kind: CoverMediaKind::PngImage,
178 data: Bytes::from(vec![0u8; size]),
179 metadata: std::collections::HashMap::new(),
180 }
181 }
182
183 fn distribute_and_get_covers(
185 payload: &[u8],
186 data_shards: u8,
187 parity_shards: u8,
188 hmac_key: &[u8],
189 cover_size: usize,
190 ) -> Result<Vec<CoverMedia>, Box<dyn std::error::Error>> {
191 let shards = encode_shards(payload, data_shards, parity_shards, hmac_key)?;
192 let embedder = MockEmbedder;
193 let covers = shards
194 .iter()
195 .map(|shard| {
196 let cover = make_cover(cover_size);
197 let serialized = serialize_shard(shard);
198 let shard_payload = Payload::from_bytes(serialized);
199 embedder.embed(cover, &shard_payload)
200 })
201 .collect::<Result<Vec<_>, _>>()?;
202 Ok(covers)
203 }
204
205 #[test]
206 fn full_recovery_all_shards_present() -> TestResult {
207 let original = b"hello reconstruction world!";
208 let hmac_key = b"test-hmac-key";
209 let covers = distribute_and_get_covers(original, 3, 2, hmac_key, 128)?;
210 assert_eq!(covers.len(), 5);
211
212 let reconstructor = ReconstructorImpl::new(3, 2, hmac_key.to_vec(), original.len());
213 let extractor = MockExtractor {
214 cover_prefix_len: 128,
215 };
216 let progress_calls = Cell::new(0usize);
217 let result = reconstructor.reconstruct(covers, &extractor, &|_done, _total| {
218 progress_calls.set(progress_calls.get().strict_add(1));
219 })?;
220
221 assert_eq!(result.as_bytes(), original);
222 assert_eq!(progress_calls.get(), 5);
223 Ok(())
224 }
225
226 #[test]
227 fn partial_recovery_minimum_shards() -> TestResult {
228 let original = b"partial recovery test payload";
229 let hmac_key = b"test-hmac-key";
230 let mut covers = distribute_and_get_covers(original, 3, 2, hmac_key, 128)?;
231 assert_eq!(covers.len(), 5);
232
233 covers.remove(4);
235 covers.remove(3);
236
237 let reconstructor = ReconstructorImpl::new(3, 2, hmac_key.to_vec(), original.len());
238 let extractor = MockExtractor {
239 cover_prefix_len: 128,
240 };
241 let result = reconstructor.reconstruct(covers, &extractor, &|_, _| {})?;
242
243 assert_eq!(result.as_bytes(), original);
244 Ok(())
245 }
246
247 #[test]
248 fn insufficient_shards_returns_error() -> TestResult {
249 let original = b"not enough shards";
250 let hmac_key = b"test-hmac-key";
251 let mut covers = distribute_and_get_covers(original, 3, 2, hmac_key, 128)?;
252
253 covers.remove(4);
255 covers.remove(3);
256 covers.remove(2);
257
258 let reconstructor = ReconstructorImpl::new(3, 2, hmac_key.to_vec(), original.len());
259 let extractor = MockExtractor {
260 cover_prefix_len: 128,
261 };
262 let result = reconstructor.reconstruct(covers, &extractor, &|_, _| {});
263
264 assert!(result.is_err());
265 Ok(())
266 }
267
268 #[test]
269 fn progress_callback_called_correctly() -> TestResult {
270 let original = b"track progress";
271 let hmac_key = b"test-hmac-key";
272 let covers = distribute_and_get_covers(original, 2, 1, hmac_key, 64)?;
273 let total_covers = covers.len();
274
275 let reconstructor = ReconstructorImpl::new(2, 1, hmac_key.to_vec(), original.len());
276 let extractor = MockExtractor {
277 cover_prefix_len: 64,
278 };
279
280 let progress_log = std::cell::RefCell::new(Vec::new());
281 let result = reconstructor.reconstruct(covers, &extractor, &|done, total| {
282 progress_log.borrow_mut().push((done, total));
283 })?;
284
285 assert_eq!(result.as_bytes(), original);
286 let log = progress_log.borrow();
287 assert_eq!(log.len(), total_covers);
288 for (i, &(done, total)) in log.iter().enumerate() {
289 assert_eq!(done, i + 1);
290 assert_eq!(total, total_covers);
291 }
292 Ok(())
293 }
294}