Skip to main content

subx_cli/services/ai/
openrouter.rs

1use crate::Result;
2use crate::cli::display_ai_usage;
3use crate::error::SubXError;
4use crate::services::ai::AiUsageStats;
5use crate::services::ai::{
6    AIProvider, AnalysisRequest, ConfidenceScore, MatchResult, VerificationRequest,
7};
8use async_trait::async_trait;
9use reqwest::Client;
10use serde_json::{Value, json};
11use std::time::Duration;
12use tokio::time;
13
14use crate::services::ai::prompts::{PromptBuilder, ResponseParser};
15use crate::services::ai::retry::HttpRetryClient;
16
17/// OpenRouter client implementation
18pub struct OpenRouterClient {
19    client: Client,
20    api_key: String,
21    model: String,
22    temperature: f32,
23    max_tokens: u32,
24    retry_attempts: u32,
25    retry_delay_ms: u64,
26    base_url: String,
27}
28
29impl std::fmt::Debug for OpenRouterClient {
30    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31        f.debug_struct("OpenRouterClient")
32            .field("client", &self.client)
33            .field("api_key", &"[REDACTED]")
34            .field("model", &self.model)
35            .field("temperature", &self.temperature)
36            .field("max_tokens", &self.max_tokens)
37            .field("retry_attempts", &self.retry_attempts)
38            .field("retry_delay_ms", &self.retry_delay_ms)
39            .field("base_url", &self.base_url)
40            .finish()
41    }
42}
43
44impl PromptBuilder for OpenRouterClient {}
45impl ResponseParser for OpenRouterClient {}
46impl HttpRetryClient for OpenRouterClient {
47    fn retry_attempts(&self) -> u32 {
48        self.retry_attempts
49    }
50    fn retry_delay_ms(&self) -> u64 {
51        self.retry_delay_ms
52    }
53}
54
55impl OpenRouterClient {
56    /// Create new OpenRouterClient with default configuration
57    pub fn new(
58        api_key: String,
59        model: String,
60        temperature: f32,
61        max_tokens: u32,
62        retry_attempts: u32,
63        retry_delay_ms: u64,
64    ) -> Self {
65        Self::new_with_base_url_and_timeout(
66            api_key,
67            model,
68            temperature,
69            max_tokens,
70            retry_attempts,
71            retry_delay_ms,
72            "https://openrouter.ai/api/v1".to_string(),
73            120,
74        )
75    }
76
77    /// Create new OpenRouterClient with custom base URL and timeout
78    #[allow(clippy::too_many_arguments)]
79    pub fn new_with_base_url_and_timeout(
80        api_key: String,
81        model: String,
82        temperature: f32,
83        max_tokens: u32,
84        retry_attempts: u32,
85        retry_delay_ms: u64,
86        base_url: String,
87        request_timeout_seconds: u64,
88    ) -> Self {
89        let client = Client::builder()
90            .timeout(Duration::from_secs(request_timeout_seconds))
91            .build()
92            .expect("Failed to create HTTP client");
93
94        Self {
95            client,
96            api_key,
97            model,
98            temperature,
99            max_tokens,
100            retry_attempts,
101            retry_delay_ms,
102            base_url: base_url.trim_end_matches('/').to_string(),
103        }
104    }
105
106    /// Create client from unified configuration
107    pub fn from_config(config: &crate::config::AIConfig) -> crate::Result<Self> {
108        let api_key = config
109            .api_key
110            .as_ref()
111            .ok_or_else(|| SubXError::config("Missing OpenRouter API Key"))?;
112
113        // Validate base URL format
114        Self::validate_base_url(&config.base_url)?;
115        crate::services::ai::security::warn_on_insecure_http_str(&config.base_url, api_key);
116
117        Ok(Self::new_with_base_url_and_timeout(
118            api_key.clone(),
119            config.model.clone(),
120            config.temperature,
121            config.max_tokens,
122            config.retry_attempts,
123            config.retry_delay_ms,
124            config.base_url.clone(),
125            config.request_timeout_seconds,
126        ))
127    }
128
129    /// Validate base URL format
130    fn validate_base_url(url: &str) -> crate::Result<()> {
131        use url::Url;
132        let parsed =
133            Url::parse(url).map_err(|e| SubXError::config(format!("Invalid base URL: {}", e)))?;
134
135        if !matches!(parsed.scheme(), "http" | "https") {
136            return Err(SubXError::config(
137                "Base URL must use http or https protocol".to_string(),
138            ));
139        }
140
141        if parsed.host().is_none() {
142            return Err(SubXError::config(
143                "Base URL must contain a valid hostname".to_string(),
144            ));
145        }
146
147        Ok(())
148    }
149
150    async fn chat_completion(&self, messages: Vec<Value>) -> Result<String> {
151        let request_body = json!({
152            "model": self.model,
153            "messages": messages,
154            "temperature": self.temperature,
155            "max_tokens": self.max_tokens,
156        });
157
158        let request = self
159            .client
160            .post(format!("{}/chat/completions", self.base_url))
161            .header("Authorization", format!("Bearer {}", self.api_key))
162            .header("Content-Type", "application/json")
163            .header("HTTP-Referer", "https://github.com/jim60105/subx-cli")
164            .header("X-Title", "Subx")
165            .json(&request_body);
166
167        let mut response = self.make_request_with_retry(request).await?;
168
169        const MAX_AI_RESPONSE_BYTES: u64 = 10 * 1024 * 1024; // 10 MiB
170        if let Some(len) = response.content_length() {
171            if len > MAX_AI_RESPONSE_BYTES {
172                return Err(SubXError::AiService(format!(
173                    "AI response too large: {} bytes (limit: {} bytes)",
174                    len, MAX_AI_RESPONSE_BYTES
175                )));
176            }
177        }
178
179        if !response.status().is_success() {
180            let status = response.status();
181            let error_text = response.text().await?;
182            let safe_body = crate::services::ai::error_sanitizer::sanitize_url_in_error(
183                &crate::services::ai::error_sanitizer::truncate_error_body(
184                    &error_text,
185                    crate::services::ai::error_sanitizer::DEFAULT_ERROR_BODY_MAX_LEN,
186                ),
187            );
188            return Err(SubXError::AiService(format!(
189                "OpenRouter API error {}: {}",
190                status, safe_body
191            )));
192        }
193
194        // Bounded chunked read to guard against oversized responses when
195        // content_length() is not reported by the server.
196        let mut body = Vec::new();
197        while let Some(chunk) = response.chunk().await? {
198            body.extend_from_slice(&chunk);
199            if body.len() as u64 > MAX_AI_RESPONSE_BYTES {
200                return Err(SubXError::AiService(format!(
201                    "AI response too large: {} bytes read (limit: {} bytes)",
202                    body.len(),
203                    MAX_AI_RESPONSE_BYTES
204                )));
205            }
206        }
207        let response_json: Value = serde_json::from_slice(&body)
208            .map_err(|e| SubXError::AiService(format!("Failed to parse AI response: {}", e)))?;
209        let content = response_json["choices"][0]["message"]["content"]
210            .as_str()
211            .ok_or_else(|| SubXError::AiService("Invalid API response format".to_string()))?;
212
213        // Parse usage statistics and display
214        if let Some(usage_obj) = response_json.get("usage") {
215            if let (Some(p), Some(c), Some(t)) = (
216                usage_obj.get("prompt_tokens").and_then(Value::as_u64),
217                usage_obj.get("completion_tokens").and_then(Value::as_u64),
218                usage_obj.get("total_tokens").and_then(Value::as_u64),
219            ) {
220                let stats = AiUsageStats {
221                    model: self.model.clone(),
222                    prompt_tokens: p as u32,
223                    completion_tokens: c as u32,
224                    total_tokens: t as u32,
225                };
226                display_ai_usage(&stats);
227            }
228        }
229
230        Ok(content.to_string())
231    }
232
233    async fn make_request_with_retry(
234        &self,
235        request: reqwest::RequestBuilder,
236    ) -> crate::Result<reqwest::Response> {
237        let mut attempts = 0;
238        loop {
239            let cloned = request.try_clone().ok_or_else(|| {
240                crate::error::SubXError::AiService(
241                    "Request body cannot be cloned for retry".to_string(),
242                )
243            })?;
244            match cloned.send().await {
245                Ok(resp) => {
246                    // Retry on server error statuses (5xx) if attempts remain
247                    if resp.status().is_server_error() && (attempts as u32) < self.retry_attempts {
248                        attempts += 1;
249                        log::warn!(
250                            "Request attempt {} failed with status {}. Retrying in {}ms...",
251                            attempts,
252                            resp.status(),
253                            self.retry_delay_ms
254                        );
255                        time::sleep(Duration::from_millis(self.retry_delay_ms)).await;
256                        continue;
257                    }
258                    if attempts > 0 {
259                        log::info!("Request succeeded after {} retry attempts", attempts);
260                    }
261                    return Ok(resp);
262                }
263                Err(e) if (attempts as u32) < self.retry_attempts => {
264                    attempts += 1;
265                    log::warn!(
266                        "Request attempt {} failed: {}. Retrying in {}ms...",
267                        attempts,
268                        e,
269                        self.retry_delay_ms
270                    );
271
272                    if e.is_timeout() {
273                        log::warn!(
274                            "This appears to be a timeout error. If this persists, consider increasing 'ai.request_timeout_seconds' in your configuration."
275                        );
276                    }
277
278                    time::sleep(Duration::from_millis(self.retry_delay_ms)).await;
279                    continue;
280                }
281                Err(e) => {
282                    log::error!(
283                        "Request failed after {} attempts. Final error: {}",
284                        attempts + 1,
285                        e
286                    );
287
288                    if e.is_timeout() {
289                        log::error!(
290                            "AI service error: Request timed out after multiple attempts. \
291                        This usually indicates network connectivity issues or server overload. \
292                        Try increasing 'ai.request_timeout_seconds' configuration. \
293                        Hint: check network connection and API service status"
294                        );
295                    } else if e.is_connect() {
296                        log::error!(
297                            "AI service error: Connection failed. \
298                        Hint: check network connection and API base URL settings"
299                        );
300                    }
301
302                    return Err(e.into());
303                }
304            }
305        }
306    }
307}
308
309#[async_trait]
310impl AIProvider for OpenRouterClient {
311    async fn analyze_content(&self, request: AnalysisRequest) -> Result<MatchResult> {
312        let prompt = self.build_analysis_prompt(&request);
313        let messages = vec![
314            json!({"role": "system", "content": "You are a professional subtitle matching assistant that can analyze the correspondence between video and subtitle files."}),
315            json!({"role": "user", "content": prompt}),
316        ];
317        let response = self.chat_completion(messages).await?;
318        self.parse_match_result(&response)
319    }
320
321    async fn verify_match(&self, verification: VerificationRequest) -> Result<ConfidenceScore> {
322        let prompt = self.build_verification_prompt(&verification);
323        let messages = vec![
324            json!({"role": "system", "content": "Please evaluate the confidence level of subtitle matching and provide a score between 0-1."}),
325            json!({"role": "user", "content": prompt}),
326        ];
327        let response = self.chat_completion(messages).await?;
328        self.parse_confidence_score(&response)
329    }
330}
331
332#[cfg(test)]
333mod tests {
334    use super::*;
335    use mockall::mock;
336    use serde_json::json;
337    use wiremock::matchers::{header, method, path};
338    use wiremock::{Mock, MockServer, ResponseTemplate};
339
340    mock! {
341        AIClient {}
342
343        #[async_trait]
344        impl AIProvider for AIClient {
345            async fn analyze_content(&self, request: AnalysisRequest) -> crate::Result<MatchResult>;
346            async fn verify_match(&self, verification: VerificationRequest) -> crate::Result<ConfidenceScore>;
347        }
348    }
349
350    #[tokio::test]
351    async fn test_openrouter_client_creation() {
352        let client = OpenRouterClient::new(
353            "test-key".into(),
354            "deepseek/deepseek-r1-0528:free".into(),
355            0.5,
356            1000,
357            2,
358            100,
359        );
360        assert_eq!(client.api_key, "test-key");
361        assert_eq!(client.model, "deepseek/deepseek-r1-0528:free");
362        assert_eq!(client.temperature, 0.5);
363        assert_eq!(client.max_tokens, 1000);
364        assert_eq!(client.retry_attempts, 2);
365        assert_eq!(client.retry_delay_ms, 100);
366        assert_eq!(client.base_url, "https://openrouter.ai/api/v1");
367    }
368
369    #[tokio::test]
370    async fn test_openrouter_client_creation_with_custom_base_url() {
371        let client = OpenRouterClient::new_with_base_url_and_timeout(
372            "test-key".into(),
373            "deepseek/deepseek-r1-0528:free".into(),
374            0.3,
375            2000,
376            3,
377            200,
378            "https://custom-openrouter.ai/api/v1".into(),
379            60,
380        );
381        assert_eq!(client.base_url, "https://custom-openrouter.ai/api/v1");
382    }
383
384    #[tokio::test]
385    async fn test_chat_completion_success() {
386        let server = MockServer::start().await;
387        Mock::given(method("POST"))
388            .and(path("/chat/completions"))
389            .and(header("authorization", "Bearer test-key"))
390            .and(header(
391                "HTTP-Referer",
392                "https://github.com/jim60105/subx-cli",
393            ))
394            .and(header("X-Title", "Subx"))
395            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
396                "choices": [{"message": {"content": "test response content"}}],
397                "usage": { "prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15 }
398            })))
399            .mount(&server)
400            .await;
401
402        let mut client = OpenRouterClient::new(
403            "test-key".into(),
404            "deepseek/deepseek-r1-0528:free".into(),
405            0.3,
406            1000,
407            1,
408            0,
409        );
410        client.base_url = server.uri();
411
412        let messages = vec![json!({"role":"user","content":"test"})];
413        let resp = client.chat_completion(messages).await.unwrap();
414        assert_eq!(resp, "test response content");
415    }
416
417    #[tokio::test]
418    async fn test_chat_completion_error_handling() {
419        let server = MockServer::start().await;
420        Mock::given(method("POST"))
421            .and(path("/chat/completions"))
422            .respond_with(ResponseTemplate::new(401).set_body_json(json!({
423                "error": {"message":"Invalid API key"}
424            })))
425            .mount(&server)
426            .await;
427
428        let mut client = OpenRouterClient::new(
429            "bad-key".into(),
430            "deepseek/deepseek-r1-0528:free".into(),
431            0.3,
432            1000,
433            1,
434            0,
435        );
436        client.base_url = server.uri();
437
438        let messages = vec![json!({"role":"user","content":"test"})];
439        let result = client.chat_completion(messages).await;
440        assert!(result.is_err());
441        assert!(
442            result
443                .err()
444                .unwrap()
445                .to_string()
446                .contains("OpenRouter API error 401")
447        );
448    }
449
450    #[tokio::test]
451    async fn test_retry_mechanism() {
452        let server = MockServer::start().await;
453
454        // First request fails, second succeeds
455        Mock::given(method("POST"))
456            .and(path("/chat/completions"))
457            .respond_with(ResponseTemplate::new(500))
458            .up_to_n_times(1)
459            .mount(&server)
460            .await;
461
462        Mock::given(method("POST"))
463            .and(path("/chat/completions"))
464            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
465                "choices": [{"message": {"content": "success after retry"}}]
466            })))
467            .mount(&server)
468            .await;
469
470        let mut client = OpenRouterClient::new(
471            "test-key".into(),
472            "deepseek/deepseek-r1-0528:free".into(),
473            0.3,
474            1000,
475            2,  // Allow 2 retries
476            50, // Short delay for testing
477        );
478        client.base_url = server.uri();
479
480        let messages = vec![json!({"role":"user","content":"test"})];
481        let result = client.chat_completion(messages).await.unwrap();
482        assert_eq!(result, "success after retry");
483    }
484
485    #[test]
486    fn test_openrouter_client_from_config() {
487        let config = crate::config::AIConfig {
488            provider: "openrouter".to_string(),
489            api_key: Some("test-key".to_string()),
490            model: "deepseek/deepseek-r1-0528:free".to_string(),
491            base_url: "https://openrouter.ai/api/v1".to_string(),
492            max_sample_length: 500,
493            temperature: 0.7,
494            max_tokens: 2000,
495            retry_attempts: 3,
496            retry_delay_ms: 150,
497            request_timeout_seconds: 120,
498            api_version: None,
499        };
500
501        let client = OpenRouterClient::from_config(&config).unwrap();
502        assert_eq!(client.api_key, "test-key");
503        assert_eq!(client.model, "deepseek/deepseek-r1-0528:free");
504        assert_eq!(client.temperature, 0.7);
505        assert_eq!(client.max_tokens, 2000);
506        assert_eq!(client.retry_attempts, 3);
507        assert_eq!(client.retry_delay_ms, 150);
508    }
509
510    #[test]
511    fn test_openrouter_client_from_config_missing_api_key() {
512        let config = crate::config::AIConfig {
513            provider: "openrouter".to_string(),
514            api_key: None,
515            model: "deepseek/deepseek-r1-0528:free".to_string(),
516            base_url: "https://openrouter.ai/api/v1".to_string(),
517            max_sample_length: 500,
518            temperature: 0.3,
519            max_tokens: 1000,
520            retry_attempts: 2,
521            retry_delay_ms: 100,
522            request_timeout_seconds: 30,
523            api_version: None,
524        };
525
526        let result = OpenRouterClient::from_config(&config);
527        assert!(result.is_err());
528        assert!(
529            result
530                .err()
531                .unwrap()
532                .to_string()
533                .contains("Missing OpenRouter API Key")
534        );
535    }
536
537    #[test]
538    fn test_openrouter_client_from_config_invalid_base_url() {
539        let config = crate::config::AIConfig {
540            provider: "openrouter".to_string(),
541            api_key: Some("test-key".to_string()),
542            model: "deepseek/deepseek-r1-0528:free".to_string(),
543            base_url: "ftp://invalid.url".to_string(),
544            max_sample_length: 500,
545            temperature: 0.3,
546            max_tokens: 1000,
547            retry_attempts: 2,
548            retry_delay_ms: 100,
549            request_timeout_seconds: 30,
550            api_version: None,
551        };
552
553        let result = OpenRouterClient::from_config(&config);
554        assert!(result.is_err());
555        assert!(
556            result
557                .err()
558                .unwrap()
559                .to_string()
560                .contains("must use http or https protocol")
561        );
562    }
563
564    #[test]
565    fn test_prompt_building_and_parsing() {
566        let client = OpenRouterClient::new(
567            "test-key".into(),
568            "deepseek/deepseek-r1-0528:free".into(),
569            0.1,
570            1000,
571            0,
572            0,
573        );
574        let request = AnalysisRequest {
575            video_files: vec!["video1.mp4".into()],
576            subtitle_files: vec!["subtitle1.srt".into()],
577            content_samples: vec![],
578        };
579
580        let prompt = client.build_analysis_prompt(&request);
581        assert!(prompt.contains("video1.mp4"));
582        assert!(prompt.contains("subtitle1.srt"));
583        assert!(prompt.contains("JSON"));
584
585        let json_response = r#"{ "matches": [], "confidence":0.9, "reasoning":"test reason" }"#;
586        let match_result = client.parse_match_result(json_response).unwrap();
587        assert_eq!(match_result.confidence, 0.9);
588        assert_eq!(match_result.reasoning, "test reason");
589    }
590}