Skip to main content

shadowforge_lib/adapters/
media.rs

1//! Image and audio codec adapters for cover media.
2
3use std::collections::HashMap;
4use std::fs::File;
5use std::io::BufWriter;
6use std::path::Path;
7
8use bytes::Bytes;
9use hound::{WavReader, WavSpec, WavWriter};
10use image::{DynamicImage, ImageFormat};
11
12use crate::domain::errors::MediaError;
13use crate::domain::ports::MediaLoader;
14use crate::domain::types::{CoverMedia, CoverMediaKind};
15
16// Metadata keys
17const KEY_WIDTH: &str = "width";
18const KEY_HEIGHT: &str = "height";
19const KEY_FORMAT: &str = "format";
20const KEY_SAMPLE_RATE: &str = "sample_rate";
21const KEY_CHANNELS: &str = "channels";
22const KEY_BITS_PER_SAMPLE: &str = "bits_per_sample";
23
24// TODO(T13): Extract GIF palette for palette stego
25#[expect(dead_code, reason = "will be used in T13 for palette stego")]
26const KEY_PALETTE: &str = "palette";
27
28// TODO(T16): Extract JPEG quant tables for adaptive embedding
29#[expect(dead_code, reason = "will be used in T16 for adaptive embedding")]
30const KEY_QUANT_TABLES: &str = "quant_tables";
31
32/// Image media loader for PNG/BMP/JPEG/GIF.
33///
34/// Loads images to raw RGBA8 pixel data stored in `CoverMedia.data`.
35/// Metadata includes width, height, format, and format-specific data
36/// (palette for GIF, quantization tables for JPEG).
37#[derive(Debug, Default)]
38pub struct ImageMediaLoader;
39
40impl MediaLoader for ImageMediaLoader {
41    fn load(&self, path: &Path) -> Result<CoverMedia, MediaError> {
42        // Detect format from extension
43        let extension = path.extension().and_then(|s| s.to_str()).ok_or_else(|| {
44            MediaError::UnsupportedFormat {
45                extension: "none".to_string(),
46            }
47        })?;
48
49        let format = match extension.to_lowercase().as_str() {
50            "png" => ImageFormat::Png,
51            "bmp" => ImageFormat::Bmp,
52            "jpg" | "jpeg" => ImageFormat::Jpeg,
53            "gif" => ImageFormat::Gif,
54            ext => {
55                return Err(MediaError::UnsupportedFormat {
56                    extension: ext.to_string(),
57                });
58            }
59        };
60
61        // Load image
62        let img = image::open(path).map_err(|e| MediaError::DecodeFailed {
63            reason: e.to_string(),
64        })?;
65
66        let kind = match format {
67            ImageFormat::Png => CoverMediaKind::PngImage,
68            ImageFormat::Bmp => CoverMediaKind::BmpImage,
69            ImageFormat::Jpeg => CoverMediaKind::JpegImage,
70            ImageFormat::Gif => CoverMediaKind::GifImage,
71            _ => unreachable!(),
72        };
73
74        // Convert to RGBA8
75        let rgba = img.to_rgba8();
76        let (width, height) = rgba.dimensions();
77
78        // Build metadata
79        let mut metadata = HashMap::new();
80        metadata.insert(KEY_WIDTH.to_string(), width.to_string());
81        metadata.insert(KEY_HEIGHT.to_string(), height.to_string());
82        metadata.insert(KEY_FORMAT.to_string(), format!("{format:?}"));
83
84        // TODO(T13): Extract GIF palette for palette stego
85        // TODO(T16): Extract JPEG quant tables for adaptive embedding
86
87        Ok(CoverMedia {
88            kind,
89            data: Bytes::from(rgba.into_raw()),
90            metadata,
91        })
92    }
93
94    fn save(&self, media: &CoverMedia, path: &Path) -> Result<(), MediaError> {
95        // Parse metadata
96        let width: u32 = media
97            .metadata
98            .get(KEY_WIDTH)
99            .ok_or_else(|| MediaError::EncodeFailed {
100                reason: "missing width metadata".to_string(),
101            })?
102            .parse()
103            .map_err(|e: std::num::ParseIntError| MediaError::EncodeFailed {
104                reason: e.to_string(),
105            })?;
106
107        let height: u32 = media
108            .metadata
109            .get(KEY_HEIGHT)
110            .ok_or_else(|| MediaError::EncodeFailed {
111                reason: "missing height metadata".to_string(),
112            })?
113            .parse()
114            .map_err(|e: std::num::ParseIntError| MediaError::EncodeFailed {
115                reason: e.to_string(),
116            })?;
117
118        // Reconstruct image from RGBA8 data
119        let img =
120            image::RgbaImage::from_raw(width, height, media.data.to_vec()).ok_or_else(|| {
121                MediaError::EncodeFailed {
122                    reason: "invalid image dimensions or data length".to_string(),
123                }
124            })?;
125
126        let dynamic_img = DynamicImage::ImageRgba8(img);
127
128        // Determine output format from cover media kind
129        let format = match media.kind {
130            CoverMediaKind::PngImage => ImageFormat::Png,
131            CoverMediaKind::BmpImage => ImageFormat::Bmp,
132            CoverMediaKind::JpegImage => ImageFormat::Jpeg,
133            CoverMediaKind::GifImage => ImageFormat::Gif,
134            _ => {
135                return Err(MediaError::EncodeFailed {
136                    reason: format!("unsupported media kind: {:?}", media.kind),
137                });
138            }
139        };
140
141        // Save image
142        dynamic_img
143            .save_with_format(path, format)
144            .map_err(|e| MediaError::EncodeFailed {
145                reason: e.to_string(),
146            })?;
147
148        Ok(())
149    }
150}
151
152/// Audio media loader for WAV files.
153///
154/// Loads WAV audio to raw i16 LE sample data stored in `CoverMedia.data`.
155/// Metadata includes `sample_rate`, `channels`, and `bits_per_sample`.
156#[derive(Debug, Default)]
157pub struct AudioMediaLoader;
158
159impl MediaLoader for AudioMediaLoader {
160    fn load(&self, path: &Path) -> Result<CoverMedia, MediaError> {
161        let reader = WavReader::open(path).map_err(|e| MediaError::DecodeFailed {
162            reason: e.to_string(),
163        })?;
164
165        let spec = reader.spec();
166
167        // Read all samples as i16
168        let samples: Vec<i16> = reader
169            .into_samples::<i16>()
170            .collect::<Result<Vec<_>, _>>()
171            .map_err(|e| MediaError::DecodeFailed {
172                reason: e.to_string(),
173            })?;
174
175        // Convert samples to little-endian bytes
176        let mut data = Vec::with_capacity(samples.len().strict_mul(2));
177        for sample in samples {
178            data.extend_from_slice(&sample.to_le_bytes());
179        }
180
181        // Build metadata
182        let mut metadata = HashMap::new();
183        metadata.insert(KEY_SAMPLE_RATE.to_string(), spec.sample_rate.to_string());
184        metadata.insert(KEY_CHANNELS.to_string(), spec.channels.to_string());
185        metadata.insert(
186            KEY_BITS_PER_SAMPLE.to_string(),
187            spec.bits_per_sample.to_string(),
188        );
189
190        Ok(CoverMedia {
191            kind: CoverMediaKind::WavAudio,
192            data: Bytes::from(data),
193            metadata,
194        })
195    }
196
197    fn save(&self, media: &CoverMedia, path: &Path) -> Result<(), MediaError> {
198        // Parse metadata
199        let sample_rate: u32 = media
200            .metadata
201            .get(KEY_SAMPLE_RATE)
202            .ok_or_else(|| MediaError::EncodeFailed {
203                reason: "missing sample_rate metadata".to_string(),
204            })?
205            .parse()
206            .map_err(|e: std::num::ParseIntError| MediaError::EncodeFailed {
207                reason: e.to_string(),
208            })?;
209
210        let channels: u16 = media
211            .metadata
212            .get(KEY_CHANNELS)
213            .ok_or_else(|| MediaError::EncodeFailed {
214                reason: "missing channels metadata".to_string(),
215            })?
216            .parse()
217            .map_err(|e: std::num::ParseIntError| MediaError::EncodeFailed {
218                reason: e.to_string(),
219            })?;
220
221        let bits_per_sample: u16 = media
222            .metadata
223            .get(KEY_BITS_PER_SAMPLE)
224            .ok_or_else(|| MediaError::EncodeFailed {
225                reason: "missing bits_per_sample metadata".to_string(),
226            })?
227            .parse()
228            .map_err(|e: std::num::ParseIntError| MediaError::EncodeFailed {
229                reason: e.to_string(),
230            })?;
231
232        // Create WAV spec
233        let spec = WavSpec {
234            channels,
235            sample_rate,
236            bits_per_sample,
237            sample_format: hound::SampleFormat::Int,
238        };
239
240        // Create writer
241        let file = File::create(path).map_err(|e| MediaError::IoError {
242            reason: e.to_string(),
243        })?;
244
245        let mut writer =
246            WavWriter::new(BufWriter::new(file), spec).map_err(|e| MediaError::EncodeFailed {
247                reason: e.to_string(),
248            })?;
249
250        // Convert bytes back to i16 samples
251        for chunk in media.data.chunks_exact(2) {
252            if let Ok(pair) = <[u8; 2]>::try_from(chunk) {
253                let sample = i16::from_le_bytes(pair);
254                writer
255                    .write_sample(sample)
256                    .map_err(|e| MediaError::EncodeFailed {
257                        reason: e.to_string(),
258                    })?;
259            }
260        }
261
262        writer.finalize().map_err(|e| MediaError::EncodeFailed {
263            reason: e.to_string(),
264        })?;
265
266        Ok(())
267    }
268}
269
270// ─── Tests ────────────────────────────────────────────────────────────────────
271
272#[cfg(test)]
273mod tests {
274    use super::*;
275    use tempfile::tempdir;
276
277    type TestResult = Result<(), Box<dyn std::error::Error>>;
278
279    #[test]
280    fn test_image_loader_png_roundtrip() -> TestResult {
281        let loader = ImageMediaLoader;
282        let dir = tempdir()?;
283        let path = dir.path().join("test.png");
284
285        // Create a 10x10 white RGBA image
286        let img = DynamicImage::ImageRgba8(image::RgbaImage::from_pixel(
287            10,
288            10,
289            image::Rgba([255, 255, 255, 255]),
290        ));
291        img.save(&path)?;
292
293        // Load
294        let media = loader.load(&path)?;
295        assert_eq!(media.kind, CoverMediaKind::PngImage);
296        assert_eq!(media.metadata.get(KEY_WIDTH), Some(&"10".to_string()));
297        assert_eq!(media.metadata.get(KEY_HEIGHT), Some(&"10".to_string()));
298
299        // Save
300        let out_path = dir.path().join("out.png");
301        loader.save(&media, &out_path)?;
302
303        // Reload and verify
304        let reloaded = loader.load(&out_path)?;
305        assert_eq!(reloaded.data, media.data);
306        Ok(())
307    }
308
309    #[test]
310    fn test_audio_loader_wav_roundtrip() -> TestResult {
311        let loader = AudioMediaLoader;
312        let dir = tempdir()?;
313        let path = dir.path().join("test.wav");
314
315        // Create a simple WAV with 1000 samples
316        let spec = WavSpec {
317            channels: 1,
318            sample_rate: 44100,
319            bits_per_sample: 16,
320            sample_format: hound::SampleFormat::Int,
321        };
322
323        let mut writer = WavWriter::create(&path, spec)?;
324        for i in 0..1000_i16 {
325            writer.write_sample(i)?;
326        }
327        writer.finalize()?;
328
329        // Load
330        let media = loader.load(&path)?;
331        assert_eq!(media.kind, CoverMediaKind::WavAudio);
332        assert_eq!(
333            media.metadata.get(KEY_SAMPLE_RATE),
334            Some(&"44100".to_string())
335        );
336        assert_eq!(media.metadata.get(KEY_CHANNELS), Some(&"1".to_string()));
337
338        // Save
339        let out_path = dir.path().join("out.wav");
340        loader.save(&media, &out_path)?;
341
342        // Reload and verify
343        let reloaded = loader.load(&out_path)?;
344        assert_eq!(reloaded.data, media.data);
345        Ok(())
346    }
347
348    #[test]
349    fn test_image_loader_unsupported_format() {
350        let loader = ImageMediaLoader;
351        let result = loader.load(Path::new("test.xyz"));
352        assert!(matches!(result, Err(MediaError::UnsupportedFormat { .. })));
353    }
354
355    #[test]
356    fn test_image_loader_no_extension() {
357        let loader = ImageMediaLoader;
358        let result = loader.load(Path::new("test"));
359        assert!(matches!(result, Err(MediaError::UnsupportedFormat { .. })));
360    }
361
362    #[test]
363    fn test_image_loader_bmp_roundtrip() -> TestResult {
364        let loader = ImageMediaLoader;
365        let dir = tempdir()?;
366        let path = dir.path().join("test.bmp");
367
368        let img = DynamicImage::ImageRgba8(image::RgbaImage::from_pixel(
369            5,
370            5,
371            image::Rgba([128, 64, 32, 255]),
372        ));
373        img.save(&path)?;
374
375        let media = loader.load(&path)?;
376        assert_eq!(media.kind, CoverMediaKind::BmpImage);
377        assert_eq!(media.metadata.get(KEY_WIDTH), Some(&"5".to_string()));
378        assert_eq!(media.metadata.get(KEY_HEIGHT), Some(&"5".to_string()));
379
380        let out_path = dir.path().join("out.bmp");
381        loader.save(&media, &out_path)?;
382        let reloaded = loader.load(&out_path)?;
383        assert_eq!(reloaded.data, media.data);
384        Ok(())
385    }
386
387    #[test]
388    fn test_image_loader_jpeg_can_load() -> TestResult {
389        let loader = ImageMediaLoader;
390        let dir = tempdir()?;
391        let path = dir.path().join("test.jpg");
392
393        let img = DynamicImage::ImageRgba8(image::RgbaImage::from_pixel(
394            8,
395            8,
396            image::Rgba([200, 100, 50, 255]),
397        ));
398        img.save(&path)?;
399
400        let media = loader.load(&path)?;
401        assert_eq!(media.kind, CoverMediaKind::JpegImage);
402        Ok(())
403    }
404
405    #[test]
406    fn test_image_save_unsupported_kind() {
407        let loader = ImageMediaLoader;
408        let media = CoverMedia {
409            kind: CoverMediaKind::WavAudio,
410            data: Bytes::from(vec![0u8; 100]),
411            metadata: {
412                let mut m = HashMap::new();
413                m.insert(KEY_WIDTH.to_string(), "10".to_string());
414                m.insert(KEY_HEIGHT.to_string(), "10".to_string());
415                m
416            },
417        };
418        let result = loader.save(&media, Path::new("/tmp/test.wav"));
419        assert!(matches!(result, Err(MediaError::EncodeFailed { .. })));
420    }
421
422    #[test]
423    fn test_image_save_missing_width() {
424        let loader = ImageMediaLoader;
425        let media = CoverMedia {
426            kind: CoverMediaKind::PngImage,
427            data: Bytes::from(vec![0u8; 100]),
428            metadata: {
429                let mut m = HashMap::new();
430                m.insert(KEY_HEIGHT.to_string(), "10".to_string());
431                m
432            },
433        };
434        let result = loader.save(&media, Path::new("/tmp/test.png"));
435        assert!(matches!(result, Err(MediaError::EncodeFailed { .. })));
436    }
437
438    #[test]
439    fn test_image_save_missing_height() {
440        let loader = ImageMediaLoader;
441        let media = CoverMedia {
442            kind: CoverMediaKind::PngImage,
443            data: Bytes::from(vec![0u8; 100]),
444            metadata: {
445                let mut m = HashMap::new();
446                m.insert(KEY_WIDTH.to_string(), "10".to_string());
447                m
448            },
449        };
450        let result = loader.save(&media, Path::new("/tmp/test.png"));
451        assert!(matches!(result, Err(MediaError::EncodeFailed { .. })));
452    }
453
454    #[test]
455    fn test_audio_save_missing_sample_rate() {
456        let loader = AudioMediaLoader;
457        let media = CoverMedia {
458            kind: CoverMediaKind::WavAudio,
459            data: Bytes::from(vec![0u8; 100]),
460            metadata: {
461                let mut m = HashMap::new();
462                m.insert(KEY_CHANNELS.to_string(), "1".to_string());
463                m.insert(KEY_BITS_PER_SAMPLE.to_string(), "16".to_string());
464                m
465            },
466        };
467        let result = loader.save(&media, Path::new("/tmp/test.wav"));
468        assert!(matches!(result, Err(MediaError::EncodeFailed { .. })));
469    }
470
471    #[test]
472    fn test_audio_save_missing_channels() {
473        let loader = AudioMediaLoader;
474        let media = CoverMedia {
475            kind: CoverMediaKind::WavAudio,
476            data: Bytes::from(vec![0u8; 100]),
477            metadata: {
478                let mut m = HashMap::new();
479                m.insert(KEY_SAMPLE_RATE.to_string(), "44100".to_string());
480                m.insert(KEY_BITS_PER_SAMPLE.to_string(), "16".to_string());
481                m
482            },
483        };
484        let result = loader.save(&media, Path::new("/tmp/test.wav"));
485        assert!(matches!(result, Err(MediaError::EncodeFailed { .. })));
486    }
487
488    #[test]
489    fn test_audio_save_missing_bits_per_sample() {
490        let loader = AudioMediaLoader;
491        let media = CoverMedia {
492            kind: CoverMediaKind::WavAudio,
493            data: Bytes::from(vec![0u8; 100]),
494            metadata: {
495                let mut m = HashMap::new();
496                m.insert(KEY_SAMPLE_RATE.to_string(), "44100".to_string());
497                m.insert(KEY_CHANNELS.to_string(), "1".to_string());
498                m
499            },
500        };
501        let result = loader.save(&media, Path::new("/tmp/test.wav"));
502        assert!(matches!(result, Err(MediaError::EncodeFailed { .. })));
503    }
504
505    #[test]
506    fn test_image_load_nonexistent_file() {
507        let loader = ImageMediaLoader;
508        let result = loader.load(Path::new("/nonexistent/path/image.png"));
509        assert!(matches!(result, Err(MediaError::DecodeFailed { .. })));
510    }
511
512    #[test]
513    fn test_audio_load_nonexistent_file() {
514        let loader = AudioMediaLoader;
515        let result = loader.load(Path::new("/nonexistent/path/audio.wav"));
516        assert!(matches!(result, Err(MediaError::DecodeFailed { .. })));
517    }
518}