Skip to main content

subx_cli/services/ai/
openai.rs

1use crate::Result;
2use crate::cli::display_ai_usage;
3use crate::error::SubXError;
4use crate::services::ai::AiUsageStats;
5use crate::services::ai::{
6    AIProvider, AnalysisRequest, ConfidenceScore, MatchResult, VerificationRequest,
7};
8use async_trait::async_trait;
9use reqwest::Client;
10use serde_json::Value;
11use serde_json::json;
12use std::time::Duration;
13
14use crate::services::ai::prompts::{PromptBuilder, ResponseParser};
15use crate::services::ai::retry::HttpRetryClient;
16
17/// OpenAI client implementation
18pub struct OpenAIClient {
19    client: Client,
20    api_key: String,
21    model: String,
22    temperature: f32,
23    max_tokens: u32,
24    retry_attempts: u32,
25    retry_delay_ms: u64,
26    base_url: String,
27}
28
29impl std::fmt::Debug for OpenAIClient {
30    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31        f.debug_struct("OpenAIClient")
32            .field("client", &self.client)
33            .field("api_key", &"[REDACTED]")
34            .field("model", &self.model)
35            .field("temperature", &self.temperature)
36            .field("max_tokens", &self.max_tokens)
37            .field("retry_attempts", &self.retry_attempts)
38            .field("retry_delay_ms", &self.retry_delay_ms)
39            .field("base_url", &self.base_url)
40            .finish()
41    }
42}
43
44impl PromptBuilder for OpenAIClient {}
45impl ResponseParser for OpenAIClient {}
46impl HttpRetryClient for OpenAIClient {
47    fn retry_attempts(&self) -> u32 {
48        self.retry_attempts
49    }
50    fn retry_delay_ms(&self) -> u64 {
51        self.retry_delay_ms
52    }
53}
54
55// Mock testing: OpenAIClient with AIProvider interface
56#[cfg(test)]
57mod tests {
58    use super::*;
59    use mockall::{mock, predicate::eq};
60    use serde_json::json;
61    use wiremock::matchers::{header, method, path};
62    use wiremock::{Mock, MockServer, ResponseTemplate};
63
64    mock! {
65        AIClient {}
66
67        #[async_trait]
68        impl AIProvider for AIClient {
69            async fn analyze_content(&self, request: AnalysisRequest) -> crate::Result<MatchResult>;
70            async fn verify_match(&self, verification: VerificationRequest) -> crate::Result<ConfidenceScore>;
71        }
72    }
73
74    #[tokio::test]
75    async fn test_openai_client_creation() {
76        let client = OpenAIClient::new("test-key".into(), "gpt-4.1-mini".into(), 0.5, 1000, 2, 100);
77        assert_eq!(client.api_key, "test-key");
78        assert_eq!(client.model, "gpt-4.1-mini");
79        assert_eq!(client.temperature, 0.5);
80        assert_eq!(client.max_tokens, 1000);
81        assert_eq!(client.retry_attempts, 2);
82        assert_eq!(client.retry_delay_ms, 100);
83    }
84
85    #[tokio::test]
86    async fn test_chat_completion_success() {
87        let server = MockServer::start().await;
88        Mock::given(method("POST"))
89            .and(path("/chat/completions"))
90            .and(header("authorization", "Bearer test-key"))
91            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
92                "choices": [{"message": {"content": "test response content"}}]
93            })))
94            .mount(&server)
95            .await;
96        let mut client =
97            OpenAIClient::new("test-key".into(), "gpt-4.1-mini".into(), 0.3, 1000, 1, 0);
98        client.base_url = server.uri();
99        let messages = vec![json!({"role":"user","content":"test"})];
100        let resp = client.chat_completion(messages).await.unwrap();
101        assert_eq!(resp, "test response content");
102    }
103
104    #[tokio::test]
105    async fn test_chat_completion_error() {
106        let server = MockServer::start().await;
107        Mock::given(method("POST"))
108            .and(path("/chat/completions"))
109            .respond_with(ResponseTemplate::new(400).set_body_json(json!({
110                "error": {"message":"Invalid API key"}
111            })))
112            .mount(&server)
113            .await;
114        let mut client =
115            OpenAIClient::new("bad-key".into(), "gpt-4.1-mini".into(), 0.3, 1000, 1, 0);
116        client.base_url = server.uri();
117        let messages = vec![json!({"role":"user","content":"test"})];
118        let result = client.chat_completion(messages).await;
119        assert!(result.is_err());
120    }
121
122    #[tokio::test]
123    async fn test_analyze_content_with_mock() {
124        let mut mock = MockAIClient::new();
125        let req = AnalysisRequest {
126            video_files: vec!["v.mp4".into()],
127            subtitle_files: vec!["s.srt".into()],
128            content_samples: vec![],
129        };
130        let expected = MatchResult {
131            matches: vec![],
132            confidence: 0.5,
133            reasoning: "OK".into(),
134        };
135        mock.expect_analyze_content()
136            .with(eq(req.clone()))
137            .times(1)
138            .returning(move |_| Ok(expected.clone()));
139        let res = mock.analyze_content(req.clone()).await.unwrap();
140        assert_eq!(res.confidence, 0.5);
141    }
142
143    #[test]
144    fn test_prompt_building_and_parsing() {
145        let client = OpenAIClient::new("k".into(), "m".into(), 0.1, 1000, 0, 0);
146        let request = AnalysisRequest {
147            video_files: vec!["F1.mp4".into()],
148            subtitle_files: vec!["S1.srt".into()],
149            content_samples: vec![],
150        };
151        let prompt = client.build_analysis_prompt(&request);
152        assert!(prompt.contains("F1.mp4"));
153        assert!(prompt.contains("S1.srt"));
154        assert!(prompt.contains("JSON"));
155        let json_resp = r#"{ "matches": [], "confidence":0.9, "reasoning":"r" }"#;
156        let mr = client.parse_match_result(json_resp).unwrap();
157        assert_eq!(mr.confidence, 0.9);
158    }
159
160    #[test]
161    fn test_openai_client_from_config() {
162        let config = crate::config::AIConfig {
163            provider: "openai".to_string(),
164            api_key: Some("test-key".to_string()),
165            model: "gpt-test".to_string(),
166            base_url: "https://custom.openai.com/v1".to_string(),
167            max_sample_length: 500,
168            temperature: 0.7,
169            max_tokens: 2000,
170            retry_attempts: 2,
171            retry_delay_ms: 150,
172            request_timeout_seconds: 60,
173            api_version: None,
174        };
175        let client = OpenAIClient::from_config(&config).unwrap();
176        assert_eq!(client.api_key, "test-key");
177        assert_eq!(client.model, "gpt-test");
178        assert_eq!(client.temperature, 0.7);
179        assert_eq!(client.max_tokens, 2000);
180    }
181
182    #[test]
183    fn test_openai_client_from_config_invalid_base_url() {
184        let config = crate::config::AIConfig {
185            provider: "openai".to_string(),
186            api_key: Some("test-key".to_string()),
187            model: "gpt-test".to_string(),
188            base_url: "ftp://invalid.url".to_string(),
189            max_sample_length: 500,
190            temperature: 0.7,
191            max_tokens: 1000,
192            retry_attempts: 2,
193            retry_delay_ms: 150,
194            request_timeout_seconds: 30,
195            api_version: None,
196        };
197        let err = OpenAIClient::from_config(&config).unwrap_err();
198        // Non-http/https protocols should return protocol error message
199        assert!(
200            err.to_string()
201                .contains("Base URL must use http or https protocol")
202        );
203    }
204}
205
206impl OpenAIClient {
207    /// Create new OpenAIClient (using default base_url)
208    pub fn new(
209        api_key: String,
210        model: String,
211        temperature: f32,
212        max_tokens: u32,
213        retry_attempts: u32,
214        retry_delay_ms: u64,
215    ) -> Self {
216        Self::new_with_base_url(
217            api_key,
218            model,
219            temperature,
220            max_tokens,
221            retry_attempts,
222            retry_delay_ms,
223            "https://api.openai.com/v1".to_string(),
224        )
225    }
226
227    /// Create a new OpenAIClient with custom base_url support
228    pub fn new_with_base_url(
229        api_key: String,
230        model: String,
231        temperature: f32,
232        max_tokens: u32,
233        retry_attempts: u32,
234        retry_delay_ms: u64,
235        base_url: String,
236    ) -> Self {
237        // Use default 30 second timeout for backward compatibility
238        Self::new_with_base_url_and_timeout(
239            api_key,
240            model,
241            temperature,
242            max_tokens,
243            retry_attempts,
244            retry_delay_ms,
245            base_url,
246            30,
247        )
248    }
249
250    /// Create a new OpenAIClient with custom base_url and timeout support
251    #[allow(clippy::too_many_arguments)]
252    pub fn new_with_base_url_and_timeout(
253        api_key: String,
254        model: String,
255        temperature: f32,
256        max_tokens: u32,
257        retry_attempts: u32,
258        retry_delay_ms: u64,
259        base_url: String,
260        request_timeout_seconds: u64,
261    ) -> Self {
262        let client = Client::builder()
263            .timeout(Duration::from_secs(request_timeout_seconds))
264            .build()
265            .expect("Failed to create HTTP client");
266        Self {
267            client,
268            api_key,
269            model,
270            temperature,
271            max_tokens,
272            retry_attempts,
273            retry_delay_ms,
274            base_url: base_url.trim_end_matches('/').to_string(),
275        }
276    }
277
278    /// Create client from unified configuration
279    pub fn from_config(config: &crate::config::AIConfig) -> crate::Result<Self> {
280        let api_key = config
281            .api_key
282            .as_ref()
283            .ok_or_else(|| crate::error::SubXError::config("Missing OpenAI API Key"))?;
284
285        // Validate base URL format
286        Self::validate_base_url(&config.base_url)?;
287        crate::services::ai::security::warn_on_insecure_http_str(&config.base_url, api_key);
288
289        Ok(Self::new_with_base_url_and_timeout(
290            api_key.clone(),
291            config.model.clone(),
292            config.temperature,
293            config.max_tokens,
294            config.retry_attempts,
295            config.retry_delay_ms,
296            config.base_url.clone(),
297            config.request_timeout_seconds,
298        ))
299    }
300
301    /// Validate base URL format
302    fn validate_base_url(url: &str) -> crate::Result<()> {
303        use url::Url;
304        let parsed = Url::parse(url)
305            .map_err(|e| crate::error::SubXError::config(format!("Invalid base URL: {}", e)))?;
306
307        if !matches!(parsed.scheme(), "http" | "https") {
308            return Err(crate::error::SubXError::config(
309                "Base URL must use http or https protocol".to_string(),
310            ));
311        }
312
313        if parsed.host().is_none() {
314            return Err(crate::error::SubXError::config(
315                "Base URL must contain a valid hostname".to_string(),
316            ));
317        }
318
319        Ok(())
320    }
321
322    async fn chat_completion(&self, messages: Vec<serde_json::Value>) -> Result<String> {
323        let request_body = json!({
324            "model": self.model,
325            "messages": messages,
326            "temperature": self.temperature,
327            "max_tokens": self.max_tokens,
328        });
329
330        let request = self
331            .client
332            .post(format!("{}/chat/completions", self.base_url))
333            .header("Authorization", format!("Bearer {}", self.api_key))
334            .header("Content-Type", "application/json")
335            .json(&request_body);
336        let mut response = self.make_request_with_retry(request).await?;
337
338        const MAX_AI_RESPONSE_BYTES: u64 = 10 * 1024 * 1024; // 10 MiB
339        if let Some(len) = response.content_length() {
340            if len > MAX_AI_RESPONSE_BYTES {
341                return Err(SubXError::AiService(format!(
342                    "AI response too large: {} bytes (limit: {} bytes)",
343                    len, MAX_AI_RESPONSE_BYTES
344                )));
345            }
346        }
347
348        if !response.status().is_success() {
349            let status = response.status();
350            let error_text = response.text().await?;
351            let safe_body = crate::services::ai::error_sanitizer::sanitize_url_in_error(
352                &crate::services::ai::error_sanitizer::truncate_error_body(
353                    &error_text,
354                    crate::services::ai::error_sanitizer::DEFAULT_ERROR_BODY_MAX_LEN,
355                ),
356            );
357            return Err(SubXError::AiService(format!(
358                "OpenAI API error {}: {}",
359                status, safe_body
360            )));
361        }
362
363        // Bounded chunked read to guard against oversized responses when
364        // content_length() is not reported by the server.
365        let mut body = Vec::new();
366        while let Some(chunk) = response.chunk().await? {
367            body.extend_from_slice(&chunk);
368            if body.len() as u64 > MAX_AI_RESPONSE_BYTES {
369                return Err(SubXError::AiService(format!(
370                    "AI response too large: {} bytes read (limit: {} bytes)",
371                    body.len(),
372                    MAX_AI_RESPONSE_BYTES
373                )));
374            }
375        }
376        let response_json: Value = serde_json::from_slice(&body)
377            .map_err(|e| SubXError::AiService(format!("Failed to parse AI response: {}", e)))?;
378        let content = response_json["choices"][0]["message"]["content"]
379            .as_str()
380            .ok_or_else(|| SubXError::AiService("Invalid API response format".to_string()))?;
381
382        // Parse usage statistics and display
383        if let Some(usage_obj) = response_json.get("usage") {
384            if let (Some(p), Some(c), Some(t)) = (
385                usage_obj.get("prompt_tokens").and_then(Value::as_u64),
386                usage_obj.get("completion_tokens").and_then(Value::as_u64),
387                usage_obj.get("total_tokens").and_then(Value::as_u64),
388            ) {
389                let stats = AiUsageStats {
390                    model: self.model.clone(),
391                    prompt_tokens: p as u32,
392                    completion_tokens: c as u32,
393                    total_tokens: t as u32,
394                };
395                display_ai_usage(&stats);
396            }
397        }
398
399        Ok(content.to_string())
400    }
401}
402
403#[async_trait]
404impl AIProvider for OpenAIClient {
405    async fn analyze_content(&self, request: AnalysisRequest) -> Result<MatchResult> {
406        let prompt = self.build_analysis_prompt(&request);
407        let messages = vec![
408            json!({"role": "system", "content": Self::get_analysis_system_message()}),
409            json!({"role": "user", "content": prompt}),
410        ];
411        let response = self.chat_completion(messages).await?;
412        self.parse_match_result(&response)
413    }
414
415    async fn verify_match(&self, verification: VerificationRequest) -> Result<ConfidenceScore> {
416        let prompt = self.build_verification_prompt(&verification);
417        let messages = vec![
418            json!({"role": "system", "content": Self::get_verification_system_message()}),
419            json!({"role": "user", "content": prompt}),
420        ];
421        let response = self.chat_completion(messages).await?;
422        self.parse_confidence_score(&response)
423    }
424}