stygian_graph/adapters/ai/
openai.rs1use std::time::Duration;
22
23use async_trait::async_trait;
24use futures::stream::{self, BoxStream};
25use reqwest::Client;
26use serde_json::{Value, json};
27
28use crate::domain::error::{ProviderError, Result, StygianError};
29use crate::ports::{AIProvider, ProviderCapabilities};
30
31const DEFAULT_MODEL: &str = "gpt-4o";
33
34const API_URL: &str = "https://api.openai.com/v1/chat/completions";
36
37#[derive(Debug, Clone)]
39pub struct OpenAIConfig {
40 pub api_key: String,
42 pub model: String,
44 pub max_tokens: u32,
46 pub timeout: Duration,
48}
49
50impl OpenAIConfig {
51 #[must_use]
53 pub fn new(api_key: String) -> Self {
54 Self {
55 api_key,
56 model: DEFAULT_MODEL.to_string(),
57 max_tokens: 4096,
58 timeout: Duration::from_mins(2),
59 }
60 }
61
62 #[must_use]
64 pub fn with_model(mut self, model: impl Into<String>) -> Self {
65 self.model = model.into();
66 self
67 }
68}
69
70pub struct OpenAIProvider {
74 config: OpenAIConfig,
75 client: Client,
76}
77
78impl OpenAIProvider {
79 #[must_use]
88 pub fn new(api_key: String) -> Self {
89 Self::with_config(OpenAIConfig::new(api_key))
90 }
91
92 #[must_use]
108 pub fn with_config(config: OpenAIConfig) -> Self {
109 #[allow(clippy::expect_used)]
111 let client = Client::builder()
112 .timeout(config.timeout)
113 .build()
114 .expect("Failed to build HTTP client");
115 Self { config, client }
116 }
117
118 fn build_body(&self, content: &str, schema: &Value) -> Value {
119 let system = "You are a precise data extraction assistant. \
120 Extract structured data from the provided content matching the given JSON schema. \
121 Return ONLY valid JSON matching the schema, no extra text.";
122
123 let user_msg = format!(
124 "Schema: {}\n\nContent:\n{}",
125 serde_json::to_string(schema).unwrap_or_default(),
126 content
127 );
128
129 json!({
130 "model": self.config.model,
131 "max_tokens": self.config.max_tokens,
132 "response_format": {"type": "json_object"},
133 "messages": [
134 {"role": "system", "content": system},
135 {"role": "user", "content": user_msg}
136 ]
137 })
138 }
139
140 fn parse_response(response: &Value) -> Result<Value> {
141 let text = response
142 .pointer("/choices/0/message/content")
143 .and_then(Value::as_str)
144 .ok_or_else(|| {
145 StygianError::Provider(ProviderError::ApiError(
146 "No content in OpenAI response".to_string(),
147 ))
148 })?;
149
150 serde_json::from_str(text).map_err(|e| {
151 StygianError::Provider(ProviderError::ApiError(format!(
152 "Failed to parse OpenAI JSON response: {e}"
153 )))
154 })
155 }
156
157 fn map_http_error(status: u16, body: &str) -> StygianError {
158 match status {
159 401 => StygianError::Provider(ProviderError::InvalidCredentials),
160 429 => StygianError::Provider(ProviderError::ApiError(format!(
161 "OpenAI rate limited: {body}"
162 ))),
163 _ => StygianError::Provider(ProviderError::ApiError(format!("HTTP {status}: {body}"))),
164 }
165 }
166}
167
168#[async_trait]
169impl AIProvider for OpenAIProvider {
170 async fn extract(&self, content: String, schema: Value) -> Result<Value> {
171 let body = self.build_body(&content, &schema);
172
173 let response = self
174 .client
175 .post(API_URL)
176 .header("Authorization", format!("Bearer {}", &self.config.api_key))
177 .header("Content-Type", "application/json")
178 .json(&body)
179 .send()
180 .await
181 .map_err(|e| {
182 StygianError::Provider(ProviderError::ApiError(format!(
183 "OpenAI request failed: {e}"
184 )))
185 })?;
186
187 let status = response.status().as_u16();
188 let text = response
189 .text()
190 .await
191 .map_err(|e| StygianError::Provider(ProviderError::ApiError(e.to_string())))?;
192
193 if status != 200 {
194 return Err(Self::map_http_error(status, &text));
195 }
196
197 let json_val: Value = serde_json::from_str(&text)
198 .map_err(|e| StygianError::Provider(ProviderError::ApiError(e.to_string())))?;
199
200 Self::parse_response(&json_val)
201 }
202
203 async fn stream_extract(
204 &self,
205 content: String,
206 schema: Value,
207 ) -> Result<BoxStream<'static, Result<Value>>> {
208 let result = self.extract(content, schema).await;
209 Ok(Box::pin(stream::once(async move { result })))
210 }
211
212 fn capabilities(&self) -> ProviderCapabilities {
213 ProviderCapabilities {
214 streaming: true,
215 vision: true,
216 tool_use: true,
217 json_mode: true,
218 }
219 }
220
221 fn name(&self) -> &'static str {
222 "openai"
223 }
224}
225
226#[cfg(test)]
227mod tests {
228 use super::*;
229 use serde_json::json;
230
231 #[test]
232 fn test_name() {
233 assert_eq!(OpenAIProvider::new("k".to_string()).name(), "openai");
234 }
235
236 #[test]
237 fn test_capabilities() {
238 let caps = OpenAIProvider::new("k".to_string()).capabilities();
239 assert!(caps.json_mode);
240 assert!(caps.streaming);
241 }
242
243 #[test]
244 fn test_build_body_contains_json_format() {
245 let p = OpenAIProvider::new("k".to_string());
246 let body = p.build_body("content", &json!({"type": "object"}));
247 assert_eq!(
248 body.get("response_format")
249 .and_then(|rf| rf.get("type"))
250 .and_then(Value::as_str),
251 Some("json_object")
252 );
253 }
254
255 #[test]
256 fn test_parse_response_valid() -> Result<()> {
257 let resp = json!({
258 "choices": [{"message": {"content": "{\"title\": \"Hello\"}"}}]
259 });
260 let val = OpenAIProvider::parse_response(&resp)?;
261 assert_eq!(val.get("title").and_then(Value::as_str), Some("Hello"));
262 Ok(())
263 }
264
265 #[test]
266 fn test_parse_response_invalid_json() {
267 let resp = json!({"choices": [{"message": {"content": "not json"}}]});
268 assert!(OpenAIProvider::parse_response(&resp).is_err());
269 }
270
271 #[test]
272 fn test_map_http_error_401() {
273 assert!(matches!(
274 OpenAIProvider::map_http_error(401, ""),
275 StygianError::Provider(ProviderError::InvalidCredentials)
276 ));
277 }
278
279 #[test]
280 fn test_map_http_error_429() {
281 let err = OpenAIProvider::map_http_error(429, "too many");
282 assert!(
283 matches!(err, StygianError::Provider(ProviderError::ApiError(ref msg)) if msg.contains("rate limited"))
284 );
285 }
286
287 #[test]
288 fn test_map_http_error_server_error() {
289 let err = OpenAIProvider::map_http_error(500, "internal");
290 assert!(
291 matches!(err, StygianError::Provider(ProviderError::ApiError(ref msg)) if msg.contains("500"))
292 );
293 }
294
295 #[test]
296 fn test_parse_response_missing_choices() {
297 let resp = serde_json::json!({"id": "chatcmpl-abc"});
298 assert!(OpenAIProvider::parse_response(&resp).is_err());
299 }
300
301 #[test]
302 fn test_config_with_model() {
303 let cfg = OpenAIConfig::new("key".to_string()).with_model("gpt-4-turbo");
304 assert_eq!(cfg.model, "gpt-4-turbo");
305 }
306}