Skip to main content

shift_preflight/
pipeline.rs

1//! Core SHIFT pipeline: inspect → policy → transform → reconstruct.
2
3use anyhow::{Context, Result};
4use serde_json::Value;
5
6use crate::cost::{estimate_tokens, ImageMetrics};
7use crate::inspector;
8use crate::inspector::MediaFormat;
9use crate::mode::{ShiftConfig, SvgMode};
10use crate::payload;
11use crate::policy;
12use crate::report::Report;
13use crate::transformer;
14
15/// Process a payload through the SHIFT pipeline.
16///
17/// Returns the transformed payload and a report of changes.
18pub fn process(payload: &Value, config: &ShiftConfig) -> Result<(Value, Report)> {
19    let mut report = Report::new();
20    report.dry_run = config.dry_run;
21
22    // Detect provider format if not specified
23    let provider_format = payload::detect_provider(payload);
24
25    // Fix #7: Load provider profile from config, not env var
26    let profile = if let Some(ref custom_path) = config.profile_path {
27        // R7: Validate the profile path more thoroughly
28        let path = std::path::Path::new(custom_path);
29
30        // Must have a .json extension
31        match path.extension().and_then(|e| e.to_str()) {
32            Some("json") => {}
33            _ => anyhow::bail!("profile path must have a .json extension"),
34        }
35
36        // Reject path traversal components
37        for component in path.components() {
38            if matches!(component, std::path::Component::ParentDir) {
39                anyhow::bail!("profile path must not contain '..' path traversal");
40            }
41        }
42
43        // Canonicalize to resolve symlinks, then verify the canonical path
44        // ends with .json (symlink to /etc/passwd would fail this)
45        if path.exists() {
46            let canonical = std::fs::canonicalize(path)
47                .with_context(|| "failed to resolve profile path".to_string())?;
48            match canonical.extension().and_then(|e| e.to_str()) {
49                Some("json") => {}
50                _ => anyhow::bail!(
51                    "profile path resolves to a non-JSON file (possible symlink attack)"
52                ),
53            }
54            policy::load_from_file(canonical.to_str().unwrap_or(custom_path))?
55        } else {
56            policy::load_from_file(custom_path)?
57        }
58    } else {
59        policy::load_builtin(&config.provider)?
60    };
61
62    // Get model-specific constraints
63    let model_name = config
64        .model
65        .as_deref()
66        .or_else(|| payload.get("model").and_then(|m| m.as_str()));
67    let constraints = profile.constraints_for(model_name);
68
69    // R8: Extract images with configured safety limits
70    let images = match provider_format {
71        Some("openai") => payload::openai::extract_images_with_limits(payload, &config.limits)?,
72        Some("anthropic") => {
73            payload::anthropic::extract_images_with_limits(payload, &config.limits)?
74        }
75        _ => {
76            // No images found or text-only payload — pass through
77            return Ok((payload.clone(), report));
78        }
79    };
80
81    if images.is_empty() {
82        return Ok((payload.clone(), report));
83    }
84
85    report.images_found = images.len();
86    // Fix #16: Track image byte sizes separately from JSON serialization
87    let original_image_bytes: usize = images.iter().map(|img| img.data.len()).sum();
88    report.original_size = original_image_bytes;
89
90    let total_images = images.len();
91    let mut transformed_images: Vec<(usize, Vec<u8>, String)> = Vec::new();
92
93    for extracted in &images {
94        // Fix #15: Inspect with skip-and-warn on individual failures
95        let meta = match inspector::image::inspect_bytes(&extracted.data) {
96            Ok(m) => m,
97            Err(e) => {
98                report.add_warning(&format!(
99                    "image {}: skipped ({})",
100                    extracted.global_index, e
101                ));
102                // R6: Use the original MIME type from the image reference,
103                // not a hardcoded "image/png" which would mislabel JPEG/WebP/GIF.
104                let original_mime = match &extracted.original_ref {
105                    payload::ImageRef::DataUri { mime_type, .. } => mime_type.clone(),
106                    payload::ImageRef::Base64 { media_type, .. } => media_type.clone(),
107                    payload::ImageRef::Url(_) => "application/octet-stream".to_string(),
108                };
109                let orig_bytes = extracted.data.len();
110                let format_short = mime_to_short(&original_mime);
111                // Record metrics for skipped images (0x0 = unknown dims)
112                report.add_image_metrics(ImageMetrics {
113                    image_index: extracted.global_index,
114                    original_width: 0,
115                    original_height: 0,
116                    transformed_width: 0,
117                    transformed_height: 0,
118                    original_bytes: orig_bytes,
119                    transformed_bytes: orig_bytes,
120                    format_before: format_short.clone(),
121                    format_after: format_short,
122                    tokens_before: estimate_tokens(0, 0),
123                    tokens_after: estimate_tokens(0, 0),
124                });
125                // Push original data through unchanged
126                transformed_images.push((
127                    extracted.global_index,
128                    extracted.data.clone(),
129                    original_mime,
130                ));
131                continue;
132            }
133        };
134
135        // Capture original dimensions for token estimation
136        let orig_w = meta.width;
137        let orig_h = meta.height;
138        let orig_bytes = extracted.data.len();
139        let format_before = meta.format.to_string();
140
141        // Evaluate policy
142        let actions = policy::evaluate(
143            &meta,
144            constraints,
145            config.mode,
146            extracted.global_index,
147            total_images,
148        );
149
150        // Handle SVG mode
151        if meta.format == MediaFormat::Svg {
152            let result = handle_svg(
153                &extracted.data,
154                &meta,
155                &actions,
156                config,
157                extracted.global_index,
158                &mut report,
159            )?;
160
161            // Record metrics for SVG
162            let (_, ref out_data, ref out_mime) = result;
163            let (tw, th) = if out_data.is_empty() {
164                (0, 0)
165            } else {
166                inspector::image::inspect_bytes(out_data)
167                    .map(|m| (m.width, m.height))
168                    .unwrap_or((orig_w, orig_h))
169            };
170            let format_after = mime_to_short(out_mime);
171            report.add_image_metrics(ImageMetrics {
172                image_index: extracted.global_index,
173                original_width: orig_w,
174                original_height: orig_h,
175                transformed_width: tw,
176                transformed_height: th,
177                original_bytes: orig_bytes,
178                transformed_bytes: out_data.len(),
179                format_before: format_before.clone(),
180                format_after,
181                tokens_before: estimate_tokens(orig_w, orig_h),
182                tokens_after: estimate_tokens(tw, th),
183            });
184
185            transformed_images.push(result);
186            continue;
187        }
188
189        // Apply transformations
190        let mut current_data = extracted.data.clone();
191        let mut was_modified = false;
192        let mut output_mime = meta.format.mime_type().to_string();
193        let mut was_dropped = false;
194
195        for action in &actions {
196            match action {
197                policy::Action::Pass => {}
198                policy::Action::Drop { reason } => {
199                    report.add_action(extracted.global_index, "drop", reason);
200                    report.images_dropped += 1;
201                    current_data = Vec::new();
202                    was_modified = true;
203                    was_dropped = true;
204                    break;
205                }
206                _ => {
207                    if !config.dry_run {
208                        let new_data = transformer::transform_image(&current_data, action)?;
209                        let detail = describe_action(action, &meta);
210                        report.add_action(extracted.global_index, action_name(action), &detail);
211                        current_data = new_data;
212                        was_modified = true;
213
214                        // Update mime type based on action
215                        match action {
216                            policy::Action::ConvertFormat { to } => {
217                                output_mime = format!("image/{}", to);
218                            }
219                            policy::Action::Resize { .. } => {
220                                output_mime = "image/png".to_string();
221                            }
222                            policy::Action::Recompress { .. } => {
223                                output_mime = "image/jpeg".to_string();
224                            }
225                            _ => {}
226                        }
227                    } else {
228                        let detail = describe_action(action, &meta);
229                        report.add_action(
230                            extracted.global_index,
231                            &format!("would_{}", action_name(action)),
232                            &detail,
233                        );
234                        was_modified = true;
235                    }
236                }
237            }
238        }
239
240        if was_modified {
241            report.images_modified += 1;
242        }
243
244        // Determine transformed dimensions
245        let (tw, th) = if was_dropped || current_data.is_empty() {
246            (0, 0)
247        } else if was_modified && !config.dry_run {
248            // Re-inspect transformed data to get actual dimensions
249            inspector::image::inspect_bytes(&current_data)
250                .map(|m| (m.width, m.height))
251                .unwrap_or((orig_w, orig_h))
252        } else {
253            // Dry-run or unchanged: estimate from policy actions
254            estimate_dims_from_actions(&actions, orig_w, orig_h)
255        };
256
257        let format_after = mime_to_short(&output_mime);
258        report.add_image_metrics(ImageMetrics {
259            image_index: extracted.global_index,
260            original_width: orig_w,
261            original_height: orig_h,
262            transformed_width: tw,
263            transformed_height: th,
264            original_bytes: orig_bytes,
265            transformed_bytes: current_data.len(),
266            format_before,
267            format_after,
268            tokens_before: estimate_tokens(orig_w, orig_h),
269            tokens_after: estimate_tokens(tw, th),
270        });
271
272        transformed_images.push((extracted.global_index, current_data, output_mime));
273    }
274
275    // Reconstruct the payload
276    let result = if config.dry_run {
277        payload.clone()
278    } else {
279        match provider_format {
280            Some("openai") => payload::openai::reconstruct(payload, &transformed_images)?,
281            Some("anthropic") => payload::anthropic::reconstruct(payload, &transformed_images)?,
282            _ => payload.clone(),
283        }
284    };
285
286    // Fix #16: Track transformed image byte sizes
287    let transformed_image_bytes: usize = transformed_images
288        .iter()
289        .map(|(_, data, _)| data.len())
290        .sum();
291    report.transformed_size = transformed_image_bytes;
292
293    // Finalize aggregate token savings from per-image metrics
294    report.finalize_token_savings();
295
296    Ok((result, report))
297}
298
299/// Extract a short format name from a MIME type (e.g. "image/png" -> "png").
300fn mime_to_short(mime: &str) -> String {
301    mime.strip_prefix("image/").unwrap_or(mime).to_string()
302}
303
304/// Estimate target dimensions from policy actions (for dry-run reporting).
305fn estimate_dims_from_actions(actions: &[policy::Action], orig_w: u32, orig_h: u32) -> (u32, u32) {
306    for action in actions {
307        match action {
308            policy::Action::Resize {
309                target_width,
310                target_height,
311            } => return (*target_width, *target_height),
312            policy::Action::RasterizeSvg {
313                target_width,
314                target_height,
315            } => return (*target_width, *target_height),
316            policy::Action::Drop { .. } => return (0, 0),
317            _ => {}
318        }
319    }
320    (orig_w, orig_h)
321}
322
323/// Handle SVG images according to the configured SvgMode.
324fn handle_svg(
325    data: &[u8],
326    meta: &inspector::ImageMetadata,
327    actions: &[policy::Action],
328    config: &ShiftConfig,
329    global_index: usize,
330    report: &mut Report,
331) -> Result<(usize, Vec<u8>, String)> {
332    match config.svg_mode {
333        SvgMode::Raster => {
334            // Rasterize SVG to PNG
335            if config.dry_run {
336                let detail = format!("would rasterize {}x{} SVG to PNG", meta.width, meta.height);
337                report.add_action(global_index, "would_rasterize_svg", &detail);
338                report.images_modified += 1;
339                return Ok((global_index, data.to_vec(), "image/svg+xml".to_string()));
340            }
341
342            // Find the rasterize action to get target dims
343            let (tw, th) = actions
344                .iter()
345                .find_map(|a| match a {
346                    policy::Action::RasterizeSvg {
347                        target_width,
348                        target_height,
349                    } => Some((*target_width, *target_height)),
350                    _ => None,
351                })
352                .unwrap_or((meta.width.max(256), meta.height.max(256)));
353
354            let svg_text = std::str::from_utf8(data).context("SVG is not valid UTF-8")?;
355            let png_data = transformer::rasterize_svg(svg_text, tw, th)?;
356
357            report.add_action(
358                global_index,
359                "rasterize_svg",
360                &format!(
361                    "SVG ({}x{}) -> PNG ({}x{})",
362                    meta.width, meta.height, tw, th
363                ),
364            );
365            report.svgs_rasterized += 1;
366            report.images_modified += 1;
367
368            Ok((global_index, png_data, "image/png".to_string()))
369        }
370
371        SvgMode::Source => {
372            // Fix #5: SVG Source mode drops the image and records it as dropped.
373            // The image block is removed from the payload. In the future, we could
374            // inject the SVG XML as a text content block, but for now we drop + warn.
375            report.add_action(
376                global_index,
377                "svg_dropped_as_source",
378                &format!(
379                    "SVG ({}x{}) removed (source mode: SVG not supported by provider)",
380                    meta.width, meta.height
381                ),
382            );
383            report.images_dropped += 1;
384            report.add_warning(
385                "SVG source mode dropped an image. Consider --svg-mode raster for provider compatibility.",
386            );
387
388            Ok((global_index, Vec::new(), "text/plain".to_string()))
389        }
390
391        SvgMode::Hybrid => {
392            // Rasterize but the caller could also add SVG source as text
393            if config.dry_run {
394                report.add_action(
395                    global_index,
396                    "would_rasterize_svg_hybrid",
397                    &format!(
398                        "would rasterize {}x{} SVG (hybrid mode)",
399                        meta.width, meta.height
400                    ),
401                );
402                report.images_modified += 1;
403                return Ok((global_index, data.to_vec(), "image/svg+xml".to_string()));
404            }
405
406            let (tw, th) = actions
407                .iter()
408                .find_map(|a| match a {
409                    policy::Action::RasterizeSvg {
410                        target_width,
411                        target_height,
412                    } => Some((*target_width, *target_height)),
413                    _ => None,
414                })
415                .unwrap_or((meta.width.max(256), meta.height.max(256)));
416
417            let svg_text = std::str::from_utf8(data).context("SVG is not valid UTF-8")?;
418            let png_data = transformer::rasterize_svg(svg_text, tw, th)?;
419
420            report.add_action(
421                global_index,
422                "rasterize_svg_hybrid",
423                &format!(
424                    "SVG ({}x{}) -> PNG ({}x{}) + source retained",
425                    meta.width, meta.height, tw, th
426                ),
427            );
428            report.svgs_rasterized += 1;
429            report.images_modified += 1;
430
431            Ok((global_index, png_data, "image/png".to_string()))
432        }
433    }
434}
435
436fn action_name(action: &policy::Action) -> &'static str {
437    match action {
438        policy::Action::Pass => "pass",
439        policy::Action::Resize { .. } => "resize",
440        policy::Action::Recompress { .. } => "recompress",
441        policy::Action::ConvertFormat { .. } => "convert",
442        policy::Action::RasterizeSvg { .. } => "rasterize_svg",
443        policy::Action::Drop { .. } => "drop",
444    }
445}
446
447fn describe_action(action: &policy::Action, meta: &inspector::ImageMetadata) -> String {
448    match action {
449        policy::Action::Pass => "no changes needed".to_string(),
450        policy::Action::Resize {
451            target_width,
452            target_height,
453        } => format!(
454            "{}x{} -> {}x{}",
455            meta.width, meta.height, target_width, target_height
456        ),
457        policy::Action::Recompress { quality } => {
458            format!("recompress at quality {}", quality)
459        }
460        policy::Action::ConvertFormat { to } => {
461            format!("{} -> {}", meta.format, to)
462        }
463        policy::Action::RasterizeSvg {
464            target_width,
465            target_height,
466        } => format!("SVG -> PNG at {}x{}", target_width, target_height),
467        policy::Action::Drop { reason } => reason.clone(),
468    }
469}
470
471#[cfg(test)]
472mod tests {
473    use super::*;
474    use crate::mode::DriveMode;
475    use serde_json::json;
476
477    fn make_png_data_uri(width: u32, height: u32) -> String {
478        use base64::Engine;
479        let img = image::RgbaImage::new(width, height);
480        let mut buf = Vec::new();
481        let encoder = image::codecs::png::PngEncoder::new(&mut buf);
482        image::ImageEncoder::write_image(
483            encoder,
484            img.as_raw(),
485            width,
486            height,
487            image::ExtendedColorType::Rgba8,
488        )
489        .unwrap();
490        let b64 = base64::engine::general_purpose::STANDARD.encode(&buf);
491        format!("data:image/png;base64,{}", b64)
492    }
493
494    fn make_anthropic_png_base64(width: u32, height: u32) -> String {
495        use base64::Engine;
496        let img = image::RgbaImage::new(width, height);
497        let mut buf = Vec::new();
498        let encoder = image::codecs::png::PngEncoder::new(&mut buf);
499        image::ImageEncoder::write_image(
500            encoder,
501            img.as_raw(),
502            width,
503            height,
504            image::ExtendedColorType::Rgba8,
505        )
506        .unwrap();
507        base64::engine::general_purpose::STANDARD.encode(&buf)
508    }
509
510    #[test]
511    fn test_text_only_passthrough() {
512        let payload = json!({
513            "model": "gpt-4o",
514            "messages": [{"role": "user", "content": "Hello"}]
515        });
516        let config = ShiftConfig::default();
517        let (result, report) = process(&payload, &config).unwrap();
518        assert_eq!(result, payload);
519        assert_eq!(report.images_found, 0);
520        assert!(!report.has_changes());
521    }
522
523    #[test]
524    fn test_small_image_passthrough() {
525        let data_uri = make_png_data_uri(640, 480);
526        let payload = json!({
527            "model": "gpt-4o",
528            "messages": [{
529                "role": "user",
530                "content": [
531                    {"type": "text", "text": "What's this?"},
532                    {"type": "image_url", "image_url": {"url": data_uri}}
533                ]
534            }]
535        });
536        let config = ShiftConfig::default();
537        let (_result, report) = process(&payload, &config).unwrap();
538        assert_eq!(report.images_found, 1);
539    }
540
541    #[test]
542    fn test_oversized_image_resized_openai() {
543        let data_uri = make_png_data_uri(4000, 3000);
544        let payload = json!({
545            "model": "gpt-4o",
546            "messages": [{
547                "role": "user",
548                "content": [
549                    {"type": "image_url", "image_url": {"url": data_uri}}
550                ]
551            }]
552        });
553        let config = ShiftConfig {
554            provider: "openai".to_string(),
555            mode: DriveMode::Balanced,
556            ..Default::default()
557        };
558        let (_result, report) = process(&payload, &config).unwrap();
559        assert_eq!(report.images_found, 1);
560        assert!(report.has_changes());
561        assert!(report.actions.iter().any(|a| a.action == "resize"));
562    }
563
564    #[test]
565    fn test_oversized_image_resized_anthropic() {
566        let b64 = make_anthropic_png_base64(4000, 3000);
567        let payload = json!({
568            "model": "claude-sonnet-4-20250514",
569            "messages": [{
570                "role": "user",
571                "content": [{
572                    "type": "image",
573                    "source": {"type": "base64", "media_type": "image/png", "data": b64}
574                }]
575            }]
576        });
577        let config = ShiftConfig {
578            provider: "anthropic".to_string(),
579            mode: DriveMode::Balanced,
580            ..Default::default()
581        };
582        let (_result, report) = process(&payload, &config).unwrap();
583        assert_eq!(report.images_found, 1);
584        assert!(report.has_changes());
585    }
586
587    #[test]
588    fn test_dry_run_no_modifications() {
589        let data_uri = make_png_data_uri(4000, 3000);
590        let payload = json!({
591            "model": "gpt-4o",
592            "messages": [{
593                "role": "user",
594                "content": [
595                    {"type": "image_url", "image_url": {"url": data_uri.clone()}}
596                ]
597            }]
598        });
599        let config = ShiftConfig {
600            dry_run: true,
601            ..Default::default()
602        };
603        let (result, report) = process(&payload, &config).unwrap();
604        // Dry run should not modify the payload
605        assert_eq!(result, payload);
606        // But should report what would happen
607        assert!(report.has_changes());
608        assert!(report.dry_run);
609        assert!(report
610            .actions
611            .iter()
612            .any(|a| a.action.starts_with("would_")));
613    }
614
615    #[test]
616    fn test_svg_rasterization_in_openai_payload() {
617        use base64::Engine;
618        let svg = r#"<svg xmlns="http://www.w3.org/2000/svg" width="200" height="100"><rect width="200" height="100" fill="red"/></svg>"#;
619        let b64 = base64::engine::general_purpose::STANDARD.encode(svg.as_bytes());
620        let data_uri = format!("data:image/svg+xml;base64,{}", b64);
621
622        let payload = json!({
623            "model": "gpt-4o",
624            "messages": [{
625                "role": "user",
626                "content": [
627                    {"type": "image_url", "image_url": {"url": data_uri}}
628                ]
629            }]
630        });
631        let config = ShiftConfig {
632            svg_mode: SvgMode::Raster,
633            ..Default::default()
634        };
635        let (_result, report) = process(&payload, &config).unwrap();
636        assert_eq!(report.svgs_rasterized, 1);
637        assert!(report.actions.iter().any(|a| a.action == "rasterize_svg"));
638    }
639
640    #[test]
641    fn test_economy_mode_aggressive() {
642        // 1500px image — within OpenAI limits but economy mode will downscale
643        let data_uri = make_png_data_uri(1500, 1000);
644        let payload = json!({
645            "model": "gpt-4o",
646            "messages": [{
647                "role": "user",
648                "content": [
649                    {"type": "image_url", "image_url": {"url": data_uri}}
650                ]
651            }]
652        });
653        let config = ShiftConfig {
654            mode: DriveMode::Economy,
655            ..Default::default()
656        };
657        let (_result, report) = process(&payload, &config).unwrap();
658        assert!(report.has_changes());
659    }
660
661    #[test]
662    fn test_performance_mode_minimal() {
663        // 1500px image — within limits, performance mode should pass
664        let data_uri = make_png_data_uri(1500, 1000);
665        let payload = json!({
666            "model": "gpt-4o",
667            "messages": [{
668                "role": "user",
669                "content": [
670                    {"type": "image_url", "image_url": {"url": data_uri}}
671                ]
672            }]
673        });
674        let config = ShiftConfig {
675            mode: DriveMode::Performance,
676            ..Default::default()
677        };
678        let (_result, report) = process(&payload, &config).unwrap();
679        // Performance mode should not modify images within limits
680        assert!(!report.has_changes() || report.images_modified == 0);
681    }
682}