1use crate::Result;
2use crate::cli::display_ai_usage;
3use crate::error::SubXError;
4use crate::services::ai::AiUsageStats;
5use crate::services::ai::{
6 AIProvider, AnalysisRequest, ConfidenceScore, MatchResult, VerificationRequest,
7};
8use async_trait::async_trait;
9use reqwest::Client;
10use serde_json::{Value, json};
11use std::time::Duration;
12use tokio::time;
13
14use crate::services::ai::prompts::{PromptBuilder, ResponseParser};
15use crate::services::ai::retry::HttpRetryClient;
16
17pub struct OpenRouterClient {
19 client: Client,
20 api_key: String,
21 model: String,
22 temperature: f32,
23 max_tokens: u32,
24 retry_attempts: u32,
25 retry_delay_ms: u64,
26 base_url: String,
27}
28
29impl std::fmt::Debug for OpenRouterClient {
30 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31 f.debug_struct("OpenRouterClient")
32 .field("client", &self.client)
33 .field("api_key", &"[REDACTED]")
34 .field("model", &self.model)
35 .field("temperature", &self.temperature)
36 .field("max_tokens", &self.max_tokens)
37 .field("retry_attempts", &self.retry_attempts)
38 .field("retry_delay_ms", &self.retry_delay_ms)
39 .field("base_url", &self.base_url)
40 .finish()
41 }
42}
43
44impl PromptBuilder for OpenRouterClient {}
45impl ResponseParser for OpenRouterClient {}
46impl HttpRetryClient for OpenRouterClient {
47 fn retry_attempts(&self) -> u32 {
48 self.retry_attempts
49 }
50 fn retry_delay_ms(&self) -> u64 {
51 self.retry_delay_ms
52 }
53}
54
55impl OpenRouterClient {
56 pub fn new(
58 api_key: String,
59 model: String,
60 temperature: f32,
61 max_tokens: u32,
62 retry_attempts: u32,
63 retry_delay_ms: u64,
64 ) -> Self {
65 Self::new_with_base_url_and_timeout(
66 api_key,
67 model,
68 temperature,
69 max_tokens,
70 retry_attempts,
71 retry_delay_ms,
72 "https://openrouter.ai/api/v1".to_string(),
73 120,
74 )
75 }
76
77 #[allow(clippy::too_many_arguments)]
79 pub fn new_with_base_url_and_timeout(
80 api_key: String,
81 model: String,
82 temperature: f32,
83 max_tokens: u32,
84 retry_attempts: u32,
85 retry_delay_ms: u64,
86 base_url: String,
87 request_timeout_seconds: u64,
88 ) -> Self {
89 let client = Client::builder()
90 .timeout(Duration::from_secs(request_timeout_seconds))
91 .build()
92 .expect("Failed to create HTTP client");
93
94 Self {
95 client,
96 api_key,
97 model,
98 temperature,
99 max_tokens,
100 retry_attempts,
101 retry_delay_ms,
102 base_url: base_url.trim_end_matches('/').to_string(),
103 }
104 }
105
106 pub fn from_config(config: &crate::config::AIConfig) -> crate::Result<Self> {
108 let api_key = config
109 .api_key
110 .as_ref()
111 .ok_or_else(|| SubXError::config("Missing OpenRouter API Key"))?;
112
113 Self::validate_base_url(&config.base_url)?;
115 crate::services::ai::security::warn_on_insecure_http_str(&config.base_url, api_key);
116
117 Ok(Self::new_with_base_url_and_timeout(
118 api_key.clone(),
119 config.model.clone(),
120 config.temperature,
121 config.max_tokens,
122 config.retry_attempts,
123 config.retry_delay_ms,
124 config.base_url.clone(),
125 config.request_timeout_seconds,
126 ))
127 }
128
129 fn validate_base_url(url: &str) -> crate::Result<()> {
131 use url::Url;
132 let parsed =
133 Url::parse(url).map_err(|e| SubXError::config(format!("Invalid base URL: {}", e)))?;
134
135 if !matches!(parsed.scheme(), "http" | "https") {
136 return Err(SubXError::config(
137 "Base URL must use http or https protocol".to_string(),
138 ));
139 }
140
141 if parsed.host().is_none() {
142 return Err(SubXError::config(
143 "Base URL must contain a valid hostname".to_string(),
144 ));
145 }
146
147 Ok(())
148 }
149
150 async fn chat_completion(&self, messages: Vec<Value>) -> Result<String> {
151 let request_body = json!({
152 "model": self.model,
153 "messages": messages,
154 "temperature": self.temperature,
155 "max_tokens": self.max_tokens,
156 });
157
158 let request = self
159 .client
160 .post(format!("{}/chat/completions", self.base_url))
161 .header("Authorization", format!("Bearer {}", self.api_key))
162 .header("Content-Type", "application/json")
163 .header("HTTP-Referer", "https://github.com/jim60105/subx-cli")
164 .header("X-Title", "Subx")
165 .json(&request_body);
166
167 let mut response = self.make_request_with_retry(request).await?;
168
169 const MAX_AI_RESPONSE_BYTES: u64 = 10 * 1024 * 1024; if let Some(len) = response.content_length() {
171 if len > MAX_AI_RESPONSE_BYTES {
172 return Err(SubXError::AiService(format!(
173 "AI response too large: {} bytes (limit: {} bytes)",
174 len, MAX_AI_RESPONSE_BYTES
175 )));
176 }
177 }
178
179 if !response.status().is_success() {
180 let status = response.status();
181 let error_text = response.text().await?;
182 let safe_body = crate::services::ai::error_sanitizer::sanitize_url_in_error(
183 &crate::services::ai::error_sanitizer::truncate_error_body(
184 &error_text,
185 crate::services::ai::error_sanitizer::DEFAULT_ERROR_BODY_MAX_LEN,
186 ),
187 );
188 return Err(SubXError::AiService(format!(
189 "OpenRouter API error {}: {}",
190 status, safe_body
191 )));
192 }
193
194 let mut body = Vec::new();
197 while let Some(chunk) = response.chunk().await? {
198 body.extend_from_slice(&chunk);
199 if body.len() as u64 > MAX_AI_RESPONSE_BYTES {
200 return Err(SubXError::AiService(format!(
201 "AI response too large: {} bytes read (limit: {} bytes)",
202 body.len(),
203 MAX_AI_RESPONSE_BYTES
204 )));
205 }
206 }
207 let response_json: Value = serde_json::from_slice(&body)
208 .map_err(|e| SubXError::AiService(format!("Failed to parse AI response: {}", e)))?;
209 let content = response_json["choices"][0]["message"]["content"]
210 .as_str()
211 .ok_or_else(|| SubXError::AiService("Invalid API response format".to_string()))?;
212
213 if let Some(usage_obj) = response_json.get("usage") {
215 if let (Some(p), Some(c), Some(t)) = (
216 usage_obj.get("prompt_tokens").and_then(Value::as_u64),
217 usage_obj.get("completion_tokens").and_then(Value::as_u64),
218 usage_obj.get("total_tokens").and_then(Value::as_u64),
219 ) {
220 let stats = AiUsageStats {
221 model: self.model.clone(),
222 prompt_tokens: p as u32,
223 completion_tokens: c as u32,
224 total_tokens: t as u32,
225 };
226 display_ai_usage(&stats);
227 }
228 }
229
230 Ok(content.to_string())
231 }
232
233 async fn make_request_with_retry(
234 &self,
235 request: reqwest::RequestBuilder,
236 ) -> crate::Result<reqwest::Response> {
237 let mut attempts = 0;
238 loop {
239 let cloned = request.try_clone().ok_or_else(|| {
240 crate::error::SubXError::AiService(
241 "Request body cannot be cloned for retry".to_string(),
242 )
243 })?;
244 match cloned.send().await {
245 Ok(resp) => {
246 if resp.status().is_server_error() && (attempts as u32) < self.retry_attempts {
248 attempts += 1;
249 log::warn!(
250 "Request attempt {} failed with status {}. Retrying in {}ms...",
251 attempts,
252 resp.status(),
253 self.retry_delay_ms
254 );
255 time::sleep(Duration::from_millis(self.retry_delay_ms)).await;
256 continue;
257 }
258 if attempts > 0 {
259 log::info!("Request succeeded after {} retry attempts", attempts);
260 }
261 return Ok(resp);
262 }
263 Err(e) if (attempts as u32) < self.retry_attempts => {
264 attempts += 1;
265 log::warn!(
266 "Request attempt {} failed: {}. Retrying in {}ms...",
267 attempts,
268 e,
269 self.retry_delay_ms
270 );
271
272 if e.is_timeout() {
273 log::warn!(
274 "This appears to be a timeout error. If this persists, consider increasing 'ai.request_timeout_seconds' in your configuration."
275 );
276 }
277
278 time::sleep(Duration::from_millis(self.retry_delay_ms)).await;
279 continue;
280 }
281 Err(e) => {
282 log::error!(
283 "Request failed after {} attempts. Final error: {}",
284 attempts + 1,
285 e
286 );
287
288 if e.is_timeout() {
289 log::error!(
290 "AI service error: Request timed out after multiple attempts. \
291 This usually indicates network connectivity issues or server overload. \
292 Try increasing 'ai.request_timeout_seconds' configuration. \
293 Hint: check network connection and API service status"
294 );
295 } else if e.is_connect() {
296 log::error!(
297 "AI service error: Connection failed. \
298 Hint: check network connection and API base URL settings"
299 );
300 }
301
302 return Err(e.into());
303 }
304 }
305 }
306 }
307}
308
309#[async_trait]
310impl AIProvider for OpenRouterClient {
311 async fn analyze_content(&self, request: AnalysisRequest) -> Result<MatchResult> {
312 let prompt = self.build_analysis_prompt(&request);
313 let messages = vec![
314 json!({"role": "system", "content": "You are a professional subtitle matching assistant that can analyze the correspondence between video and subtitle files."}),
315 json!({"role": "user", "content": prompt}),
316 ];
317 let response = self.chat_completion(messages).await?;
318 self.parse_match_result(&response)
319 }
320
321 async fn verify_match(&self, verification: VerificationRequest) -> Result<ConfidenceScore> {
322 let prompt = self.build_verification_prompt(&verification);
323 let messages = vec![
324 json!({"role": "system", "content": "Please evaluate the confidence level of subtitle matching and provide a score between 0-1."}),
325 json!({"role": "user", "content": prompt}),
326 ];
327 let response = self.chat_completion(messages).await?;
328 self.parse_confidence_score(&response)
329 }
330}
331
332#[cfg(test)]
333mod tests {
334 use super::*;
335 use mockall::mock;
336 use serde_json::json;
337 use wiremock::matchers::{header, method, path};
338 use wiremock::{Mock, MockServer, ResponseTemplate};
339
340 mock! {
341 AIClient {}
342
343 #[async_trait]
344 impl AIProvider for AIClient {
345 async fn analyze_content(&self, request: AnalysisRequest) -> crate::Result<MatchResult>;
346 async fn verify_match(&self, verification: VerificationRequest) -> crate::Result<ConfidenceScore>;
347 }
348 }
349
350 #[tokio::test]
351 async fn test_openrouter_client_creation() {
352 let client = OpenRouterClient::new(
353 "test-key".into(),
354 "deepseek/deepseek-r1-0528:free".into(),
355 0.5,
356 1000,
357 2,
358 100,
359 );
360 assert_eq!(client.api_key, "test-key");
361 assert_eq!(client.model, "deepseek/deepseek-r1-0528:free");
362 assert_eq!(client.temperature, 0.5);
363 assert_eq!(client.max_tokens, 1000);
364 assert_eq!(client.retry_attempts, 2);
365 assert_eq!(client.retry_delay_ms, 100);
366 assert_eq!(client.base_url, "https://openrouter.ai/api/v1");
367 }
368
369 #[tokio::test]
370 async fn test_openrouter_client_creation_with_custom_base_url() {
371 let client = OpenRouterClient::new_with_base_url_and_timeout(
372 "test-key".into(),
373 "deepseek/deepseek-r1-0528:free".into(),
374 0.3,
375 2000,
376 3,
377 200,
378 "https://custom-openrouter.ai/api/v1".into(),
379 60,
380 );
381 assert_eq!(client.base_url, "https://custom-openrouter.ai/api/v1");
382 }
383
384 #[tokio::test]
385 async fn test_chat_completion_success() {
386 let server = MockServer::start().await;
387 Mock::given(method("POST"))
388 .and(path("/chat/completions"))
389 .and(header("authorization", "Bearer test-key"))
390 .and(header(
391 "HTTP-Referer",
392 "https://github.com/jim60105/subx-cli",
393 ))
394 .and(header("X-Title", "Subx"))
395 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
396 "choices": [{"message": {"content": "test response content"}}],
397 "usage": { "prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15 }
398 })))
399 .mount(&server)
400 .await;
401
402 let mut client = OpenRouterClient::new(
403 "test-key".into(),
404 "deepseek/deepseek-r1-0528:free".into(),
405 0.3,
406 1000,
407 1,
408 0,
409 );
410 client.base_url = server.uri();
411
412 let messages = vec![json!({"role":"user","content":"test"})];
413 let resp = client.chat_completion(messages).await.unwrap();
414 assert_eq!(resp, "test response content");
415 }
416
417 #[tokio::test]
418 async fn test_chat_completion_error_handling() {
419 let server = MockServer::start().await;
420 Mock::given(method("POST"))
421 .and(path("/chat/completions"))
422 .respond_with(ResponseTemplate::new(401).set_body_json(json!({
423 "error": {"message":"Invalid API key"}
424 })))
425 .mount(&server)
426 .await;
427
428 let mut client = OpenRouterClient::new(
429 "bad-key".into(),
430 "deepseek/deepseek-r1-0528:free".into(),
431 0.3,
432 1000,
433 1,
434 0,
435 );
436 client.base_url = server.uri();
437
438 let messages = vec![json!({"role":"user","content":"test"})];
439 let result = client.chat_completion(messages).await;
440 assert!(result.is_err());
441 assert!(
442 result
443 .err()
444 .unwrap()
445 .to_string()
446 .contains("OpenRouter API error 401")
447 );
448 }
449
450 #[tokio::test]
451 async fn test_retry_mechanism() {
452 let server = MockServer::start().await;
453
454 Mock::given(method("POST"))
456 .and(path("/chat/completions"))
457 .respond_with(ResponseTemplate::new(500))
458 .up_to_n_times(1)
459 .mount(&server)
460 .await;
461
462 Mock::given(method("POST"))
463 .and(path("/chat/completions"))
464 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
465 "choices": [{"message": {"content": "success after retry"}}]
466 })))
467 .mount(&server)
468 .await;
469
470 let mut client = OpenRouterClient::new(
471 "test-key".into(),
472 "deepseek/deepseek-r1-0528:free".into(),
473 0.3,
474 1000,
475 2, 50, );
478 client.base_url = server.uri();
479
480 let messages = vec![json!({"role":"user","content":"test"})];
481 let result = client.chat_completion(messages).await.unwrap();
482 assert_eq!(result, "success after retry");
483 }
484
485 #[test]
486 fn test_openrouter_client_from_config() {
487 let config = crate::config::AIConfig {
488 provider: "openrouter".to_string(),
489 api_key: Some("test-key".to_string()),
490 model: "deepseek/deepseek-r1-0528:free".to_string(),
491 base_url: "https://openrouter.ai/api/v1".to_string(),
492 max_sample_length: 500,
493 temperature: 0.7,
494 max_tokens: 2000,
495 retry_attempts: 3,
496 retry_delay_ms: 150,
497 request_timeout_seconds: 120,
498 api_version: None,
499 };
500
501 let client = OpenRouterClient::from_config(&config).unwrap();
502 assert_eq!(client.api_key, "test-key");
503 assert_eq!(client.model, "deepseek/deepseek-r1-0528:free");
504 assert_eq!(client.temperature, 0.7);
505 assert_eq!(client.max_tokens, 2000);
506 assert_eq!(client.retry_attempts, 3);
507 assert_eq!(client.retry_delay_ms, 150);
508 }
509
510 #[test]
511 fn test_openrouter_client_from_config_missing_api_key() {
512 let config = crate::config::AIConfig {
513 provider: "openrouter".to_string(),
514 api_key: None,
515 model: "deepseek/deepseek-r1-0528:free".to_string(),
516 base_url: "https://openrouter.ai/api/v1".to_string(),
517 max_sample_length: 500,
518 temperature: 0.3,
519 max_tokens: 1000,
520 retry_attempts: 2,
521 retry_delay_ms: 100,
522 request_timeout_seconds: 30,
523 api_version: None,
524 };
525
526 let result = OpenRouterClient::from_config(&config);
527 assert!(result.is_err());
528 assert!(
529 result
530 .err()
531 .unwrap()
532 .to_string()
533 .contains("Missing OpenRouter API Key")
534 );
535 }
536
537 #[test]
538 fn test_openrouter_client_from_config_invalid_base_url() {
539 let config = crate::config::AIConfig {
540 provider: "openrouter".to_string(),
541 api_key: Some("test-key".to_string()),
542 model: "deepseek/deepseek-r1-0528:free".to_string(),
543 base_url: "ftp://invalid.url".to_string(),
544 max_sample_length: 500,
545 temperature: 0.3,
546 max_tokens: 1000,
547 retry_attempts: 2,
548 retry_delay_ms: 100,
549 request_timeout_seconds: 30,
550 api_version: None,
551 };
552
553 let result = OpenRouterClient::from_config(&config);
554 assert!(result.is_err());
555 assert!(
556 result
557 .err()
558 .unwrap()
559 .to_string()
560 .contains("must use http or https protocol")
561 );
562 }
563
564 #[test]
565 fn test_prompt_building_and_parsing() {
566 let client = OpenRouterClient::new(
567 "test-key".into(),
568 "deepseek/deepseek-r1-0528:free".into(),
569 0.1,
570 1000,
571 0,
572 0,
573 );
574 let request = AnalysisRequest {
575 video_files: vec!["video1.mp4".into()],
576 subtitle_files: vec!["subtitle1.srt".into()],
577 content_samples: vec![],
578 };
579
580 let prompt = client.build_analysis_prompt(&request);
581 assert!(prompt.contains("video1.mp4"));
582 assert!(prompt.contains("subtitle1.srt"));
583 assert!(prompt.contains("JSON"));
584
585 let json_response = r#"{ "matches": [], "confidence":0.9, "reasoning":"test reason" }"#;
586 let match_result = client.parse_match_result(json_response).unwrap();
587 assert_eq!(match_result.confidence, 0.9);
588 assert_eq!(match_result.reasoning, "test reason");
589 }
590}