Skip to main content

shift_preflight/
pipeline.rs

1//! Core SHIFT pipeline: inspect → policy → transform → reconstruct.
2
3use anyhow::{Context, Result};
4use serde_json::Value;
5
6use crate::cost::{estimate_tokens, ImageMetrics};
7use crate::inspector;
8use crate::inspector::MediaFormat;
9use crate::mode::{ShiftConfig, SvgMode};
10use crate::payload;
11use crate::policy;
12use crate::report::Report;
13use crate::transformer;
14
15/// Process a payload through the SHIFT pipeline.
16///
17/// Returns the transformed payload and a report of changes.
18pub fn process(payload: &Value, config: &ShiftConfig) -> Result<(Value, Report)> {
19    let mut report = Report::new();
20    report.dry_run = config.dry_run;
21
22    // Detect provider format if not specified
23    let provider_format = payload::detect_provider(payload);
24
25    // Fix #7: Load provider profile from config, not env var
26    let profile = if let Some(ref custom_path) = config.profile_path {
27        // R7: Validate the profile path more thoroughly
28        let path = std::path::Path::new(custom_path);
29
30        // Must have a .json extension
31        match path.extension().and_then(|e| e.to_str()) {
32            Some("json") => {}
33            _ => anyhow::bail!("profile path must have a .json extension"),
34        }
35
36        // Reject path traversal components
37        for component in path.components() {
38            if matches!(component, std::path::Component::ParentDir) {
39                anyhow::bail!("profile path must not contain '..' path traversal");
40            }
41        }
42
43        // Canonicalize to resolve symlinks, then verify the canonical path
44        // ends with .json (symlink to /etc/passwd would fail this)
45        if path.exists() {
46            let canonical = std::fs::canonicalize(path)
47                .with_context(|| "failed to resolve profile path".to_string())?;
48            match canonical.extension().and_then(|e| e.to_str()) {
49                Some("json") => {}
50                _ => anyhow::bail!(
51                    "profile path resolves to a non-JSON file (possible symlink attack)"
52                ),
53            }
54            policy::load_from_file(canonical.to_str().unwrap_or(custom_path))?
55        } else {
56            policy::load_from_file(custom_path)?
57        }
58    } else {
59        policy::load_builtin(&config.provider)?
60    };
61
62    // Get model-specific constraints
63    let model_name = config
64        .model
65        .as_deref()
66        .or_else(|| payload.get("model").and_then(|m| m.as_str()));
67    let constraints = profile.constraints_for(model_name);
68
69    // R8: Extract images with configured safety limits
70    let images = match provider_format {
71        Some("openai") => payload::openai::extract_images_with_limits(payload, &config.limits)?,
72        Some("anthropic") => {
73            payload::anthropic::extract_images_with_limits(payload, &config.limits)?
74        }
75        _ => {
76            // No images found or text-only payload — pass through
77            return Ok((payload.clone(), report));
78        }
79    };
80
81    if images.is_empty() {
82        return Ok((payload.clone(), report));
83    }
84
85    report.images_found = images.len();
86    // Fix #16: Track image byte sizes separately from JSON serialization
87    let original_image_bytes: usize = images.iter().map(|img| img.data.len()).sum();
88    report.original_size = original_image_bytes;
89
90    let total_images = images.len();
91    let mut transformed_images: Vec<(usize, Vec<u8>, String)> = Vec::new();
92
93    for extracted in &images {
94        // Fix #15: Inspect with skip-and-warn on individual failures
95        let meta = match inspector::image::inspect_bytes(&extracted.data) {
96            Ok(m) => m,
97            Err(e) => {
98                report.add_warning(&format!(
99                    "image {}: skipped ({})",
100                    extracted.global_index, e
101                ));
102                // R6: Use the original MIME type from the image reference,
103                // not a hardcoded "image/png" which would mislabel JPEG/WebP/GIF.
104                let original_mime = match &extracted.original_ref {
105                    payload::ImageRef::DataUri { mime_type, .. } => mime_type.clone(),
106                    payload::ImageRef::Base64 { media_type, .. } => media_type.clone(),
107                    payload::ImageRef::Url(_) => "application/octet-stream".to_string(),
108                };
109                let orig_bytes = extracted.data.len();
110                let format_short = mime_to_short(&original_mime);
111                // Record metrics for skipped images (0x0 = unknown dims)
112                report.add_image_metrics(ImageMetrics {
113                    image_index: extracted.global_index,
114                    original_width: 0,
115                    original_height: 0,
116                    transformed_width: 0,
117                    transformed_height: 0,
118                    original_bytes: orig_bytes,
119                    transformed_bytes: orig_bytes,
120                    format_before: format_short.clone(),
121                    format_after: format_short,
122                    tokens_before: estimate_tokens(0, 0),
123                    tokens_after: estimate_tokens(0, 0),
124                });
125                // Push original data through unchanged
126                transformed_images.push((
127                    extracted.global_index,
128                    extracted.data.clone(),
129                    original_mime,
130                ));
131                continue;
132            }
133        };
134
135        // Capture original dimensions for token estimation
136        let orig_w = meta.width;
137        let orig_h = meta.height;
138        let orig_bytes = extracted.data.len();
139        let format_before = meta.format.to_string();
140
141        // Evaluate policy
142        let actions = policy::evaluate(
143            &meta,
144            constraints,
145            config.mode,
146            extracted.global_index,
147            total_images,
148        );
149
150        // Handle SVG mode
151        if meta.format == MediaFormat::Svg {
152            let result = handle_svg(
153                &extracted.data,
154                &meta,
155                &actions,
156                config,
157                extracted.global_index,
158                &mut report,
159            )?;
160
161            // Record metrics for SVG
162            let (_, ref out_data, ref out_mime) = result;
163            let (tw, th) = if out_data.is_empty() {
164                // SVG was dropped (source mode)
165                (0, 0)
166            } else if config.dry_run {
167                // Dry-run: estimate target dims from policy actions so we
168                // can preview token savings without actually rasterizing.
169                estimate_dims_from_actions(&actions, orig_w, orig_h)
170            } else {
171                inspector::image::inspect_bytes(out_data)
172                    .map(|m| (m.width, m.height))
173                    .unwrap_or((orig_w, orig_h))
174            };
175            let format_after = if config.dry_run && !out_data.is_empty() {
176                // In dry-run the data is still SVG, but we'd produce PNG
177                "png".to_string()
178            } else {
179                mime_to_short(out_mime)
180            };
181            report.add_image_metrics(ImageMetrics {
182                image_index: extracted.global_index,
183                original_width: orig_w,
184                original_height: orig_h,
185                transformed_width: tw,
186                transformed_height: th,
187                original_bytes: orig_bytes,
188                transformed_bytes: out_data.len(),
189                format_before: format_before.clone(),
190                format_after,
191                tokens_before: estimate_tokens(orig_w, orig_h),
192                tokens_after: estimate_tokens(tw, th),
193            });
194
195            transformed_images.push(result);
196            continue;
197        }
198
199        // Apply transformations
200        let mut current_data = extracted.data.clone();
201        let mut was_modified = false;
202        let mut output_mime = meta.format.mime_type().to_string();
203        let mut was_dropped = false;
204
205        for action in &actions {
206            match action {
207                policy::Action::Pass => {}
208                policy::Action::Drop { reason } => {
209                    report.add_action(extracted.global_index, "drop", reason);
210                    report.images_dropped += 1;
211                    current_data = Vec::new();
212                    was_modified = true;
213                    was_dropped = true;
214                    break;
215                }
216                _ => {
217                    if !config.dry_run {
218                        let new_data = transformer::transform_image(&current_data, action)?;
219                        let detail = describe_action(action, &meta);
220                        report.add_action(extracted.global_index, action_name(action), &detail);
221                        current_data = new_data;
222                        was_modified = true;
223
224                        // Update mime type based on action
225                        match action {
226                            policy::Action::ConvertFormat { to } => {
227                                output_mime = format!("image/{}", to);
228                            }
229                            policy::Action::Resize { .. } => {
230                                output_mime = "image/png".to_string();
231                            }
232                            policy::Action::Recompress { .. } => {
233                                output_mime = "image/jpeg".to_string();
234                            }
235                            _ => {}
236                        }
237                    } else {
238                        let detail = describe_action(action, &meta);
239                        report.add_action(
240                            extracted.global_index,
241                            &format!("would_{}", action_name(action)),
242                            &detail,
243                        );
244                        was_modified = true;
245                    }
246                }
247            }
248        }
249
250        if was_modified {
251            report.images_modified += 1;
252        }
253
254        // Determine transformed dimensions
255        let (tw, th) = if was_dropped || current_data.is_empty() {
256            (0, 0)
257        } else if was_modified && !config.dry_run {
258            // Re-inspect transformed data to get actual dimensions
259            inspector::image::inspect_bytes(&current_data)
260                .map(|m| (m.width, m.height))
261                .unwrap_or((orig_w, orig_h))
262        } else {
263            // Dry-run or unchanged: estimate from policy actions
264            estimate_dims_from_actions(&actions, orig_w, orig_h)
265        };
266
267        let format_after = mime_to_short(&output_mime);
268        report.add_image_metrics(ImageMetrics {
269            image_index: extracted.global_index,
270            original_width: orig_w,
271            original_height: orig_h,
272            transformed_width: tw,
273            transformed_height: th,
274            original_bytes: orig_bytes,
275            transformed_bytes: current_data.len(),
276            format_before,
277            format_after,
278            tokens_before: estimate_tokens(orig_w, orig_h),
279            tokens_after: estimate_tokens(tw, th),
280        });
281
282        transformed_images.push((extracted.global_index, current_data, output_mime));
283    }
284
285    // Reconstruct the payload
286    let result = if config.dry_run {
287        payload.clone()
288    } else {
289        match provider_format {
290            Some("openai") => payload::openai::reconstruct(payload, &transformed_images)?,
291            Some("anthropic") => payload::anthropic::reconstruct(payload, &transformed_images)?,
292            _ => payload.clone(),
293        }
294    };
295
296    // Fix #16: Track transformed image byte sizes
297    let transformed_image_bytes: usize = transformed_images
298        .iter()
299        .map(|(_, data, _)| data.len())
300        .sum();
301    report.transformed_size = transformed_image_bytes;
302
303    // Finalize aggregate token savings from per-image metrics
304    report.finalize_token_savings();
305
306    Ok((result, report))
307}
308
309/// Extract a short format name from a MIME type (e.g. "image/png" -> "png").
310fn mime_to_short(mime: &str) -> String {
311    mime.strip_prefix("image/").unwrap_or(mime).to_string()
312}
313
314/// Estimate target dimensions from policy actions (for dry-run reporting).
315fn estimate_dims_from_actions(actions: &[policy::Action], orig_w: u32, orig_h: u32) -> (u32, u32) {
316    for action in actions {
317        match action {
318            policy::Action::Resize {
319                target_width,
320                target_height,
321            } => return (*target_width, *target_height),
322            policy::Action::RasterizeSvg {
323                target_width,
324                target_height,
325            } => return (*target_width, *target_height),
326            policy::Action::Drop { .. } => return (0, 0),
327            _ => {}
328        }
329    }
330    (orig_w, orig_h)
331}
332
333/// Handle SVG images according to the configured SvgMode.
334fn handle_svg(
335    data: &[u8],
336    meta: &inspector::ImageMetadata,
337    actions: &[policy::Action],
338    config: &ShiftConfig,
339    global_index: usize,
340    report: &mut Report,
341) -> Result<(usize, Vec<u8>, String)> {
342    match config.svg_mode {
343        SvgMode::Raster => {
344            // Rasterize SVG to PNG
345            if config.dry_run {
346                let detail = format!("would rasterize {}x{} SVG to PNG", meta.width, meta.height);
347                report.add_action(global_index, "would_rasterize_svg", &detail);
348                report.images_modified += 1;
349                return Ok((global_index, data.to_vec(), "image/svg+xml".to_string()));
350            }
351
352            // Find the rasterize action to get target dims
353            let (tw, th) = actions
354                .iter()
355                .find_map(|a| match a {
356                    policy::Action::RasterizeSvg {
357                        target_width,
358                        target_height,
359                    } => Some((*target_width, *target_height)),
360                    _ => None,
361                })
362                .unwrap_or((meta.width.max(256), meta.height.max(256)));
363
364            let svg_text = std::str::from_utf8(data).context("SVG is not valid UTF-8")?;
365            let png_data = transformer::rasterize_svg(svg_text, tw, th)?;
366
367            report.add_action(
368                global_index,
369                "rasterize_svg",
370                &format!(
371                    "SVG ({}x{}) -> PNG ({}x{})",
372                    meta.width, meta.height, tw, th
373                ),
374            );
375            report.svgs_rasterized += 1;
376            report.images_modified += 1;
377
378            Ok((global_index, png_data, "image/png".to_string()))
379        }
380
381        SvgMode::Source => {
382            // Fix #5: SVG Source mode drops the image and records it as dropped.
383            // The image block is removed from the payload. In the future, we could
384            // inject the SVG XML as a text content block, but for now we drop + warn.
385            report.add_action(
386                global_index,
387                "svg_dropped_as_source",
388                &format!(
389                    "SVG ({}x{}) removed (source mode: SVG not supported by provider)",
390                    meta.width, meta.height
391                ),
392            );
393            report.images_dropped += 1;
394            report.add_warning(
395                "SVG source mode dropped an image. Consider --svg-mode raster for provider compatibility.",
396            );
397
398            Ok((global_index, Vec::new(), "text/plain".to_string()))
399        }
400
401        SvgMode::Hybrid => {
402            // Rasterize but the caller could also add SVG source as text
403            if config.dry_run {
404                report.add_action(
405                    global_index,
406                    "would_rasterize_svg_hybrid",
407                    &format!(
408                        "would rasterize {}x{} SVG (hybrid mode)",
409                        meta.width, meta.height
410                    ),
411                );
412                report.images_modified += 1;
413                return Ok((global_index, data.to_vec(), "image/svg+xml".to_string()));
414            }
415
416            let (tw, th) = actions
417                .iter()
418                .find_map(|a| match a {
419                    policy::Action::RasterizeSvg {
420                        target_width,
421                        target_height,
422                    } => Some((*target_width, *target_height)),
423                    _ => None,
424                })
425                .unwrap_or((meta.width.max(256), meta.height.max(256)));
426
427            let svg_text = std::str::from_utf8(data).context("SVG is not valid UTF-8")?;
428            let png_data = transformer::rasterize_svg(svg_text, tw, th)?;
429
430            report.add_action(
431                global_index,
432                "rasterize_svg_hybrid",
433                &format!(
434                    "SVG ({}x{}) -> PNG ({}x{}) + source retained",
435                    meta.width, meta.height, tw, th
436                ),
437            );
438            report.svgs_rasterized += 1;
439            report.images_modified += 1;
440
441            Ok((global_index, png_data, "image/png".to_string()))
442        }
443    }
444}
445
446fn action_name(action: &policy::Action) -> &'static str {
447    match action {
448        policy::Action::Pass => "pass",
449        policy::Action::Resize { .. } => "resize",
450        policy::Action::Recompress { .. } => "recompress",
451        policy::Action::ConvertFormat { .. } => "convert",
452        policy::Action::RasterizeSvg { .. } => "rasterize_svg",
453        policy::Action::Drop { .. } => "drop",
454    }
455}
456
457fn describe_action(action: &policy::Action, meta: &inspector::ImageMetadata) -> String {
458    match action {
459        policy::Action::Pass => "no changes needed".to_string(),
460        policy::Action::Resize {
461            target_width,
462            target_height,
463        } => format!(
464            "{}x{} -> {}x{}",
465            meta.width, meta.height, target_width, target_height
466        ),
467        policy::Action::Recompress { quality } => {
468            format!("recompress at quality {}", quality)
469        }
470        policy::Action::ConvertFormat { to } => {
471            format!("{} -> {}", meta.format, to)
472        }
473        policy::Action::RasterizeSvg {
474            target_width,
475            target_height,
476        } => format!("SVG -> PNG at {}x{}", target_width, target_height),
477        policy::Action::Drop { reason } => reason.clone(),
478    }
479}
480
481#[cfg(test)]
482mod tests {
483    use super::*;
484    use crate::mode::DriveMode;
485    use serde_json::json;
486
487    fn make_png_data_uri(width: u32, height: u32) -> String {
488        use base64::Engine;
489        let img = image::RgbaImage::new(width, height);
490        let mut buf = Vec::new();
491        let encoder = image::codecs::png::PngEncoder::new(&mut buf);
492        image::ImageEncoder::write_image(
493            encoder,
494            img.as_raw(),
495            width,
496            height,
497            image::ExtendedColorType::Rgba8,
498        )
499        .unwrap();
500        let b64 = base64::engine::general_purpose::STANDARD.encode(&buf);
501        format!("data:image/png;base64,{}", b64)
502    }
503
504    fn make_anthropic_png_base64(width: u32, height: u32) -> String {
505        use base64::Engine;
506        let img = image::RgbaImage::new(width, height);
507        let mut buf = Vec::new();
508        let encoder = image::codecs::png::PngEncoder::new(&mut buf);
509        image::ImageEncoder::write_image(
510            encoder,
511            img.as_raw(),
512            width,
513            height,
514            image::ExtendedColorType::Rgba8,
515        )
516        .unwrap();
517        base64::engine::general_purpose::STANDARD.encode(&buf)
518    }
519
520    #[test]
521    fn test_text_only_passthrough() {
522        let payload = json!({
523            "model": "gpt-4o",
524            "messages": [{"role": "user", "content": "Hello"}]
525        });
526        let config = ShiftConfig::default();
527        let (result, report) = process(&payload, &config).unwrap();
528        assert_eq!(result, payload);
529        assert_eq!(report.images_found, 0);
530        assert!(!report.has_changes());
531    }
532
533    #[test]
534    fn test_small_image_passthrough() {
535        let data_uri = make_png_data_uri(640, 480);
536        let payload = json!({
537            "model": "gpt-4o",
538            "messages": [{
539                "role": "user",
540                "content": [
541                    {"type": "text", "text": "What's this?"},
542                    {"type": "image_url", "image_url": {"url": data_uri}}
543                ]
544            }]
545        });
546        let config = ShiftConfig::default();
547        let (_result, report) = process(&payload, &config).unwrap();
548        assert_eq!(report.images_found, 1);
549    }
550
551    #[test]
552    fn test_oversized_image_resized_openai() {
553        let data_uri = make_png_data_uri(4000, 3000);
554        let payload = json!({
555            "model": "gpt-4o",
556            "messages": [{
557                "role": "user",
558                "content": [
559                    {"type": "image_url", "image_url": {"url": data_uri}}
560                ]
561            }]
562        });
563        let config = ShiftConfig {
564            provider: "openai".to_string(),
565            mode: DriveMode::Balanced,
566            ..Default::default()
567        };
568        let (_result, report) = process(&payload, &config).unwrap();
569        assert_eq!(report.images_found, 1);
570        assert!(report.has_changes());
571        assert!(report.actions.iter().any(|a| a.action == "resize"));
572    }
573
574    #[test]
575    fn test_oversized_image_resized_anthropic() {
576        let b64 = make_anthropic_png_base64(4000, 3000);
577        let payload = json!({
578            "model": "claude-sonnet-4-20250514",
579            "messages": [{
580                "role": "user",
581                "content": [{
582                    "type": "image",
583                    "source": {"type": "base64", "media_type": "image/png", "data": b64}
584                }]
585            }]
586        });
587        let config = ShiftConfig {
588            provider: "anthropic".to_string(),
589            mode: DriveMode::Balanced,
590            ..Default::default()
591        };
592        let (_result, report) = process(&payload, &config).unwrap();
593        assert_eq!(report.images_found, 1);
594        assert!(report.has_changes());
595    }
596
597    #[test]
598    fn test_dry_run_no_modifications() {
599        let data_uri = make_png_data_uri(4000, 3000);
600        let payload = json!({
601            "model": "gpt-4o",
602            "messages": [{
603                "role": "user",
604                "content": [
605                    {"type": "image_url", "image_url": {"url": data_uri.clone()}}
606                ]
607            }]
608        });
609        let config = ShiftConfig {
610            dry_run: true,
611            ..Default::default()
612        };
613        let (result, report) = process(&payload, &config).unwrap();
614        // Dry run should not modify the payload
615        assert_eq!(result, payload);
616        // But should report what would happen
617        assert!(report.has_changes());
618        assert!(report.dry_run);
619        assert!(report
620            .actions
621            .iter()
622            .any(|a| a.action.starts_with("would_")));
623    }
624
625    #[test]
626    fn test_svg_rasterization_in_openai_payload() {
627        use base64::Engine;
628        let svg = r#"<svg xmlns="http://www.w3.org/2000/svg" width="200" height="100"><rect width="200" height="100" fill="red"/></svg>"#;
629        let b64 = base64::engine::general_purpose::STANDARD.encode(svg.as_bytes());
630        let data_uri = format!("data:image/svg+xml;base64,{}", b64);
631
632        let payload = json!({
633            "model": "gpt-4o",
634            "messages": [{
635                "role": "user",
636                "content": [
637                    {"type": "image_url", "image_url": {"url": data_uri}}
638                ]
639            }]
640        });
641        let config = ShiftConfig {
642            svg_mode: SvgMode::Raster,
643            ..Default::default()
644        };
645        let (_result, report) = process(&payload, &config).unwrap();
646        assert_eq!(report.svgs_rasterized, 1);
647        assert!(report.actions.iter().any(|a| a.action == "rasterize_svg"));
648    }
649
650    #[test]
651    fn test_economy_mode_aggressive() {
652        // 1500px image — within OpenAI limits but economy mode will downscale
653        let data_uri = make_png_data_uri(1500, 1000);
654        let payload = json!({
655            "model": "gpt-4o",
656            "messages": [{
657                "role": "user",
658                "content": [
659                    {"type": "image_url", "image_url": {"url": data_uri}}
660                ]
661            }]
662        });
663        let config = ShiftConfig {
664            mode: DriveMode::Economy,
665            ..Default::default()
666        };
667        let (_result, report) = process(&payload, &config).unwrap();
668        assert!(report.has_changes());
669    }
670
671    #[test]
672    fn test_performance_mode_minimal() {
673        // 1500px image — within limits, performance mode should pass
674        let data_uri = make_png_data_uri(1500, 1000);
675        let payload = json!({
676            "model": "gpt-4o",
677            "messages": [{
678                "role": "user",
679                "content": [
680                    {"type": "image_url", "image_url": {"url": data_uri}}
681                ]
682            }]
683        });
684        let config = ShiftConfig {
685            mode: DriveMode::Performance,
686            ..Default::default()
687        };
688        let (_result, report) = process(&payload, &config).unwrap();
689        // Performance mode should not modify images within limits
690        assert!(!report.has_changes() || report.images_modified == 0);
691    }
692}