Skip to main content

visual_rubric/
vision.rs

1//! OpenAI-compatible vision API client for the vision extraction stage.
2//!
3//! Sends base64-encoded images to a vision model (e.g. Qwen VL via
4//! llama-swap) at `POST /v1/chat/completions` and returns the model's
5//! text output as a structured JSON description for downstream rubric
6//! scoring.
7
8use crate::PoolError;
9
10/// Configuration for calling an OpenAI-compatible vision API.
11#[derive(Clone, Debug, PartialEq, Eq)]
12pub struct VisionApiConfig {
13    /// Base URL of the API, e.g. `"http://localhost:8013"`.
14    pub url: String,
15    /// Model name, e.g. `"qwen3-vl-8b"`.
16    pub model: String,
17    /// Optional Bearer token for API authentication.
18    pub api_key: Option<String>,
19}
20
21/// Calls an OpenAI-compatible vision API with a base64-encoded PNG and a
22/// text question, returning the model's response text.
23///
24/// The request is a `POST {url}/v1/chat/completions` with the image
25/// embedded as a `data:image/png;base64,...` URL in a multimodal message.
26///
27/// # Errors
28///
29/// Returns [`PoolError::VisionApi`] for HTTP, JSON, or unexpected response
30/// shape failures.
31pub fn call_vision_api(
32    b64_png: &str,
33    question: &str,
34    config: &VisionApiConfig,
35) -> Result<String, PoolError> {
36    let client = reqwest::blocking::Client::builder()
37        .timeout(std::time::Duration::from_secs(300))
38        .build()
39        .map_err(|e| PoolError::VisionApi(format!("build http client: {e}")))?;
40
41    let url = format!("{}/v1/chat/completions", config.url.trim_end_matches('/'));
42
43    let body = serde_json::json!({
44        "model": config.model,
45        "messages": [{
46            "role": "user",
47            "content": [
48                {
49                    "type": "image_url",
50                    "image_url": {
51                        "url": format!("data:image/png;base64,{b64_png}")
52                    }
53                },
54                {
55                    "type": "text",
56                    "text": question
57                }
58            ]
59        }],
60        "max_tokens": 4096,
61    });
62
63    let mut request = client.post(&url).json(&body);
64    if let Some(ref api_key) = config.api_key {
65        request = request.header("Authorization", format!("Bearer {api_key}"));
66    }
67
68    let response = request
69        .send()
70        .map_err(|e| PoolError::VisionApi(format!("vision request: {e}")))?;
71
72    let status = response.status();
73    if !status.is_success() {
74        let text = response.text().unwrap_or_default();
75        return Err(PoolError::VisionApi(format!("vision API {status}: {text}")));
76    }
77
78    let result: serde_json::Value = response
79        .json()
80        .map_err(|e| PoolError::VisionApi(format!("vision response json: {e}")))?;
81
82    let content = result["choices"][0]["message"]["content"]
83        .as_str()
84        .ok_or_else(|| {
85            PoolError::VisionApi("vision response missing choices[0].message.content".to_string())
86        })?
87        .to_string();
88
89    Ok(content)
90}
91
92#[cfg(test)]
93mod tests {
94    use super::*;
95
96    #[test]
97    fn vision_api_rejects_bad_url() {
98        let config = VisionApiConfig {
99            url: "http://127.0.0.1:1".to_string(),
100            model: "test-model".to_string(),
101            api_key: None,
102        };
103        let result = call_vision_api("fake-base64", "test question", &config);
104        assert!(result.is_err(), "expected error for unreachable port");
105        match result {
106            Err(PoolError::VisionApi(msg)) => {
107                assert!(!msg.is_empty(), "error message should not be empty");
108            }
109            other => panic!("expected VisionApi error, got {other:?}"),
110        }
111    }
112}