shift_preflight/inspector/
image.rs1use anyhow::{Context, Result};
2
3use super::{decode_base64_image, detect_format, Encoding, ImageMetadata, MediaFormat};
4use crate::mode::SafetyLimits;
5
6pub fn inspect_bytes(data: &[u8]) -> Result<ImageMetadata> {
8 let format = detect_format(data);
9
10 match format {
11 MediaFormat::Svg => inspect_svg(data),
12 _ if format.is_image() => inspect_raster(data, format),
13 _ => anyhow::bail!("not a recognized image format"),
14 }
15}
16
17pub fn inspect_base64(input: &str) -> Result<ImageMetadata> {
19 let (bytes, _mime_hint) = decode_base64_image(input)?;
20 let mut meta = inspect_bytes(&bytes)?;
21 meta.encoding = Encoding::Base64;
22 meta.size_bytes = bytes.len(); Ok(meta)
24}
25
26pub fn inspect_url(url: &str) -> Result<ImageMetadata> {
30 inspect_url_with_limits(url, &SafetyLimits::default())
31}
32
33pub fn inspect_url_with_limits(url: &str, limits: &SafetyLimits) -> Result<ImageMetadata> {
35 validate_url(url)?;
37
38 let bytes = fetch_url_safe(url, limits)?;
39
40 let mut meta = inspect_bytes(&bytes)?;
41 meta.encoding = Encoding::Url(url.to_string());
42 meta.size_bytes = bytes.len();
43 Ok(meta)
44}
45
46fn validate_url(input: &str) -> Result<()> {
55 let parsed = url::Url::parse(input).context("invalid URL")?;
56
57 match parsed.scheme() {
59 "https" | "http" => {}
60 scheme => anyhow::bail!(
61 "unsupported URL scheme '{}': only http/https allowed",
62 scheme
63 ),
64 }
65
66 let host = parsed.host_str().context("URL missing host")?;
67
68 if host == "localhost" || host == "metadata.google.internal" {
70 anyhow::bail!("URL host '{}' is not allowed", host);
71 }
72
73 if host.starts_with("0x") || host.starts_with("0X") {
75 anyhow::bail!("URL host appears to be a hex-encoded IP address");
76 }
77
78 match parsed.host() {
80 Some(url::Host::Ipv4(ip)) => {
81 if is_private_ip(&std::net::IpAddr::V4(ip)) {
82 anyhow::bail!("URL contains a private/loopback IP address");
83 }
84 }
85 Some(url::Host::Ipv6(ip)) => {
86 if is_private_ip(&std::net::IpAddr::V6(ip)) {
87 anyhow::bail!("URL contains a private/loopback IP address");
88 }
89 }
90 Some(url::Host::Domain(_)) => {
91 let port = parsed
95 .port()
96 .unwrap_or(if parsed.scheme() == "https" { 443 } else { 80 });
97 if let Ok(addrs) = std::net::ToSocketAddrs::to_socket_addrs(&(host, port)) {
98 for addr in addrs {
99 if is_private_ip(&addr.ip()) {
100 anyhow::bail!("URL hostname resolves to a private/loopback IP address");
101 }
102 }
103 }
104 }
108 None => {
109 anyhow::bail!("URL has no host");
110 }
111 }
112
113 Ok(())
114}
115
116fn is_private_ip(ip: &std::net::IpAddr) -> bool {
121 match ip {
122 std::net::IpAddr::V4(v4) => {
123 v4.is_loopback() || v4.is_private() || v4.is_link_local() || v4.is_broadcast()
127 || v4.is_unspecified()
128 || v4.octets()[0] == 0 }
130 std::net::IpAddr::V6(v6) => {
131 if let Some(mapped_v4) = v6.to_ipv4_mapped() {
133 return is_private_ip(&std::net::IpAddr::V4(mapped_v4));
134 }
135
136 v6.is_loopback() || v6.is_unspecified() || (v6.segments()[0] & 0xffc0) == 0xfe80
140 || (v6.segments()[0] & 0xfe00) == 0xfc00
142 }
143 }
144}
145
146fn inspect_raster(data: &[u8], detected_format: MediaFormat) -> Result<ImageMetadata> {
148 let reader = image::ImageReader::new(std::io::Cursor::new(data))
149 .with_guessed_format()
150 .context("failed to guess image format")?;
151
152 let (width, height) = reader
153 .into_dimensions()
154 .context("failed to read image dimensions")?;
155
156 let limits = SafetyLimits::default();
158 let pixels = width as u64 * height as u64;
159 if pixels > limits.max_pixels {
160 anyhow::bail!(
161 "image too large: {}x{} ({:.1} megapixels) exceeds limit of {:.0} megapixels",
162 width,
163 height,
164 pixels as f64 / 1_000_000.0,
165 limits.max_pixels as f64 / 1_000_000.0
166 );
167 }
168
169 Ok(ImageMetadata::new(
170 detected_format,
171 width,
172 height,
173 data.len(),
174 Encoding::Raw,
175 ))
176}
177
178fn inspect_svg(data: &[u8]) -> Result<ImageMetadata> {
180 let source = std::str::from_utf8(data).context("SVG is not valid UTF-8")?;
181
182 let (width, height) = parse_svg_dimensions(source);
183
184 let mut meta = ImageMetadata::new(MediaFormat::Svg, width, height, data.len(), Encoding::Raw);
185 meta.svg_source = Some(source.to_string());
186 Ok(meta)
187}
188
189fn parse_svg_dimensions(svg: &str) -> (u32, u32) {
191 let width = extract_svg_attr(svg, "width");
192 let height = extract_svg_attr(svg, "height");
193
194 if let (Some(w), Some(h)) = (width, height) {
195 if w > 0 && h > 0 {
196 return (w, h);
197 }
198 }
199
200 if let Some(vb) = extract_svg_viewbox(svg) {
202 return vb;
203 }
204
205 (300, 150)
207}
208
209fn extract_svg_attr(svg: &str, attr_name: &str) -> Option<u32> {
214 let svg_tag_start = svg.find("<svg")?;
215 let svg_tag_end = svg[svg_tag_start..].find('>')? + svg_tag_start;
216 let tag = &svg[svg_tag_start..=svg_tag_end];
217
218 let search_pattern = format!(" {}=", attr_name);
221 let attr_pos = tag.find(&search_pattern)?;
222 let after_eq = &tag[attr_pos + search_pattern.len()..];
224
225 let value = if let Some(stripped) = after_eq.strip_prefix('"') {
227 let end = stripped.find('"')?;
228 &stripped[..end]
229 } else if let Some(stripped) = after_eq.strip_prefix('\'') {
230 let end = stripped.find('\'')?;
231 &stripped[..end]
232 } else {
233 let end = after_eq
234 .find(|c: char| c.is_whitespace() || c == '>')
235 .unwrap_or(after_eq.len());
236 &after_eq[..end]
237 };
238
239 let lower = value.to_lowercase();
241 if lower.contains('%')
242 || lower.contains("em")
243 || lower.contains("rem")
244 || lower.contains("vw")
245 || lower.contains("vh")
246 || lower.contains("vmin")
247 || lower.contains("vmax")
248 {
249 return None;
250 }
251
252 let numeric: String = value
254 .chars()
255 .take_while(|c| c.is_ascii_digit() || *c == '.')
256 .collect();
257 numeric.parse::<f64>().ok().map(|v| v as u32)
258}
259
260fn extract_svg_viewbox(svg: &str) -> Option<(u32, u32)> {
262 let svg_tag_start = svg.find("<svg")?;
263 let svg_tag_end = svg[svg_tag_start..].find('>')? + svg_tag_start;
264 let tag = &svg[svg_tag_start..=svg_tag_end];
265
266 let vb_pos = tag.find("viewBox=")?;
267 let after_eq = &tag[vb_pos + 8..];
268
269 let value = if let Some(stripped) = after_eq.strip_prefix('"') {
270 let end = stripped.find('"')?;
271 &stripped[..end]
272 } else if let Some(stripped) = after_eq.strip_prefix('\'') {
273 let end = stripped.find('\'')?;
274 &stripped[..end]
275 } else {
276 return None;
277 };
278
279 let parts: Vec<f64> = value
281 .split_whitespace()
282 .flat_map(|s| s.split(','))
283 .filter(|s| !s.is_empty())
284 .filter_map(|s| s.parse::<f64>().ok())
285 .collect();
286
287 if parts.len() >= 4 && parts[2] > 0.0 && parts[3] > 0.0 {
288 Some((parts[2] as u32, parts[3] as u32))
289 } else {
290 None
291 }
292}
293
294pub fn fetch_url_safe(url: &str, limits: &SafetyLimits) -> Result<Vec<u8>> {
300 validate_url(url)?;
301
302 let agent = ureq::Agent::new_with_config(
305 ureq::config::Config::builder()
306 .redirect_auth_headers(ureq::config::RedirectAuthHeaders::Never)
307 .max_redirects(0)
308 .timeout_global(Some(std::time::Duration::from_secs(30)))
309 .build(),
310 );
311
312 let response = agent
313 .get(url)
314 .call()
315 .with_context(|| "failed to fetch image from URL".to_string())?;
316
317 let status = response.status().as_u16();
318
319 if (300..400).contains(&status) {
320 anyhow::bail!(
321 "image URL returned a redirect (HTTP {}); redirects are disabled for security",
322 status
323 );
324 }
325
326 if status != 200 {
327 anyhow::bail!("failed to fetch image: HTTP {}", status);
328 }
329
330 if let Some(cl) = response.headers().get("content-length") {
332 if let Ok(size) = cl.to_str().unwrap_or("").parse::<usize>() {
333 if size > limits.max_download_bytes {
334 anyhow::bail!(
335 "image URL Content-Length ({} bytes) exceeds limit of {} bytes",
336 size,
337 limits.max_download_bytes
338 );
339 }
340 }
341 }
342
343 use std::io::Read;
346 let max = limits.max_download_bytes;
347 let mut body = response.into_body();
348 let mut buf = Vec::new();
349 let bytes_read = body
350 .as_reader()
351 .take((max + 1) as u64)
352 .read_to_end(&mut buf)
353 .context("failed to read image response body")?;
354
355 if bytes_read > max {
356 anyhow::bail!(
357 "downloaded image too large: read at least {} bytes, exceeds limit of {} bytes",
358 bytes_read,
359 max
360 );
361 }
362
363 Ok(buf)
364}
365
366#[cfg(test)]
367mod tests {
368 use super::*;
369
370 fn make_png(width: u32, height: u32) -> Vec<u8> {
371 let img = image::RgbaImage::new(width, height);
372 let mut buf = Vec::new();
373 let encoder = image::codecs::png::PngEncoder::new(&mut buf);
374 image::ImageEncoder::write_image(
375 encoder,
376 img.as_raw(),
377 width,
378 height,
379 image::ExtendedColorType::Rgba8,
380 )
381 .unwrap();
382 buf
383 }
384
385 fn make_jpeg(width: u32, height: u32) -> Vec<u8> {
386 let img = image::RgbImage::new(width, height);
387 let mut buf = Vec::new();
388 let mut encoder = image::codecs::jpeg::JpegEncoder::new_with_quality(&mut buf, 80);
389 encoder
390 .encode(img.as_raw(), width, height, image::ExtendedColorType::Rgb8)
391 .unwrap();
392 buf
393 }
394
395 #[test]
396 fn test_inspect_png() {
397 let data = make_png(640, 480);
398 let meta = inspect_bytes(&data).unwrap();
399 assert_eq!(meta.format, MediaFormat::Png);
400 assert_eq!(meta.width, 640);
401 assert_eq!(meta.height, 480);
402 assert_eq!(meta.max_dim(), 640);
403 }
404
405 #[test]
406 fn test_inspect_jpeg() {
407 let data = make_jpeg(1920, 1080);
408 let meta = inspect_bytes(&data).unwrap();
409 assert_eq!(meta.format, MediaFormat::Jpeg);
410 assert_eq!(meta.width, 1920);
411 assert_eq!(meta.height, 1080);
412 }
413
414 #[test]
415 fn test_inspect_svg_with_dimensions() {
416 let svg =
417 r#"<svg xmlns="http://www.w3.org/2000/svg" width="200" height="100"><rect/></svg>"#;
418 let meta = inspect_bytes(svg.as_bytes()).unwrap();
419 assert_eq!(meta.format, MediaFormat::Svg);
420 assert_eq!(meta.width, 200);
421 assert_eq!(meta.height, 100);
422 assert!(meta.svg_source.is_some());
423 }
424
425 #[test]
426 fn test_inspect_svg_with_viewbox() {
427 let svg = r#"<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 800 600"><rect/></svg>"#;
428 let meta = inspect_bytes(svg.as_bytes()).unwrap();
429 assert_eq!(meta.format, MediaFormat::Svg);
430 assert_eq!(meta.width, 800);
431 assert_eq!(meta.height, 600);
432 }
433
434 #[test]
435 fn test_inspect_svg_with_xml_declaration() {
436 let svg = r#"<?xml version="1.0" encoding="UTF-8"?>
437<svg xmlns="http://www.w3.org/2000/svg" width="500" height="300">
438 <circle cx="250" cy="150" r="100"/>
439</svg>"#;
440 let meta = inspect_bytes(svg.as_bytes()).unwrap();
441 assert_eq!(meta.format, MediaFormat::Svg);
442 assert_eq!(meta.width, 500);
443 assert_eq!(meta.height, 300);
444 }
445
446 #[test]
447 fn test_inspect_svg_viewbox_comma_separated() {
448 let svg = r#"<svg xmlns="http://www.w3.org/2000/svg" viewBox="0,0,1024,768"><rect/></svg>"#;
449 let meta = inspect_bytes(svg.as_bytes()).unwrap();
450 assert_eq!(meta.width, 1024);
451 assert_eq!(meta.height, 768);
452 }
453
454 #[test]
455 fn test_inspect_svg_px_units() {
456 let svg =
457 r#"<svg xmlns="http://www.w3.org/2000/svg" width="200px" height="150px"><rect/></svg>"#;
458 let meta = inspect_bytes(svg.as_bytes()).unwrap();
459 assert_eq!(meta.width, 200);
460 assert_eq!(meta.height, 150);
461 }
462
463 #[test]
465 fn test_inspect_svg_percentage_falls_to_viewbox() {
466 let svg = r#"<svg xmlns="http://www.w3.org/2000/svg" width="100%" height="100%" viewBox="0 0 4000 3000"><rect/></svg>"#;
467 let meta = inspect_bytes(svg.as_bytes()).unwrap();
468 assert_eq!(meta.width, 4000);
469 assert_eq!(meta.height, 3000);
470 }
471
472 #[test]
473 fn test_inspect_svg_em_units_falls_to_viewbox() {
474 let svg = r#"<svg xmlns="http://www.w3.org/2000/svg" width="10em" height="8em" viewBox="0 0 500 400"><rect/></svg>"#;
475 let meta = inspect_bytes(svg.as_bytes()).unwrap();
476 assert_eq!(meta.width, 500);
477 assert_eq!(meta.height, 400);
478 }
479
480 #[test]
482 fn test_inspect_svg_stroke_width_not_confused() {
483 let svg = r#"<svg xmlns="http://www.w3.org/2000/svg" stroke-width="3" width="800" height="600"><rect/></svg>"#;
484 let meta = inspect_bytes(svg.as_bytes()).unwrap();
485 assert_eq!(meta.width, 800);
486 assert_eq!(meta.height, 600);
487 }
488
489 #[test]
490 fn test_inspect_base64_png() {
491 use base64::Engine;
492 let png_data = make_png(100, 50);
493 let encoded = base64::engine::general_purpose::STANDARD.encode(&png_data);
494 let data_uri = format!("data:image/png;base64,{}", encoded);
495
496 let meta = inspect_base64(&data_uri).unwrap();
497 assert_eq!(meta.format, MediaFormat::Png);
498 assert_eq!(meta.width, 100);
499 assert_eq!(meta.height, 50);
500 assert_eq!(meta.encoding, Encoding::Base64);
501 }
502
503 #[test]
504 fn test_inspect_base64_raw() {
505 use base64::Engine;
506 let png_data = make_png(64, 64);
507 let encoded = base64::engine::general_purpose::STANDARD.encode(&png_data);
508
509 let meta = inspect_base64(&encoded).unwrap();
510 assert_eq!(meta.format, MediaFormat::Png);
511 assert_eq!(meta.width, 64);
512 assert_eq!(meta.height, 64);
513 }
514
515 #[test]
516 fn test_inspect_not_an_image() {
517 let result = inspect_bytes(b"this is just text, not an image");
518 assert!(result.is_err());
519 }
520
521 #[test]
522 fn test_megapixels() {
523 let data = make_png(4000, 3000);
524 let meta = inspect_bytes(&data).unwrap();
525 assert!((meta.megapixels - 12.0).abs() < 0.001);
526 }
527
528 #[test]
530 fn test_validate_url_rejects_private_ip() {
531 assert!(validate_url("http://127.0.0.1/image.png").is_err());
532 assert!(validate_url("http://10.0.0.1/image.png").is_err());
533 assert!(validate_url("http://172.16.0.1/image.png").is_err());
534 assert!(validate_url("http://192.168.1.1/image.png").is_err());
535 assert!(validate_url("http://169.254.169.254/latest/meta-data/").is_err());
536 }
537
538 #[test]
539 fn test_validate_url_rejects_localhost() {
540 assert!(validate_url("http://localhost/image.png").is_err());
541 assert!(validate_url("http://localhost:8080/secret").is_err());
542 }
543
544 #[test]
545 fn test_validate_url_rejects_ipv6_loopback() {
546 assert!(validate_url("http://[::1]/image.png").is_err());
547 }
548
549 #[test]
550 fn test_validate_url_rejects_file_scheme() {
551 assert!(validate_url("file:///etc/passwd").is_err());
552 }
553
554 #[test]
555 fn test_validate_url_rejects_hex_ip() {
556 assert!(validate_url("http://0x7f000001/image.png").is_err());
557 }
558
559 #[test]
560 fn test_validate_url_allows_public() {
561 assert!(validate_url("https://example.com/image.png").is_ok());
562 assert!(validate_url("https://cdn.openai.com/image.png").is_ok());
563 }
564
565 #[test]
567 fn test_validate_url_rejects_ipv4_mapped_ipv6() {
568 assert!(validate_url("http://[::ffff:127.0.0.1]/image.png").is_err());
570 assert!(validate_url("http://[::ffff:10.0.0.1]/image.png").is_err());
572 assert!(validate_url("http://[::ffff:169.254.169.254]/image.png").is_err());
574 assert!(validate_url("http://[::ffff:192.168.1.1]/image.png").is_err());
576 }
577
578 #[test]
579 fn test_is_private_ip_ipv4_mapped_ipv6() {
580 use std::net::{IpAddr, Ipv6Addr};
581 let mapped_loopback: Ipv6Addr = "::ffff:127.0.0.1".parse().unwrap();
583 assert!(is_private_ip(&IpAddr::V6(mapped_loopback)));
584 let mapped_private: Ipv6Addr = "::ffff:10.0.0.1".parse().unwrap();
586 assert!(is_private_ip(&IpAddr::V6(mapped_private)));
587 let mapped_public: Ipv6Addr = "::ffff:8.8.8.8".parse().unwrap();
589 assert!(!is_private_ip(&IpAddr::V6(mapped_public)));
590 }
591
592 #[test]
594 fn test_validate_url_resolves_hostname_localhost() {
595 assert!(validate_url("http://localhost/image.png").is_err());
598 }
599
600 #[test]
602 fn test_normal_image_passes_pixel_budget() {
603 let data = make_png(4000, 3000); let meta = inspect_bytes(&data).unwrap();
605 assert_eq!(meta.width, 4000);
606 }
607
608 #[test]
613 fn test_pixel_budget_rejects_oversized() {
614 use crate::mode::SafetyLimits;
619
620 assert_eq!(SafetyLimits::default().max_pixels, 100_000_000);
622
623 let data = make_png(10000, 10000);
625 let meta = inspect_bytes(&data).unwrap();
626 assert_eq!(meta.width, 10000);
627 }
628
629 #[test]
631 fn test_svg_viewbox_negative_dims_fallback() {
632 let svg =
633 r#"<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 -100 -100"><rect/></svg>"#;
634 let meta = inspect_bytes(svg.as_bytes()).unwrap();
635 assert_eq!(meta.width, 300);
637 assert_eq!(meta.height, 150);
638 }
639}