subx_cli/services/ai/
openai.rs

1use crate::Result;
2use crate::cli::display_ai_usage;
3use crate::error::SubXError;
4use crate::services::ai::AiUsageStats;
5use crate::services::ai::{
6    AIProvider, AnalysisRequest, ConfidenceScore, MatchResult, VerificationRequest,
7};
8use async_trait::async_trait;
9use reqwest::Client;
10use serde_json::Value;
11use serde_json::json;
12use std::time::Duration;
13use tokio::time;
14
15/// OpenAI client implementation
16#[derive(Debug)]
17pub struct OpenAIClient {
18    client: Client,
19    api_key: String,
20    model: String,
21    temperature: f32,
22    max_tokens: u32,
23    retry_attempts: u32,
24    retry_delay_ms: u64,
25    base_url: String,
26}
27
28// Mock testing: OpenAIClient with AIProvider interface
29#[cfg(test)]
30mod tests {
31    use super::*;
32    use mockall::{mock, predicate::eq};
33    use serde_json::json;
34    use wiremock::matchers::{header, method, path};
35    use wiremock::{Mock, MockServer, ResponseTemplate};
36
37    mock! {
38        AIClient {}
39
40        #[async_trait]
41        impl AIProvider for AIClient {
42            async fn analyze_content(&self, request: AnalysisRequest) -> crate::Result<MatchResult>;
43            async fn verify_match(&self, verification: VerificationRequest) -> crate::Result<ConfidenceScore>;
44        }
45    }
46
47    #[tokio::test]
48    async fn test_openai_client_creation() {
49        let client = OpenAIClient::new("test-key".into(), "gpt-4.1-mini".into(), 0.5, 1000, 2, 100);
50        assert_eq!(client.api_key, "test-key");
51        assert_eq!(client.model, "gpt-4.1-mini");
52        assert_eq!(client.temperature, 0.5);
53        assert_eq!(client.max_tokens, 1000);
54        assert_eq!(client.retry_attempts, 2);
55        assert_eq!(client.retry_delay_ms, 100);
56    }
57
58    #[tokio::test]
59    async fn test_chat_completion_success() {
60        let server = MockServer::start().await;
61        Mock::given(method("POST"))
62            .and(path("/chat/completions"))
63            .and(header("authorization", "Bearer test-key"))
64            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
65                "choices": [{"message": {"content": "test response content"}}]
66            })))
67            .mount(&server)
68            .await;
69        let mut client =
70            OpenAIClient::new("test-key".into(), "gpt-4.1-mini".into(), 0.3, 1000, 1, 0);
71        client.base_url = server.uri();
72        let messages = vec![json!({"role":"user","content":"test"})];
73        let resp = client.chat_completion(messages).await.unwrap();
74        assert_eq!(resp, "test response content");
75    }
76
77    #[tokio::test]
78    async fn test_chat_completion_error() {
79        let server = MockServer::start().await;
80        Mock::given(method("POST"))
81            .and(path("/chat/completions"))
82            .respond_with(ResponseTemplate::new(400).set_body_json(json!({
83                "error": {"message":"Invalid API key"}
84            })))
85            .mount(&server)
86            .await;
87        let mut client =
88            OpenAIClient::new("bad-key".into(), "gpt-4.1-mini".into(), 0.3, 1000, 1, 0);
89        client.base_url = server.uri();
90        let messages = vec![json!({"role":"user","content":"test"})];
91        let result = client.chat_completion(messages).await;
92        assert!(result.is_err());
93    }
94
95    #[tokio::test]
96    async fn test_analyze_content_with_mock() {
97        let mut mock = MockAIClient::new();
98        let req = AnalysisRequest {
99            video_files: vec!["v.mp4".into()],
100            subtitle_files: vec!["s.srt".into()],
101            content_samples: vec![],
102        };
103        let expected = MatchResult {
104            matches: vec![],
105            confidence: 0.5,
106            reasoning: "OK".into(),
107        };
108        mock.expect_analyze_content()
109            .with(eq(req.clone()))
110            .times(1)
111            .returning(move |_| Ok(expected.clone()));
112        let res = mock.analyze_content(req.clone()).await.unwrap();
113        assert_eq!(res.confidence, 0.5);
114    }
115
116    #[test]
117    fn test_prompt_building_and_parsing() {
118        let client = OpenAIClient::new("k".into(), "m".into(), 0.1, 1000, 0, 0);
119        let request = AnalysisRequest {
120            video_files: vec!["F1.mp4".into()],
121            subtitle_files: vec!["S1.srt".into()],
122            content_samples: vec![],
123        };
124        let prompt = client.build_analysis_prompt(&request);
125        assert!(prompt.contains("F1.mp4"));
126        assert!(prompt.contains("S1.srt"));
127        assert!(prompt.contains("JSON"));
128        let json_resp = r#"{ "matches": [], "confidence":0.9, "reasoning":"r" }"#;
129        let mr = client.parse_match_result(json_resp).unwrap();
130        assert_eq!(mr.confidence, 0.9);
131    }
132
133    #[test]
134    fn test_openai_client_from_config() {
135        let config = crate::config::AIConfig {
136            provider: "openai".to_string(),
137            api_key: Some("test-key".to_string()),
138            model: "gpt-test".to_string(),
139            base_url: "https://custom.openai.com/v1".to_string(),
140            temperature: 0.7,
141            max_tokens: 2000,
142            retry_attempts: 2,
143            retry_delay_ms: 150,
144            max_sample_length: 500,
145            request_timeout_seconds: 60,
146        };
147        let client = OpenAIClient::from_config(&config).unwrap();
148        assert_eq!(client.api_key, "test-key");
149        assert_eq!(client.model, "gpt-test");
150        assert_eq!(client.temperature, 0.7);
151        assert_eq!(client.max_tokens, 2000);
152    }
153
154    #[test]
155    fn test_openai_client_from_config_invalid_base_url() {
156        let config = crate::config::AIConfig {
157            provider: "openai".to_string(),
158            api_key: Some("test-key".to_string()),
159            model: "gpt-test".to_string(),
160            base_url: "ftp://invalid.url".to_string(),
161            temperature: 0.7,
162            max_tokens: 1000,
163            retry_attempts: 2,
164            retry_delay_ms: 150,
165            max_sample_length: 500,
166            request_timeout_seconds: 30,
167        };
168        let err = OpenAIClient::from_config(&config).unwrap_err();
169        // Non-http/https protocols should return protocol error message
170        assert!(
171            err.to_string()
172                .contains("Base URL must use http or https protocol")
173        );
174    }
175}
176
177impl OpenAIClient {
178    /// Create new OpenAIClient (using default base_url)
179    pub fn new(
180        api_key: String,
181        model: String,
182        temperature: f32,
183        max_tokens: u32,
184        retry_attempts: u32,
185        retry_delay_ms: u64,
186    ) -> Self {
187        Self::new_with_base_url(
188            api_key,
189            model,
190            temperature,
191            max_tokens,
192            retry_attempts,
193            retry_delay_ms,
194            "https://api.openai.com/v1".to_string(),
195        )
196    }
197
198    /// Create a new OpenAIClient with custom base_url support
199    pub fn new_with_base_url(
200        api_key: String,
201        model: String,
202        temperature: f32,
203        max_tokens: u32,
204        retry_attempts: u32,
205        retry_delay_ms: u64,
206        base_url: String,
207    ) -> Self {
208        // Use default 30 second timeout for backward compatibility
209        Self::new_with_base_url_and_timeout(
210            api_key,
211            model,
212            temperature,
213            max_tokens,
214            retry_attempts,
215            retry_delay_ms,
216            base_url,
217            30,
218        )
219    }
220
221    /// Create a new OpenAIClient with custom base_url and timeout support
222    #[allow(clippy::too_many_arguments)]
223    pub fn new_with_base_url_and_timeout(
224        api_key: String,
225        model: String,
226        temperature: f32,
227        max_tokens: u32,
228        retry_attempts: u32,
229        retry_delay_ms: u64,
230        base_url: String,
231        request_timeout_seconds: u64,
232    ) -> Self {
233        let client = Client::builder()
234            .timeout(Duration::from_secs(request_timeout_seconds))
235            .build()
236            .expect("Failed to create HTTP client");
237        Self {
238            client,
239            api_key,
240            model,
241            temperature,
242            max_tokens,
243            retry_attempts,
244            retry_delay_ms,
245            base_url: base_url.trim_end_matches('/').to_string(),
246        }
247    }
248
249    /// Create client from unified configuration
250    pub fn from_config(config: &crate::config::AIConfig) -> crate::Result<Self> {
251        let api_key = config
252            .api_key
253            .as_ref()
254            .ok_or_else(|| crate::error::SubXError::config("Missing OpenAI API Key"))?;
255
256        // Validate base URL format
257        Self::validate_base_url(&config.base_url)?;
258
259        Ok(Self::new_with_base_url_and_timeout(
260            api_key.clone(),
261            config.model.clone(),
262            config.temperature,
263            config.max_tokens,
264            config.retry_attempts,
265            config.retry_delay_ms,
266            config.base_url.clone(),
267            config.request_timeout_seconds,
268        ))
269    }
270
271    /// Validate base URL format
272    fn validate_base_url(url: &str) -> crate::Result<()> {
273        use url::Url;
274        let parsed = Url::parse(url)
275            .map_err(|e| crate::error::SubXError::config(format!("Invalid base URL: {}", e)))?;
276
277        if !matches!(parsed.scheme(), "http" | "https") {
278            return Err(crate::error::SubXError::config(
279                "Base URL must use http or https protocol".to_string(),
280            ));
281        }
282
283        if parsed.host().is_none() {
284            return Err(crate::error::SubXError::config(
285                "Base URL must contain a valid hostname".to_string(),
286            ));
287        }
288
289        Ok(())
290    }
291
292    async fn chat_completion(&self, messages: Vec<serde_json::Value>) -> Result<String> {
293        let request_body = json!({
294            "model": self.model,
295            "messages": messages,
296            "temperature": self.temperature,
297            "max_tokens": self.max_tokens,
298        });
299
300        let request = self
301            .client
302            .post(format!("{}/chat/completions", self.base_url))
303            .header("Authorization", format!("Bearer {}", self.api_key))
304            .header("Content-Type", "application/json")
305            .json(&request_body);
306        let response = self.make_request_with_retry(request).await?;
307
308        if !response.status().is_success() {
309            let status = response.status();
310            let error_text = response.text().await?;
311            return Err(SubXError::AiService(format!(
312                "OpenAI API error {}: {}",
313                status, error_text
314            )));
315        }
316
317        let response_json: Value = response.json().await?;
318        let content = response_json["choices"][0]["message"]["content"]
319            .as_str()
320            .ok_or_else(|| SubXError::AiService("Invalid API response format".to_string()))?;
321
322        // Parse usage statistics and display
323        if let Some(usage_obj) = response_json.get("usage") {
324            if let (Some(p), Some(c), Some(t)) = (
325                usage_obj.get("prompt_tokens").and_then(Value::as_u64),
326                usage_obj.get("completion_tokens").and_then(Value::as_u64),
327                usage_obj.get("total_tokens").and_then(Value::as_u64),
328            ) {
329                let stats = AiUsageStats {
330                    model: self.model.clone(),
331                    prompt_tokens: p as u32,
332                    completion_tokens: c as u32,
333                    total_tokens: t as u32,
334                };
335                display_ai_usage(&stats);
336            }
337        }
338
339        Ok(content.to_string())
340    }
341}
342
343#[async_trait]
344impl AIProvider for OpenAIClient {
345    async fn analyze_content(&self, request: AnalysisRequest) -> Result<MatchResult> {
346        let prompt = self.build_analysis_prompt(&request);
347        let messages = vec![
348            json!({"role": "system", "content": "You are a professional subtitle matching assistant that can analyze the correspondence between video and subtitle files."}),
349            json!({"role": "user", "content": prompt}),
350        ];
351        let response = self.chat_completion(messages).await?;
352        self.parse_match_result(&response)
353    }
354
355    async fn verify_match(&self, verification: VerificationRequest) -> Result<ConfidenceScore> {
356        let prompt = self.build_verification_prompt(&verification);
357        let messages = vec![
358            json!({"role": "system", "content": "Please evaluate the confidence level of subtitle matching and provide a score between 0-1."}),
359            json!({"role": "user", "content": prompt}),
360        ];
361        let response = self.chat_completion(messages).await?;
362        self.parse_confidence_score(&response)
363    }
364}
365
366impl OpenAIClient {
367    async fn make_request_with_retry(
368        &self,
369        request: reqwest::RequestBuilder,
370    ) -> reqwest::Result<reqwest::Response> {
371        let mut attempts = 0;
372        loop {
373            match request.try_clone().unwrap().send().await {
374                Ok(resp) => {
375                    if attempts > 0 {
376                        log::info!("Request succeeded after {} retry attempts", attempts);
377                    }
378                    return Ok(resp);
379                }
380                Err(e) if (attempts as u32) < self.retry_attempts => {
381                    attempts += 1;
382                    log::warn!(
383                        "Request attempt {} failed: {}. Retrying in {}ms...",
384                        attempts,
385                        e,
386                        self.retry_delay_ms
387                    );
388
389                    // Provide specific guidance for timeout errors
390                    if e.is_timeout() {
391                        log::warn!(
392                            "This appears to be a timeout error. If this persists, consider increasing 'ai.request_timeout_seconds' in your configuration."
393                        );
394                    }
395
396                    time::sleep(Duration::from_millis(self.retry_delay_ms)).await;
397                    continue;
398                }
399                Err(e) => {
400                    log::error!(
401                        "Request failed after {} attempts. Final error: {}",
402                        attempts + 1,
403                        e
404                    );
405
406                    // Provide actionable error messages
407                    if e.is_timeout() {
408                        log::error!(
409                            "AI service error: Request timed out after multiple attempts. \
410                            This usually indicates network connectivity issues or server overload. \
411                            Try increasing 'ai.request_timeout_seconds' configuration. \
412                            Hint: check network connection and API service status"
413                        );
414                    } else if e.is_connect() {
415                        log::error!(
416                            "AI service error: Connection failed. \
417                            Hint: check network connection and API base URL settings"
418                        );
419                    }
420
421                    return Err(e);
422                }
423            }
424        }
425    }
426}