Skip to main content

stygian_graph/adapters/ai/
openai.rs

1//! OpenAI (ChatGPT) AI provider adapter
2//!
3//! Implements the `AIProvider` port using OpenAI's Chat Completions API.
4//! Supports GPT-4o, GPT-4, and o1-series models with native JSON mode
5//! (`response_format: json_object`) and function calling for structured extraction.
6//!
7//! # Example
8//!
9//! ```no_run
10//! use stygian_graph::adapters::ai::openai::{OpenAIProvider, OpenAIConfig};
11//! use stygian_graph::ports::AIProvider;
12//! use serde_json::json;
13//!
14//! # tokio::runtime::Runtime::new().unwrap().block_on(async {
15//! let provider = OpenAIProvider::new("sk-...".to_string());
16//! let schema = json!({"type": "object", "properties": {"title": {"type": "string"}}});
17//! // let result = provider.extract("<html>Hello</html>".to_string(), schema).await.unwrap();
18//! # });
19//! ```
20
21use std::time::Duration;
22
23use async_trait::async_trait;
24use futures::stream::{self, BoxStream};
25use reqwest::Client;
26use serde_json::{Value, json};
27
28use crate::domain::error::{ProviderError, Result, StygianError};
29use crate::ports::{AIProvider, ProviderCapabilities};
30
31/// Default model
32const DEFAULT_MODEL: &str = "gpt-4o";
33
34/// Chat completions endpoint
35const API_URL: &str = "https://api.openai.com/v1/chat/completions";
36
37/// Configuration for the `OpenAI` provider
38#[derive(Debug, Clone)]
39pub struct OpenAIConfig {
40    /// `OpenAI` API key
41    pub api_key: String,
42    /// Model identifier
43    pub model: String,
44    /// Maximum response tokens
45    pub max_tokens: u32,
46    /// Request timeout
47    pub timeout: Duration,
48}
49
50impl OpenAIConfig {
51    /// Create config with API key and defaults
52    #[must_use]
53    pub fn new(api_key: String) -> Self {
54        Self {
55            api_key,
56            model: DEFAULT_MODEL.to_string(),
57            max_tokens: 4096,
58            timeout: Duration::from_mins(2),
59        }
60    }
61
62    /// Override model
63    #[must_use]
64    pub fn with_model(mut self, model: impl Into<String>) -> Self {
65        self.model = model.into();
66        self
67    }
68}
69
70/// `OpenAI` provider adapter
71///
72/// Uses `response_format: json_object` + function calling to enforce schema.
73pub struct OpenAIProvider {
74    config: OpenAIConfig,
75    client: Client,
76}
77
78impl OpenAIProvider {
79    /// Create with API key and defaults
80    ///
81    /// # Example
82    ///
83    /// ```no_run
84    /// use stygian_graph::adapters::ai::openai::OpenAIProvider;
85    /// let p = OpenAIProvider::new("sk-...".to_string());
86    /// ```
87    #[must_use]
88    pub fn new(api_key: String) -> Self {
89        Self::with_config(OpenAIConfig::new(api_key))
90    }
91
92    /// Create with custom configuration
93    ///
94    /// # Panics
95    ///
96    /// Panics if the underlying HTTP client fails to build. With `rustls` as the
97    /// TLS backend this is unreachable in practice (build only fails when no TLS
98    /// backend is configured).
99    ///
100    /// # Example
101    ///
102    /// ```no_run
103    /// use stygian_graph::adapters::ai::openai::{OpenAIProvider, OpenAIConfig};
104    /// let config = OpenAIConfig::new("sk-...".to_string()).with_model("gpt-4");
105    /// let p = OpenAIProvider::with_config(config);
106    /// ```
107    #[must_use]
108    pub fn with_config(config: OpenAIConfig) -> Self {
109        // SAFETY: TLS backend (rustls) is always available; build() only fails if no TLS backend.
110        #[allow(clippy::expect_used)]
111        let client = Client::builder()
112            .timeout(config.timeout)
113            .build()
114            .expect("Failed to build HTTP client");
115        Self { config, client }
116    }
117
118    fn build_body(&self, content: &str, schema: &Value) -> Value {
119        let system = "You are a precise data extraction assistant. \
120            Extract structured data from the provided content matching the given JSON schema. \
121            Return ONLY valid JSON matching the schema, no extra text.";
122
123        let user_msg = format!(
124            "Schema: {}\n\nContent:\n{}",
125            serde_json::to_string(schema).unwrap_or_default(),
126            content
127        );
128
129        json!({
130            "model": self.config.model,
131            "max_tokens": self.config.max_tokens,
132            "response_format": {"type": "json_object"},
133            "messages": [
134                {"role": "system", "content": system},
135                {"role": "user", "content": user_msg}
136            ]
137        })
138    }
139
140    fn parse_response(response: &Value) -> Result<Value> {
141        let text = response
142            .pointer("/choices/0/message/content")
143            .and_then(Value::as_str)
144            .ok_or_else(|| {
145                StygianError::Provider(ProviderError::ApiError(
146                    "No content in OpenAI response".to_string(),
147                ))
148            })?;
149
150        serde_json::from_str(text).map_err(|e| {
151            StygianError::Provider(ProviderError::ApiError(format!(
152                "Failed to parse OpenAI JSON response: {e}"
153            )))
154        })
155    }
156
157    fn map_http_error(status: u16, body: &str) -> StygianError {
158        match status {
159            401 => StygianError::Provider(ProviderError::InvalidCredentials),
160            429 => StygianError::Provider(ProviderError::ApiError(format!(
161                "OpenAI rate limited: {body}"
162            ))),
163            _ => StygianError::Provider(ProviderError::ApiError(format!("HTTP {status}: {body}"))),
164        }
165    }
166}
167
168#[async_trait]
169impl AIProvider for OpenAIProvider {
170    async fn extract(&self, content: String, schema: Value) -> Result<Value> {
171        let body = self.build_body(&content, &schema);
172
173        let response = self
174            .client
175            .post(API_URL)
176            .header("Authorization", format!("Bearer {}", &self.config.api_key))
177            .header("Content-Type", "application/json")
178            .json(&body)
179            .send()
180            .await
181            .map_err(|e| {
182                StygianError::Provider(ProviderError::ApiError(format!(
183                    "OpenAI request failed: {e}"
184                )))
185            })?;
186
187        let status = response.status().as_u16();
188        let text = response
189            .text()
190            .await
191            .map_err(|e| StygianError::Provider(ProviderError::ApiError(e.to_string())))?;
192
193        if status != 200 {
194            return Err(Self::map_http_error(status, &text));
195        }
196
197        let json_val: Value = serde_json::from_str(&text)
198            .map_err(|e| StygianError::Provider(ProviderError::ApiError(e.to_string())))?;
199
200        Self::parse_response(&json_val)
201    }
202
203    async fn stream_extract(
204        &self,
205        content: String,
206        schema: Value,
207    ) -> Result<BoxStream<'static, Result<Value>>> {
208        let result = self.extract(content, schema).await;
209        Ok(Box::pin(stream::once(async move { result })))
210    }
211
212    fn capabilities(&self) -> ProviderCapabilities {
213        ProviderCapabilities {
214            streaming: true,
215            vision: true,
216            tool_use: true,
217            json_mode: true,
218        }
219    }
220
221    fn name(&self) -> &'static str {
222        "openai"
223    }
224}
225
226#[cfg(test)]
227mod tests {
228    use super::*;
229    use serde_json::json;
230
231    #[test]
232    fn test_name() {
233        assert_eq!(OpenAIProvider::new("k".to_string()).name(), "openai");
234    }
235
236    #[test]
237    fn test_capabilities() {
238        let caps = OpenAIProvider::new("k".to_string()).capabilities();
239        assert!(caps.json_mode);
240        assert!(caps.streaming);
241    }
242
243    #[test]
244    fn test_build_body_contains_json_format() {
245        let p = OpenAIProvider::new("k".to_string());
246        let body = p.build_body("content", &json!({"type": "object"}));
247        assert_eq!(
248            body.get("response_format")
249                .and_then(|rf| rf.get("type"))
250                .and_then(Value::as_str),
251            Some("json_object")
252        );
253    }
254
255    #[test]
256    fn test_parse_response_valid() -> Result<()> {
257        let resp = json!({
258            "choices": [{"message": {"content": "{\"title\": \"Hello\"}"}}]
259        });
260        let val = OpenAIProvider::parse_response(&resp)?;
261        assert_eq!(val.get("title").and_then(Value::as_str), Some("Hello"));
262        Ok(())
263    }
264
265    #[test]
266    fn test_parse_response_invalid_json() {
267        let resp = json!({"choices": [{"message": {"content": "not json"}}]});
268        assert!(OpenAIProvider::parse_response(&resp).is_err());
269    }
270
271    #[test]
272    fn test_map_http_error_401() {
273        assert!(matches!(
274            OpenAIProvider::map_http_error(401, ""),
275            StygianError::Provider(ProviderError::InvalidCredentials)
276        ));
277    }
278
279    #[test]
280    fn test_map_http_error_429() {
281        let err = OpenAIProvider::map_http_error(429, "too many");
282        assert!(
283            matches!(err, StygianError::Provider(ProviderError::ApiError(ref msg)) if msg.contains("rate limited"))
284        );
285    }
286
287    #[test]
288    fn test_map_http_error_server_error() {
289        let err = OpenAIProvider::map_http_error(500, "internal");
290        assert!(
291            matches!(err, StygianError::Provider(ProviderError::ApiError(ref msg)) if msg.contains("500"))
292        );
293    }
294
295    #[test]
296    fn test_parse_response_missing_choices() {
297        let resp = serde_json::json!({"id": "chatcmpl-abc"});
298        assert!(OpenAIProvider::parse_response(&resp).is_err());
299    }
300
301    #[test]
302    fn test_config_with_model() {
303        let cfg = OpenAIConfig::new("key".to_string()).with_model("gpt-4-turbo");
304        assert_eq!(cfg.model, "gpt-4-turbo");
305    }
306}