1use crate::Result;
2use crate::cli::display_ai_usage;
3use crate::error::SubXError;
4use crate::services::ai::AiUsageStats;
5use crate::services::ai::{
6 AIProvider, AnalysisRequest, ConfidenceScore, MatchResult, VerificationRequest,
7};
8use async_trait::async_trait;
9use reqwest::Client;
10use serde_json::{Value, json};
11use std::time::Duration;
12use tokio::time;
13
14#[derive(Debug)]
16pub struct OpenRouterClient {
17 client: Client,
18 api_key: String,
19 model: String,
20 temperature: f32,
21 max_tokens: u32,
22 retry_attempts: u32,
23 retry_delay_ms: u64,
24 base_url: String,
25}
26
27impl OpenRouterClient {
28 pub fn new(
30 api_key: String,
31 model: String,
32 temperature: f32,
33 max_tokens: u32,
34 retry_attempts: u32,
35 retry_delay_ms: u64,
36 ) -> Self {
37 Self::new_with_base_url_and_timeout(
38 api_key,
39 model,
40 temperature,
41 max_tokens,
42 retry_attempts,
43 retry_delay_ms,
44 "https://openrouter.ai/api/v1".to_string(),
45 120,
46 )
47 }
48
49 #[allow(clippy::too_many_arguments)]
51 pub fn new_with_base_url_and_timeout(
52 api_key: String,
53 model: String,
54 temperature: f32,
55 max_tokens: u32,
56 retry_attempts: u32,
57 retry_delay_ms: u64,
58 base_url: String,
59 request_timeout_seconds: u64,
60 ) -> Self {
61 let client = Client::builder()
62 .timeout(Duration::from_secs(request_timeout_seconds))
63 .build()
64 .expect("Failed to create HTTP client");
65
66 Self {
67 client,
68 api_key,
69 model,
70 temperature,
71 max_tokens,
72 retry_attempts,
73 retry_delay_ms,
74 base_url: base_url.trim_end_matches('/').to_string(),
75 }
76 }
77
78 pub fn from_config(config: &crate::config::AIConfig) -> crate::Result<Self> {
80 let api_key = config
81 .api_key
82 .as_ref()
83 .ok_or_else(|| SubXError::config("Missing OpenRouter API Key"))?;
84
85 Self::validate_base_url(&config.base_url)?;
87
88 Ok(Self::new_with_base_url_and_timeout(
89 api_key.clone(),
90 config.model.clone(),
91 config.temperature,
92 config.max_tokens,
93 config.retry_attempts,
94 config.retry_delay_ms,
95 config.base_url.clone(),
96 config.request_timeout_seconds,
97 ))
98 }
99
100 fn validate_base_url(url: &str) -> crate::Result<()> {
102 use url::Url;
103 let parsed =
104 Url::parse(url).map_err(|e| SubXError::config(format!("Invalid base URL: {}", e)))?;
105
106 if !matches!(parsed.scheme(), "http" | "https") {
107 return Err(SubXError::config(
108 "Base URL must use http or https protocol".to_string(),
109 ));
110 }
111
112 if parsed.host().is_none() {
113 return Err(SubXError::config(
114 "Base URL must contain a valid hostname".to_string(),
115 ));
116 }
117
118 Ok(())
119 }
120
121 async fn chat_completion(&self, messages: Vec<Value>) -> Result<String> {
122 let request_body = json!({
123 "model": self.model,
124 "messages": messages,
125 "temperature": self.temperature,
126 "max_tokens": self.max_tokens,
127 });
128
129 let request = self
130 .client
131 .post(format!("{}/chat/completions", self.base_url))
132 .header("Authorization", format!("Bearer {}", self.api_key))
133 .header("Content-Type", "application/json")
134 .header("HTTP-Referer", "https://github.com/jim60105/subx-cli")
135 .header("X-Title", "Subx")
136 .json(&request_body);
137
138 let response = self.make_request_with_retry(request).await?;
139
140 if !response.status().is_success() {
141 let status = response.status();
142 let error_text = response.text().await?;
143 return Err(SubXError::AiService(format!(
144 "OpenRouter API error {}: {}",
145 status, error_text
146 )));
147 }
148
149 let response_json: Value = response.json().await?;
150 let content = response_json["choices"][0]["message"]["content"]
151 .as_str()
152 .ok_or_else(|| SubXError::AiService("Invalid API response format".to_string()))?;
153
154 if let Some(usage_obj) = response_json.get("usage") {
156 if let (Some(p), Some(c), Some(t)) = (
157 usage_obj.get("prompt_tokens").and_then(Value::as_u64),
158 usage_obj.get("completion_tokens").and_then(Value::as_u64),
159 usage_obj.get("total_tokens").and_then(Value::as_u64),
160 ) {
161 let stats = AiUsageStats {
162 model: self.model.clone(),
163 prompt_tokens: p as u32,
164 completion_tokens: c as u32,
165 total_tokens: t as u32,
166 };
167 display_ai_usage(&stats);
168 }
169 }
170
171 Ok(content.to_string())
172 }
173
174 async fn make_request_with_retry(
175 &self,
176 request: reqwest::RequestBuilder,
177 ) -> reqwest::Result<reqwest::Response> {
178 let mut attempts = 0;
179 loop {
180 match request.try_clone().unwrap().send().await {
181 Ok(resp) => {
182 if resp.status().is_server_error() && (attempts as u32) < self.retry_attempts {
184 attempts += 1;
185 log::warn!(
186 "Request attempt {} failed with status {}. Retrying in {}ms...",
187 attempts,
188 resp.status(),
189 self.retry_delay_ms
190 );
191 time::sleep(Duration::from_millis(self.retry_delay_ms)).await;
192 continue;
193 }
194 if attempts > 0 {
195 log::info!("Request succeeded after {} retry attempts", attempts);
196 }
197 return Ok(resp);
198 }
199 Err(e) if (attempts as u32) < self.retry_attempts => {
200 attempts += 1;
201 log::warn!(
202 "Request attempt {} failed: {}. Retrying in {}ms...",
203 attempts,
204 e,
205 self.retry_delay_ms
206 );
207
208 if e.is_timeout() {
209 log::warn!(
210 "This appears to be a timeout error. If this persists, consider increasing 'ai.request_timeout_seconds' in your configuration."
211 );
212 }
213
214 time::sleep(Duration::from_millis(self.retry_delay_ms)).await;
215 continue;
216 }
217 Err(e) => {
218 log::error!(
219 "Request failed after {} attempts. Final error: {}",
220 attempts + 1,
221 e
222 );
223
224 if e.is_timeout() {
225 log::error!(
226 "AI service error: Request timed out after multiple attempts. \
227 This usually indicates network connectivity issues or server overload. \
228 Try increasing 'ai.request_timeout_seconds' configuration. \
229 Hint: check network connection and API service status"
230 );
231 } else if e.is_connect() {
232 log::error!(
233 "AI service error: Connection failed. \
234 Hint: check network connection and API base URL settings"
235 );
236 }
237
238 return Err(e);
239 }
240 }
241 }
242 }
243}
244
245#[async_trait]
246impl AIProvider for OpenRouterClient {
247 async fn analyze_content(&self, request: AnalysisRequest) -> Result<MatchResult> {
248 let prompt = self.build_analysis_prompt(&request);
249 let messages = vec![
250 json!({"role": "system", "content": "You are a professional subtitle matching assistant that can analyze the correspondence between video and subtitle files."}),
251 json!({"role": "user", "content": prompt}),
252 ];
253 let response = self.chat_completion(messages).await?;
254 self.parse_match_result(&response)
255 }
256
257 async fn verify_match(&self, verification: VerificationRequest) -> Result<ConfidenceScore> {
258 let prompt = self.build_verification_prompt(&verification);
259 let messages = vec![
260 json!({"role": "system", "content": "Please evaluate the confidence level of subtitle matching and provide a score between 0-1."}),
261 json!({"role": "user", "content": prompt}),
262 ];
263 let response = self.chat_completion(messages).await?;
264 self.parse_confidence_score(&response)
265 }
266}
267
268impl OpenRouterClient {
270 pub fn build_analysis_prompt(&self, request: &AnalysisRequest) -> String {
272 let mut prompt = String::new();
273 prompt.push_str("Please analyze the matching relationship between the following video and subtitle files. Each file has a unique ID that you must use in your response.\n\n");
274
275 prompt.push_str("Video files:\n");
276 for video in &request.video_files {
277 prompt.push_str(&format!("- {}\n", video));
278 }
279
280 prompt.push_str("\nSubtitle files:\n");
281 for subtitle in &request.subtitle_files {
282 prompt.push_str(&format!("- {}\n", subtitle));
283 }
284
285 if !request.content_samples.is_empty() {
286 prompt.push_str("\nSubtitle content preview:\n");
287 for sample in &request.content_samples {
288 prompt.push_str(&format!("File: {}\n", sample.filename));
289 prompt.push_str(&format!("Content: {}\n\n", sample.content_preview));
290 }
291 }
292
293 prompt.push_str(
294 "Please provide matching suggestions based on filename patterns, content similarity, and other factors.\n\
295 Response format must be JSON using the file IDs:\n\
296 {\n\
297 \"matches\": [\n\
298 {\n\
299 \"video_file_id\": \"file_abc123456789abcd\",\n\
300 \"subtitle_file_id\": \"file_def456789abcdef0\",\n\
301 \"confidence\": 0.95,\n\
302 \"match_factors\": [\"filename_similarity\", \"content_correlation\"]\n\
303 }\n\
304 ],\n\
305 \"confidence\": 0.9,\n\
306 \"reasoning\": \"Explanation for the matching decisions\"\n\
307 }",
308 );
309
310 prompt
311 }
312
313 pub fn parse_match_result(&self, response: &str) -> Result<MatchResult> {
315 let json_start = response.find('{').unwrap_or(0);
316 let json_end = response.rfind('}').map(|i| i + 1).unwrap_or(response.len());
317 let json_str = &response[json_start..json_end];
318
319 serde_json::from_str(json_str)
320 .map_err(|e| SubXError::AiService(format!("AI response parsing failed: {}", e)))
321 }
322
323 pub fn build_verification_prompt(&self, request: &VerificationRequest) -> String {
325 let mut prompt = String::new();
326 prompt.push_str(
327 "Please evaluate the confidence level based on the following matching information:\n",
328 );
329 prompt.push_str(&format!("Video file: {}\n", request.video_file));
330 prompt.push_str(&format!("Subtitle file: {}\n", request.subtitle_file));
331 prompt.push_str("Matching factors:\n");
332 for factor in &request.match_factors {
333 prompt.push_str(&format!("- {}\n", factor));
334 }
335 prompt.push_str(
336 "\nPlease respond in JSON format as follows:\n{\n \"score\": 0.9,\n \"factors\": [\"...\"]\n}",
337 );
338 prompt
339 }
340
341 pub fn parse_confidence_score(&self, response: &str) -> Result<ConfidenceScore> {
343 let json_start = response.find('{').unwrap_or(0);
344 let json_end = response.rfind('}').map(|i| i + 1).unwrap_or(response.len());
345 let json_str = &response[json_start..json_end];
346
347 serde_json::from_str(json_str)
348 .map_err(|e| SubXError::AiService(format!("AI confidence parsing failed: {}", e)))
349 }
350}
351#[cfg(test)]
352mod tests {
353 use super::*;
354 use mockall::mock;
355 use serde_json::json;
356 use wiremock::matchers::{header, method, path};
357 use wiremock::{Mock, MockServer, ResponseTemplate};
358
359 mock! {
360 AIClient {}
361
362 #[async_trait]
363 impl AIProvider for AIClient {
364 async fn analyze_content(&self, request: AnalysisRequest) -> crate::Result<MatchResult>;
365 async fn verify_match(&self, verification: VerificationRequest) -> crate::Result<ConfidenceScore>;
366 }
367 }
368
369 #[tokio::test]
370 async fn test_openrouter_client_creation() {
371 let client = OpenRouterClient::new(
372 "test-key".into(),
373 "deepseek/deepseek-r1-0528:free".into(),
374 0.5,
375 1000,
376 2,
377 100,
378 );
379 assert_eq!(client.api_key, "test-key");
380 assert_eq!(client.model, "deepseek/deepseek-r1-0528:free");
381 assert_eq!(client.temperature, 0.5);
382 assert_eq!(client.max_tokens, 1000);
383 assert_eq!(client.retry_attempts, 2);
384 assert_eq!(client.retry_delay_ms, 100);
385 assert_eq!(client.base_url, "https://openrouter.ai/api/v1");
386 }
387
388 #[tokio::test]
389 async fn test_openrouter_client_creation_with_custom_base_url() {
390 let client = OpenRouterClient::new_with_base_url_and_timeout(
391 "test-key".into(),
392 "deepseek/deepseek-r1-0528:free".into(),
393 0.3,
394 2000,
395 3,
396 200,
397 "https://custom-openrouter.ai/api/v1".into(),
398 60,
399 );
400 assert_eq!(client.base_url, "https://custom-openrouter.ai/api/v1");
401 }
402
403 #[tokio::test]
404 async fn test_chat_completion_success() {
405 let server = MockServer::start().await;
406 Mock::given(method("POST"))
407 .and(path("/chat/completions"))
408 .and(header("authorization", "Bearer test-key"))
409 .and(header(
410 "HTTP-Referer",
411 "https://github.com/jim60105/subx-cli",
412 ))
413 .and(header("X-Title", "Subx"))
414 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
415 "choices": [{"message": {"content": "test response content"}}],
416 "usage": { "prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15 }
417 })))
418 .mount(&server)
419 .await;
420
421 let mut client = OpenRouterClient::new(
422 "test-key".into(),
423 "deepseek/deepseek-r1-0528:free".into(),
424 0.3,
425 1000,
426 1,
427 0,
428 );
429 client.base_url = server.uri();
430
431 let messages = vec![json!({"role":"user","content":"test"})];
432 let resp = client.chat_completion(messages).await.unwrap();
433 assert_eq!(resp, "test response content");
434 }
435
436 #[tokio::test]
437 async fn test_chat_completion_error_handling() {
438 let server = MockServer::start().await;
439 Mock::given(method("POST"))
440 .and(path("/chat/completions"))
441 .respond_with(ResponseTemplate::new(401).set_body_json(json!({
442 "error": {"message":"Invalid API key"}
443 })))
444 .mount(&server)
445 .await;
446
447 let mut client = OpenRouterClient::new(
448 "bad-key".into(),
449 "deepseek/deepseek-r1-0528:free".into(),
450 0.3,
451 1000,
452 1,
453 0,
454 );
455 client.base_url = server.uri();
456
457 let messages = vec![json!({"role":"user","content":"test"})];
458 let result = client.chat_completion(messages).await;
459 assert!(result.is_err());
460 assert!(
461 result
462 .err()
463 .unwrap()
464 .to_string()
465 .contains("OpenRouter API error 401")
466 );
467 }
468
469 #[tokio::test]
470 async fn test_retry_mechanism() {
471 let server = MockServer::start().await;
472
473 Mock::given(method("POST"))
475 .and(path("/chat/completions"))
476 .respond_with(ResponseTemplate::new(500))
477 .up_to_n_times(1)
478 .mount(&server)
479 .await;
480
481 Mock::given(method("POST"))
482 .and(path("/chat/completions"))
483 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
484 "choices": [{"message": {"content": "success after retry"}}]
485 })))
486 .mount(&server)
487 .await;
488
489 let mut client = OpenRouterClient::new(
490 "test-key".into(),
491 "deepseek/deepseek-r1-0528:free".into(),
492 0.3,
493 1000,
494 2, 50, );
497 client.base_url = server.uri();
498
499 let messages = vec![json!({"role":"user","content":"test"})];
500 let result = client.chat_completion(messages).await.unwrap();
501 assert_eq!(result, "success after retry");
502 }
503
504 #[test]
505 fn test_openrouter_client_from_config() {
506 let config = crate::config::AIConfig {
507 provider: "openrouter".to_string(),
508 api_key: Some("test-key".to_string()),
509 model: "deepseek/deepseek-r1-0528:free".to_string(),
510 base_url: "https://openrouter.ai/api/v1".to_string(),
511 max_sample_length: 500,
512 temperature: 0.7,
513 max_tokens: 2000,
514 retry_attempts: 3,
515 retry_delay_ms: 150,
516 request_timeout_seconds: 120,
517 api_version: None,
518 };
519
520 let client = OpenRouterClient::from_config(&config).unwrap();
521 assert_eq!(client.api_key, "test-key");
522 assert_eq!(client.model, "deepseek/deepseek-r1-0528:free");
523 assert_eq!(client.temperature, 0.7);
524 assert_eq!(client.max_tokens, 2000);
525 assert_eq!(client.retry_attempts, 3);
526 assert_eq!(client.retry_delay_ms, 150);
527 }
528
529 #[test]
530 fn test_openrouter_client_from_config_missing_api_key() {
531 let config = crate::config::AIConfig {
532 provider: "openrouter".to_string(),
533 api_key: None,
534 model: "deepseek/deepseek-r1-0528:free".to_string(),
535 base_url: "https://openrouter.ai/api/v1".to_string(),
536 max_sample_length: 500,
537 temperature: 0.3,
538 max_tokens: 1000,
539 retry_attempts: 2,
540 retry_delay_ms: 100,
541 request_timeout_seconds: 30,
542 api_version: None,
543 };
544
545 let result = OpenRouterClient::from_config(&config);
546 assert!(result.is_err());
547 assert!(
548 result
549 .err()
550 .unwrap()
551 .to_string()
552 .contains("Missing OpenRouter API Key")
553 );
554 }
555
556 #[test]
557 fn test_openrouter_client_from_config_invalid_base_url() {
558 let config = crate::config::AIConfig {
559 provider: "openrouter".to_string(),
560 api_key: Some("test-key".to_string()),
561 model: "deepseek/deepseek-r1-0528:free".to_string(),
562 base_url: "ftp://invalid.url".to_string(),
563 max_sample_length: 500,
564 temperature: 0.3,
565 max_tokens: 1000,
566 retry_attempts: 2,
567 retry_delay_ms: 100,
568 request_timeout_seconds: 30,
569 api_version: None,
570 };
571
572 let result = OpenRouterClient::from_config(&config);
573 assert!(result.is_err());
574 assert!(
575 result
576 .err()
577 .unwrap()
578 .to_string()
579 .contains("must use http or https protocol")
580 );
581 }
582
583 #[test]
584 fn test_prompt_building_and_parsing() {
585 let client = OpenRouterClient::new(
586 "test-key".into(),
587 "deepseek/deepseek-r1-0528:free".into(),
588 0.1,
589 1000,
590 0,
591 0,
592 );
593 let request = AnalysisRequest {
594 video_files: vec!["video1.mp4".into()],
595 subtitle_files: vec!["subtitle1.srt".into()],
596 content_samples: vec![],
597 };
598
599 let prompt = client.build_analysis_prompt(&request);
600 assert!(prompt.contains("video1.mp4"));
601 assert!(prompt.contains("subtitle1.srt"));
602 assert!(prompt.contains("JSON"));
603
604 let json_response = r#"{ "matches": [], "confidence":0.9, "reasoning":"test reason" }"#;
605 let match_result = client.parse_match_result(json_response).unwrap();
606 assert_eq!(match_result.confidence, 0.9);
607 assert_eq!(match_result.reasoning, "test reason");
608 }
609}