Skip to main content

shift_preflight/transformer/
image.rs

1use anyhow::{Context, Result};
2use image::codecs::jpeg::JpegEncoder;
3use image::codecs::png::PngEncoder;
4use image::imageops::FilterType;
5use image::{DynamicImage, ImageEncoder};
6
7use crate::policy::Action;
8
9/// Apply a transformation action to raw image bytes.
10///
11/// Returns the transformed image bytes.
12pub fn transform_image(data: &[u8], action: &Action) -> Result<Vec<u8>> {
13    match action {
14        Action::Pass => Ok(data.to_vec()),
15
16        Action::Resize {
17            target_width,
18            target_height,
19        } => resize_image(data, *target_width, *target_height),
20
21        Action::Recompress { quality } => recompress_jpeg(data, *quality),
22
23        Action::ConvertFormat { to } => convert_format(data, to),
24
25        Action::RasterizeSvg {
26            target_width,
27            target_height,
28        } => {
29            // SVG data should be passed as the raw SVG text bytes
30            let svg_text = std::str::from_utf8(data).context("SVG data is not valid UTF-8")?;
31            rasterize_svg(svg_text, *target_width, *target_height)
32        }
33
34        Action::Drop { .. } => {
35            // Dropping returns empty — caller handles removal from payload
36            Ok(Vec::new())
37        }
38    }
39}
40
41/// Load an image from memory with a pixel budget to prevent decompression bombs.
42///
43/// R5: Propagates dimension-read errors instead of silently falling through
44/// to an unguarded decode. If we can't read dimensions from the header,
45/// we reject the image rather than risk a decompression bomb.
46fn load_image_safe(data: &[u8]) -> Result<DynamicImage> {
47    use crate::mode::SafetyLimits;
48
49    let limits = SafetyLimits::default();
50
51    let reader = image::ImageReader::new(std::io::Cursor::new(data))
52        .with_guessed_format()
53        .context("failed to guess image format")?;
54
55    // R5: Propagate the error — don't silently skip the budget check
56    let (w, h) = reader
57        .into_dimensions()
58        .context("failed to read image dimensions (cannot verify pixel budget)")?;
59
60    let pixels = w as u64 * h as u64;
61    if pixels > limits.max_pixels {
62        anyhow::bail!(
63            "image decompression blocked: {}x{} ({:.0} megapixels) exceeds {:.0} megapixel safety limit",
64            w,
65            h,
66            pixels as f64 / 1_000_000.0,
67            limits.max_pixels as f64 / 1_000_000.0
68        );
69    }
70
71    // Now do the full decode — we know it's within pixel budget
72    image::load_from_memory(data).context("failed to decode image")
73}
74
75/// Resize an image to fit within target dimensions, preserving aspect ratio.
76fn resize_image(data: &[u8], target_width: u32, target_height: u32) -> Result<Vec<u8>> {
77    let img = load_image_safe(data)?;
78
79    let resized = img.resize(target_width, target_height, FilterType::Lanczos3);
80
81    // Encode as PNG (lossless, safe for all providers)
82    encode_png(&resized)
83}
84
85/// Recompress an image as JPEG at the given quality.
86fn recompress_jpeg(data: &[u8], quality: u8) -> Result<Vec<u8>> {
87    let img = load_image_safe(data)?;
88
89    let rgb = img.to_rgb8();
90    let mut buf = Vec::new();
91    let encoder = JpegEncoder::new_with_quality(&mut buf, quality);
92    encoder
93        .write_image(
94            rgb.as_raw(),
95            rgb.width(),
96            rgb.height(),
97            image::ExtendedColorType::Rgb8,
98        )
99        .context("failed to encode JPEG")?;
100
101    Ok(buf)
102}
103
104/// Convert an image to a different format.
105fn convert_format(data: &[u8], to: &str) -> Result<Vec<u8>> {
106    let img = load_image_safe(data)?;
107
108    match to {
109        "png" => encode_png(&img),
110        "jpeg" | "jpg" => {
111            let rgb = img.to_rgb8();
112            let mut buf = Vec::new();
113            let encoder = JpegEncoder::new_with_quality(&mut buf, 85);
114            encoder
115                .write_image(
116                    rgb.as_raw(),
117                    rgb.width(),
118                    rgb.height(),
119                    image::ExtendedColorType::Rgb8,
120                )
121                .context("failed to encode JPEG")?;
122            Ok(buf)
123        }
124        _ => anyhow::bail!("unsupported target format: {}", to),
125    }
126}
127
128/// Encode a DynamicImage as PNG.
129fn encode_png(img: &DynamicImage) -> Result<Vec<u8>> {
130    let rgba = img.to_rgba8();
131    let mut buf = Vec::new();
132    let encoder = PngEncoder::new(&mut buf);
133    encoder
134        .write_image(
135            rgba.as_raw(),
136            rgba.width(),
137            rgba.height(),
138            image::ExtendedColorType::Rgba8,
139        )
140        .context("failed to encode PNG")?;
141    Ok(buf)
142}
143
144/// Rasterize SVG text to PNG at the given dimensions.
145pub fn rasterize_svg(svg_text: &str, target_width: u32, target_height: u32) -> Result<Vec<u8>> {
146    use resvg::tiny_skia;
147    use resvg::usvg;
148
149    let options = usvg::Options::default();
150    let tree = usvg::Tree::from_str(svg_text, &options).context("failed to parse SVG")?;
151
152    let size = tree.size();
153    let (svg_w, svg_h) = (size.width(), size.height());
154
155    // Calculate scale to fit within target dimensions
156    let scale_x = target_width as f32 / svg_w;
157    let scale_y = target_height as f32 / svg_h;
158    let scale = scale_x.min(scale_y);
159
160    let pixel_w = (svg_w * scale).ceil() as u32;
161    let pixel_h = (svg_h * scale).ceil() as u32;
162
163    // R9: Pixel budget for SVG rasterization (shared via SafetyLimits)
164    let limits = crate::mode::SafetyLimits::default();
165    let pixel_count = pixel_w as u64 * pixel_h as u64;
166    if pixel_count > limits.max_pixels {
167        anyhow::bail!(
168            "SVG rasterization blocked: {}x{} exceeds {:.0} megapixel safety limit",
169            pixel_w,
170            pixel_h,
171            limits.max_pixels as f64 / 1_000_000.0
172        );
173    }
174
175    let mut pixmap = tiny_skia::Pixmap::new(pixel_w.max(1), pixel_h.max(1))
176        .context("failed to create pixmap")?;
177
178    let transform = tiny_skia::Transform::from_scale(scale, scale);
179    resvg::render(&tree, transform, &mut pixmap.as_mut());
180
181    // Convert pixmap to PNG
182    let png_data = pixmap
183        .encode_png()
184        .context("failed to encode rasterized SVG as PNG")?;
185    Ok(png_data)
186}
187
188#[cfg(test)]
189mod tests {
190    use super::*;
191    use crate::inspector::{detect_format, MediaFormat};
192
193    fn make_test_png(width: u32, height: u32) -> Vec<u8> {
194        let img = image::RgbaImage::new(width, height);
195        let mut buf = Vec::new();
196        let encoder = PngEncoder::new(&mut buf);
197        encoder
198            .write_image(img.as_raw(), width, height, image::ExtendedColorType::Rgba8)
199            .unwrap();
200        buf
201    }
202
203    fn make_test_jpeg(width: u32, height: u32) -> Vec<u8> {
204        let img = image::RgbImage::new(width, height);
205        let mut buf = Vec::new();
206        let mut encoder = JpegEncoder::new_with_quality(&mut buf, 90);
207        encoder
208            .encode(img.as_raw(), width, height, image::ExtendedColorType::Rgb8)
209            .unwrap();
210        buf
211    }
212
213    #[test]
214    fn test_resize_png() {
215        let data = make_test_png(4000, 3000);
216        let action = Action::Resize {
217            target_width: 2048,
218            target_height: 2048,
219        };
220        let result = transform_image(&data, &action).unwrap();
221
222        // Verify it's still a valid image
223        let img = image::load_from_memory(&result).unwrap();
224        assert!(img.width() <= 2048);
225        assert!(img.height() <= 2048);
226        // Verify aspect ratio preserved
227        let ratio_orig = 4000.0 / 3000.0;
228        let ratio_new = img.width() as f64 / img.height() as f64;
229        assert!((ratio_orig - ratio_new).abs() < 0.02);
230    }
231
232    #[test]
233    fn test_resize_preserves_format_as_png() {
234        let data = make_test_png(3000, 2000);
235        let action = Action::Resize {
236            target_width: 1024,
237            target_height: 1024,
238        };
239        let result = transform_image(&data, &action).unwrap();
240        assert_eq!(detect_format(&result), MediaFormat::Png);
241    }
242
243    #[test]
244    fn test_recompress_jpeg() {
245        let data = make_test_jpeg(1000, 800);
246        let original_size = data.len();
247        let action = Action::Recompress { quality: 50 };
248        let result = transform_image(&data, &action).unwrap();
249
250        // Lower quality should produce a smaller file
251        assert!(result.len() <= original_size);
252        // Should still be valid JPEG
253        assert_eq!(detect_format(&result), MediaFormat::Jpeg);
254    }
255
256    #[test]
257    fn test_convert_bmp_to_png() {
258        // Create a test image and save as BMP bytes
259        let img = image::RgbImage::from_pixel(100, 100, image::Rgb([255, 0, 0]));
260        let mut bmp_data = Vec::new();
261        let mut cursor = std::io::Cursor::new(&mut bmp_data);
262        img.write_to(&mut cursor, image::ImageFormat::Bmp).unwrap();
263
264        let action = Action::ConvertFormat {
265            to: "png".to_string(),
266        };
267        let result = transform_image(&bmp_data, &action).unwrap();
268        assert_eq!(detect_format(&result), MediaFormat::Png);
269    }
270
271    #[test]
272    fn test_pass_action() {
273        let data = make_test_png(100, 100);
274        let action = Action::Pass;
275        let result = transform_image(&data, &action).unwrap();
276        assert_eq!(result, data);
277    }
278
279    #[test]
280    fn test_drop_action() {
281        let data = make_test_png(100, 100);
282        let action = Action::Drop {
283            reason: "test".into(),
284        };
285        let result = transform_image(&data, &action).unwrap();
286        assert!(result.is_empty());
287    }
288
289    #[test]
290    fn test_rasterize_svg_simple() {
291        let svg = r#"<svg xmlns="http://www.w3.org/2000/svg" width="200" height="100">
292            <rect width="200" height="100" fill="red"/>
293        </svg>"#;
294
295        let result = rasterize_svg(svg, 200, 100).unwrap();
296        assert!(!result.is_empty());
297        assert_eq!(detect_format(&result), MediaFormat::Png);
298
299        // Verify dimensions
300        let img = image::load_from_memory(&result).unwrap();
301        assert_eq!(img.width(), 200);
302        assert_eq!(img.height(), 100);
303    }
304
305    #[test]
306    fn test_rasterize_svg_scaled_down() {
307        let svg = r#"<svg xmlns="http://www.w3.org/2000/svg" width="2000" height="1000">
308            <circle cx="1000" cy="500" r="400" fill="blue"/>
309        </svg>"#;
310
311        let result = rasterize_svg(svg, 500, 500).unwrap();
312        let img = image::load_from_memory(&result).unwrap();
313        assert!(img.width() <= 500);
314        assert!(img.height() <= 500);
315    }
316
317    #[test]
318    fn test_rasterize_svg_with_viewbox() {
319        let svg = r#"<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 100 100">
320            <rect x="10" y="10" width="80" height="80" fill="green"/>
321        </svg>"#;
322
323        let result = rasterize_svg(svg, 256, 256).unwrap();
324        assert!(!result.is_empty());
325        assert_eq!(detect_format(&result), MediaFormat::Png);
326    }
327
328    #[test]
329    fn test_rasterize_svg_complex() {
330        let svg = r#"<?xml version="1.0" encoding="UTF-8"?>
331<svg xmlns="http://www.w3.org/2000/svg" width="400" height="300" viewBox="0 0 400 300">
332  <defs>
333    <linearGradient id="grad" x1="0%" y1="0%" x2="100%" y2="100%">
334      <stop offset="0%" style="stop-color:rgb(255,0,0);stop-opacity:1" />
335      <stop offset="100%" style="stop-color:rgb(0,0,255);stop-opacity:1" />
336    </linearGradient>
337  </defs>
338  <rect width="400" height="300" fill="url(#grad)"/>
339  <circle cx="200" cy="150" r="80" fill="white" opacity="0.5"/>
340  <text x="200" y="160" text-anchor="middle" font-size="24" fill="white">SHIFT</text>
341</svg>"#;
342
343        let result = rasterize_svg(svg, 800, 600).unwrap();
344        assert!(!result.is_empty());
345        let img = image::load_from_memory(&result).unwrap();
346        assert!(img.width() > 0);
347        assert!(img.height() > 0);
348    }
349
350    #[test]
351    fn test_transform_svg_via_action() {
352        let svg = r#"<svg xmlns="http://www.w3.org/2000/svg" width="100" height="100">
353            <rect width="100" height="100" fill="red"/>
354        </svg>"#;
355
356        let action = Action::RasterizeSvg {
357            target_width: 256,
358            target_height: 256,
359        };
360        let result = transform_image(svg.as_bytes(), &action).unwrap();
361        assert_eq!(detect_format(&result), MediaFormat::Png);
362    }
363}