1use crate::Result;
2use crate::cli::display_ai_usage;
3use crate::error::SubXError;
4use crate::services::ai::AiUsageStats;
5use crate::services::ai::{
6 AIProvider, AnalysisRequest, ConfidenceScore, MatchResult, VerificationRequest,
7};
8use async_trait::async_trait;
9use reqwest::Client;
10use serde_json::{Value, json};
11use std::time::Duration;
12use tokio::time;
13
14use crate::services::ai::prompts::{PromptBuilder, ResponseParser};
15use crate::services::ai::retry::HttpRetryClient;
16
17#[derive(Debug)]
19pub struct OpenRouterClient {
20 client: Client,
21 api_key: String,
22 model: String,
23 temperature: f32,
24 max_tokens: u32,
25 retry_attempts: u32,
26 retry_delay_ms: u64,
27 base_url: String,
28}
29
30impl PromptBuilder for OpenRouterClient {}
31impl ResponseParser for OpenRouterClient {}
32impl HttpRetryClient for OpenRouterClient {
33 fn retry_attempts(&self) -> u32 {
34 self.retry_attempts
35 }
36 fn retry_delay_ms(&self) -> u64 {
37 self.retry_delay_ms
38 }
39}
40
41impl OpenRouterClient {
42 pub fn new(
44 api_key: String,
45 model: String,
46 temperature: f32,
47 max_tokens: u32,
48 retry_attempts: u32,
49 retry_delay_ms: u64,
50 ) -> Self {
51 Self::new_with_base_url_and_timeout(
52 api_key,
53 model,
54 temperature,
55 max_tokens,
56 retry_attempts,
57 retry_delay_ms,
58 "https://openrouter.ai/api/v1".to_string(),
59 120,
60 )
61 }
62
63 #[allow(clippy::too_many_arguments)]
65 pub fn new_with_base_url_and_timeout(
66 api_key: String,
67 model: String,
68 temperature: f32,
69 max_tokens: u32,
70 retry_attempts: u32,
71 retry_delay_ms: u64,
72 base_url: String,
73 request_timeout_seconds: u64,
74 ) -> Self {
75 let client = Client::builder()
76 .timeout(Duration::from_secs(request_timeout_seconds))
77 .build()
78 .expect("Failed to create HTTP client");
79
80 Self {
81 client,
82 api_key,
83 model,
84 temperature,
85 max_tokens,
86 retry_attempts,
87 retry_delay_ms,
88 base_url: base_url.trim_end_matches('/').to_string(),
89 }
90 }
91
92 pub fn from_config(config: &crate::config::AIConfig) -> crate::Result<Self> {
94 let api_key = config
95 .api_key
96 .as_ref()
97 .ok_or_else(|| SubXError::config("Missing OpenRouter API Key"))?;
98
99 Self::validate_base_url(&config.base_url)?;
101
102 Ok(Self::new_with_base_url_and_timeout(
103 api_key.clone(),
104 config.model.clone(),
105 config.temperature,
106 config.max_tokens,
107 config.retry_attempts,
108 config.retry_delay_ms,
109 config.base_url.clone(),
110 config.request_timeout_seconds,
111 ))
112 }
113
114 fn validate_base_url(url: &str) -> crate::Result<()> {
116 use url::Url;
117 let parsed =
118 Url::parse(url).map_err(|e| SubXError::config(format!("Invalid base URL: {}", e)))?;
119
120 if !matches!(parsed.scheme(), "http" | "https") {
121 return Err(SubXError::config(
122 "Base URL must use http or https protocol".to_string(),
123 ));
124 }
125
126 if parsed.host().is_none() {
127 return Err(SubXError::config(
128 "Base URL must contain a valid hostname".to_string(),
129 ));
130 }
131
132 Ok(())
133 }
134
135 async fn chat_completion(&self, messages: Vec<Value>) -> Result<String> {
136 let request_body = json!({
137 "model": self.model,
138 "messages": messages,
139 "temperature": self.temperature,
140 "max_tokens": self.max_tokens,
141 });
142
143 let request = self
144 .client
145 .post(format!("{}/chat/completions", self.base_url))
146 .header("Authorization", format!("Bearer {}", self.api_key))
147 .header("Content-Type", "application/json")
148 .header("HTTP-Referer", "https://github.com/jim60105/subx-cli")
149 .header("X-Title", "Subx")
150 .json(&request_body);
151
152 let response = self.make_request_with_retry(request).await?;
153
154 if !response.status().is_success() {
155 let status = response.status();
156 let error_text = response.text().await?;
157 return Err(SubXError::AiService(format!(
158 "OpenRouter API error {}: {}",
159 status, error_text
160 )));
161 }
162
163 let response_json: Value = response.json().await?;
164 let content = response_json["choices"][0]["message"]["content"]
165 .as_str()
166 .ok_or_else(|| SubXError::AiService("Invalid API response format".to_string()))?;
167
168 if let Some(usage_obj) = response_json.get("usage") {
170 if let (Some(p), Some(c), Some(t)) = (
171 usage_obj.get("prompt_tokens").and_then(Value::as_u64),
172 usage_obj.get("completion_tokens").and_then(Value::as_u64),
173 usage_obj.get("total_tokens").and_then(Value::as_u64),
174 ) {
175 let stats = AiUsageStats {
176 model: self.model.clone(),
177 prompt_tokens: p as u32,
178 completion_tokens: c as u32,
179 total_tokens: t as u32,
180 };
181 display_ai_usage(&stats);
182 }
183 }
184
185 Ok(content.to_string())
186 }
187
188 async fn make_request_with_retry(
189 &self,
190 request: reqwest::RequestBuilder,
191 ) -> reqwest::Result<reqwest::Response> {
192 let mut attempts = 0;
193 loop {
194 match request.try_clone().unwrap().send().await {
195 Ok(resp) => {
196 if resp.status().is_server_error() && (attempts as u32) < self.retry_attempts {
198 attempts += 1;
199 log::warn!(
200 "Request attempt {} failed with status {}. Retrying in {}ms...",
201 attempts,
202 resp.status(),
203 self.retry_delay_ms
204 );
205 time::sleep(Duration::from_millis(self.retry_delay_ms)).await;
206 continue;
207 }
208 if attempts > 0 {
209 log::info!("Request succeeded after {} retry attempts", attempts);
210 }
211 return Ok(resp);
212 }
213 Err(e) if (attempts as u32) < self.retry_attempts => {
214 attempts += 1;
215 log::warn!(
216 "Request attempt {} failed: {}. Retrying in {}ms...",
217 attempts,
218 e,
219 self.retry_delay_ms
220 );
221
222 if e.is_timeout() {
223 log::warn!(
224 "This appears to be a timeout error. If this persists, consider increasing 'ai.request_timeout_seconds' in your configuration."
225 );
226 }
227
228 time::sleep(Duration::from_millis(self.retry_delay_ms)).await;
229 continue;
230 }
231 Err(e) => {
232 log::error!(
233 "Request failed after {} attempts. Final error: {}",
234 attempts + 1,
235 e
236 );
237
238 if e.is_timeout() {
239 log::error!(
240 "AI service error: Request timed out after multiple attempts. \
241 This usually indicates network connectivity issues or server overload. \
242 Try increasing 'ai.request_timeout_seconds' configuration. \
243 Hint: check network connection and API service status"
244 );
245 } else if e.is_connect() {
246 log::error!(
247 "AI service error: Connection failed. \
248 Hint: check network connection and API base URL settings"
249 );
250 }
251
252 return Err(e);
253 }
254 }
255 }
256 }
257}
258
259#[async_trait]
260impl AIProvider for OpenRouterClient {
261 async fn analyze_content(&self, request: AnalysisRequest) -> Result<MatchResult> {
262 let prompt = self.build_analysis_prompt(&request);
263 let messages = vec![
264 json!({"role": "system", "content": "You are a professional subtitle matching assistant that can analyze the correspondence between video and subtitle files."}),
265 json!({"role": "user", "content": prompt}),
266 ];
267 let response = self.chat_completion(messages).await?;
268 self.parse_match_result(&response)
269 }
270
271 async fn verify_match(&self, verification: VerificationRequest) -> Result<ConfidenceScore> {
272 let prompt = self.build_verification_prompt(&verification);
273 let messages = vec![
274 json!({"role": "system", "content": "Please evaluate the confidence level of subtitle matching and provide a score between 0-1."}),
275 json!({"role": "user", "content": prompt}),
276 ];
277 let response = self.chat_completion(messages).await?;
278 self.parse_confidence_score(&response)
279 }
280}
281
282#[cfg(test)]
283mod tests {
284 use super::*;
285 use mockall::mock;
286 use serde_json::json;
287 use wiremock::matchers::{header, method, path};
288 use wiremock::{Mock, MockServer, ResponseTemplate};
289
290 mock! {
291 AIClient {}
292
293 #[async_trait]
294 impl AIProvider for AIClient {
295 async fn analyze_content(&self, request: AnalysisRequest) -> crate::Result<MatchResult>;
296 async fn verify_match(&self, verification: VerificationRequest) -> crate::Result<ConfidenceScore>;
297 }
298 }
299
300 #[tokio::test]
301 async fn test_openrouter_client_creation() {
302 let client = OpenRouterClient::new(
303 "test-key".into(),
304 "deepseek/deepseek-r1-0528:free".into(),
305 0.5,
306 1000,
307 2,
308 100,
309 );
310 assert_eq!(client.api_key, "test-key");
311 assert_eq!(client.model, "deepseek/deepseek-r1-0528:free");
312 assert_eq!(client.temperature, 0.5);
313 assert_eq!(client.max_tokens, 1000);
314 assert_eq!(client.retry_attempts, 2);
315 assert_eq!(client.retry_delay_ms, 100);
316 assert_eq!(client.base_url, "https://openrouter.ai/api/v1");
317 }
318
319 #[tokio::test]
320 async fn test_openrouter_client_creation_with_custom_base_url() {
321 let client = OpenRouterClient::new_with_base_url_and_timeout(
322 "test-key".into(),
323 "deepseek/deepseek-r1-0528:free".into(),
324 0.3,
325 2000,
326 3,
327 200,
328 "https://custom-openrouter.ai/api/v1".into(),
329 60,
330 );
331 assert_eq!(client.base_url, "https://custom-openrouter.ai/api/v1");
332 }
333
334 #[tokio::test]
335 async fn test_chat_completion_success() {
336 let server = MockServer::start().await;
337 Mock::given(method("POST"))
338 .and(path("/chat/completions"))
339 .and(header("authorization", "Bearer test-key"))
340 .and(header(
341 "HTTP-Referer",
342 "https://github.com/jim60105/subx-cli",
343 ))
344 .and(header("X-Title", "Subx"))
345 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
346 "choices": [{"message": {"content": "test response content"}}],
347 "usage": { "prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15 }
348 })))
349 .mount(&server)
350 .await;
351
352 let mut client = OpenRouterClient::new(
353 "test-key".into(),
354 "deepseek/deepseek-r1-0528:free".into(),
355 0.3,
356 1000,
357 1,
358 0,
359 );
360 client.base_url = server.uri();
361
362 let messages = vec![json!({"role":"user","content":"test"})];
363 let resp = client.chat_completion(messages).await.unwrap();
364 assert_eq!(resp, "test response content");
365 }
366
367 #[tokio::test]
368 async fn test_chat_completion_error_handling() {
369 let server = MockServer::start().await;
370 Mock::given(method("POST"))
371 .and(path("/chat/completions"))
372 .respond_with(ResponseTemplate::new(401).set_body_json(json!({
373 "error": {"message":"Invalid API key"}
374 })))
375 .mount(&server)
376 .await;
377
378 let mut client = OpenRouterClient::new(
379 "bad-key".into(),
380 "deepseek/deepseek-r1-0528:free".into(),
381 0.3,
382 1000,
383 1,
384 0,
385 );
386 client.base_url = server.uri();
387
388 let messages = vec![json!({"role":"user","content":"test"})];
389 let result = client.chat_completion(messages).await;
390 assert!(result.is_err());
391 assert!(
392 result
393 .err()
394 .unwrap()
395 .to_string()
396 .contains("OpenRouter API error 401")
397 );
398 }
399
400 #[tokio::test]
401 async fn test_retry_mechanism() {
402 let server = MockServer::start().await;
403
404 Mock::given(method("POST"))
406 .and(path("/chat/completions"))
407 .respond_with(ResponseTemplate::new(500))
408 .up_to_n_times(1)
409 .mount(&server)
410 .await;
411
412 Mock::given(method("POST"))
413 .and(path("/chat/completions"))
414 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
415 "choices": [{"message": {"content": "success after retry"}}]
416 })))
417 .mount(&server)
418 .await;
419
420 let mut client = OpenRouterClient::new(
421 "test-key".into(),
422 "deepseek/deepseek-r1-0528:free".into(),
423 0.3,
424 1000,
425 2, 50, );
428 client.base_url = server.uri();
429
430 let messages = vec![json!({"role":"user","content":"test"})];
431 let result = client.chat_completion(messages).await.unwrap();
432 assert_eq!(result, "success after retry");
433 }
434
435 #[test]
436 fn test_openrouter_client_from_config() {
437 let config = crate::config::AIConfig {
438 provider: "openrouter".to_string(),
439 api_key: Some("test-key".to_string()),
440 model: "deepseek/deepseek-r1-0528:free".to_string(),
441 base_url: "https://openrouter.ai/api/v1".to_string(),
442 max_sample_length: 500,
443 temperature: 0.7,
444 max_tokens: 2000,
445 retry_attempts: 3,
446 retry_delay_ms: 150,
447 request_timeout_seconds: 120,
448 api_version: None,
449 };
450
451 let client = OpenRouterClient::from_config(&config).unwrap();
452 assert_eq!(client.api_key, "test-key");
453 assert_eq!(client.model, "deepseek/deepseek-r1-0528:free");
454 assert_eq!(client.temperature, 0.7);
455 assert_eq!(client.max_tokens, 2000);
456 assert_eq!(client.retry_attempts, 3);
457 assert_eq!(client.retry_delay_ms, 150);
458 }
459
460 #[test]
461 fn test_openrouter_client_from_config_missing_api_key() {
462 let config = crate::config::AIConfig {
463 provider: "openrouter".to_string(),
464 api_key: None,
465 model: "deepseek/deepseek-r1-0528:free".to_string(),
466 base_url: "https://openrouter.ai/api/v1".to_string(),
467 max_sample_length: 500,
468 temperature: 0.3,
469 max_tokens: 1000,
470 retry_attempts: 2,
471 retry_delay_ms: 100,
472 request_timeout_seconds: 30,
473 api_version: None,
474 };
475
476 let result = OpenRouterClient::from_config(&config);
477 assert!(result.is_err());
478 assert!(
479 result
480 .err()
481 .unwrap()
482 .to_string()
483 .contains("Missing OpenRouter API Key")
484 );
485 }
486
487 #[test]
488 fn test_openrouter_client_from_config_invalid_base_url() {
489 let config = crate::config::AIConfig {
490 provider: "openrouter".to_string(),
491 api_key: Some("test-key".to_string()),
492 model: "deepseek/deepseek-r1-0528:free".to_string(),
493 base_url: "ftp://invalid.url".to_string(),
494 max_sample_length: 500,
495 temperature: 0.3,
496 max_tokens: 1000,
497 retry_attempts: 2,
498 retry_delay_ms: 100,
499 request_timeout_seconds: 30,
500 api_version: None,
501 };
502
503 let result = OpenRouterClient::from_config(&config);
504 assert!(result.is_err());
505 assert!(
506 result
507 .err()
508 .unwrap()
509 .to_string()
510 .contains("must use http or https protocol")
511 );
512 }
513
514 #[test]
515 fn test_prompt_building_and_parsing() {
516 let client = OpenRouterClient::new(
517 "test-key".into(),
518 "deepseek/deepseek-r1-0528:free".into(),
519 0.1,
520 1000,
521 0,
522 0,
523 );
524 let request = AnalysisRequest {
525 video_files: vec!["video1.mp4".into()],
526 subtitle_files: vec!["subtitle1.srt".into()],
527 content_samples: vec![],
528 };
529
530 let prompt = client.build_analysis_prompt(&request);
531 assert!(prompt.contains("video1.mp4"));
532 assert!(prompt.contains("subtitle1.srt"));
533 assert!(prompt.contains("JSON"));
534
535 let json_response = r#"{ "matches": [], "confidence":0.9, "reasoning":"test reason" }"#;
536 let match_result = client.parse_match_result(json_response).unwrap();
537 assert_eq!(match_result.confidence, 0.9);
538 assert_eq!(match_result.reasoning, "test reason");
539 }
540}