1pub mod image;
2
3pub mod audio;
5pub mod document;
6pub mod video;
7
8use serde::{Deserialize, Serialize};
9
10use crate::mode::SafetyLimits;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
14#[serde(rename_all = "lowercase")]
15pub enum MediaFormat {
16 Png,
17 Jpeg,
18 Gif,
19 WebP,
20 Svg,
21 Bmp,
22 Tiff,
23 Mp4,
25 Mp3,
26 Wav,
27 Pdf,
28 Unknown,
29}
30
31impl MediaFormat {
32 pub fn mime_type(&self) -> &'static str {
34 match self {
35 MediaFormat::Png => "image/png",
36 MediaFormat::Jpeg => "image/jpeg",
37 MediaFormat::Gif => "image/gif",
38 MediaFormat::WebP => "image/webp",
39 MediaFormat::Svg => "image/svg+xml",
40 MediaFormat::Bmp => "image/bmp",
41 MediaFormat::Tiff => "image/tiff",
42 MediaFormat::Mp4 => "video/mp4",
43 MediaFormat::Mp3 => "audio/mpeg",
44 MediaFormat::Wav => "audio/wav",
45 MediaFormat::Pdf => "application/pdf",
46 MediaFormat::Unknown => "application/octet-stream",
47 }
48 }
49
50 pub fn is_provider_safe(&self) -> bool {
52 matches!(
53 self,
54 MediaFormat::Png | MediaFormat::Jpeg | MediaFormat::Gif | MediaFormat::WebP
55 )
56 }
57
58 pub fn is_image(&self) -> bool {
60 matches!(
61 self,
62 MediaFormat::Png
63 | MediaFormat::Jpeg
64 | MediaFormat::Gif
65 | MediaFormat::WebP
66 | MediaFormat::Svg
67 | MediaFormat::Bmp
68 | MediaFormat::Tiff
69 )
70 }
71}
72
73impl std::fmt::Display for MediaFormat {
74 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75 match self {
76 MediaFormat::Png => write!(f, "png"),
77 MediaFormat::Jpeg => write!(f, "jpeg"),
78 MediaFormat::Gif => write!(f, "gif"),
79 MediaFormat::WebP => write!(f, "webp"),
80 MediaFormat::Svg => write!(f, "svg"),
81 MediaFormat::Bmp => write!(f, "bmp"),
82 MediaFormat::Tiff => write!(f, "tiff"),
83 MediaFormat::Mp4 => write!(f, "mp4"),
84 MediaFormat::Mp3 => write!(f, "mp3"),
85 MediaFormat::Wav => write!(f, "wav"),
86 MediaFormat::Pdf => write!(f, "pdf"),
87 MediaFormat::Unknown => write!(f, "unknown"),
88 }
89 }
90}
91
92#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
94#[serde(rename_all = "lowercase")]
95pub enum Encoding {
96 Base64,
98 Url(String),
100 Raw,
102}
103
104#[derive(Debug, Clone, Serialize, Deserialize)]
106pub struct ImageMetadata {
107 pub format: MediaFormat,
108 pub width: u32,
109 pub height: u32,
110 pub size_bytes: usize,
111 pub encoding: Encoding,
112 pub megapixels: f64,
114 pub svg_source: Option<String>,
116}
117
118impl ImageMetadata {
119 pub fn new(
120 format: MediaFormat,
121 width: u32,
122 height: u32,
123 size_bytes: usize,
124 encoding: Encoding,
125 ) -> Self {
126 let megapixels = (width as f64 * height as f64) / 1_000_000.0;
127 ImageMetadata {
128 format,
129 width,
130 height,
131 size_bytes,
132 encoding,
133 megapixels,
134 svg_source: None,
135 }
136 }
137
138 pub fn max_dim(&self) -> u32 {
140 self.width.max(self.height)
141 }
142}
143
144pub fn detect_format(data: &[u8]) -> MediaFormat {
146 if data.len() < 4 {
147 return MediaFormat::Unknown;
148 }
149
150 if data.starts_with(&[0x89, 0x50, 0x4E, 0x47]) {
152 return MediaFormat::Png;
153 }
154
155 if data.starts_with(&[0xFF, 0xD8, 0xFF]) {
157 return MediaFormat::Jpeg;
158 }
159
160 if data.starts_with(b"GIF87a") || data.starts_with(b"GIF89a") {
162 return MediaFormat::Gif;
163 }
164
165 if data.len() >= 12 && data.starts_with(b"RIFF") && &data[8..12] == b"WEBP" {
167 return MediaFormat::WebP;
168 }
169
170 if data.starts_with(b"BM") {
172 return MediaFormat::Bmp;
173 }
174
175 if data.starts_with(&[0x49, 0x49, 0x2A, 0x00]) || data.starts_with(&[0x4D, 0x4D, 0x00, 0x2A]) {
177 return MediaFormat::Tiff;
178 }
179
180 if is_svg(data) {
182 return MediaFormat::Svg;
183 }
184
185 if data.starts_with(b"%PDF") {
187 return MediaFormat::Pdf;
188 }
189
190 MediaFormat::Unknown
191}
192
193fn is_svg(data: &[u8]) -> bool {
195 let text = match std::str::from_utf8(data) {
197 Ok(s) => s,
198 Err(_) => {
199 let end = data.len().min(1024);
201 match std::str::from_utf8(&data[..end]) {
202 Ok(s) => s,
203 Err(_) => return false,
204 }
205 }
206 };
207
208 let trimmed = text.trim_start();
209 if trimmed.starts_with("<?xml") || trimmed.starts_with("<svg") {
211 return trimmed.contains("<svg");
213 }
214
215 false
216}
217
218pub fn decode_base64_image(input: &str) -> anyhow::Result<(Vec<u8>, Option<String>)> {
227 decode_base64_image_with_limits(input, &SafetyLimits::default())
228}
229
230pub fn decode_base64_image_with_limits(
232 input: &str,
233 limits: &SafetyLimits,
234) -> anyhow::Result<(Vec<u8>, Option<String>)> {
235 use base64::engine::general_purpose;
236 use base64::Engine;
237
238 let (b64_data, mime_hint) = if let Some(rest) = input.strip_prefix("data:") {
239 if let Some(comma_pos) = rest.find(',') {
241 let header = &rest[..comma_pos];
242 let data = &rest[comma_pos + 1..];
243 let mime = header.split(';').next().map(|s| s.to_string());
244 (data, mime)
245 } else {
246 (rest, None)
247 }
248 } else {
249 (input, None)
250 };
251
252 if b64_data.len() > limits.max_base64_bytes {
254 anyhow::bail!(
255 "base64 input too large: {} bytes exceeds limit of {} bytes",
256 b64_data.len(),
257 limits.max_base64_bytes
258 );
259 }
260
261 let engine = general_purpose::STANDARD;
263
264 let cleaned: String = b64_data.chars().filter(|c| !c.is_whitespace()).collect();
266
267 let bytes = engine
269 .decode(&cleaned)
270 .or_else(|_| general_purpose::STANDARD_NO_PAD.decode(&cleaned))?;
271
272 Ok((bytes, mime_hint))
273}
274
275#[cfg(test)]
276mod tests {
277 use super::*;
278
279 #[test]
280 fn test_detect_png() {
281 let data = [0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A];
282 assert_eq!(detect_format(&data), MediaFormat::Png);
283 }
284
285 #[test]
286 fn test_detect_jpeg() {
287 let data = [0xFF, 0xD8, 0xFF, 0xE0, 0x00, 0x10];
288 assert_eq!(detect_format(&data), MediaFormat::Jpeg);
289 }
290
291 #[test]
292 fn test_detect_gif() {
293 assert_eq!(detect_format(b"GIF89a..."), MediaFormat::Gif);
294 assert_eq!(detect_format(b"GIF87a..."), MediaFormat::Gif);
295 }
296
297 #[test]
298 fn test_detect_webp() {
299 let mut data = Vec::new();
300 data.extend_from_slice(b"RIFF");
301 data.extend_from_slice(&[0x00; 4]); data.extend_from_slice(b"WEBP");
303 assert_eq!(detect_format(&data), MediaFormat::WebP);
304 }
305
306 #[test]
307 fn test_detect_bmp() {
308 let data = b"BM\x00\x00\x00\x00";
309 assert_eq!(detect_format(data), MediaFormat::Bmp);
310 }
311
312 #[test]
313 fn test_detect_svg_with_xml_declaration() {
314 let data =
315 b"<?xml version=\"1.0\"?><svg xmlns=\"http://www.w3.org/2000/svg\"><rect/></svg>";
316 assert_eq!(detect_format(data), MediaFormat::Svg);
317 }
318
319 #[test]
320 fn test_detect_svg_bare() {
321 let data = b"<svg xmlns=\"http://www.w3.org/2000/svg\" width=\"100\" height=\"100\"><circle/></svg>";
322 assert_eq!(detect_format(data), MediaFormat::Svg);
323 }
324
325 #[test]
326 fn test_detect_svg_with_whitespace() {
327 let data = b" \n <svg xmlns=\"http://www.w3.org/2000/svg\"><rect/></svg>";
328 assert_eq!(detect_format(data), MediaFormat::Svg);
329 }
330
331 #[test]
332 fn test_detect_pdf() {
333 assert_eq!(detect_format(b"%PDF-1.4 ..."), MediaFormat::Pdf);
334 }
335
336 #[test]
337 fn test_detect_unknown() {
338 assert_eq!(detect_format(b"random data here"), MediaFormat::Unknown);
339 }
340
341 #[test]
342 fn test_detect_too_short() {
343 assert_eq!(detect_format(b"ab"), MediaFormat::Unknown);
344 }
345
346 #[test]
347 fn test_media_format_mime() {
348 assert_eq!(MediaFormat::Png.mime_type(), "image/png");
349 assert_eq!(MediaFormat::Jpeg.mime_type(), "image/jpeg");
350 assert_eq!(MediaFormat::Svg.mime_type(), "image/svg+xml");
351 }
352
353 #[test]
354 fn test_media_format_is_provider_safe() {
355 assert!(MediaFormat::Png.is_provider_safe());
356 assert!(MediaFormat::Jpeg.is_provider_safe());
357 assert!(MediaFormat::Gif.is_provider_safe());
358 assert!(MediaFormat::WebP.is_provider_safe());
359 assert!(!MediaFormat::Svg.is_provider_safe());
360 assert!(!MediaFormat::Bmp.is_provider_safe());
361 assert!(!MediaFormat::Tiff.is_provider_safe());
362 }
363
364 #[test]
365 fn test_media_format_is_image() {
366 assert!(MediaFormat::Png.is_image());
367 assert!(MediaFormat::Svg.is_image());
368 assert!(!MediaFormat::Mp4.is_image());
369 assert!(!MediaFormat::Pdf.is_image());
370 }
371
372 #[test]
373 fn test_decode_base64_data_uri() {
374 use base64::Engine;
375 let raw = vec![0x89, 0x50, 0x4E, 0x47]; let encoded = base64::engine::general_purpose::STANDARD.encode(&raw);
377 let uri = format!("data:image/png;base64,{}", encoded);
378
379 let (bytes, mime) = decode_base64_image(&uri).unwrap();
380 assert_eq!(bytes, raw);
381 assert_eq!(mime.unwrap(), "image/png");
382 }
383
384 #[test]
385 fn test_decode_base64_raw() {
386 use base64::Engine;
387 let raw = vec![0xFF, 0xD8, 0xFF]; let encoded = base64::engine::general_purpose::STANDARD.encode(&raw);
389
390 let (bytes, mime) = decode_base64_image(&encoded).unwrap();
391 assert_eq!(bytes, raw);
392 assert!(mime.is_none());
393 }
394
395 #[test]
396 fn test_decode_base64_size_limit() {
397 let limits = SafetyLimits {
398 max_base64_bytes: 100,
399 ..Default::default()
400 };
401 let big_input = "A".repeat(200);
402 let result = decode_base64_image_with_limits(&big_input, &limits);
403 assert!(result.is_err());
404 assert!(result.unwrap_err().to_string().contains("too large"));
405 }
406
407 #[test]
408 fn test_decode_base64_unpadded() {
409 use base64::Engine;
410 let raw = vec![0x89, 0x50, 0x4E, 0x47, 0x0D]; let encoded_nopad = base64::engine::general_purpose::STANDARD_NO_PAD.encode(&raw);
413 assert!(!encoded_nopad.contains('='));
414
415 let (bytes, _) = decode_base64_image(&encoded_nopad).unwrap();
416 assert_eq!(bytes, raw);
417 }
418
419 #[test]
420 fn test_image_metadata() {
421 let meta = ImageMetadata::new(MediaFormat::Png, 1920, 1080, 500_000, Encoding::Base64);
422 assert_eq!(meta.max_dim(), 1920);
423 assert!((meta.megapixels - 2.0736).abs() < 0.001);
424 }
425}