subx_cli/services/ai/
openai.rs

1use crate::Result;
2use crate::cli::display_ai_usage;
3use crate::error::SubXError;
4use crate::services::ai::AiUsageStats;
5use crate::services::ai::{
6    AIProvider, AnalysisRequest, ConfidenceScore, MatchResult, VerificationRequest,
7};
8use async_trait::async_trait;
9use reqwest::Client;
10use serde_json::Value;
11use serde_json::json;
12use std::time::Duration;
13use tokio::time;
14
15/// OpenAI client implementation
16#[derive(Debug)]
17pub struct OpenAIClient {
18    client: Client,
19    api_key: String,
20    model: String,
21    temperature: f32,
22    max_tokens: u32,
23    retry_attempts: u32,
24    retry_delay_ms: u64,
25    base_url: String,
26}
27
28// Mock testing: OpenAIClient with AIProvider interface
29#[cfg(test)]
30mod tests {
31    use super::*;
32    use mockall::{mock, predicate::eq};
33    use serde_json::json;
34    use wiremock::matchers::{header, method, path};
35    use wiremock::{Mock, MockServer, ResponseTemplate};
36
37    mock! {
38        AIClient {}
39
40        #[async_trait]
41        impl AIProvider for AIClient {
42            async fn analyze_content(&self, request: AnalysisRequest) -> crate::Result<MatchResult>;
43            async fn verify_match(&self, verification: VerificationRequest) -> crate::Result<ConfidenceScore>;
44        }
45    }
46
47    #[tokio::test]
48    async fn test_openai_client_creation() {
49        let client = OpenAIClient::new("test-key".into(), "gpt-4.1-mini".into(), 0.5, 1000, 2, 100);
50        assert_eq!(client.api_key, "test-key");
51        assert_eq!(client.model, "gpt-4.1-mini");
52        assert_eq!(client.temperature, 0.5);
53        assert_eq!(client.max_tokens, 1000);
54        assert_eq!(client.retry_attempts, 2);
55        assert_eq!(client.retry_delay_ms, 100);
56    }
57
58    #[tokio::test]
59    async fn test_chat_completion_success() {
60        let server = MockServer::start().await;
61        Mock::given(method("POST"))
62            .and(path("/chat/completions"))
63            .and(header("authorization", "Bearer test-key"))
64            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
65                "choices": [{"message": {"content": "test response content"}}]
66            })))
67            .mount(&server)
68            .await;
69        let mut client =
70            OpenAIClient::new("test-key".into(), "gpt-4.1-mini".into(), 0.3, 1000, 1, 0);
71        client.base_url = server.uri();
72        let messages = vec![json!({"role":"user","content":"test"})];
73        let resp = client.chat_completion(messages).await.unwrap();
74        assert_eq!(resp, "test response content");
75    }
76
77    #[tokio::test]
78    async fn test_chat_completion_error() {
79        let server = MockServer::start().await;
80        Mock::given(method("POST"))
81            .and(path("/chat/completions"))
82            .respond_with(ResponseTemplate::new(400).set_body_json(json!({
83                "error": {"message":"Invalid API key"}
84            })))
85            .mount(&server)
86            .await;
87        let mut client =
88            OpenAIClient::new("bad-key".into(), "gpt-4.1-mini".into(), 0.3, 1000, 1, 0);
89        client.base_url = server.uri();
90        let messages = vec![json!({"role":"user","content":"test"})];
91        let result = client.chat_completion(messages).await;
92        assert!(result.is_err());
93    }
94
95    #[tokio::test]
96    async fn test_analyze_content_with_mock() {
97        let mut mock = MockAIClient::new();
98        let req = AnalysisRequest {
99            video_files: vec!["v.mp4".into()],
100            subtitle_files: vec!["s.srt".into()],
101            content_samples: vec![],
102        };
103        let expected = MatchResult {
104            matches: vec![],
105            confidence: 0.5,
106            reasoning: "OK".into(),
107        };
108        mock.expect_analyze_content()
109            .with(eq(req.clone()))
110            .times(1)
111            .returning(move |_| Ok(expected.clone()));
112        let res = mock.analyze_content(req.clone()).await.unwrap();
113        assert_eq!(res.confidence, 0.5);
114    }
115
116    #[test]
117    fn test_prompt_building_and_parsing() {
118        let client = OpenAIClient::new("k".into(), "m".into(), 0.1, 1000, 0, 0);
119        let request = AnalysisRequest {
120            video_files: vec!["F1.mp4".into()],
121            subtitle_files: vec!["S1.srt".into()],
122            content_samples: vec![],
123        };
124        let prompt = client.build_analysis_prompt(&request);
125        assert!(prompt.contains("F1.mp4"));
126        assert!(prompt.contains("S1.srt"));
127        assert!(prompt.contains("JSON"));
128        let json_resp = r#"{ "matches": [], "confidence":0.9, "reasoning":"r" }"#;
129        let mr = client.parse_match_result(json_resp).unwrap();
130        assert_eq!(mr.confidence, 0.9);
131    }
132
133    #[test]
134    fn test_openai_client_from_config() {
135        let config = crate::config::AIConfig {
136            provider: "openai".to_string(),
137            api_key: Some("test-key".to_string()),
138            model: "gpt-test".to_string(),
139            base_url: "https://custom.openai.com/v1".to_string(),
140            max_sample_length: 500,
141            temperature: 0.7,
142            max_tokens: 2000,
143            retry_attempts: 2,
144            retry_delay_ms: 150,
145            request_timeout_seconds: 60,
146            api_version: None,
147        };
148        let client = OpenAIClient::from_config(&config).unwrap();
149        assert_eq!(client.api_key, "test-key");
150        assert_eq!(client.model, "gpt-test");
151        assert_eq!(client.temperature, 0.7);
152        assert_eq!(client.max_tokens, 2000);
153    }
154
155    #[test]
156    fn test_openai_client_from_config_invalid_base_url() {
157        let config = crate::config::AIConfig {
158            provider: "openai".to_string(),
159            api_key: Some("test-key".to_string()),
160            model: "gpt-test".to_string(),
161            base_url: "ftp://invalid.url".to_string(),
162            max_sample_length: 500,
163            temperature: 0.7,
164            max_tokens: 1000,
165            retry_attempts: 2,
166            retry_delay_ms: 150,
167            request_timeout_seconds: 30,
168            api_version: None,
169        };
170        let err = OpenAIClient::from_config(&config).unwrap_err();
171        // Non-http/https protocols should return protocol error message
172        assert!(
173            err.to_string()
174                .contains("Base URL must use http or https protocol")
175        );
176    }
177}
178
179impl OpenAIClient {
180    /// Create new OpenAIClient (using default base_url)
181    pub fn new(
182        api_key: String,
183        model: String,
184        temperature: f32,
185        max_tokens: u32,
186        retry_attempts: u32,
187        retry_delay_ms: u64,
188    ) -> Self {
189        Self::new_with_base_url(
190            api_key,
191            model,
192            temperature,
193            max_tokens,
194            retry_attempts,
195            retry_delay_ms,
196            "https://api.openai.com/v1".to_string(),
197        )
198    }
199
200    /// Create a new OpenAIClient with custom base_url support
201    pub fn new_with_base_url(
202        api_key: String,
203        model: String,
204        temperature: f32,
205        max_tokens: u32,
206        retry_attempts: u32,
207        retry_delay_ms: u64,
208        base_url: String,
209    ) -> Self {
210        // Use default 30 second timeout for backward compatibility
211        Self::new_with_base_url_and_timeout(
212            api_key,
213            model,
214            temperature,
215            max_tokens,
216            retry_attempts,
217            retry_delay_ms,
218            base_url,
219            30,
220        )
221    }
222
223    /// Create a new OpenAIClient with custom base_url and timeout support
224    #[allow(clippy::too_many_arguments)]
225    pub fn new_with_base_url_and_timeout(
226        api_key: String,
227        model: String,
228        temperature: f32,
229        max_tokens: u32,
230        retry_attempts: u32,
231        retry_delay_ms: u64,
232        base_url: String,
233        request_timeout_seconds: u64,
234    ) -> Self {
235        let client = Client::builder()
236            .timeout(Duration::from_secs(request_timeout_seconds))
237            .build()
238            .expect("Failed to create HTTP client");
239        Self {
240            client,
241            api_key,
242            model,
243            temperature,
244            max_tokens,
245            retry_attempts,
246            retry_delay_ms,
247            base_url: base_url.trim_end_matches('/').to_string(),
248        }
249    }
250
251    /// Create client from unified configuration
252    pub fn from_config(config: &crate::config::AIConfig) -> crate::Result<Self> {
253        let api_key = config
254            .api_key
255            .as_ref()
256            .ok_or_else(|| crate::error::SubXError::config("Missing OpenAI API Key"))?;
257
258        // Validate base URL format
259        Self::validate_base_url(&config.base_url)?;
260
261        Ok(Self::new_with_base_url_and_timeout(
262            api_key.clone(),
263            config.model.clone(),
264            config.temperature,
265            config.max_tokens,
266            config.retry_attempts,
267            config.retry_delay_ms,
268            config.base_url.clone(),
269            config.request_timeout_seconds,
270        ))
271    }
272
273    /// Validate base URL format
274    fn validate_base_url(url: &str) -> crate::Result<()> {
275        use url::Url;
276        let parsed = Url::parse(url)
277            .map_err(|e| crate::error::SubXError::config(format!("Invalid base URL: {}", e)))?;
278
279        if !matches!(parsed.scheme(), "http" | "https") {
280            return Err(crate::error::SubXError::config(
281                "Base URL must use http or https protocol".to_string(),
282            ));
283        }
284
285        if parsed.host().is_none() {
286            return Err(crate::error::SubXError::config(
287                "Base URL must contain a valid hostname".to_string(),
288            ));
289        }
290
291        Ok(())
292    }
293
294    async fn chat_completion(&self, messages: Vec<serde_json::Value>) -> Result<String> {
295        let request_body = json!({
296            "model": self.model,
297            "messages": messages,
298            "temperature": self.temperature,
299            "max_tokens": self.max_tokens,
300        });
301
302        let request = self
303            .client
304            .post(format!("{}/chat/completions", self.base_url))
305            .header("Authorization", format!("Bearer {}", self.api_key))
306            .header("Content-Type", "application/json")
307            .json(&request_body);
308        let response = self.make_request_with_retry(request).await?;
309
310        if !response.status().is_success() {
311            let status = response.status();
312            let error_text = response.text().await?;
313            return Err(SubXError::AiService(format!(
314                "OpenAI API error {}: {}",
315                status, error_text
316            )));
317        }
318
319        let response_json: Value = response.json().await?;
320        let content = response_json["choices"][0]["message"]["content"]
321            .as_str()
322            .ok_or_else(|| SubXError::AiService("Invalid API response format".to_string()))?;
323
324        // Parse usage statistics and display
325        if let Some(usage_obj) = response_json.get("usage") {
326            if let (Some(p), Some(c), Some(t)) = (
327                usage_obj.get("prompt_tokens").and_then(Value::as_u64),
328                usage_obj.get("completion_tokens").and_then(Value::as_u64),
329                usage_obj.get("total_tokens").and_then(Value::as_u64),
330            ) {
331                let stats = AiUsageStats {
332                    model: self.model.clone(),
333                    prompt_tokens: p as u32,
334                    completion_tokens: c as u32,
335                    total_tokens: t as u32,
336                };
337                display_ai_usage(&stats);
338            }
339        }
340
341        Ok(content.to_string())
342    }
343}
344
345#[async_trait]
346impl AIProvider for OpenAIClient {
347    async fn analyze_content(&self, request: AnalysisRequest) -> Result<MatchResult> {
348        let prompt = self.build_analysis_prompt(&request);
349        let messages = vec![
350            json!({"role": "system", "content": "You are a professional subtitle matching assistant that can analyze the correspondence between video and subtitle files."}),
351            json!({"role": "user", "content": prompt}),
352        ];
353        let response = self.chat_completion(messages).await?;
354        self.parse_match_result(&response)
355    }
356
357    async fn verify_match(&self, verification: VerificationRequest) -> Result<ConfidenceScore> {
358        let prompt = self.build_verification_prompt(&verification);
359        let messages = vec![
360            json!({"role": "system", "content": "Please evaluate the confidence level of subtitle matching and provide a score between 0-1."}),
361            json!({"role": "user", "content": prompt}),
362        ];
363        let response = self.chat_completion(messages).await?;
364        self.parse_confidence_score(&response)
365    }
366}
367
368impl OpenAIClient {
369    async fn make_request_with_retry(
370        &self,
371        request: reqwest::RequestBuilder,
372    ) -> reqwest::Result<reqwest::Response> {
373        let mut attempts = 0;
374        loop {
375            match request.try_clone().unwrap().send().await {
376                Ok(resp) => {
377                    if attempts > 0 {
378                        log::info!("Request succeeded after {} retry attempts", attempts);
379                    }
380                    return Ok(resp);
381                }
382                Err(e) if (attempts as u32) < self.retry_attempts => {
383                    attempts += 1;
384                    log::warn!(
385                        "Request attempt {} failed: {}. Retrying in {}ms...",
386                        attempts,
387                        e,
388                        self.retry_delay_ms
389                    );
390
391                    // Provide specific guidance for timeout errors
392                    if e.is_timeout() {
393                        log::warn!(
394                            "This appears to be a timeout error. If this persists, consider increasing 'ai.request_timeout_seconds' in your configuration."
395                        );
396                    }
397
398                    time::sleep(Duration::from_millis(self.retry_delay_ms)).await;
399                    continue;
400                }
401                Err(e) => {
402                    log::error!(
403                        "Request failed after {} attempts. Final error: {}",
404                        attempts + 1,
405                        e
406                    );
407
408                    // Provide actionable error messages
409                    if e.is_timeout() {
410                        log::error!(
411                            "AI service error: Request timed out after multiple attempts. \
412                            This usually indicates network connectivity issues or server overload. \
413                            Try increasing 'ai.request_timeout_seconds' configuration. \
414                            Hint: check network connection and API service status"
415                        );
416                    } else if e.is_connect() {
417                        log::error!(
418                            "AI service error: Connection failed. \
419                            Hint: check network connection and API base URL settings"
420                        );
421                    }
422
423                    return Err(e);
424                }
425            }
426        }
427    }
428}