1use crate::Result;
2use crate::cli::display_ai_usage;
3use crate::error::SubXError;
4use crate::services::ai::AiUsageStats;
5use crate::services::ai::{
6 AIProvider, AnalysisRequest, ConfidenceScore, MatchResult, VerificationRequest,
7};
8use async_trait::async_trait;
9use reqwest::Client;
10use serde_json::Value;
11use serde_json::json;
12use std::time::Duration;
13use tokio::time;
14
15#[derive(Debug)]
17pub struct OpenAIClient {
18 client: Client,
19 api_key: String,
20 model: String,
21 temperature: f32,
22 max_tokens: u32,
23 retry_attempts: u32,
24 retry_delay_ms: u64,
25 base_url: String,
26}
27
28#[cfg(test)]
30mod tests {
31 use super::*;
32 use mockall::{mock, predicate::eq};
33 use serde_json::json;
34 use wiremock::matchers::{header, method, path};
35 use wiremock::{Mock, MockServer, ResponseTemplate};
36
37 mock! {
38 AIClient {}
39
40 #[async_trait]
41 impl AIProvider for AIClient {
42 async fn analyze_content(&self, request: AnalysisRequest) -> crate::Result<MatchResult>;
43 async fn verify_match(&self, verification: VerificationRequest) -> crate::Result<ConfidenceScore>;
44 }
45 }
46
47 #[tokio::test]
48 async fn test_openai_client_creation() {
49 let client = OpenAIClient::new("test-key".into(), "gpt-4.1-mini".into(), 0.5, 1000, 2, 100);
50 assert_eq!(client.api_key, "test-key");
51 assert_eq!(client.model, "gpt-4.1-mini");
52 assert_eq!(client.temperature, 0.5);
53 assert_eq!(client.max_tokens, 1000);
54 assert_eq!(client.retry_attempts, 2);
55 assert_eq!(client.retry_delay_ms, 100);
56 }
57
58 #[tokio::test]
59 async fn test_chat_completion_success() {
60 let server = MockServer::start().await;
61 Mock::given(method("POST"))
62 .and(path("/chat/completions"))
63 .and(header("authorization", "Bearer test-key"))
64 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
65 "choices": [{"message": {"content": "test response content"}}]
66 })))
67 .mount(&server)
68 .await;
69 let mut client =
70 OpenAIClient::new("test-key".into(), "gpt-4.1-mini".into(), 0.3, 1000, 1, 0);
71 client.base_url = server.uri();
72 let messages = vec![json!({"role":"user","content":"test"})];
73 let resp = client.chat_completion(messages).await.unwrap();
74 assert_eq!(resp, "test response content");
75 }
76
77 #[tokio::test]
78 async fn test_chat_completion_error() {
79 let server = MockServer::start().await;
80 Mock::given(method("POST"))
81 .and(path("/chat/completions"))
82 .respond_with(ResponseTemplate::new(400).set_body_json(json!({
83 "error": {"message":"Invalid API key"}
84 })))
85 .mount(&server)
86 .await;
87 let mut client =
88 OpenAIClient::new("bad-key".into(), "gpt-4.1-mini".into(), 0.3, 1000, 1, 0);
89 client.base_url = server.uri();
90 let messages = vec![json!({"role":"user","content":"test"})];
91 let result = client.chat_completion(messages).await;
92 assert!(result.is_err());
93 }
94
95 #[tokio::test]
96 async fn test_analyze_content_with_mock() {
97 let mut mock = MockAIClient::new();
98 let req = AnalysisRequest {
99 video_files: vec!["v.mp4".into()],
100 subtitle_files: vec!["s.srt".into()],
101 content_samples: vec![],
102 };
103 let expected = MatchResult {
104 matches: vec![],
105 confidence: 0.5,
106 reasoning: "OK".into(),
107 };
108 mock.expect_analyze_content()
109 .with(eq(req.clone()))
110 .times(1)
111 .returning(move |_| Ok(expected.clone()));
112 let res = mock.analyze_content(req.clone()).await.unwrap();
113 assert_eq!(res.confidence, 0.5);
114 }
115
116 #[test]
117 fn test_prompt_building_and_parsing() {
118 let client = OpenAIClient::new("k".into(), "m".into(), 0.1, 1000, 0, 0);
119 let request = AnalysisRequest {
120 video_files: vec!["F1.mp4".into()],
121 subtitle_files: vec!["S1.srt".into()],
122 content_samples: vec![],
123 };
124 let prompt = client.build_analysis_prompt(&request);
125 assert!(prompt.contains("F1.mp4"));
126 assert!(prompt.contains("S1.srt"));
127 assert!(prompt.contains("JSON"));
128 let json_resp = r#"{ "matches": [], "confidence":0.9, "reasoning":"r" }"#;
129 let mr = client.parse_match_result(json_resp).unwrap();
130 assert_eq!(mr.confidence, 0.9);
131 }
132
133 #[test]
134 fn test_openai_client_from_config() {
135 let config = crate::config::AIConfig {
136 provider: "openai".to_string(),
137 api_key: Some("test-key".to_string()),
138 model: "gpt-test".to_string(),
139 base_url: "https://custom.openai.com/v1".to_string(),
140 temperature: 0.7,
141 max_tokens: 2000,
142 retry_attempts: 2,
143 retry_delay_ms: 150,
144 max_sample_length: 500,
145 request_timeout_seconds: 60,
146 };
147 let client = OpenAIClient::from_config(&config).unwrap();
148 assert_eq!(client.api_key, "test-key");
149 assert_eq!(client.model, "gpt-test");
150 assert_eq!(client.temperature, 0.7);
151 assert_eq!(client.max_tokens, 2000);
152 }
153
154 #[test]
155 fn test_openai_client_from_config_invalid_base_url() {
156 let config = crate::config::AIConfig {
157 provider: "openai".to_string(),
158 api_key: Some("test-key".to_string()),
159 model: "gpt-test".to_string(),
160 base_url: "ftp://invalid.url".to_string(),
161 temperature: 0.7,
162 max_tokens: 1000,
163 retry_attempts: 2,
164 retry_delay_ms: 150,
165 max_sample_length: 500,
166 request_timeout_seconds: 30,
167 };
168 let err = OpenAIClient::from_config(&config).unwrap_err();
169 assert!(
171 err.to_string()
172 .contains("Base URL must use http or https protocol")
173 );
174 }
175}
176
177impl OpenAIClient {
178 pub fn new(
180 api_key: String,
181 model: String,
182 temperature: f32,
183 max_tokens: u32,
184 retry_attempts: u32,
185 retry_delay_ms: u64,
186 ) -> Self {
187 Self::new_with_base_url(
188 api_key,
189 model,
190 temperature,
191 max_tokens,
192 retry_attempts,
193 retry_delay_ms,
194 "https://api.openai.com/v1".to_string(),
195 )
196 }
197
198 pub fn new_with_base_url(
200 api_key: String,
201 model: String,
202 temperature: f32,
203 max_tokens: u32,
204 retry_attempts: u32,
205 retry_delay_ms: u64,
206 base_url: String,
207 ) -> Self {
208 Self::new_with_base_url_and_timeout(
210 api_key,
211 model,
212 temperature,
213 max_tokens,
214 retry_attempts,
215 retry_delay_ms,
216 base_url,
217 30,
218 )
219 }
220
221 #[allow(clippy::too_many_arguments)]
223 pub fn new_with_base_url_and_timeout(
224 api_key: String,
225 model: String,
226 temperature: f32,
227 max_tokens: u32,
228 retry_attempts: u32,
229 retry_delay_ms: u64,
230 base_url: String,
231 request_timeout_seconds: u64,
232 ) -> Self {
233 let client = Client::builder()
234 .timeout(Duration::from_secs(request_timeout_seconds))
235 .build()
236 .expect("Failed to create HTTP client");
237 Self {
238 client,
239 api_key,
240 model,
241 temperature,
242 max_tokens,
243 retry_attempts,
244 retry_delay_ms,
245 base_url: base_url.trim_end_matches('/').to_string(),
246 }
247 }
248
249 pub fn from_config(config: &crate::config::AIConfig) -> crate::Result<Self> {
251 let api_key = config
252 .api_key
253 .as_ref()
254 .ok_or_else(|| crate::error::SubXError::config("Missing OpenAI API Key"))?;
255
256 Self::validate_base_url(&config.base_url)?;
258
259 Ok(Self::new_with_base_url_and_timeout(
260 api_key.clone(),
261 config.model.clone(),
262 config.temperature,
263 config.max_tokens,
264 config.retry_attempts,
265 config.retry_delay_ms,
266 config.base_url.clone(),
267 config.request_timeout_seconds,
268 ))
269 }
270
271 fn validate_base_url(url: &str) -> crate::Result<()> {
273 use url::Url;
274 let parsed = Url::parse(url)
275 .map_err(|e| crate::error::SubXError::config(format!("Invalid base URL: {}", e)))?;
276
277 if !matches!(parsed.scheme(), "http" | "https") {
278 return Err(crate::error::SubXError::config(
279 "Base URL must use http or https protocol".to_string(),
280 ));
281 }
282
283 if parsed.host().is_none() {
284 return Err(crate::error::SubXError::config(
285 "Base URL must contain a valid hostname".to_string(),
286 ));
287 }
288
289 Ok(())
290 }
291
292 async fn chat_completion(&self, messages: Vec<serde_json::Value>) -> Result<String> {
293 let request_body = json!({
294 "model": self.model,
295 "messages": messages,
296 "temperature": self.temperature,
297 "max_tokens": self.max_tokens,
298 });
299
300 let request = self
301 .client
302 .post(format!("{}/chat/completions", self.base_url))
303 .header("Authorization", format!("Bearer {}", self.api_key))
304 .header("Content-Type", "application/json")
305 .json(&request_body);
306 let response = self.make_request_with_retry(request).await?;
307
308 if !response.status().is_success() {
309 let status = response.status();
310 let error_text = response.text().await?;
311 return Err(SubXError::AiService(format!(
312 "OpenAI API error {}: {}",
313 status, error_text
314 )));
315 }
316
317 let response_json: Value = response.json().await?;
318 let content = response_json["choices"][0]["message"]["content"]
319 .as_str()
320 .ok_or_else(|| SubXError::AiService("Invalid API response format".to_string()))?;
321
322 if let Some(usage_obj) = response_json.get("usage") {
324 if let (Some(p), Some(c), Some(t)) = (
325 usage_obj.get("prompt_tokens").and_then(Value::as_u64),
326 usage_obj.get("completion_tokens").and_then(Value::as_u64),
327 usage_obj.get("total_tokens").and_then(Value::as_u64),
328 ) {
329 let stats = AiUsageStats {
330 model: self.model.clone(),
331 prompt_tokens: p as u32,
332 completion_tokens: c as u32,
333 total_tokens: t as u32,
334 };
335 display_ai_usage(&stats);
336 }
337 }
338
339 Ok(content.to_string())
340 }
341}
342
343#[async_trait]
344impl AIProvider for OpenAIClient {
345 async fn analyze_content(&self, request: AnalysisRequest) -> Result<MatchResult> {
346 let prompt = self.build_analysis_prompt(&request);
347 let messages = vec![
348 json!({"role": "system", "content": "You are a professional subtitle matching assistant that can analyze the correspondence between video and subtitle files."}),
349 json!({"role": "user", "content": prompt}),
350 ];
351 let response = self.chat_completion(messages).await?;
352 self.parse_match_result(&response)
353 }
354
355 async fn verify_match(&self, verification: VerificationRequest) -> Result<ConfidenceScore> {
356 let prompt = self.build_verification_prompt(&verification);
357 let messages = vec![
358 json!({"role": "system", "content": "Please evaluate the confidence level of subtitle matching and provide a score between 0-1."}),
359 json!({"role": "user", "content": prompt}),
360 ];
361 let response = self.chat_completion(messages).await?;
362 self.parse_confidence_score(&response)
363 }
364}
365
366impl OpenAIClient {
367 async fn make_request_with_retry(
368 &self,
369 request: reqwest::RequestBuilder,
370 ) -> reqwest::Result<reqwest::Response> {
371 let mut attempts = 0;
372 loop {
373 match request.try_clone().unwrap().send().await {
374 Ok(resp) => {
375 if attempts > 0 {
376 log::info!("Request succeeded after {} retry attempts", attempts);
377 }
378 return Ok(resp);
379 }
380 Err(e) if (attempts as u32) < self.retry_attempts => {
381 attempts += 1;
382 log::warn!(
383 "Request attempt {} failed: {}. Retrying in {}ms...",
384 attempts,
385 e,
386 self.retry_delay_ms
387 );
388
389 if e.is_timeout() {
391 log::warn!(
392 "This appears to be a timeout error. If this persists, consider increasing 'ai.request_timeout_seconds' in your configuration."
393 );
394 }
395
396 time::sleep(Duration::from_millis(self.retry_delay_ms)).await;
397 continue;
398 }
399 Err(e) => {
400 log::error!(
401 "Request failed after {} attempts. Final error: {}",
402 attempts + 1,
403 e
404 );
405
406 if e.is_timeout() {
408 log::error!(
409 "AI service error: Request timed out after multiple attempts. \
410 This usually indicates network connectivity issues or server overload. \
411 Try increasing 'ai.request_timeout_seconds' configuration. \
412 Hint: check network connection and API service status"
413 );
414 } else if e.is_connect() {
415 log::error!(
416 "AI service error: Connection failed. \
417 Hint: check network connection and API base URL settings"
418 );
419 }
420
421 return Err(e);
422 }
423 }
424 }
425 }
426}