Skip to main content

shift_preflight/payload/
openai.rs

1//! OpenAI chat completion message format parser.
2//!
3//! OpenAI uses `image_url` content parts:
4//! ```json
5//! {
6//!   "type": "image_url",
7//!   "image_url": {
8//!     "url": "data:image/png;base64,iVBOR..." // or "https://..."
9//!   }
10//! }
11//! ```
12
13use anyhow::{Context, Result};
14use serde_json::Value;
15
16use super::{ExtractedImage, ImageRef};
17use crate::inspector::{decode_base64_image, image::fetch_url_safe};
18use crate::mode::SafetyLimits;
19
20/// Extract all images from an OpenAI-format payload.
21pub fn extract_images(payload: &Value) -> Result<Vec<ExtractedImage>> {
22    extract_images_with_limits(payload, &SafetyLimits::default())
23}
24
25/// Extract images with explicit safety limits.
26pub fn extract_images_with_limits(
27    payload: &Value,
28    limits: &SafetyLimits,
29) -> Result<Vec<ExtractedImage>> {
30    let mut images = Vec::new();
31    let mut global_index = 0;
32
33    let messages = payload
34        .get("messages")
35        .and_then(|m| m.as_array())
36        .context("payload missing 'messages' array")?;
37
38    for (msg_idx, message) in messages.iter().enumerate() {
39        let content = match message.get("content") {
40            Some(Value::Array(arr)) => arr,
41            _ => continue,
42        };
43
44        for (part_idx, part) in content.iter().enumerate() {
45            let part_type = part.get("type").and_then(|t| t.as_str()).unwrap_or("");
46            if part_type != "image_url" {
47                continue;
48            }
49
50            // Fix #8: Cap total images extracted
51            if global_index >= limits.max_images_extract {
52                break;
53            }
54
55            let url = part
56                .get("image_url")
57                .and_then(|iu| iu.get("url"))
58                .and_then(|u| u.as_str())
59                .context("image_url part missing url field")?;
60
61            let (data, image_ref) = if url.starts_with("data:") {
62                let (bytes, mime_hint) = decode_base64_image(url)?;
63                let mime = mime_hint.unwrap_or_else(|| "image/png".to_string());
64                let b64 = url.find(',').map(|pos| &url[pos + 1..]).unwrap_or(url);
65                (
66                    bytes,
67                    ImageRef::DataUri {
68                        mime_type: mime,
69                        base64: b64.to_string(),
70                    },
71                )
72            } else if url.starts_with("http://") || url.starts_with("https://") {
73                // Fix #1, #3: Use safe URL fetcher with SSRF prevention and size limits
74                let bytes = fetch_url_safe(url, limits)?;
75                (bytes, ImageRef::Url(url.to_string()))
76            } else {
77                anyhow::bail!(
78                    "unsupported image_url format: {}",
79                    &url[..url.len().min(50)]
80                );
81            };
82
83            images.push(ExtractedImage {
84                message_index: msg_idx,
85                content_index: part_idx,
86                data,
87                original_ref: image_ref,
88                global_index,
89            });
90            global_index += 1;
91        }
92    }
93
94    Ok(images)
95}
96
97/// Reconstruct an OpenAI payload with transformed image data.
98///
99/// Takes the original payload and a list of (global_index, new_data, new_mime) tuples.
100/// Images with empty data are dropped from the payload.
101pub fn reconstruct(payload: &Value, transformed: &[(usize, Vec<u8>, String)]) -> Result<Value> {
102    use base64::Engine;
103    let engine = base64::engine::general_purpose::STANDARD;
104
105    let mut result = payload.clone();
106    let messages = result
107        .get_mut("messages")
108        .and_then(|m| m.as_array_mut())
109        .context("payload missing 'messages' array")?;
110
111    let mut global_index = 0;
112
113    for message in messages.iter_mut() {
114        let content = match message.get_mut("content") {
115            Some(Value::Array(arr)) => arr,
116            _ => continue,
117        };
118
119        // Collect indices to remove (for dropped images)
120        let mut to_remove = Vec::new();
121
122        for (part_idx, part) in content.iter_mut().enumerate() {
123            let part_type = part.get("type").and_then(|t| t.as_str()).unwrap_or("");
124            if part_type != "image_url" {
125                continue;
126            }
127
128            // Find this image in the transformed list
129            if let Some((_idx, new_data, new_mime)) =
130                transformed.iter().find(|(idx, _, _)| *idx == global_index)
131            {
132                if new_data.is_empty() {
133                    // Image was dropped
134                    to_remove.push(part_idx);
135                } else {
136                    // Replace with new data
137                    let b64 = engine.encode(new_data);
138                    let data_uri = format!("data:{};base64,{}", new_mime, b64);
139                    if let Some(image_url) = part.get_mut("image_url") {
140                        image_url["url"] = Value::String(data_uri);
141                    }
142                }
143            }
144
145            global_index += 1;
146        }
147
148        // Remove dropped images (reverse order to preserve indices)
149        for idx in to_remove.into_iter().rev() {
150            content.remove(idx);
151        }
152    }
153
154    Ok(result)
155}
156
157#[cfg(test)]
158mod tests {
159    use super::*;
160    use serde_json::json;
161
162    fn make_png_data_uri() -> String {
163        use base64::Engine;
164        let img = image::RgbaImage::new(100, 100);
165        let mut buf = Vec::new();
166        let encoder = image::codecs::png::PngEncoder::new(&mut buf);
167        image::ImageEncoder::write_image(
168            encoder,
169            img.as_raw(),
170            100,
171            100,
172            image::ExtendedColorType::Rgba8,
173        )
174        .unwrap();
175        let b64 = base64::engine::general_purpose::STANDARD.encode(&buf);
176        format!("data:image/png;base64,{}", b64)
177    }
178
179    #[test]
180    fn test_extract_single_image() {
181        let data_uri = make_png_data_uri();
182        let payload = json!({
183            "model": "gpt-4o",
184            "messages": [{
185                "role": "user",
186                "content": [
187                    {"type": "text", "text": "What's in this image?"},
188                    {"type": "image_url", "image_url": {"url": data_uri}}
189                ]
190            }]
191        });
192
193        let images = extract_images(&payload).unwrap();
194        assert_eq!(images.len(), 1);
195        assert_eq!(images[0].message_index, 0);
196        assert_eq!(images[0].content_index, 1);
197        assert_eq!(images[0].global_index, 0);
198        assert!(!images[0].data.is_empty());
199    }
200
201    #[test]
202    fn test_extract_multiple_images() {
203        let data_uri = make_png_data_uri();
204        let payload = json!({
205            "model": "gpt-4o",
206            "messages": [{
207                "role": "user",
208                "content": [
209                    {"type": "image_url", "image_url": {"url": data_uri.clone()}},
210                    {"type": "text", "text": "Compare these:"},
211                    {"type": "image_url", "image_url": {"url": data_uri}}
212                ]
213            }]
214        });
215
216        let images = extract_images(&payload).unwrap();
217        assert_eq!(images.len(), 2);
218        assert_eq!(images[0].global_index, 0);
219        assert_eq!(images[1].global_index, 1);
220    }
221
222    #[test]
223    fn test_extract_no_images() {
224        let payload = json!({
225            "model": "gpt-4o",
226            "messages": [{
227                "role": "user",
228                "content": "Hello, no images here"
229            }]
230        });
231
232        let images = extract_images(&payload).unwrap();
233        assert!(images.is_empty());
234    }
235
236    #[test]
237    fn test_extract_across_messages() {
238        let data_uri = make_png_data_uri();
239        let payload = json!({
240            "model": "gpt-4o",
241            "messages": [
242                {
243                    "role": "user",
244                    "content": [
245                        {"type": "image_url", "image_url": {"url": data_uri.clone()}}
246                    ]
247                },
248                {"role": "assistant", "content": "I see an image."},
249                {
250                    "role": "user",
251                    "content": [
252                        {"type": "image_url", "image_url": {"url": data_uri}}
253                    ]
254                }
255            ]
256        });
257
258        let images = extract_images(&payload).unwrap();
259        assert_eq!(images.len(), 2);
260        assert_eq!(images[0].message_index, 0);
261        assert_eq!(images[1].message_index, 2);
262    }
263
264    #[test]
265    fn test_reconstruct_replaces_image() {
266        let data_uri = make_png_data_uri();
267        let payload = json!({
268            "model": "gpt-4o",
269            "messages": [{
270                "role": "user",
271                "content": [
272                    {"type": "text", "text": "What's this?"},
273                    {"type": "image_url", "image_url": {"url": data_uri}}
274                ]
275            }]
276        });
277
278        // Simulate a transformed image (just some bytes)
279        let new_data = vec![0x89, 0x50, 0x4E, 0x47]; // PNG header stub
280        let transformed = vec![(0, new_data, "image/png".to_string())];
281
282        let result = reconstruct(&payload, &transformed).unwrap();
283        let url = result["messages"][0]["content"][1]["image_url"]["url"]
284            .as_str()
285            .unwrap();
286        assert!(url.starts_with("data:image/png;base64,"));
287    }
288
289    #[test]
290    fn test_reconstruct_drops_image() {
291        let data_uri = make_png_data_uri();
292        let payload = json!({
293            "model": "gpt-4o",
294            "messages": [{
295                "role": "user",
296                "content": [
297                    {"type": "text", "text": "What's this?"},
298                    {"type": "image_url", "image_url": {"url": data_uri}}
299                ]
300            }]
301        });
302
303        // Empty data means drop
304        let transformed = vec![(0, Vec::new(), "image/png".to_string())];
305
306        let result = reconstruct(&payload, &transformed).unwrap();
307        let content = result["messages"][0]["content"].as_array().unwrap();
308        // Should only have the text part left
309        assert_eq!(content.len(), 1);
310        assert_eq!(content[0]["type"], "text");
311    }
312}