1use crate::Result;
2use crate::cli::display_ai_usage;
3use crate::error::SubXError;
4use crate::services::ai::AiUsageStats;
5use crate::services::ai::{
6 AIProvider, AnalysisRequest, ConfidenceScore, MatchResult, VerificationRequest,
7};
8use async_trait::async_trait;
9use reqwest::Client;
10use serde_json::Value;
11use serde_json::json;
12use std::time::Duration;
13use tokio::time;
14
15#[derive(Debug)]
17pub struct OpenAIClient {
18 client: Client,
19 api_key: String,
20 model: String,
21 temperature: f32,
22 max_tokens: u32,
23 retry_attempts: u32,
24 retry_delay_ms: u64,
25 base_url: String,
26}
27
28#[cfg(test)]
30mod tests {
31 use super::*;
32 use mockall::{mock, predicate::eq};
33 use serde_json::json;
34 use wiremock::matchers::{header, method, path};
35 use wiremock::{Mock, MockServer, ResponseTemplate};
36
37 mock! {
38 AIClient {}
39
40 #[async_trait]
41 impl AIProvider for AIClient {
42 async fn analyze_content(&self, request: AnalysisRequest) -> crate::Result<MatchResult>;
43 async fn verify_match(&self, verification: VerificationRequest) -> crate::Result<ConfidenceScore>;
44 }
45 }
46
47 #[tokio::test]
48 async fn test_openai_client_creation() {
49 let client = OpenAIClient::new("test-key".into(), "gpt-4.1-mini".into(), 0.5, 1000, 2, 100);
50 assert_eq!(client.api_key, "test-key");
51 assert_eq!(client.model, "gpt-4.1-mini");
52 assert_eq!(client.temperature, 0.5);
53 assert_eq!(client.max_tokens, 1000);
54 assert_eq!(client.retry_attempts, 2);
55 assert_eq!(client.retry_delay_ms, 100);
56 }
57
58 #[tokio::test]
59 async fn test_chat_completion_success() {
60 let server = MockServer::start().await;
61 Mock::given(method("POST"))
62 .and(path("/chat/completions"))
63 .and(header("authorization", "Bearer test-key"))
64 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
65 "choices": [{"message": {"content": "test response content"}}]
66 })))
67 .mount(&server)
68 .await;
69 let mut client =
70 OpenAIClient::new("test-key".into(), "gpt-4.1-mini".into(), 0.3, 1000, 1, 0);
71 client.base_url = server.uri();
72 let messages = vec![json!({"role":"user","content":"test"})];
73 let resp = client.chat_completion(messages).await.unwrap();
74 assert_eq!(resp, "test response content");
75 }
76
77 #[tokio::test]
78 async fn test_chat_completion_error() {
79 let server = MockServer::start().await;
80 Mock::given(method("POST"))
81 .and(path("/chat/completions"))
82 .respond_with(ResponseTemplate::new(400).set_body_json(json!({
83 "error": {"message":"Invalid API key"}
84 })))
85 .mount(&server)
86 .await;
87 let mut client =
88 OpenAIClient::new("bad-key".into(), "gpt-4.1-mini".into(), 0.3, 1000, 1, 0);
89 client.base_url = server.uri();
90 let messages = vec![json!({"role":"user","content":"test"})];
91 let result = client.chat_completion(messages).await;
92 assert!(result.is_err());
93 }
94
95 #[tokio::test]
96 async fn test_analyze_content_with_mock() {
97 let mut mock = MockAIClient::new();
98 let req = AnalysisRequest {
99 video_files: vec!["v.mp4".into()],
100 subtitle_files: vec!["s.srt".into()],
101 content_samples: vec![],
102 };
103 let expected = MatchResult {
104 matches: vec![],
105 confidence: 0.5,
106 reasoning: "OK".into(),
107 };
108 mock.expect_analyze_content()
109 .with(eq(req.clone()))
110 .times(1)
111 .returning(move |_| Ok(expected.clone()));
112 let res = mock.analyze_content(req.clone()).await.unwrap();
113 assert_eq!(res.confidence, 0.5);
114 }
115
116 #[test]
117 fn test_prompt_building_and_parsing() {
118 let client = OpenAIClient::new("k".into(), "m".into(), 0.1, 1000, 0, 0);
119 let request = AnalysisRequest {
120 video_files: vec!["F1.mp4".into()],
121 subtitle_files: vec!["S1.srt".into()],
122 content_samples: vec![],
123 };
124 let prompt = client.build_analysis_prompt(&request);
125 assert!(prompt.contains("F1.mp4"));
126 assert!(prompt.contains("S1.srt"));
127 assert!(prompt.contains("JSON"));
128 let json_resp = r#"{ "matches": [], "confidence":0.9, "reasoning":"r" }"#;
129 let mr = client.parse_match_result(json_resp).unwrap();
130 assert_eq!(mr.confidence, 0.9);
131 }
132
133 #[test]
134 fn test_openai_client_from_config() {
135 let config = crate::config::AIConfig {
136 provider: "openai".to_string(),
137 api_key: Some("test-key".to_string()),
138 model: "gpt-test".to_string(),
139 base_url: "https://custom.openai.com/v1".to_string(),
140 max_sample_length: 500,
141 temperature: 0.7,
142 max_tokens: 2000,
143 retry_attempts: 2,
144 retry_delay_ms: 150,
145 request_timeout_seconds: 60,
146 api_version: None,
147 };
148 let client = OpenAIClient::from_config(&config).unwrap();
149 assert_eq!(client.api_key, "test-key");
150 assert_eq!(client.model, "gpt-test");
151 assert_eq!(client.temperature, 0.7);
152 assert_eq!(client.max_tokens, 2000);
153 }
154
155 #[test]
156 fn test_openai_client_from_config_invalid_base_url() {
157 let config = crate::config::AIConfig {
158 provider: "openai".to_string(),
159 api_key: Some("test-key".to_string()),
160 model: "gpt-test".to_string(),
161 base_url: "ftp://invalid.url".to_string(),
162 max_sample_length: 500,
163 temperature: 0.7,
164 max_tokens: 1000,
165 retry_attempts: 2,
166 retry_delay_ms: 150,
167 request_timeout_seconds: 30,
168 api_version: None,
169 };
170 let err = OpenAIClient::from_config(&config).unwrap_err();
171 assert!(
173 err.to_string()
174 .contains("Base URL must use http or https protocol")
175 );
176 }
177}
178
179impl OpenAIClient {
180 pub fn new(
182 api_key: String,
183 model: String,
184 temperature: f32,
185 max_tokens: u32,
186 retry_attempts: u32,
187 retry_delay_ms: u64,
188 ) -> Self {
189 Self::new_with_base_url(
190 api_key,
191 model,
192 temperature,
193 max_tokens,
194 retry_attempts,
195 retry_delay_ms,
196 "https://api.openai.com/v1".to_string(),
197 )
198 }
199
200 pub fn new_with_base_url(
202 api_key: String,
203 model: String,
204 temperature: f32,
205 max_tokens: u32,
206 retry_attempts: u32,
207 retry_delay_ms: u64,
208 base_url: String,
209 ) -> Self {
210 Self::new_with_base_url_and_timeout(
212 api_key,
213 model,
214 temperature,
215 max_tokens,
216 retry_attempts,
217 retry_delay_ms,
218 base_url,
219 30,
220 )
221 }
222
223 #[allow(clippy::too_many_arguments)]
225 pub fn new_with_base_url_and_timeout(
226 api_key: String,
227 model: String,
228 temperature: f32,
229 max_tokens: u32,
230 retry_attempts: u32,
231 retry_delay_ms: u64,
232 base_url: String,
233 request_timeout_seconds: u64,
234 ) -> Self {
235 let client = Client::builder()
236 .timeout(Duration::from_secs(request_timeout_seconds))
237 .build()
238 .expect("Failed to create HTTP client");
239 Self {
240 client,
241 api_key,
242 model,
243 temperature,
244 max_tokens,
245 retry_attempts,
246 retry_delay_ms,
247 base_url: base_url.trim_end_matches('/').to_string(),
248 }
249 }
250
251 pub fn from_config(config: &crate::config::AIConfig) -> crate::Result<Self> {
253 let api_key = config
254 .api_key
255 .as_ref()
256 .ok_or_else(|| crate::error::SubXError::config("Missing OpenAI API Key"))?;
257
258 Self::validate_base_url(&config.base_url)?;
260
261 Ok(Self::new_with_base_url_and_timeout(
262 api_key.clone(),
263 config.model.clone(),
264 config.temperature,
265 config.max_tokens,
266 config.retry_attempts,
267 config.retry_delay_ms,
268 config.base_url.clone(),
269 config.request_timeout_seconds,
270 ))
271 }
272
273 fn validate_base_url(url: &str) -> crate::Result<()> {
275 use url::Url;
276 let parsed = Url::parse(url)
277 .map_err(|e| crate::error::SubXError::config(format!("Invalid base URL: {}", e)))?;
278
279 if !matches!(parsed.scheme(), "http" | "https") {
280 return Err(crate::error::SubXError::config(
281 "Base URL must use http or https protocol".to_string(),
282 ));
283 }
284
285 if parsed.host().is_none() {
286 return Err(crate::error::SubXError::config(
287 "Base URL must contain a valid hostname".to_string(),
288 ));
289 }
290
291 Ok(())
292 }
293
294 async fn chat_completion(&self, messages: Vec<serde_json::Value>) -> Result<String> {
295 let request_body = json!({
296 "model": self.model,
297 "messages": messages,
298 "temperature": self.temperature,
299 "max_tokens": self.max_tokens,
300 });
301
302 let request = self
303 .client
304 .post(format!("{}/chat/completions", self.base_url))
305 .header("Authorization", format!("Bearer {}", self.api_key))
306 .header("Content-Type", "application/json")
307 .json(&request_body);
308 let response = self.make_request_with_retry(request).await?;
309
310 if !response.status().is_success() {
311 let status = response.status();
312 let error_text = response.text().await?;
313 return Err(SubXError::AiService(format!(
314 "OpenAI API error {}: {}",
315 status, error_text
316 )));
317 }
318
319 let response_json: Value = response.json().await?;
320 let content = response_json["choices"][0]["message"]["content"]
321 .as_str()
322 .ok_or_else(|| SubXError::AiService("Invalid API response format".to_string()))?;
323
324 if let Some(usage_obj) = response_json.get("usage") {
326 if let (Some(p), Some(c), Some(t)) = (
327 usage_obj.get("prompt_tokens").and_then(Value::as_u64),
328 usage_obj.get("completion_tokens").and_then(Value::as_u64),
329 usage_obj.get("total_tokens").and_then(Value::as_u64),
330 ) {
331 let stats = AiUsageStats {
332 model: self.model.clone(),
333 prompt_tokens: p as u32,
334 completion_tokens: c as u32,
335 total_tokens: t as u32,
336 };
337 display_ai_usage(&stats);
338 }
339 }
340
341 Ok(content.to_string())
342 }
343}
344
345#[async_trait]
346impl AIProvider for OpenAIClient {
347 async fn analyze_content(&self, request: AnalysisRequest) -> Result<MatchResult> {
348 let prompt = self.build_analysis_prompt(&request);
349 let messages = vec![
350 json!({"role": "system", "content": "You are a professional subtitle matching assistant that can analyze the correspondence between video and subtitle files."}),
351 json!({"role": "user", "content": prompt}),
352 ];
353 let response = self.chat_completion(messages).await?;
354 self.parse_match_result(&response)
355 }
356
357 async fn verify_match(&self, verification: VerificationRequest) -> Result<ConfidenceScore> {
358 let prompt = self.build_verification_prompt(&verification);
359 let messages = vec![
360 json!({"role": "system", "content": "Please evaluate the confidence level of subtitle matching and provide a score between 0-1."}),
361 json!({"role": "user", "content": prompt}),
362 ];
363 let response = self.chat_completion(messages).await?;
364 self.parse_confidence_score(&response)
365 }
366}
367
368impl OpenAIClient {
369 async fn make_request_with_retry(
370 &self,
371 request: reqwest::RequestBuilder,
372 ) -> reqwest::Result<reqwest::Response> {
373 let mut attempts = 0;
374 loop {
375 match request.try_clone().unwrap().send().await {
376 Ok(resp) => {
377 if attempts > 0 {
378 log::info!("Request succeeded after {} retry attempts", attempts);
379 }
380 return Ok(resp);
381 }
382 Err(e) if (attempts as u32) < self.retry_attempts => {
383 attempts += 1;
384 log::warn!(
385 "Request attempt {} failed: {}. Retrying in {}ms...",
386 attempts,
387 e,
388 self.retry_delay_ms
389 );
390
391 if e.is_timeout() {
393 log::warn!(
394 "This appears to be a timeout error. If this persists, consider increasing 'ai.request_timeout_seconds' in your configuration."
395 );
396 }
397
398 time::sleep(Duration::from_millis(self.retry_delay_ms)).await;
399 continue;
400 }
401 Err(e) => {
402 log::error!(
403 "Request failed after {} attempts. Final error: {}",
404 attempts + 1,
405 e
406 );
407
408 if e.is_timeout() {
410 log::error!(
411 "AI service error: Request timed out after multiple attempts. \
412 This usually indicates network connectivity issues or server overload. \
413 Try increasing 'ai.request_timeout_seconds' configuration. \
414 Hint: check network connection and API service status"
415 );
416 } else if e.is_connect() {
417 log::error!(
418 "AI service error: Connection failed. \
419 Hint: check network connection and API base URL settings"
420 );
421 }
422
423 return Err(e);
424 }
425 }
426 }
427 }
428}