Skip to main content

shift_preflight/transformer/
image.rs

1use anyhow::{Context, Result};
2use image::codecs::jpeg::JpegEncoder;
3use image::codecs::png::PngEncoder;
4use image::imageops::FilterType;
5use image::{DynamicImage, ImageEncoder};
6
7use crate::policy::Action;
8
9/// Default JPEG quality for format-preserving operations (resize, convert).
10/// The Recompress action has its own policy-driven quality.
11const DEFAULT_JPEG_QUALITY: u8 = 85;
12
13/// Apply a transformation action to raw image bytes.
14///
15/// Returns the transformed image bytes.
16pub fn transform_image(data: &[u8], action: &Action) -> Result<Vec<u8>> {
17    match action {
18        Action::Pass => Ok(data.to_vec()),
19
20        Action::Resize {
21            target_width,
22            target_height,
23        } => resize_image(data, *target_width, *target_height),
24
25        Action::Recompress { quality } => recompress_jpeg(data, *quality),
26
27        Action::ConvertFormat { to } => convert_format(data, to),
28
29        Action::RasterizeSvg {
30            target_width,
31            target_height,
32        } => {
33            // SVG data should be passed as the raw SVG text bytes
34            let svg_text = std::str::from_utf8(data).context("SVG data is not valid UTF-8")?;
35            rasterize_svg(svg_text, *target_width, *target_height)
36        }
37
38        Action::Drop { .. } => {
39            // Dropping returns empty — caller handles removal from payload
40            Ok(Vec::new())
41        }
42    }
43}
44
45/// Load an image from memory with a pixel budget to prevent decompression bombs.
46///
47/// R5: Propagates dimension-read errors instead of silently falling through
48/// to an unguarded decode. If we can't read dimensions from the header,
49/// we reject the image rather than risk a decompression bomb.
50fn load_image_safe(data: &[u8]) -> Result<DynamicImage> {
51    use crate::mode::SafetyLimits;
52
53    let limits = SafetyLimits::default();
54
55    let reader = image::ImageReader::new(std::io::Cursor::new(data))
56        .with_guessed_format()
57        .context("failed to guess image format")?;
58
59    // R5: Propagate the error — don't silently skip the budget check
60    let (w, h) = reader
61        .into_dimensions()
62        .context("failed to read image dimensions (cannot verify pixel budget)")?;
63
64    let pixels = w as u64 * h as u64;
65    if pixels > limits.max_pixels {
66        anyhow::bail!(
67            "image decompression blocked: {}x{} ({:.0} megapixels) exceeds {:.0} megapixel safety limit",
68            w,
69            h,
70            pixels as f64 / 1_000_000.0,
71            limits.max_pixels as f64 / 1_000_000.0
72        );
73    }
74
75    // Now do the full decode — we know it's within pixel budget
76    image::load_from_memory(data).context("failed to decode image")
77}
78
79/// Resize an image to fit within target dimensions, preserving aspect ratio.
80///
81/// Re-encodes in the same format as the input. JPEG inputs stay JPEG (quality 85),
82/// avoiding the size inflation that would occur from converting to PNG.
83/// Non-JPEG inputs (PNG, GIF, WebP, etc.) are encoded as PNG for lossless safety.
84fn resize_image(data: &[u8], target_width: u32, target_height: u32) -> Result<Vec<u8>> {
85    let input_format = crate::inspector::detect_format(data);
86    let img = load_image_safe(data)?;
87
88    let resized = img.resize(target_width, target_height, FilterType::Lanczos3);
89
90    match input_format {
91        crate::inspector::MediaFormat::Jpeg => encode_jpeg(&resized, DEFAULT_JPEG_QUALITY),
92        _ => encode_png(&resized),
93    }
94}
95
96/// Recompress an image as JPEG at the given quality.
97fn recompress_jpeg(data: &[u8], quality: u8) -> Result<Vec<u8>> {
98    let img = load_image_safe(data)?;
99    encode_jpeg(&img, quality)
100}
101
102/// Convert an image to a different format.
103fn convert_format(data: &[u8], to: &str) -> Result<Vec<u8>> {
104    let img = load_image_safe(data)?;
105
106    match to {
107        "png" => encode_png(&img),
108        "jpeg" | "jpg" => encode_jpeg(&img, DEFAULT_JPEG_QUALITY),
109        _ => anyhow::bail!("unsupported target format: {}", to),
110    }
111}
112
113/// Encode a DynamicImage as PNG.
114fn encode_png(img: &DynamicImage) -> Result<Vec<u8>> {
115    let rgba = img.to_rgba8();
116    let mut buf = Vec::new();
117    let encoder = PngEncoder::new(&mut buf);
118    encoder
119        .write_image(
120            rgba.as_raw(),
121            rgba.width(),
122            rgba.height(),
123            image::ExtendedColorType::Rgba8,
124        )
125        .context("failed to encode PNG")?;
126    Ok(buf)
127}
128
129/// Encode a DynamicImage as JPEG at the given quality.
130fn encode_jpeg(img: &DynamicImage, quality: u8) -> Result<Vec<u8>> {
131    let rgb = img.to_rgb8();
132    let mut buf = Vec::new();
133    let encoder = JpegEncoder::new_with_quality(&mut buf, quality);
134    encoder
135        .write_image(
136            rgb.as_raw(),
137            rgb.width(),
138            rgb.height(),
139            image::ExtendedColorType::Rgb8,
140        )
141        .context("failed to encode JPEG")?;
142    Ok(buf)
143}
144
145/// Rasterize SVG text to PNG at the given dimensions.
146pub fn rasterize_svg(svg_text: &str, target_width: u32, target_height: u32) -> Result<Vec<u8>> {
147    use resvg::tiny_skia;
148    use resvg::usvg;
149
150    let options = usvg::Options::default();
151    let tree = usvg::Tree::from_str(svg_text, &options).context("failed to parse SVG")?;
152
153    let size = tree.size();
154    let (svg_w, svg_h) = (size.width(), size.height());
155
156    // Calculate scale to fit within target dimensions
157    let scale_x = target_width as f32 / svg_w;
158    let scale_y = target_height as f32 / svg_h;
159    let scale = scale_x.min(scale_y);
160
161    let pixel_w = (svg_w * scale).ceil() as u32;
162    let pixel_h = (svg_h * scale).ceil() as u32;
163
164    // R9: Pixel budget for SVG rasterization (shared via SafetyLimits)
165    let limits = crate::mode::SafetyLimits::default();
166    let pixel_count = pixel_w as u64 * pixel_h as u64;
167    if pixel_count > limits.max_pixels {
168        anyhow::bail!(
169            "SVG rasterization blocked: {}x{} exceeds {:.0} megapixel safety limit",
170            pixel_w,
171            pixel_h,
172            limits.max_pixels as f64 / 1_000_000.0
173        );
174    }
175
176    let mut pixmap = tiny_skia::Pixmap::new(pixel_w.max(1), pixel_h.max(1))
177        .context("failed to create pixmap")?;
178
179    let transform = tiny_skia::Transform::from_scale(scale, scale);
180    resvg::render(&tree, transform, &mut pixmap.as_mut());
181
182    // Convert pixmap to PNG
183    let png_data = pixmap
184        .encode_png()
185        .context("failed to encode rasterized SVG as PNG")?;
186    Ok(png_data)
187}
188
189#[cfg(test)]
190mod tests {
191    use super::*;
192    use crate::inspector::{detect_format, MediaFormat};
193
194    fn make_test_png(width: u32, height: u32) -> Vec<u8> {
195        let img = image::RgbaImage::new(width, height);
196        let mut buf = Vec::new();
197        let encoder = PngEncoder::new(&mut buf);
198        encoder
199            .write_image(img.as_raw(), width, height, image::ExtendedColorType::Rgba8)
200            .unwrap();
201        buf
202    }
203
204    fn make_test_jpeg(width: u32, height: u32) -> Vec<u8> {
205        let img = image::RgbImage::new(width, height);
206        let mut buf = Vec::new();
207        let mut encoder = JpegEncoder::new_with_quality(&mut buf, 90);
208        encoder
209            .encode(img.as_raw(), width, height, image::ExtendedColorType::Rgb8)
210            .unwrap();
211        buf
212    }
213
214    #[test]
215    fn test_resize_png() {
216        let data = make_test_png(4000, 3000);
217        let action = Action::Resize {
218            target_width: 2048,
219            target_height: 2048,
220        };
221        let result = transform_image(&data, &action).unwrap();
222
223        // Must still be PNG after resize
224        assert_eq!(
225            detect_format(&result),
226            MediaFormat::Png,
227            "resized PNG should remain PNG"
228        );
229
230        // Verify it's still a valid image with correct dimensions
231        let img = image::load_from_memory(&result).unwrap();
232        assert!(img.width() <= 2048);
233        assert!(img.height() <= 2048);
234        // Verify aspect ratio preserved
235        let ratio_orig = 4000.0 / 3000.0;
236        let ratio_new = img.width() as f64 / img.height() as f64;
237        assert!((ratio_orig - ratio_new).abs() < 0.02);
238    }
239
240    #[test]
241    fn test_resize_preserves_jpeg_format() {
242        let data = make_test_jpeg(4000, 3000);
243        let action = Action::Resize {
244            target_width: 2048,
245            target_height: 2048,
246        };
247        let result = transform_image(&data, &action).unwrap();
248
249        // Must still be JPEG after resize — not converted to PNG
250        assert_eq!(
251            detect_format(&result),
252            MediaFormat::Jpeg,
253            "resized JPEG should remain JPEG, not be converted to PNG"
254        );
255
256        // Resized JPEG should be smaller than the original
257        assert!(
258            result.len() <= data.len(),
259            "resized JPEG ({} bytes) should not be larger than original ({} bytes)",
260            result.len(),
261            data.len()
262        );
263
264        // Verify dimensions are within the target
265        let img = image::load_from_memory(&result).unwrap();
266        assert!(img.width() <= 2048);
267        assert!(img.height() <= 2048);
268
269        // Verify aspect ratio is preserved
270        let ratio_orig = 4000.0 / 3000.0;
271        let ratio_new = img.width() as f64 / img.height() as f64;
272        assert!(
273            (ratio_orig - ratio_new).abs() < 0.02,
274            "aspect ratio should be preserved"
275        );
276    }
277
278    #[test]
279    fn test_recompress_jpeg() {
280        let data = make_test_jpeg(1000, 800);
281        let original_size = data.len();
282        let action = Action::Recompress { quality: 50 };
283        let result = transform_image(&data, &action).unwrap();
284
285        // Lower quality should produce a smaller file
286        assert!(result.len() <= original_size);
287        // Should still be valid JPEG
288        assert_eq!(detect_format(&result), MediaFormat::Jpeg);
289    }
290
291    #[test]
292    fn test_convert_bmp_to_png() {
293        // Create a test image and save as BMP bytes
294        let img = image::RgbImage::from_pixel(100, 100, image::Rgb([255, 0, 0]));
295        let mut bmp_data = Vec::new();
296        let mut cursor = std::io::Cursor::new(&mut bmp_data);
297        img.write_to(&mut cursor, image::ImageFormat::Bmp).unwrap();
298
299        let action = Action::ConvertFormat {
300            to: "png".to_string(),
301        };
302        let result = transform_image(&bmp_data, &action).unwrap();
303        assert_eq!(detect_format(&result), MediaFormat::Png);
304    }
305
306    #[test]
307    fn test_pass_action() {
308        let data = make_test_png(100, 100);
309        let action = Action::Pass;
310        let result = transform_image(&data, &action).unwrap();
311        assert_eq!(result, data);
312    }
313
314    #[test]
315    fn test_drop_action() {
316        let data = make_test_png(100, 100);
317        let action = Action::Drop {
318            reason: "test".into(),
319        };
320        let result = transform_image(&data, &action).unwrap();
321        assert!(result.is_empty());
322    }
323
324    #[test]
325    fn test_rasterize_svg_simple() {
326        let svg = r#"<svg xmlns="http://www.w3.org/2000/svg" width="200" height="100">
327            <rect width="200" height="100" fill="red"/>
328        </svg>"#;
329
330        let result = rasterize_svg(svg, 200, 100).unwrap();
331        assert!(!result.is_empty());
332        assert_eq!(detect_format(&result), MediaFormat::Png);
333
334        // Verify dimensions
335        let img = image::load_from_memory(&result).unwrap();
336        assert_eq!(img.width(), 200);
337        assert_eq!(img.height(), 100);
338    }
339
340    #[test]
341    fn test_rasterize_svg_scaled_down() {
342        let svg = r#"<svg xmlns="http://www.w3.org/2000/svg" width="2000" height="1000">
343            <circle cx="1000" cy="500" r="400" fill="blue"/>
344        </svg>"#;
345
346        let result = rasterize_svg(svg, 500, 500).unwrap();
347        let img = image::load_from_memory(&result).unwrap();
348        assert!(img.width() <= 500);
349        assert!(img.height() <= 500);
350    }
351
352    #[test]
353    fn test_rasterize_svg_with_viewbox() {
354        let svg = r#"<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 100 100">
355            <rect x="10" y="10" width="80" height="80" fill="green"/>
356        </svg>"#;
357
358        let result = rasterize_svg(svg, 256, 256).unwrap();
359        assert!(!result.is_empty());
360        assert_eq!(detect_format(&result), MediaFormat::Png);
361    }
362
363    #[test]
364    fn test_rasterize_svg_complex() {
365        let svg = r#"<?xml version="1.0" encoding="UTF-8"?>
366<svg xmlns="http://www.w3.org/2000/svg" width="400" height="300" viewBox="0 0 400 300">
367  <defs>
368    <linearGradient id="grad" x1="0%" y1="0%" x2="100%" y2="100%">
369      <stop offset="0%" style="stop-color:rgb(255,0,0);stop-opacity:1" />
370      <stop offset="100%" style="stop-color:rgb(0,0,255);stop-opacity:1" />
371    </linearGradient>
372  </defs>
373  <rect width="400" height="300" fill="url(#grad)"/>
374  <circle cx="200" cy="150" r="80" fill="white" opacity="0.5"/>
375  <text x="200" y="160" text-anchor="middle" font-size="24" fill="white">SHIFT</text>
376</svg>"#;
377
378        let result = rasterize_svg(svg, 800, 600).unwrap();
379        assert!(!result.is_empty());
380        let img = image::load_from_memory(&result).unwrap();
381        assert!(img.width() > 0);
382        assert!(img.height() > 0);
383    }
384
385    #[test]
386    fn test_transform_svg_via_action() {
387        let svg = r#"<svg xmlns="http://www.w3.org/2000/svg" width="100" height="100">
388            <rect width="100" height="100" fill="red"/>
389        </svg>"#;
390
391        let action = Action::RasterizeSvg {
392            target_width: 256,
393            target_height: 256,
394        };
395        let result = transform_image(svg.as_bytes(), &action).unwrap();
396        assert_eq!(detect_format(&result), MediaFormat::Png);
397    }
398}