subx_cli/services/ai/
openai.rs

1use crate::Result;
2use crate::cli::display_ai_usage;
3use crate::error::SubXError;
4use crate::services::ai::AiUsageStats;
5use crate::services::ai::{
6    AIProvider, AnalysisRequest, ConfidenceScore, MatchResult, VerificationRequest,
7};
8use async_trait::async_trait;
9use reqwest::Client;
10use serde_json::Value;
11use serde_json::json;
12use std::time::Duration;
13
14use crate::services::ai::prompts::{PromptBuilder, ResponseParser};
15use crate::services::ai::retry::HttpRetryClient;
16
17/// OpenAI client implementation
18#[derive(Debug)]
19pub struct OpenAIClient {
20    client: Client,
21    api_key: String,
22    model: String,
23    temperature: f32,
24    max_tokens: u32,
25    retry_attempts: u32,
26    retry_delay_ms: u64,
27    base_url: String,
28}
29
30impl PromptBuilder for OpenAIClient {}
31impl ResponseParser for OpenAIClient {}
32impl HttpRetryClient for OpenAIClient {
33    fn retry_attempts(&self) -> u32 {
34        self.retry_attempts
35    }
36    fn retry_delay_ms(&self) -> u64 {
37        self.retry_delay_ms
38    }
39}
40
41// Mock testing: OpenAIClient with AIProvider interface
42#[cfg(test)]
43mod tests {
44    use super::*;
45    use mockall::{mock, predicate::eq};
46    use serde_json::json;
47    use wiremock::matchers::{header, method, path};
48    use wiremock::{Mock, MockServer, ResponseTemplate};
49
50    mock! {
51        AIClient {}
52
53        #[async_trait]
54        impl AIProvider for AIClient {
55            async fn analyze_content(&self, request: AnalysisRequest) -> crate::Result<MatchResult>;
56            async fn verify_match(&self, verification: VerificationRequest) -> crate::Result<ConfidenceScore>;
57        }
58    }
59
60    #[tokio::test]
61    async fn test_openai_client_creation() {
62        let client = OpenAIClient::new("test-key".into(), "gpt-4.1-mini".into(), 0.5, 1000, 2, 100);
63        assert_eq!(client.api_key, "test-key");
64        assert_eq!(client.model, "gpt-4.1-mini");
65        assert_eq!(client.temperature, 0.5);
66        assert_eq!(client.max_tokens, 1000);
67        assert_eq!(client.retry_attempts, 2);
68        assert_eq!(client.retry_delay_ms, 100);
69    }
70
71    #[tokio::test]
72    async fn test_chat_completion_success() {
73        let server = MockServer::start().await;
74        Mock::given(method("POST"))
75            .and(path("/chat/completions"))
76            .and(header("authorization", "Bearer test-key"))
77            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
78                "choices": [{"message": {"content": "test response content"}}]
79            })))
80            .mount(&server)
81            .await;
82        let mut client =
83            OpenAIClient::new("test-key".into(), "gpt-4.1-mini".into(), 0.3, 1000, 1, 0);
84        client.base_url = server.uri();
85        let messages = vec![json!({"role":"user","content":"test"})];
86        let resp = client.chat_completion(messages).await.unwrap();
87        assert_eq!(resp, "test response content");
88    }
89
90    #[tokio::test]
91    async fn test_chat_completion_error() {
92        let server = MockServer::start().await;
93        Mock::given(method("POST"))
94            .and(path("/chat/completions"))
95            .respond_with(ResponseTemplate::new(400).set_body_json(json!({
96                "error": {"message":"Invalid API key"}
97            })))
98            .mount(&server)
99            .await;
100        let mut client =
101            OpenAIClient::new("bad-key".into(), "gpt-4.1-mini".into(), 0.3, 1000, 1, 0);
102        client.base_url = server.uri();
103        let messages = vec![json!({"role":"user","content":"test"})];
104        let result = client.chat_completion(messages).await;
105        assert!(result.is_err());
106    }
107
108    #[tokio::test]
109    async fn test_analyze_content_with_mock() {
110        let mut mock = MockAIClient::new();
111        let req = AnalysisRequest {
112            video_files: vec!["v.mp4".into()],
113            subtitle_files: vec!["s.srt".into()],
114            content_samples: vec![],
115        };
116        let expected = MatchResult {
117            matches: vec![],
118            confidence: 0.5,
119            reasoning: "OK".into(),
120        };
121        mock.expect_analyze_content()
122            .with(eq(req.clone()))
123            .times(1)
124            .returning(move |_| Ok(expected.clone()));
125        let res = mock.analyze_content(req.clone()).await.unwrap();
126        assert_eq!(res.confidence, 0.5);
127    }
128
129    #[test]
130    fn test_prompt_building_and_parsing() {
131        let client = OpenAIClient::new("k".into(), "m".into(), 0.1, 1000, 0, 0);
132        let request = AnalysisRequest {
133            video_files: vec!["F1.mp4".into()],
134            subtitle_files: vec!["S1.srt".into()],
135            content_samples: vec![],
136        };
137        let prompt = client.build_analysis_prompt(&request);
138        assert!(prompt.contains("F1.mp4"));
139        assert!(prompt.contains("S1.srt"));
140        assert!(prompt.contains("JSON"));
141        let json_resp = r#"{ "matches": [], "confidence":0.9, "reasoning":"r" }"#;
142        let mr = client.parse_match_result(json_resp).unwrap();
143        assert_eq!(mr.confidence, 0.9);
144    }
145
146    #[test]
147    fn test_openai_client_from_config() {
148        let config = crate::config::AIConfig {
149            provider: "openai".to_string(),
150            api_key: Some("test-key".to_string()),
151            model: "gpt-test".to_string(),
152            base_url: "https://custom.openai.com/v1".to_string(),
153            max_sample_length: 500,
154            temperature: 0.7,
155            max_tokens: 2000,
156            retry_attempts: 2,
157            retry_delay_ms: 150,
158            request_timeout_seconds: 60,
159            api_version: None,
160        };
161        let client = OpenAIClient::from_config(&config).unwrap();
162        assert_eq!(client.api_key, "test-key");
163        assert_eq!(client.model, "gpt-test");
164        assert_eq!(client.temperature, 0.7);
165        assert_eq!(client.max_tokens, 2000);
166    }
167
168    #[test]
169    fn test_openai_client_from_config_invalid_base_url() {
170        let config = crate::config::AIConfig {
171            provider: "openai".to_string(),
172            api_key: Some("test-key".to_string()),
173            model: "gpt-test".to_string(),
174            base_url: "ftp://invalid.url".to_string(),
175            max_sample_length: 500,
176            temperature: 0.7,
177            max_tokens: 1000,
178            retry_attempts: 2,
179            retry_delay_ms: 150,
180            request_timeout_seconds: 30,
181            api_version: None,
182        };
183        let err = OpenAIClient::from_config(&config).unwrap_err();
184        // Non-http/https protocols should return protocol error message
185        assert!(
186            err.to_string()
187                .contains("Base URL must use http or https protocol")
188        );
189    }
190}
191
192impl OpenAIClient {
193    /// Create new OpenAIClient (using default base_url)
194    pub fn new(
195        api_key: String,
196        model: String,
197        temperature: f32,
198        max_tokens: u32,
199        retry_attempts: u32,
200        retry_delay_ms: u64,
201    ) -> Self {
202        Self::new_with_base_url(
203            api_key,
204            model,
205            temperature,
206            max_tokens,
207            retry_attempts,
208            retry_delay_ms,
209            "https://api.openai.com/v1".to_string(),
210        )
211    }
212
213    /// Create a new OpenAIClient with custom base_url support
214    pub fn new_with_base_url(
215        api_key: String,
216        model: String,
217        temperature: f32,
218        max_tokens: u32,
219        retry_attempts: u32,
220        retry_delay_ms: u64,
221        base_url: String,
222    ) -> Self {
223        // Use default 30 second timeout for backward compatibility
224        Self::new_with_base_url_and_timeout(
225            api_key,
226            model,
227            temperature,
228            max_tokens,
229            retry_attempts,
230            retry_delay_ms,
231            base_url,
232            30,
233        )
234    }
235
236    /// Create a new OpenAIClient with custom base_url and timeout support
237    #[allow(clippy::too_many_arguments)]
238    pub fn new_with_base_url_and_timeout(
239        api_key: String,
240        model: String,
241        temperature: f32,
242        max_tokens: u32,
243        retry_attempts: u32,
244        retry_delay_ms: u64,
245        base_url: String,
246        request_timeout_seconds: u64,
247    ) -> Self {
248        let client = Client::builder()
249            .timeout(Duration::from_secs(request_timeout_seconds))
250            .build()
251            .expect("Failed to create HTTP client");
252        Self {
253            client,
254            api_key,
255            model,
256            temperature,
257            max_tokens,
258            retry_attempts,
259            retry_delay_ms,
260            base_url: base_url.trim_end_matches('/').to_string(),
261        }
262    }
263
264    /// Create client from unified configuration
265    pub fn from_config(config: &crate::config::AIConfig) -> crate::Result<Self> {
266        let api_key = config
267            .api_key
268            .as_ref()
269            .ok_or_else(|| crate::error::SubXError::config("Missing OpenAI API Key"))?;
270
271        // Validate base URL format
272        Self::validate_base_url(&config.base_url)?;
273
274        Ok(Self::new_with_base_url_and_timeout(
275            api_key.clone(),
276            config.model.clone(),
277            config.temperature,
278            config.max_tokens,
279            config.retry_attempts,
280            config.retry_delay_ms,
281            config.base_url.clone(),
282            config.request_timeout_seconds,
283        ))
284    }
285
286    /// Validate base URL format
287    fn validate_base_url(url: &str) -> crate::Result<()> {
288        use url::Url;
289        let parsed = Url::parse(url)
290            .map_err(|e| crate::error::SubXError::config(format!("Invalid base URL: {}", e)))?;
291
292        if !matches!(parsed.scheme(), "http" | "https") {
293            return Err(crate::error::SubXError::config(
294                "Base URL must use http or https protocol".to_string(),
295            ));
296        }
297
298        if parsed.host().is_none() {
299            return Err(crate::error::SubXError::config(
300                "Base URL must contain a valid hostname".to_string(),
301            ));
302        }
303
304        Ok(())
305    }
306
307    async fn chat_completion(&self, messages: Vec<serde_json::Value>) -> Result<String> {
308        let request_body = json!({
309            "model": self.model,
310            "messages": messages,
311            "temperature": self.temperature,
312            "max_tokens": self.max_tokens,
313        });
314
315        let request = self
316            .client
317            .post(format!("{}/chat/completions", self.base_url))
318            .header("Authorization", format!("Bearer {}", self.api_key))
319            .header("Content-Type", "application/json")
320            .json(&request_body);
321        let response = self.make_request_with_retry(request).await?;
322
323        if !response.status().is_success() {
324            let status = response.status();
325            let error_text = response.text().await?;
326            return Err(SubXError::AiService(format!(
327                "OpenAI API error {}: {}",
328                status, error_text
329            )));
330        }
331
332        let response_json: Value = response.json().await?;
333        let content = response_json["choices"][0]["message"]["content"]
334            .as_str()
335            .ok_or_else(|| SubXError::AiService("Invalid API response format".to_string()))?;
336
337        // Parse usage statistics and display
338        if let Some(usage_obj) = response_json.get("usage") {
339            if let (Some(p), Some(c), Some(t)) = (
340                usage_obj.get("prompt_tokens").and_then(Value::as_u64),
341                usage_obj.get("completion_tokens").and_then(Value::as_u64),
342                usage_obj.get("total_tokens").and_then(Value::as_u64),
343            ) {
344                let stats = AiUsageStats {
345                    model: self.model.clone(),
346                    prompt_tokens: p as u32,
347                    completion_tokens: c as u32,
348                    total_tokens: t as u32,
349                };
350                display_ai_usage(&stats);
351            }
352        }
353
354        Ok(content.to_string())
355    }
356}
357
358#[async_trait]
359impl AIProvider for OpenAIClient {
360    async fn analyze_content(&self, request: AnalysisRequest) -> Result<MatchResult> {
361        let prompt = self.build_analysis_prompt(&request);
362        let messages = vec![
363            json!({"role": "system", "content": Self::get_analysis_system_message()}),
364            json!({"role": "user", "content": prompt}),
365        ];
366        let response = self.chat_completion(messages).await?;
367        self.parse_match_result(&response)
368    }
369
370    async fn verify_match(&self, verification: VerificationRequest) -> Result<ConfidenceScore> {
371        let prompt = self.build_verification_prompt(&verification);
372        let messages = vec![
373            json!({"role": "system", "content": Self::get_verification_system_message()}),
374            json!({"role": "user", "content": prompt}),
375        ];
376        let response = self.chat_completion(messages).await?;
377        self.parse_confidence_score(&response)
378    }
379}