Skip to main content

subx_cli/services/ai/
openrouter.rs

1use crate::Result;
2use crate::cli::display_ai_usage;
3use crate::error::SubXError;
4use crate::services::ai::AiUsageStats;
5use crate::services::ai::{
6    AIProvider, AnalysisRequest, ConfidenceScore, MatchResult, VerificationRequest,
7};
8use async_trait::async_trait;
9use reqwest::Client;
10use serde_json::{Value, json};
11use std::time::Duration;
12use tokio::time;
13
14use crate::services::ai::hosted_hint::{append_local_hint, maybe_attach_local_hint};
15use crate::services::ai::prompts::{PromptBuilder, ResponseParser};
16use crate::services::ai::retry::HttpRetryClient;
17
18/// OpenRouter client implementation
19pub struct OpenRouterClient {
20    client: Client,
21    api_key: String,
22    model: String,
23    temperature: f32,
24    max_tokens: u32,
25    retry_attempts: u32,
26    retry_delay_ms: u64,
27    base_url: String,
28}
29
30impl std::fmt::Debug for OpenRouterClient {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        f.debug_struct("OpenRouterClient")
33            .field("client", &self.client)
34            .field("api_key", &"[REDACTED]")
35            .field("model", &self.model)
36            .field("temperature", &self.temperature)
37            .field("max_tokens", &self.max_tokens)
38            .field("retry_attempts", &self.retry_attempts)
39            .field("retry_delay_ms", &self.retry_delay_ms)
40            .field("base_url", &self.base_url)
41            .finish()
42    }
43}
44
45impl PromptBuilder for OpenRouterClient {}
46impl ResponseParser for OpenRouterClient {}
47impl HttpRetryClient for OpenRouterClient {
48    fn retry_attempts(&self) -> u32 {
49        self.retry_attempts
50    }
51    fn retry_delay_ms(&self) -> u64 {
52        self.retry_delay_ms
53    }
54}
55
56impl OpenRouterClient {
57    /// Create new OpenRouterClient with default configuration
58    pub fn new(
59        api_key: String,
60        model: String,
61        temperature: f32,
62        max_tokens: u32,
63        retry_attempts: u32,
64        retry_delay_ms: u64,
65    ) -> Self {
66        Self::new_with_base_url_and_timeout(
67            api_key,
68            model,
69            temperature,
70            max_tokens,
71            retry_attempts,
72            retry_delay_ms,
73            "https://openrouter.ai/api/v1".to_string(),
74            120,
75        )
76    }
77
78    /// Create new OpenRouterClient with custom base URL and timeout
79    #[allow(clippy::too_many_arguments)]
80    pub fn new_with_base_url_and_timeout(
81        api_key: String,
82        model: String,
83        temperature: f32,
84        max_tokens: u32,
85        retry_attempts: u32,
86        retry_delay_ms: u64,
87        base_url: String,
88        request_timeout_seconds: u64,
89    ) -> Self {
90        let client = Client::builder()
91            .timeout(Duration::from_secs(request_timeout_seconds))
92            .build()
93            .expect("Failed to create HTTP client");
94
95        Self {
96            client,
97            api_key,
98            model,
99            temperature,
100            max_tokens,
101            retry_attempts,
102            retry_delay_ms,
103            base_url: base_url.trim_end_matches('/').to_string(),
104        }
105    }
106
107    /// Create client from unified configuration
108    pub fn from_config(config: &crate::config::AIConfig) -> crate::Result<Self> {
109        let api_key = config
110            .api_key
111            .as_ref()
112            .ok_or_else(|| SubXError::config("Missing OpenRouter API Key"))?;
113
114        // Validate base URL format
115        Self::validate_base_url(&config.base_url)?;
116        crate::services::ai::security::warn_on_insecure_http_str(&config.base_url, api_key);
117
118        Ok(Self::new_with_base_url_and_timeout(
119            api_key.clone(),
120            config.model.clone(),
121            config.temperature,
122            config.max_tokens,
123            config.retry_attempts,
124            config.retry_delay_ms,
125            config.base_url.clone(),
126            config.request_timeout_seconds,
127        ))
128    }
129
130    /// Validate base URL format
131    fn validate_base_url(url: &str) -> crate::Result<()> {
132        use url::Url;
133        let parsed =
134            Url::parse(url).map_err(|e| SubXError::config(format!("Invalid base URL: {}", e)))?;
135
136        if !matches!(parsed.scheme(), "http" | "https") {
137            return Err(SubXError::config(
138                "Base URL must use http or https protocol".to_string(),
139            ));
140        }
141
142        if parsed.host().is_none() {
143            return Err(SubXError::config(
144                "Base URL must contain a valid hostname".to_string(),
145            ));
146        }
147
148        Ok(())
149    }
150
151    /// Send a raw chat completion request to the OpenRouter Chat Completions API.
152    pub async fn chat_completion(&self, messages: Vec<Value>) -> Result<String> {
153        let request_body = json!({
154            "model": self.model,
155            "messages": messages,
156            "temperature": self.temperature,
157            "max_tokens": self.max_tokens,
158        });
159
160        let request = self
161            .client
162            .post(format!("{}/chat/completions", self.base_url))
163            .header("Authorization", format!("Bearer {}", self.api_key))
164            .header("Content-Type", "application/json")
165            .header("HTTP-Referer", "https://github.com/jim60105/subx-cli")
166            .header("X-Title", "Subx")
167            .json(&request_body);
168
169        let mut response = match self.make_request_with_retry(request).await {
170            Ok(r) => r,
171            Err(e) => return Err(maybe_attach_local_hint(e, &self.base_url)),
172        };
173
174        const MAX_AI_RESPONSE_BYTES: u64 = 10 * 1024 * 1024; // 10 MiB
175        if let Some(len) = response.content_length() {
176            if len > MAX_AI_RESPONSE_BYTES {
177                return Err(SubXError::AiService(format!(
178                    "AI response too large: {} bytes (limit: {} bytes)",
179                    len, MAX_AI_RESPONSE_BYTES
180                )));
181            }
182        }
183
184        if !response.status().is_success() {
185            let status = response.status();
186            let error_text = response.text().await?;
187            let safe_body = crate::services::ai::error_sanitizer::sanitize_url_in_error(
188                &crate::services::ai::error_sanitizer::truncate_error_body(
189                    &error_text,
190                    crate::services::ai::error_sanitizer::DEFAULT_ERROR_BODY_MAX_LEN,
191                ),
192            );
193            return Err(SubXError::AiService(format!(
194                "OpenRouter API error {}: {}",
195                status, safe_body
196            )));
197        }
198
199        // Bounded chunked read to guard against oversized responses when
200        // content_length() is not reported by the server.
201        let mut body = Vec::new();
202        while let Some(chunk) = response.chunk().await? {
203            body.extend_from_slice(&chunk);
204            if body.len() as u64 > MAX_AI_RESPONSE_BYTES {
205                return Err(SubXError::AiService(format!(
206                    "AI response too large: {} bytes read (limit: {} bytes)",
207                    body.len(),
208                    MAX_AI_RESPONSE_BYTES
209                )));
210            }
211        }
212        let response_json: Value = serde_json::from_slice(&body)
213            .map_err(|e| SubXError::AiService(format!("Failed to parse AI response: {}", e)))?;
214        let content = response_json["choices"][0]["message"]["content"]
215            .as_str()
216            .ok_or_else(|| {
217                SubXError::AiService(append_local_hint("Invalid API response format"))
218            })?;
219
220        // Parse usage statistics and display
221        if let Some(usage_obj) = response_json.get("usage") {
222            if let (Some(p), Some(c), Some(t)) = (
223                usage_obj.get("prompt_tokens").and_then(Value::as_u64),
224                usage_obj.get("completion_tokens").and_then(Value::as_u64),
225                usage_obj.get("total_tokens").and_then(Value::as_u64),
226            ) {
227                let stats = AiUsageStats {
228                    model: self.model.clone(),
229                    prompt_tokens: p as u32,
230                    completion_tokens: c as u32,
231                    total_tokens: t as u32,
232                };
233                display_ai_usage(&stats);
234            }
235        }
236
237        Ok(content.to_string())
238    }
239
240    async fn make_request_with_retry(
241        &self,
242        request: reqwest::RequestBuilder,
243    ) -> crate::Result<reqwest::Response> {
244        let mut attempts = 0;
245        loop {
246            let cloned = request.try_clone().ok_or_else(|| {
247                crate::error::SubXError::AiService(
248                    "Request body cannot be cloned for retry".to_string(),
249                )
250            })?;
251            match cloned.send().await {
252                Ok(resp) => {
253                    // Retry on server error statuses (5xx) if attempts remain
254                    if resp.status().is_server_error() && (attempts as u32) < self.retry_attempts {
255                        attempts += 1;
256                        log::warn!(
257                            "Request attempt {} failed with status {}. Retrying in {}ms...",
258                            attempts,
259                            resp.status(),
260                            self.retry_delay_ms
261                        );
262                        time::sleep(Duration::from_millis(self.retry_delay_ms)).await;
263                        continue;
264                    }
265                    if attempts > 0 {
266                        log::info!("Request succeeded after {} retry attempts", attempts);
267                    }
268                    return Ok(resp);
269                }
270                Err(e) if (attempts as u32) < self.retry_attempts => {
271                    attempts += 1;
272                    log::warn!(
273                        "Request attempt {} failed: {}. Retrying in {}ms...",
274                        attempts,
275                        e,
276                        self.retry_delay_ms
277                    );
278
279                    if e.is_timeout() {
280                        log::warn!(
281                            "This appears to be a timeout error. If this persists, consider increasing 'ai.request_timeout_seconds' in your configuration."
282                        );
283                    }
284
285                    time::sleep(Duration::from_millis(self.retry_delay_ms)).await;
286                    continue;
287                }
288                Err(e) => {
289                    log::error!(
290                        "Request failed after {} attempts. Final error: {}",
291                        attempts + 1,
292                        e
293                    );
294
295                    if e.is_timeout() {
296                        log::error!(
297                            "AI service error: Request timed out after multiple attempts. \
298                        This usually indicates network connectivity issues or server overload. \
299                        Try increasing 'ai.request_timeout_seconds' configuration. \
300                        Hint: check network connection and API service status"
301                        );
302                    } else if e.is_connect() {
303                        log::error!(
304                            "AI service error: Connection failed. \
305                        Hint: check network connection and API base URL settings"
306                        );
307                    }
308
309                    return Err(e.into());
310                }
311            }
312        }
313    }
314}
315
316#[async_trait]
317impl AIProvider for OpenRouterClient {
318    async fn analyze_content(&self, request: AnalysisRequest) -> Result<MatchResult> {
319        let prompt = self.build_analysis_prompt(&request);
320        let messages = vec![
321            json!({"role": "system", "content": "You are a professional subtitle matching assistant that can analyze the correspondence between video and subtitle files."}),
322            json!({"role": "user", "content": prompt}),
323        ];
324        let response = self.chat_completion(messages).await?;
325        self.parse_match_result(&response)
326    }
327
328    async fn verify_match(&self, verification: VerificationRequest) -> Result<ConfidenceScore> {
329        let prompt = self.build_verification_prompt(&verification);
330        let messages = vec![
331            json!({"role": "system", "content": "Please evaluate the confidence level of subtitle matching and provide a score between 0-1."}),
332            json!({"role": "user", "content": prompt}),
333        ];
334        let response = self.chat_completion(messages).await?;
335        self.parse_confidence_score(&response)
336    }
337
338    async fn chat_completion(&self, messages: Vec<Value>) -> Result<String> {
339        OpenRouterClient::chat_completion(self, messages).await
340    }
341}
342
343#[cfg(test)]
344mod tests {
345    use super::*;
346    use mockall::mock;
347    use serde_json::json;
348    use wiremock::matchers::{header, method, path};
349    use wiremock::{Mock, MockServer, ResponseTemplate};
350
351    mock! {
352        AIClient {}
353
354        #[async_trait]
355        impl AIProvider for AIClient {
356            async fn analyze_content(&self, request: AnalysisRequest) -> crate::Result<MatchResult>;
357            async fn verify_match(&self, verification: VerificationRequest) -> crate::Result<ConfidenceScore>;
358        }
359    }
360
361    #[tokio::test]
362    async fn test_openrouter_client_creation() {
363        let client = OpenRouterClient::new(
364            "test-key".into(),
365            "deepseek/deepseek-r1-0528:free".into(),
366            0.5,
367            1000,
368            2,
369            100,
370        );
371        assert_eq!(client.api_key, "test-key");
372        assert_eq!(client.model, "deepseek/deepseek-r1-0528:free");
373        assert_eq!(client.temperature, 0.5);
374        assert_eq!(client.max_tokens, 1000);
375        assert_eq!(client.retry_attempts, 2);
376        assert_eq!(client.retry_delay_ms, 100);
377        assert_eq!(client.base_url, "https://openrouter.ai/api/v1");
378    }
379
380    #[tokio::test]
381    async fn test_openrouter_client_creation_with_custom_base_url() {
382        let client = OpenRouterClient::new_with_base_url_and_timeout(
383            "test-key".into(),
384            "deepseek/deepseek-r1-0528:free".into(),
385            0.3,
386            2000,
387            3,
388            200,
389            "https://custom-openrouter.ai/api/v1".into(),
390            60,
391        );
392        assert_eq!(client.base_url, "https://custom-openrouter.ai/api/v1");
393    }
394
395    #[tokio::test]
396    async fn test_chat_completion_success() {
397        let server = MockServer::start().await;
398        Mock::given(method("POST"))
399            .and(path("/chat/completions"))
400            .and(header("authorization", "Bearer test-key"))
401            .and(header(
402                "HTTP-Referer",
403                "https://github.com/jim60105/subx-cli",
404            ))
405            .and(header("X-Title", "Subx"))
406            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
407                "choices": [{"message": {"content": "test response content"}}],
408                "usage": { "prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15 }
409            })))
410            .mount(&server)
411            .await;
412
413        let mut client = OpenRouterClient::new(
414            "test-key".into(),
415            "deepseek/deepseek-r1-0528:free".into(),
416            0.3,
417            1000,
418            1,
419            0,
420        );
421        client.base_url = server.uri();
422
423        let messages = vec![json!({"role":"user","content":"test"})];
424        let resp = client.chat_completion(messages).await.unwrap();
425        assert_eq!(resp, "test response content");
426    }
427
428    #[tokio::test]
429    async fn test_chat_completion_error_handling() {
430        let server = MockServer::start().await;
431        Mock::given(method("POST"))
432            .and(path("/chat/completions"))
433            .respond_with(ResponseTemplate::new(401).set_body_json(json!({
434                "error": {"message":"Invalid API key"}
435            })))
436            .mount(&server)
437            .await;
438
439        let mut client = OpenRouterClient::new(
440            "bad-key".into(),
441            "deepseek/deepseek-r1-0528:free".into(),
442            0.3,
443            1000,
444            1,
445            0,
446        );
447        client.base_url = server.uri();
448
449        let messages = vec![json!({"role":"user","content":"test"})];
450        let result = client.chat_completion(messages).await;
451        assert!(result.is_err());
452        assert!(
453            result
454                .err()
455                .unwrap()
456                .to_string()
457                .contains("OpenRouter API error 401")
458        );
459    }
460
461    #[tokio::test]
462    async fn test_retry_mechanism() {
463        let server = MockServer::start().await;
464
465        // First request fails, second succeeds
466        Mock::given(method("POST"))
467            .and(path("/chat/completions"))
468            .respond_with(ResponseTemplate::new(500))
469            .up_to_n_times(1)
470            .mount(&server)
471            .await;
472
473        Mock::given(method("POST"))
474            .and(path("/chat/completions"))
475            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
476                "choices": [{"message": {"content": "success after retry"}}]
477            })))
478            .mount(&server)
479            .await;
480
481        let mut client = OpenRouterClient::new(
482            "test-key".into(),
483            "deepseek/deepseek-r1-0528:free".into(),
484            0.3,
485            1000,
486            2,  // Allow 2 retries
487            50, // Short delay for testing
488        );
489        client.base_url = server.uri();
490
491        let messages = vec![json!({"role":"user","content":"test"})];
492        let result = client.chat_completion(messages).await.unwrap();
493        assert_eq!(result, "success after retry");
494    }
495
496    #[test]
497    fn test_openrouter_client_from_config() {
498        let config = crate::config::AIConfig {
499            provider: "openrouter".to_string(),
500            api_key: Some("test-key".to_string()),
501            model: "deepseek/deepseek-r1-0528:free".to_string(),
502            base_url: "https://openrouter.ai/api/v1".to_string(),
503            max_sample_length: 500,
504            temperature: 0.7,
505            max_tokens: 2000,
506            retry_attempts: 3,
507            retry_delay_ms: 150,
508            request_timeout_seconds: 120,
509            api_version: None,
510        };
511
512        let client = OpenRouterClient::from_config(&config).unwrap();
513        assert_eq!(client.api_key, "test-key");
514        assert_eq!(client.model, "deepseek/deepseek-r1-0528:free");
515        assert_eq!(client.temperature, 0.7);
516        assert_eq!(client.max_tokens, 2000);
517        assert_eq!(client.retry_attempts, 3);
518        assert_eq!(client.retry_delay_ms, 150);
519    }
520
521    #[test]
522    fn test_openrouter_client_from_config_missing_api_key() {
523        let config = crate::config::AIConfig {
524            provider: "openrouter".to_string(),
525            api_key: None,
526            model: "deepseek/deepseek-r1-0528:free".to_string(),
527            base_url: "https://openrouter.ai/api/v1".to_string(),
528            max_sample_length: 500,
529            temperature: 0.3,
530            max_tokens: 1000,
531            retry_attempts: 2,
532            retry_delay_ms: 100,
533            request_timeout_seconds: 30,
534            api_version: None,
535        };
536
537        let result = OpenRouterClient::from_config(&config);
538        assert!(result.is_err());
539        assert!(
540            result
541                .err()
542                .unwrap()
543                .to_string()
544                .contains("Missing OpenRouter API Key")
545        );
546    }
547
548    #[test]
549    fn test_openrouter_client_from_config_invalid_base_url() {
550        let config = crate::config::AIConfig {
551            provider: "openrouter".to_string(),
552            api_key: Some("test-key".to_string()),
553            model: "deepseek/deepseek-r1-0528:free".to_string(),
554            base_url: "ftp://invalid.url".to_string(),
555            max_sample_length: 500,
556            temperature: 0.3,
557            max_tokens: 1000,
558            retry_attempts: 2,
559            retry_delay_ms: 100,
560            request_timeout_seconds: 30,
561            api_version: None,
562        };
563
564        let result = OpenRouterClient::from_config(&config);
565        assert!(result.is_err());
566        assert!(
567            result
568                .err()
569                .unwrap()
570                .to_string()
571                .contains("must use http or https protocol")
572        );
573    }
574
575    #[test]
576    fn test_prompt_building_and_parsing() {
577        let client = OpenRouterClient::new(
578            "test-key".into(),
579            "deepseek/deepseek-r1-0528:free".into(),
580            0.1,
581            1000,
582            0,
583            0,
584        );
585        let request = AnalysisRequest {
586            video_files: vec!["video1.mp4".into()],
587            subtitle_files: vec!["subtitle1.srt".into()],
588            content_samples: vec![],
589        };
590
591        let prompt = client.build_analysis_prompt(&request);
592        assert!(prompt.contains("video1.mp4"));
593        assert!(prompt.contains("subtitle1.srt"));
594        assert!(prompt.contains("JSON"));
595
596        let json_response = r#"{ "matches": [], "confidence":0.9, "reasoning":"test reason" }"#;
597        let match_result = client.parse_match_result(json_response).unwrap();
598        assert_eq!(match_result.confidence, 0.9);
599        assert_eq!(match_result.reasoning, "test reason");
600    }
601
602    /// §3.6 — connection refused against `127.0.0.1` on a port with no
603    /// listener results in a hint-bearing error.
604    #[tokio::test]
605    async fn test_hosted_hint_connection_refused_loopback() {
606        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
607        let port = listener.local_addr().unwrap().port();
608        drop(listener);
609        let client = OpenRouterClient::new_with_base_url_and_timeout(
610            "k".into(),
611            "deepseek/deepseek-r1-0528:free".into(),
612            0.0,
613            16,
614            0,
615            0,
616            format!("http://127.0.0.1:{}", port),
617            1,
618        );
619        let err = client
620            .chat_completion(vec![json!({"role":"user","content":"x"})])
621            .await
622            .unwrap_err();
623        let msg = err.to_string();
624        assert!(
625            msg.contains("ollama") && msg.contains("ai.provider"),
626            "expected local-provider hint: {msg}"
627        );
628    }
629
630    /// §3.6 — HTTP 200 with a non-OpenAI body MUST surface the hint.
631    #[tokio::test]
632    async fn test_hosted_hint_http_200_non_openai_body() {
633        let server = MockServer::start().await;
634        Mock::given(method("POST"))
635            .and(path("/chat/completions"))
636            .respond_with(ResponseTemplate::new(200).set_body_json(json!({ "hello": "world" })))
637            .mount(&server)
638            .await;
639        let mut client = OpenRouterClient::new(
640            "k".into(),
641            "deepseek/deepseek-r1-0528:free".into(),
642            0.0,
643            16,
644            0,
645            0,
646        );
647        client.base_url = server.uri();
648        let err = client
649            .chat_completion(vec![json!({"role":"user","content":"x"})])
650            .await
651            .unwrap_err();
652        let msg = err.to_string();
653        assert!(
654            msg.contains("Invalid API response format")
655                && msg.contains("ollama")
656                && msg.contains("ai.provider"),
657            "expected hint-bearing parse-shape error: {msg}"
658        );
659    }
660
661    /// §3.6 negative — a failure against a public host MUST NOT surface
662    /// the hint. Uses TEST-NET-1 (RFC 5737) so the test is hermetic.
663    #[tokio::test]
664    async fn test_hosted_hint_not_emitted_for_public_host() {
665        let client = OpenRouterClient::new_with_base_url_and_timeout(
666            "k".into(),
667            "deepseek/deepseek-r1-0528:free".into(),
668            0.0,
669            16,
670            0,
671            0,
672            "https://192.0.2.1/api/v1".to_string(),
673            1,
674        );
675        let err = client
676            .chat_completion(vec![json!({"role":"user","content":"x"})])
677            .await
678            .unwrap_err();
679        let msg = err.to_string();
680        assert!(
681            !msg.contains("ollama"),
682            "public-host failure must NOT carry the hint: {msg}"
683        );
684    }
685}