1use crate::Result;
2use crate::cli::display_ai_usage;
3use crate::error::SubXError;
4use crate::services::ai::AiUsageStats;
5use crate::services::ai::{
6 AIProvider, AnalysisRequest, ConfidenceScore, MatchResult, VerificationRequest,
7};
8use async_trait::async_trait;
9use reqwest::Client;
10use serde_json::{Value, json};
11use std::time::Duration;
12use tokio::time;
13
14use crate::services::ai::hosted_hint::{append_local_hint, maybe_attach_local_hint};
15use crate::services::ai::prompts::{PromptBuilder, ResponseParser};
16use crate::services::ai::retry::HttpRetryClient;
17
18pub struct OpenRouterClient {
20 client: Client,
21 api_key: String,
22 model: String,
23 temperature: f32,
24 max_tokens: u32,
25 retry_attempts: u32,
26 retry_delay_ms: u64,
27 base_url: String,
28}
29
30impl std::fmt::Debug for OpenRouterClient {
31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 f.debug_struct("OpenRouterClient")
33 .field("client", &self.client)
34 .field("api_key", &"[REDACTED]")
35 .field("model", &self.model)
36 .field("temperature", &self.temperature)
37 .field("max_tokens", &self.max_tokens)
38 .field("retry_attempts", &self.retry_attempts)
39 .field("retry_delay_ms", &self.retry_delay_ms)
40 .field("base_url", &self.base_url)
41 .finish()
42 }
43}
44
45impl PromptBuilder for OpenRouterClient {}
46impl ResponseParser for OpenRouterClient {}
47impl HttpRetryClient for OpenRouterClient {
48 fn retry_attempts(&self) -> u32 {
49 self.retry_attempts
50 }
51 fn retry_delay_ms(&self) -> u64 {
52 self.retry_delay_ms
53 }
54}
55
56impl OpenRouterClient {
57 pub fn new(
59 api_key: String,
60 model: String,
61 temperature: f32,
62 max_tokens: u32,
63 retry_attempts: u32,
64 retry_delay_ms: u64,
65 ) -> Self {
66 Self::new_with_base_url_and_timeout(
67 api_key,
68 model,
69 temperature,
70 max_tokens,
71 retry_attempts,
72 retry_delay_ms,
73 "https://openrouter.ai/api/v1".to_string(),
74 120,
75 )
76 }
77
78 #[allow(clippy::too_many_arguments)]
80 pub fn new_with_base_url_and_timeout(
81 api_key: String,
82 model: String,
83 temperature: f32,
84 max_tokens: u32,
85 retry_attempts: u32,
86 retry_delay_ms: u64,
87 base_url: String,
88 request_timeout_seconds: u64,
89 ) -> Self {
90 let client = Client::builder()
91 .timeout(Duration::from_secs(request_timeout_seconds))
92 .build()
93 .expect("Failed to create HTTP client");
94
95 Self {
96 client,
97 api_key,
98 model,
99 temperature,
100 max_tokens,
101 retry_attempts,
102 retry_delay_ms,
103 base_url: base_url.trim_end_matches('/').to_string(),
104 }
105 }
106
107 pub fn from_config(config: &crate::config::AIConfig) -> crate::Result<Self> {
109 let api_key = config
110 .api_key
111 .as_ref()
112 .ok_or_else(|| SubXError::config("Missing OpenRouter API Key"))?;
113
114 Self::validate_base_url(&config.base_url)?;
116 crate::services::ai::security::warn_on_insecure_http_str(&config.base_url, api_key);
117
118 Ok(Self::new_with_base_url_and_timeout(
119 api_key.clone(),
120 config.model.clone(),
121 config.temperature,
122 config.max_tokens,
123 config.retry_attempts,
124 config.retry_delay_ms,
125 config.base_url.clone(),
126 config.request_timeout_seconds,
127 ))
128 }
129
130 fn validate_base_url(url: &str) -> crate::Result<()> {
132 use url::Url;
133 let parsed =
134 Url::parse(url).map_err(|e| SubXError::config(format!("Invalid base URL: {}", e)))?;
135
136 if !matches!(parsed.scheme(), "http" | "https") {
137 return Err(SubXError::config(
138 "Base URL must use http or https protocol".to_string(),
139 ));
140 }
141
142 if parsed.host().is_none() {
143 return Err(SubXError::config(
144 "Base URL must contain a valid hostname".to_string(),
145 ));
146 }
147
148 Ok(())
149 }
150
151 pub async fn chat_completion(&self, messages: Vec<Value>) -> Result<String> {
153 let request_body = json!({
154 "model": self.model,
155 "messages": messages,
156 "temperature": self.temperature,
157 "max_tokens": self.max_tokens,
158 });
159
160 let request = self
161 .client
162 .post(format!("{}/chat/completions", self.base_url))
163 .header("Authorization", format!("Bearer {}", self.api_key))
164 .header("Content-Type", "application/json")
165 .header("HTTP-Referer", "https://github.com/jim60105/subx-cli")
166 .header("X-Title", "Subx")
167 .json(&request_body);
168
169 let mut response = match self.make_request_with_retry(request).await {
170 Ok(r) => r,
171 Err(e) => return Err(maybe_attach_local_hint(e, &self.base_url)),
172 };
173
174 const MAX_AI_RESPONSE_BYTES: u64 = 10 * 1024 * 1024; if let Some(len) = response.content_length() {
176 if len > MAX_AI_RESPONSE_BYTES {
177 return Err(SubXError::AiService(format!(
178 "AI response too large: {} bytes (limit: {} bytes)",
179 len, MAX_AI_RESPONSE_BYTES
180 )));
181 }
182 }
183
184 if !response.status().is_success() {
185 let status = response.status();
186 let error_text = response.text().await?;
187 let safe_body = crate::services::ai::error_sanitizer::sanitize_url_in_error(
188 &crate::services::ai::error_sanitizer::truncate_error_body(
189 &error_text,
190 crate::services::ai::error_sanitizer::DEFAULT_ERROR_BODY_MAX_LEN,
191 ),
192 );
193 return Err(SubXError::AiService(format!(
194 "OpenRouter API error {}: {}",
195 status, safe_body
196 )));
197 }
198
199 let mut body = Vec::new();
202 while let Some(chunk) = response.chunk().await? {
203 body.extend_from_slice(&chunk);
204 if body.len() as u64 > MAX_AI_RESPONSE_BYTES {
205 return Err(SubXError::AiService(format!(
206 "AI response too large: {} bytes read (limit: {} bytes)",
207 body.len(),
208 MAX_AI_RESPONSE_BYTES
209 )));
210 }
211 }
212 let response_json: Value = serde_json::from_slice(&body)
213 .map_err(|e| SubXError::AiService(format!("Failed to parse AI response: {}", e)))?;
214 let content = response_json["choices"][0]["message"]["content"]
215 .as_str()
216 .ok_or_else(|| {
217 SubXError::AiService(append_local_hint("Invalid API response format"))
218 })?;
219
220 if let Some(usage_obj) = response_json.get("usage") {
222 if let (Some(p), Some(c), Some(t)) = (
223 usage_obj.get("prompt_tokens").and_then(Value::as_u64),
224 usage_obj.get("completion_tokens").and_then(Value::as_u64),
225 usage_obj.get("total_tokens").and_then(Value::as_u64),
226 ) {
227 let stats = AiUsageStats {
228 model: self.model.clone(),
229 prompt_tokens: p as u32,
230 completion_tokens: c as u32,
231 total_tokens: t as u32,
232 };
233 display_ai_usage(&stats);
234 }
235 }
236
237 Ok(content.to_string())
238 }
239
240 async fn make_request_with_retry(
241 &self,
242 request: reqwest::RequestBuilder,
243 ) -> crate::Result<reqwest::Response> {
244 let mut attempts = 0;
245 loop {
246 let cloned = request.try_clone().ok_or_else(|| {
247 crate::error::SubXError::AiService(
248 "Request body cannot be cloned for retry".to_string(),
249 )
250 })?;
251 match cloned.send().await {
252 Ok(resp) => {
253 if resp.status().is_server_error() && (attempts as u32) < self.retry_attempts {
255 attempts += 1;
256 log::warn!(
257 "Request attempt {} failed with status {}. Retrying in {}ms...",
258 attempts,
259 resp.status(),
260 self.retry_delay_ms
261 );
262 time::sleep(Duration::from_millis(self.retry_delay_ms)).await;
263 continue;
264 }
265 if attempts > 0 {
266 log::info!("Request succeeded after {} retry attempts", attempts);
267 }
268 return Ok(resp);
269 }
270 Err(e) if (attempts as u32) < self.retry_attempts => {
271 attempts += 1;
272 log::warn!(
273 "Request attempt {} failed: {}. Retrying in {}ms...",
274 attempts,
275 e,
276 self.retry_delay_ms
277 );
278
279 if e.is_timeout() {
280 log::warn!(
281 "This appears to be a timeout error. If this persists, consider increasing 'ai.request_timeout_seconds' in your configuration."
282 );
283 }
284
285 time::sleep(Duration::from_millis(self.retry_delay_ms)).await;
286 continue;
287 }
288 Err(e) => {
289 log::error!(
290 "Request failed after {} attempts. Final error: {}",
291 attempts + 1,
292 e
293 );
294
295 if e.is_timeout() {
296 log::error!(
297 "AI service error: Request timed out after multiple attempts. \
298 This usually indicates network connectivity issues or server overload. \
299 Try increasing 'ai.request_timeout_seconds' configuration. \
300 Hint: check network connection and API service status"
301 );
302 } else if e.is_connect() {
303 log::error!(
304 "AI service error: Connection failed. \
305 Hint: check network connection and API base URL settings"
306 );
307 }
308
309 return Err(e.into());
310 }
311 }
312 }
313 }
314}
315
316#[async_trait]
317impl AIProvider for OpenRouterClient {
318 async fn analyze_content(&self, request: AnalysisRequest) -> Result<MatchResult> {
319 let prompt = self.build_analysis_prompt(&request);
320 let messages = vec![
321 json!({"role": "system", "content": "You are a professional subtitle matching assistant that can analyze the correspondence between video and subtitle files."}),
322 json!({"role": "user", "content": prompt}),
323 ];
324 let response = self.chat_completion(messages).await?;
325 self.parse_match_result(&response)
326 }
327
328 async fn verify_match(&self, verification: VerificationRequest) -> Result<ConfidenceScore> {
329 let prompt = self.build_verification_prompt(&verification);
330 let messages = vec![
331 json!({"role": "system", "content": "Please evaluate the confidence level of subtitle matching and provide a score between 0-1."}),
332 json!({"role": "user", "content": prompt}),
333 ];
334 let response = self.chat_completion(messages).await?;
335 self.parse_confidence_score(&response)
336 }
337
338 async fn chat_completion(&self, messages: Vec<Value>) -> Result<String> {
339 OpenRouterClient::chat_completion(self, messages).await
340 }
341}
342
343#[cfg(test)]
344mod tests {
345 use super::*;
346 use mockall::mock;
347 use serde_json::json;
348 use wiremock::matchers::{header, method, path};
349 use wiremock::{Mock, MockServer, ResponseTemplate};
350
351 mock! {
352 AIClient {}
353
354 #[async_trait]
355 impl AIProvider for AIClient {
356 async fn analyze_content(&self, request: AnalysisRequest) -> crate::Result<MatchResult>;
357 async fn verify_match(&self, verification: VerificationRequest) -> crate::Result<ConfidenceScore>;
358 }
359 }
360
361 #[tokio::test]
362 async fn test_openrouter_client_creation() {
363 let client = OpenRouterClient::new(
364 "test-key".into(),
365 "deepseek/deepseek-r1-0528:free".into(),
366 0.5,
367 1000,
368 2,
369 100,
370 );
371 assert_eq!(client.api_key, "test-key");
372 assert_eq!(client.model, "deepseek/deepseek-r1-0528:free");
373 assert_eq!(client.temperature, 0.5);
374 assert_eq!(client.max_tokens, 1000);
375 assert_eq!(client.retry_attempts, 2);
376 assert_eq!(client.retry_delay_ms, 100);
377 assert_eq!(client.base_url, "https://openrouter.ai/api/v1");
378 }
379
380 #[tokio::test]
381 async fn test_openrouter_client_creation_with_custom_base_url() {
382 let client = OpenRouterClient::new_with_base_url_and_timeout(
383 "test-key".into(),
384 "deepseek/deepseek-r1-0528:free".into(),
385 0.3,
386 2000,
387 3,
388 200,
389 "https://custom-openrouter.ai/api/v1".into(),
390 60,
391 );
392 assert_eq!(client.base_url, "https://custom-openrouter.ai/api/v1");
393 }
394
395 #[tokio::test]
396 async fn test_chat_completion_success() {
397 let server = MockServer::start().await;
398 Mock::given(method("POST"))
399 .and(path("/chat/completions"))
400 .and(header("authorization", "Bearer test-key"))
401 .and(header(
402 "HTTP-Referer",
403 "https://github.com/jim60105/subx-cli",
404 ))
405 .and(header("X-Title", "Subx"))
406 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
407 "choices": [{"message": {"content": "test response content"}}],
408 "usage": { "prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15 }
409 })))
410 .mount(&server)
411 .await;
412
413 let mut client = OpenRouterClient::new(
414 "test-key".into(),
415 "deepseek/deepseek-r1-0528:free".into(),
416 0.3,
417 1000,
418 1,
419 0,
420 );
421 client.base_url = server.uri();
422
423 let messages = vec![json!({"role":"user","content":"test"})];
424 let resp = client.chat_completion(messages).await.unwrap();
425 assert_eq!(resp, "test response content");
426 }
427
428 #[tokio::test]
429 async fn test_chat_completion_error_handling() {
430 let server = MockServer::start().await;
431 Mock::given(method("POST"))
432 .and(path("/chat/completions"))
433 .respond_with(ResponseTemplate::new(401).set_body_json(json!({
434 "error": {"message":"Invalid API key"}
435 })))
436 .mount(&server)
437 .await;
438
439 let mut client = OpenRouterClient::new(
440 "bad-key".into(),
441 "deepseek/deepseek-r1-0528:free".into(),
442 0.3,
443 1000,
444 1,
445 0,
446 );
447 client.base_url = server.uri();
448
449 let messages = vec![json!({"role":"user","content":"test"})];
450 let result = client.chat_completion(messages).await;
451 assert!(result.is_err());
452 assert!(
453 result
454 .err()
455 .unwrap()
456 .to_string()
457 .contains("OpenRouter API error 401")
458 );
459 }
460
461 #[tokio::test]
462 async fn test_retry_mechanism() {
463 let server = MockServer::start().await;
464
465 Mock::given(method("POST"))
467 .and(path("/chat/completions"))
468 .respond_with(ResponseTemplate::new(500))
469 .up_to_n_times(1)
470 .mount(&server)
471 .await;
472
473 Mock::given(method("POST"))
474 .and(path("/chat/completions"))
475 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
476 "choices": [{"message": {"content": "success after retry"}}]
477 })))
478 .mount(&server)
479 .await;
480
481 let mut client = OpenRouterClient::new(
482 "test-key".into(),
483 "deepseek/deepseek-r1-0528:free".into(),
484 0.3,
485 1000,
486 2, 50, );
489 client.base_url = server.uri();
490
491 let messages = vec![json!({"role":"user","content":"test"})];
492 let result = client.chat_completion(messages).await.unwrap();
493 assert_eq!(result, "success after retry");
494 }
495
496 #[test]
497 fn test_openrouter_client_from_config() {
498 let config = crate::config::AIConfig {
499 provider: "openrouter".to_string(),
500 api_key: Some("test-key".to_string()),
501 model: "deepseek/deepseek-r1-0528:free".to_string(),
502 base_url: "https://openrouter.ai/api/v1".to_string(),
503 max_sample_length: 500,
504 temperature: 0.7,
505 max_tokens: 2000,
506 retry_attempts: 3,
507 retry_delay_ms: 150,
508 request_timeout_seconds: 120,
509 api_version: None,
510 };
511
512 let client = OpenRouterClient::from_config(&config).unwrap();
513 assert_eq!(client.api_key, "test-key");
514 assert_eq!(client.model, "deepseek/deepseek-r1-0528:free");
515 assert_eq!(client.temperature, 0.7);
516 assert_eq!(client.max_tokens, 2000);
517 assert_eq!(client.retry_attempts, 3);
518 assert_eq!(client.retry_delay_ms, 150);
519 }
520
521 #[test]
522 fn test_openrouter_client_from_config_missing_api_key() {
523 let config = crate::config::AIConfig {
524 provider: "openrouter".to_string(),
525 api_key: None,
526 model: "deepseek/deepseek-r1-0528:free".to_string(),
527 base_url: "https://openrouter.ai/api/v1".to_string(),
528 max_sample_length: 500,
529 temperature: 0.3,
530 max_tokens: 1000,
531 retry_attempts: 2,
532 retry_delay_ms: 100,
533 request_timeout_seconds: 30,
534 api_version: None,
535 };
536
537 let result = OpenRouterClient::from_config(&config);
538 assert!(result.is_err());
539 assert!(
540 result
541 .err()
542 .unwrap()
543 .to_string()
544 .contains("Missing OpenRouter API Key")
545 );
546 }
547
548 #[test]
549 fn test_openrouter_client_from_config_invalid_base_url() {
550 let config = crate::config::AIConfig {
551 provider: "openrouter".to_string(),
552 api_key: Some("test-key".to_string()),
553 model: "deepseek/deepseek-r1-0528:free".to_string(),
554 base_url: "ftp://invalid.url".to_string(),
555 max_sample_length: 500,
556 temperature: 0.3,
557 max_tokens: 1000,
558 retry_attempts: 2,
559 retry_delay_ms: 100,
560 request_timeout_seconds: 30,
561 api_version: None,
562 };
563
564 let result = OpenRouterClient::from_config(&config);
565 assert!(result.is_err());
566 assert!(
567 result
568 .err()
569 .unwrap()
570 .to_string()
571 .contains("must use http or https protocol")
572 );
573 }
574
575 #[test]
576 fn test_prompt_building_and_parsing() {
577 let client = OpenRouterClient::new(
578 "test-key".into(),
579 "deepseek/deepseek-r1-0528:free".into(),
580 0.1,
581 1000,
582 0,
583 0,
584 );
585 let request = AnalysisRequest {
586 video_files: vec!["video1.mp4".into()],
587 subtitle_files: vec!["subtitle1.srt".into()],
588 content_samples: vec![],
589 };
590
591 let prompt = client.build_analysis_prompt(&request);
592 assert!(prompt.contains("video1.mp4"));
593 assert!(prompt.contains("subtitle1.srt"));
594 assert!(prompt.contains("JSON"));
595
596 let json_response = r#"{ "matches": [], "confidence":0.9, "reasoning":"test reason" }"#;
597 let match_result = client.parse_match_result(json_response).unwrap();
598 assert_eq!(match_result.confidence, 0.9);
599 assert_eq!(match_result.reasoning, "test reason");
600 }
601
602 #[tokio::test]
605 async fn test_hosted_hint_connection_refused_loopback() {
606 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
607 let port = listener.local_addr().unwrap().port();
608 drop(listener);
609 let client = OpenRouterClient::new_with_base_url_and_timeout(
610 "k".into(),
611 "deepseek/deepseek-r1-0528:free".into(),
612 0.0,
613 16,
614 0,
615 0,
616 format!("http://127.0.0.1:{}", port),
617 1,
618 );
619 let err = client
620 .chat_completion(vec![json!({"role":"user","content":"x"})])
621 .await
622 .unwrap_err();
623 let msg = err.to_string();
624 assert!(
625 msg.contains("ollama") && msg.contains("ai.provider"),
626 "expected local-provider hint: {msg}"
627 );
628 }
629
630 #[tokio::test]
632 async fn test_hosted_hint_http_200_non_openai_body() {
633 let server = MockServer::start().await;
634 Mock::given(method("POST"))
635 .and(path("/chat/completions"))
636 .respond_with(ResponseTemplate::new(200).set_body_json(json!({ "hello": "world" })))
637 .mount(&server)
638 .await;
639 let mut client = OpenRouterClient::new(
640 "k".into(),
641 "deepseek/deepseek-r1-0528:free".into(),
642 0.0,
643 16,
644 0,
645 0,
646 );
647 client.base_url = server.uri();
648 let err = client
649 .chat_completion(vec![json!({"role":"user","content":"x"})])
650 .await
651 .unwrap_err();
652 let msg = err.to_string();
653 assert!(
654 msg.contains("Invalid API response format")
655 && msg.contains("ollama")
656 && msg.contains("ai.provider"),
657 "expected hint-bearing parse-shape error: {msg}"
658 );
659 }
660
661 #[tokio::test]
664 async fn test_hosted_hint_not_emitted_for_public_host() {
665 let client = OpenRouterClient::new_with_base_url_and_timeout(
666 "k".into(),
667 "deepseek/deepseek-r1-0528:free".into(),
668 0.0,
669 16,
670 0,
671 0,
672 "https://192.0.2.1/api/v1".to_string(),
673 1,
674 );
675 let err = client
676 .chat_completion(vec![json!({"role":"user","content":"x"})])
677 .await
678 .unwrap_err();
679 let msg = err.to_string();
680 assert!(
681 !msg.contains("ollama"),
682 "public-host failure must NOT carry the hint: {msg}"
683 );
684 }
685}