Skip to main content

shift_preflight/inspector/
mod.rs

1pub mod image;
2
3// v2 modality stubs
4pub mod audio;
5pub mod document;
6pub mod video;
7
8use serde::{Deserialize, Serialize};
9
10use crate::mode::SafetyLimits;
11
12/// Detected format of a media input.
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
14#[serde(rename_all = "lowercase")]
15pub enum MediaFormat {
16    Png,
17    Jpeg,
18    Gif,
19    WebP,
20    Svg,
21    Bmp,
22    Tiff,
23    // Future
24    Mp4,
25    Mp3,
26    Wav,
27    Pdf,
28    Unknown,
29}
30
31impl MediaFormat {
32    /// Returns the MIME type string for this format.
33    pub fn mime_type(&self) -> &'static str {
34        match self {
35            MediaFormat::Png => "image/png",
36            MediaFormat::Jpeg => "image/jpeg",
37            MediaFormat::Gif => "image/gif",
38            MediaFormat::WebP => "image/webp",
39            MediaFormat::Svg => "image/svg+xml",
40            MediaFormat::Bmp => "image/bmp",
41            MediaFormat::Tiff => "image/tiff",
42            MediaFormat::Mp4 => "video/mp4",
43            MediaFormat::Mp3 => "audio/mpeg",
44            MediaFormat::Wav => "audio/wav",
45            MediaFormat::Pdf => "application/pdf",
46            MediaFormat::Unknown => "application/octet-stream",
47        }
48    }
49
50    /// Whether this format is a raster image supported by most providers.
51    pub fn is_provider_safe(&self) -> bool {
52        matches!(
53            self,
54            MediaFormat::Png | MediaFormat::Jpeg | MediaFormat::Gif | MediaFormat::WebP
55        )
56    }
57
58    /// Whether this is an image format (raster or vector).
59    pub fn is_image(&self) -> bool {
60        matches!(
61            self,
62            MediaFormat::Png
63                | MediaFormat::Jpeg
64                | MediaFormat::Gif
65                | MediaFormat::WebP
66                | MediaFormat::Svg
67                | MediaFormat::Bmp
68                | MediaFormat::Tiff
69        )
70    }
71}
72
73impl std::fmt::Display for MediaFormat {
74    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75        match self {
76            MediaFormat::Png => write!(f, "png"),
77            MediaFormat::Jpeg => write!(f, "jpeg"),
78            MediaFormat::Gif => write!(f, "gif"),
79            MediaFormat::WebP => write!(f, "webp"),
80            MediaFormat::Svg => write!(f, "svg"),
81            MediaFormat::Bmp => write!(f, "bmp"),
82            MediaFormat::Tiff => write!(f, "tiff"),
83            MediaFormat::Mp4 => write!(f, "mp4"),
84            MediaFormat::Mp3 => write!(f, "mp3"),
85            MediaFormat::Wav => write!(f, "wav"),
86            MediaFormat::Pdf => write!(f, "pdf"),
87            MediaFormat::Unknown => write!(f, "unknown"),
88        }
89    }
90}
91
92/// How the image is encoded in the payload.
93#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
94#[serde(rename_all = "lowercase")]
95pub enum Encoding {
96    /// Base64-encoded inline data (data: URI or raw base64)
97    Base64,
98    /// URL reference (https://...)
99    Url(String),
100    /// Raw bytes (not from a payload, e.g. from file)
101    Raw,
102}
103
104/// Metadata extracted from inspecting a media input.
105#[derive(Debug, Clone, Serialize, Deserialize)]
106pub struct ImageMetadata {
107    pub format: MediaFormat,
108    pub width: u32,
109    pub height: u32,
110    pub size_bytes: usize,
111    pub encoding: Encoding,
112    /// Megapixels (width * height / 1_000_000)
113    pub megapixels: f64,
114    /// For SVG: the raw SVG source text
115    pub svg_source: Option<String>,
116}
117
118impl ImageMetadata {
119    pub fn new(
120        format: MediaFormat,
121        width: u32,
122        height: u32,
123        size_bytes: usize,
124        encoding: Encoding,
125    ) -> Self {
126        let megapixels = (width as f64 * height as f64) / 1_000_000.0;
127        ImageMetadata {
128            format,
129            width,
130            height,
131            size_bytes,
132            encoding,
133            megapixels,
134            svg_source: None,
135        }
136    }
137
138    /// The larger dimension.
139    pub fn max_dim(&self) -> u32 {
140        self.width.max(self.height)
141    }
142}
143
144/// Detect format from raw bytes using magic bytes.
145pub fn detect_format(data: &[u8]) -> MediaFormat {
146    if data.len() < 4 {
147        return MediaFormat::Unknown;
148    }
149
150    // PNG: 89 50 4E 47
151    if data.starts_with(&[0x89, 0x50, 0x4E, 0x47]) {
152        return MediaFormat::Png;
153    }
154
155    // JPEG: FF D8 FF
156    if data.starts_with(&[0xFF, 0xD8, 0xFF]) {
157        return MediaFormat::Jpeg;
158    }
159
160    // GIF: GIF87a or GIF89a
161    if data.starts_with(b"GIF87a") || data.starts_with(b"GIF89a") {
162        return MediaFormat::Gif;
163    }
164
165    // WebP: RIFF....WEBP
166    if data.len() >= 12 && data.starts_with(b"RIFF") && &data[8..12] == b"WEBP" {
167        return MediaFormat::WebP;
168    }
169
170    // BMP: BM
171    if data.starts_with(b"BM") {
172        return MediaFormat::Bmp;
173    }
174
175    // TIFF: II or MM
176    if data.starts_with(&[0x49, 0x49, 0x2A, 0x00]) || data.starts_with(&[0x4D, 0x4D, 0x00, 0x2A]) {
177        return MediaFormat::Tiff;
178    }
179
180    // SVG: look for XML/SVG markers in text
181    if is_svg(data) {
182        return MediaFormat::Svg;
183    }
184
185    // PDF: %PDF
186    if data.starts_with(b"%PDF") {
187        return MediaFormat::Pdf;
188    }
189
190    MediaFormat::Unknown
191}
192
193/// Check if data looks like SVG (XML with <svg element).
194fn is_svg(data: &[u8]) -> bool {
195    // Try to interpret as UTF-8 text
196    let text = match std::str::from_utf8(data) {
197        Ok(s) => s,
198        Err(_) => {
199            // Try just the first 1KB
200            let end = data.len().min(1024);
201            match std::str::from_utf8(&data[..end]) {
202                Ok(s) => s,
203                Err(_) => return false,
204            }
205        }
206    };
207
208    let trimmed = text.trim_start();
209    // XML declaration or <svg tag
210    if trimmed.starts_with("<?xml") || trimmed.starts_with("<svg") {
211        // Must contain <svg somewhere
212        return trimmed.contains("<svg");
213    }
214
215    false
216}
217
218/// Decode a base64 data URI or raw base64 string to bytes.
219///
220/// Handles formats:
221/// - `data:image/png;base64,iVBOR...`
222/// - `iVBOR...` (raw base64)
223///
224/// Enforces a size limit (default 30 MB base64 input) to prevent OOM.
225/// Uses a tolerant decoder that accepts both padded and unpadded base64.
226pub fn decode_base64_image(input: &str) -> anyhow::Result<(Vec<u8>, Option<String>)> {
227    decode_base64_image_with_limits(input, &SafetyLimits::default())
228}
229
230/// Decode base64 with explicit safety limits.
231pub fn decode_base64_image_with_limits(
232    input: &str,
233    limits: &SafetyLimits,
234) -> anyhow::Result<(Vec<u8>, Option<String>)> {
235    use base64::engine::general_purpose;
236    use base64::Engine;
237
238    let (b64_data, mime_hint) = if let Some(rest) = input.strip_prefix("data:") {
239        // data:image/png;base64,iVBOR...
240        if let Some(comma_pos) = rest.find(',') {
241            let header = &rest[..comma_pos];
242            let data = &rest[comma_pos + 1..];
243            let mime = header.split(';').next().map(|s| s.to_string());
244            (data, mime)
245        } else {
246            (rest, None)
247        }
248    } else {
249        (input, None)
250    };
251
252    // Fix #9: Check base64 input size before allocating
253    if b64_data.len() > limits.max_base64_bytes {
254        anyhow::bail!(
255            "base64 input too large: {} bytes exceeds limit of {} bytes",
256            b64_data.len(),
257            limits.max_base64_bytes
258        );
259    }
260
261    // Fix #23: Use tolerant engine that accepts padded and unpadded base64
262    let engine = general_purpose::STANDARD;
263
264    // Strip whitespace/newlines from base64
265    let cleaned: String = b64_data.chars().filter(|c| !c.is_whitespace()).collect();
266
267    // Try standard (padded) first, then no-pad
268    let bytes = engine
269        .decode(&cleaned)
270        .or_else(|_| general_purpose::STANDARD_NO_PAD.decode(&cleaned))?;
271
272    Ok((bytes, mime_hint))
273}
274
275#[cfg(test)]
276mod tests {
277    use super::*;
278
279    #[test]
280    fn test_detect_png() {
281        let data = [0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A];
282        assert_eq!(detect_format(&data), MediaFormat::Png);
283    }
284
285    #[test]
286    fn test_detect_jpeg() {
287        let data = [0xFF, 0xD8, 0xFF, 0xE0, 0x00, 0x10];
288        assert_eq!(detect_format(&data), MediaFormat::Jpeg);
289    }
290
291    #[test]
292    fn test_detect_gif() {
293        assert_eq!(detect_format(b"GIF89a..."), MediaFormat::Gif);
294        assert_eq!(detect_format(b"GIF87a..."), MediaFormat::Gif);
295    }
296
297    #[test]
298    fn test_detect_webp() {
299        let mut data = Vec::new();
300        data.extend_from_slice(b"RIFF");
301        data.extend_from_slice(&[0x00; 4]); // size placeholder
302        data.extend_from_slice(b"WEBP");
303        assert_eq!(detect_format(&data), MediaFormat::WebP);
304    }
305
306    #[test]
307    fn test_detect_bmp() {
308        let data = b"BM\x00\x00\x00\x00";
309        assert_eq!(detect_format(data), MediaFormat::Bmp);
310    }
311
312    #[test]
313    fn test_detect_svg_with_xml_declaration() {
314        let data =
315            b"<?xml version=\"1.0\"?><svg xmlns=\"http://www.w3.org/2000/svg\"><rect/></svg>";
316        assert_eq!(detect_format(data), MediaFormat::Svg);
317    }
318
319    #[test]
320    fn test_detect_svg_bare() {
321        let data = b"<svg xmlns=\"http://www.w3.org/2000/svg\" width=\"100\" height=\"100\"><circle/></svg>";
322        assert_eq!(detect_format(data), MediaFormat::Svg);
323    }
324
325    #[test]
326    fn test_detect_svg_with_whitespace() {
327        let data = b"  \n  <svg xmlns=\"http://www.w3.org/2000/svg\"><rect/></svg>";
328        assert_eq!(detect_format(data), MediaFormat::Svg);
329    }
330
331    #[test]
332    fn test_detect_pdf() {
333        assert_eq!(detect_format(b"%PDF-1.4 ..."), MediaFormat::Pdf);
334    }
335
336    #[test]
337    fn test_detect_unknown() {
338        assert_eq!(detect_format(b"random data here"), MediaFormat::Unknown);
339    }
340
341    #[test]
342    fn test_detect_too_short() {
343        assert_eq!(detect_format(b"ab"), MediaFormat::Unknown);
344    }
345
346    #[test]
347    fn test_media_format_mime() {
348        assert_eq!(MediaFormat::Png.mime_type(), "image/png");
349        assert_eq!(MediaFormat::Jpeg.mime_type(), "image/jpeg");
350        assert_eq!(MediaFormat::Svg.mime_type(), "image/svg+xml");
351    }
352
353    #[test]
354    fn test_media_format_is_provider_safe() {
355        assert!(MediaFormat::Png.is_provider_safe());
356        assert!(MediaFormat::Jpeg.is_provider_safe());
357        assert!(MediaFormat::Gif.is_provider_safe());
358        assert!(MediaFormat::WebP.is_provider_safe());
359        assert!(!MediaFormat::Svg.is_provider_safe());
360        assert!(!MediaFormat::Bmp.is_provider_safe());
361        assert!(!MediaFormat::Tiff.is_provider_safe());
362    }
363
364    #[test]
365    fn test_media_format_is_image() {
366        assert!(MediaFormat::Png.is_image());
367        assert!(MediaFormat::Svg.is_image());
368        assert!(!MediaFormat::Mp4.is_image());
369        assert!(!MediaFormat::Pdf.is_image());
370    }
371
372    #[test]
373    fn test_decode_base64_data_uri() {
374        use base64::Engine;
375        let raw = vec![0x89, 0x50, 0x4E, 0x47]; // PNG header
376        let encoded = base64::engine::general_purpose::STANDARD.encode(&raw);
377        let uri = format!("data:image/png;base64,{}", encoded);
378
379        let (bytes, mime) = decode_base64_image(&uri).unwrap();
380        assert_eq!(bytes, raw);
381        assert_eq!(mime.unwrap(), "image/png");
382    }
383
384    #[test]
385    fn test_decode_base64_raw() {
386        use base64::Engine;
387        let raw = vec![0xFF, 0xD8, 0xFF]; // JPEG header
388        let encoded = base64::engine::general_purpose::STANDARD.encode(&raw);
389
390        let (bytes, mime) = decode_base64_image(&encoded).unwrap();
391        assert_eq!(bytes, raw);
392        assert!(mime.is_none());
393    }
394
395    #[test]
396    fn test_decode_base64_size_limit() {
397        let limits = SafetyLimits {
398            max_base64_bytes: 100,
399            ..Default::default()
400        };
401        let big_input = "A".repeat(200);
402        let result = decode_base64_image_with_limits(&big_input, &limits);
403        assert!(result.is_err());
404        assert!(result.unwrap_err().to_string().contains("too large"));
405    }
406
407    #[test]
408    fn test_decode_base64_unpadded() {
409        use base64::Engine;
410        let raw = vec![0x89, 0x50, 0x4E, 0x47, 0x0D]; // 5 bytes
411                                                      // Standard encoding would be "iVBORQ==" but no-pad is "iVBORQ"
412        let encoded_nopad = base64::engine::general_purpose::STANDARD_NO_PAD.encode(&raw);
413        assert!(!encoded_nopad.contains('='));
414
415        let (bytes, _) = decode_base64_image(&encoded_nopad).unwrap();
416        assert_eq!(bytes, raw);
417    }
418
419    #[test]
420    fn test_image_metadata() {
421        let meta = ImageMetadata::new(MediaFormat::Png, 1920, 1080, 500_000, Encoding::Base64);
422        assert_eq!(meta.max_dim(), 1920);
423        assert!((meta.megapixels - 2.0736).abs() < 0.001);
424    }
425}