subx_cli/services/ai/
openrouter.rs

1use crate::Result;
2use crate::cli::display_ai_usage;
3use crate::error::SubXError;
4use crate::services::ai::AiUsageStats;
5use crate::services::ai::{
6    AIProvider, AnalysisRequest, ConfidenceScore, MatchResult, VerificationRequest,
7};
8use async_trait::async_trait;
9use reqwest::Client;
10use serde_json::{Value, json};
11use std::time::Duration;
12use tokio::time;
13
14use crate::services::ai::prompts::{PromptBuilder, ResponseParser};
15use crate::services::ai::retry::HttpRetryClient;
16
17/// OpenRouter client implementation
18#[derive(Debug)]
19pub struct OpenRouterClient {
20    client: Client,
21    api_key: String,
22    model: String,
23    temperature: f32,
24    max_tokens: u32,
25    retry_attempts: u32,
26    retry_delay_ms: u64,
27    base_url: String,
28}
29
30impl PromptBuilder for OpenRouterClient {}
31impl ResponseParser for OpenRouterClient {}
32impl HttpRetryClient for OpenRouterClient {
33    fn retry_attempts(&self) -> u32 {
34        self.retry_attempts
35    }
36    fn retry_delay_ms(&self) -> u64 {
37        self.retry_delay_ms
38    }
39}
40
41impl OpenRouterClient {
42    /// Create new OpenRouterClient with default configuration
43    pub fn new(
44        api_key: String,
45        model: String,
46        temperature: f32,
47        max_tokens: u32,
48        retry_attempts: u32,
49        retry_delay_ms: u64,
50    ) -> Self {
51        Self::new_with_base_url_and_timeout(
52            api_key,
53            model,
54            temperature,
55            max_tokens,
56            retry_attempts,
57            retry_delay_ms,
58            "https://openrouter.ai/api/v1".to_string(),
59            120,
60        )
61    }
62
63    /// Create new OpenRouterClient with custom base URL and timeout
64    #[allow(clippy::too_many_arguments)]
65    pub fn new_with_base_url_and_timeout(
66        api_key: String,
67        model: String,
68        temperature: f32,
69        max_tokens: u32,
70        retry_attempts: u32,
71        retry_delay_ms: u64,
72        base_url: String,
73        request_timeout_seconds: u64,
74    ) -> Self {
75        let client = Client::builder()
76            .timeout(Duration::from_secs(request_timeout_seconds))
77            .build()
78            .expect("Failed to create HTTP client");
79
80        Self {
81            client,
82            api_key,
83            model,
84            temperature,
85            max_tokens,
86            retry_attempts,
87            retry_delay_ms,
88            base_url: base_url.trim_end_matches('/').to_string(),
89        }
90    }
91
92    /// Create client from unified configuration
93    pub fn from_config(config: &crate::config::AIConfig) -> crate::Result<Self> {
94        let api_key = config
95            .api_key
96            .as_ref()
97            .ok_or_else(|| SubXError::config("Missing OpenRouter API Key"))?;
98
99        // Validate base URL format
100        Self::validate_base_url(&config.base_url)?;
101
102        Ok(Self::new_with_base_url_and_timeout(
103            api_key.clone(),
104            config.model.clone(),
105            config.temperature,
106            config.max_tokens,
107            config.retry_attempts,
108            config.retry_delay_ms,
109            config.base_url.clone(),
110            config.request_timeout_seconds,
111        ))
112    }
113
114    /// Validate base URL format
115    fn validate_base_url(url: &str) -> crate::Result<()> {
116        use url::Url;
117        let parsed =
118            Url::parse(url).map_err(|e| SubXError::config(format!("Invalid base URL: {}", e)))?;
119
120        if !matches!(parsed.scheme(), "http" | "https") {
121            return Err(SubXError::config(
122                "Base URL must use http or https protocol".to_string(),
123            ));
124        }
125
126        if parsed.host().is_none() {
127            return Err(SubXError::config(
128                "Base URL must contain a valid hostname".to_string(),
129            ));
130        }
131
132        Ok(())
133    }
134
135    async fn chat_completion(&self, messages: Vec<Value>) -> Result<String> {
136        let request_body = json!({
137            "model": self.model,
138            "messages": messages,
139            "temperature": self.temperature,
140            "max_tokens": self.max_tokens,
141        });
142
143        let request = self
144            .client
145            .post(format!("{}/chat/completions", self.base_url))
146            .header("Authorization", format!("Bearer {}", self.api_key))
147            .header("Content-Type", "application/json")
148            .header("HTTP-Referer", "https://github.com/jim60105/subx-cli")
149            .header("X-Title", "Subx")
150            .json(&request_body);
151
152        let response = self.make_request_with_retry(request).await?;
153
154        if !response.status().is_success() {
155            let status = response.status();
156            let error_text = response.text().await?;
157            return Err(SubXError::AiService(format!(
158                "OpenRouter API error {}: {}",
159                status, error_text
160            )));
161        }
162
163        let response_json: Value = response.json().await?;
164        let content = response_json["choices"][0]["message"]["content"]
165            .as_str()
166            .ok_or_else(|| SubXError::AiService("Invalid API response format".to_string()))?;
167
168        // Parse usage statistics and display
169        if let Some(usage_obj) = response_json.get("usage") {
170            if let (Some(p), Some(c), Some(t)) = (
171                usage_obj.get("prompt_tokens").and_then(Value::as_u64),
172                usage_obj.get("completion_tokens").and_then(Value::as_u64),
173                usage_obj.get("total_tokens").and_then(Value::as_u64),
174            ) {
175                let stats = AiUsageStats {
176                    model: self.model.clone(),
177                    prompt_tokens: p as u32,
178                    completion_tokens: c as u32,
179                    total_tokens: t as u32,
180                };
181                display_ai_usage(&stats);
182            }
183        }
184
185        Ok(content.to_string())
186    }
187
188    async fn make_request_with_retry(
189        &self,
190        request: reqwest::RequestBuilder,
191    ) -> reqwest::Result<reqwest::Response> {
192        let mut attempts = 0;
193        loop {
194            match request.try_clone().unwrap().send().await {
195                Ok(resp) => {
196                    // Retry on server error statuses (5xx) if attempts remain
197                    if resp.status().is_server_error() && (attempts as u32) < self.retry_attempts {
198                        attempts += 1;
199                        log::warn!(
200                            "Request attempt {} failed with status {}. Retrying in {}ms...",
201                            attempts,
202                            resp.status(),
203                            self.retry_delay_ms
204                        );
205                        time::sleep(Duration::from_millis(self.retry_delay_ms)).await;
206                        continue;
207                    }
208                    if attempts > 0 {
209                        log::info!("Request succeeded after {} retry attempts", attempts);
210                    }
211                    return Ok(resp);
212                }
213                Err(e) if (attempts as u32) < self.retry_attempts => {
214                    attempts += 1;
215                    log::warn!(
216                        "Request attempt {} failed: {}. Retrying in {}ms...",
217                        attempts,
218                        e,
219                        self.retry_delay_ms
220                    );
221
222                    if e.is_timeout() {
223                        log::warn!(
224                            "This appears to be a timeout error. If this persists, consider increasing 'ai.request_timeout_seconds' in your configuration."
225                        );
226                    }
227
228                    time::sleep(Duration::from_millis(self.retry_delay_ms)).await;
229                    continue;
230                }
231                Err(e) => {
232                    log::error!(
233                        "Request failed after {} attempts. Final error: {}",
234                        attempts + 1,
235                        e
236                    );
237
238                    if e.is_timeout() {
239                        log::error!(
240                            "AI service error: Request timed out after multiple attempts. \
241                        This usually indicates network connectivity issues or server overload. \
242                        Try increasing 'ai.request_timeout_seconds' configuration. \
243                        Hint: check network connection and API service status"
244                        );
245                    } else if e.is_connect() {
246                        log::error!(
247                            "AI service error: Connection failed. \
248                        Hint: check network connection and API base URL settings"
249                        );
250                    }
251
252                    return Err(e);
253                }
254            }
255        }
256    }
257}
258
259#[async_trait]
260impl AIProvider for OpenRouterClient {
261    async fn analyze_content(&self, request: AnalysisRequest) -> Result<MatchResult> {
262        let prompt = self.build_analysis_prompt(&request);
263        let messages = vec![
264            json!({"role": "system", "content": "You are a professional subtitle matching assistant that can analyze the correspondence between video and subtitle files."}),
265            json!({"role": "user", "content": prompt}),
266        ];
267        let response = self.chat_completion(messages).await?;
268        self.parse_match_result(&response)
269    }
270
271    async fn verify_match(&self, verification: VerificationRequest) -> Result<ConfidenceScore> {
272        let prompt = self.build_verification_prompt(&verification);
273        let messages = vec![
274            json!({"role": "system", "content": "Please evaluate the confidence level of subtitle matching and provide a score between 0-1."}),
275            json!({"role": "user", "content": prompt}),
276        ];
277        let response = self.chat_completion(messages).await?;
278        self.parse_confidence_score(&response)
279    }
280}
281
282#[cfg(test)]
283mod tests {
284    use super::*;
285    use mockall::mock;
286    use serde_json::json;
287    use wiremock::matchers::{header, method, path};
288    use wiremock::{Mock, MockServer, ResponseTemplate};
289
290    mock! {
291        AIClient {}
292
293        #[async_trait]
294        impl AIProvider for AIClient {
295            async fn analyze_content(&self, request: AnalysisRequest) -> crate::Result<MatchResult>;
296            async fn verify_match(&self, verification: VerificationRequest) -> crate::Result<ConfidenceScore>;
297        }
298    }
299
300    #[tokio::test]
301    async fn test_openrouter_client_creation() {
302        let client = OpenRouterClient::new(
303            "test-key".into(),
304            "deepseek/deepseek-r1-0528:free".into(),
305            0.5,
306            1000,
307            2,
308            100,
309        );
310        assert_eq!(client.api_key, "test-key");
311        assert_eq!(client.model, "deepseek/deepseek-r1-0528:free");
312        assert_eq!(client.temperature, 0.5);
313        assert_eq!(client.max_tokens, 1000);
314        assert_eq!(client.retry_attempts, 2);
315        assert_eq!(client.retry_delay_ms, 100);
316        assert_eq!(client.base_url, "https://openrouter.ai/api/v1");
317    }
318
319    #[tokio::test]
320    async fn test_openrouter_client_creation_with_custom_base_url() {
321        let client = OpenRouterClient::new_with_base_url_and_timeout(
322            "test-key".into(),
323            "deepseek/deepseek-r1-0528:free".into(),
324            0.3,
325            2000,
326            3,
327            200,
328            "https://custom-openrouter.ai/api/v1".into(),
329            60,
330        );
331        assert_eq!(client.base_url, "https://custom-openrouter.ai/api/v1");
332    }
333
334    #[tokio::test]
335    async fn test_chat_completion_success() {
336        let server = MockServer::start().await;
337        Mock::given(method("POST"))
338            .and(path("/chat/completions"))
339            .and(header("authorization", "Bearer test-key"))
340            .and(header(
341                "HTTP-Referer",
342                "https://github.com/jim60105/subx-cli",
343            ))
344            .and(header("X-Title", "Subx"))
345            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
346                "choices": [{"message": {"content": "test response content"}}],
347                "usage": { "prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15 }
348            })))
349            .mount(&server)
350            .await;
351
352        let mut client = OpenRouterClient::new(
353            "test-key".into(),
354            "deepseek/deepseek-r1-0528:free".into(),
355            0.3,
356            1000,
357            1,
358            0,
359        );
360        client.base_url = server.uri();
361
362        let messages = vec![json!({"role":"user","content":"test"})];
363        let resp = client.chat_completion(messages).await.unwrap();
364        assert_eq!(resp, "test response content");
365    }
366
367    #[tokio::test]
368    async fn test_chat_completion_error_handling() {
369        let server = MockServer::start().await;
370        Mock::given(method("POST"))
371            .and(path("/chat/completions"))
372            .respond_with(ResponseTemplate::new(401).set_body_json(json!({
373                "error": {"message":"Invalid API key"}
374            })))
375            .mount(&server)
376            .await;
377
378        let mut client = OpenRouterClient::new(
379            "bad-key".into(),
380            "deepseek/deepseek-r1-0528:free".into(),
381            0.3,
382            1000,
383            1,
384            0,
385        );
386        client.base_url = server.uri();
387
388        let messages = vec![json!({"role":"user","content":"test"})];
389        let result = client.chat_completion(messages).await;
390        assert!(result.is_err());
391        assert!(
392            result
393                .err()
394                .unwrap()
395                .to_string()
396                .contains("OpenRouter API error 401")
397        );
398    }
399
400    #[tokio::test]
401    async fn test_retry_mechanism() {
402        let server = MockServer::start().await;
403
404        // First request fails, second succeeds
405        Mock::given(method("POST"))
406            .and(path("/chat/completions"))
407            .respond_with(ResponseTemplate::new(500))
408            .up_to_n_times(1)
409            .mount(&server)
410            .await;
411
412        Mock::given(method("POST"))
413            .and(path("/chat/completions"))
414            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
415                "choices": [{"message": {"content": "success after retry"}}]
416            })))
417            .mount(&server)
418            .await;
419
420        let mut client = OpenRouterClient::new(
421            "test-key".into(),
422            "deepseek/deepseek-r1-0528:free".into(),
423            0.3,
424            1000,
425            2,  // Allow 2 retries
426            50, // Short delay for testing
427        );
428        client.base_url = server.uri();
429
430        let messages = vec![json!({"role":"user","content":"test"})];
431        let result = client.chat_completion(messages).await.unwrap();
432        assert_eq!(result, "success after retry");
433    }
434
435    #[test]
436    fn test_openrouter_client_from_config() {
437        let config = crate::config::AIConfig {
438            provider: "openrouter".to_string(),
439            api_key: Some("test-key".to_string()),
440            model: "deepseek/deepseek-r1-0528:free".to_string(),
441            base_url: "https://openrouter.ai/api/v1".to_string(),
442            max_sample_length: 500,
443            temperature: 0.7,
444            max_tokens: 2000,
445            retry_attempts: 3,
446            retry_delay_ms: 150,
447            request_timeout_seconds: 120,
448            api_version: None,
449        };
450
451        let client = OpenRouterClient::from_config(&config).unwrap();
452        assert_eq!(client.api_key, "test-key");
453        assert_eq!(client.model, "deepseek/deepseek-r1-0528:free");
454        assert_eq!(client.temperature, 0.7);
455        assert_eq!(client.max_tokens, 2000);
456        assert_eq!(client.retry_attempts, 3);
457        assert_eq!(client.retry_delay_ms, 150);
458    }
459
460    #[test]
461    fn test_openrouter_client_from_config_missing_api_key() {
462        let config = crate::config::AIConfig {
463            provider: "openrouter".to_string(),
464            api_key: None,
465            model: "deepseek/deepseek-r1-0528:free".to_string(),
466            base_url: "https://openrouter.ai/api/v1".to_string(),
467            max_sample_length: 500,
468            temperature: 0.3,
469            max_tokens: 1000,
470            retry_attempts: 2,
471            retry_delay_ms: 100,
472            request_timeout_seconds: 30,
473            api_version: None,
474        };
475
476        let result = OpenRouterClient::from_config(&config);
477        assert!(result.is_err());
478        assert!(
479            result
480                .err()
481                .unwrap()
482                .to_string()
483                .contains("Missing OpenRouter API Key")
484        );
485    }
486
487    #[test]
488    fn test_openrouter_client_from_config_invalid_base_url() {
489        let config = crate::config::AIConfig {
490            provider: "openrouter".to_string(),
491            api_key: Some("test-key".to_string()),
492            model: "deepseek/deepseek-r1-0528:free".to_string(),
493            base_url: "ftp://invalid.url".to_string(),
494            max_sample_length: 500,
495            temperature: 0.3,
496            max_tokens: 1000,
497            retry_attempts: 2,
498            retry_delay_ms: 100,
499            request_timeout_seconds: 30,
500            api_version: None,
501        };
502
503        let result = OpenRouterClient::from_config(&config);
504        assert!(result.is_err());
505        assert!(
506            result
507                .err()
508                .unwrap()
509                .to_string()
510                .contains("must use http or https protocol")
511        );
512    }
513
514    #[test]
515    fn test_prompt_building_and_parsing() {
516        let client = OpenRouterClient::new(
517            "test-key".into(),
518            "deepseek/deepseek-r1-0528:free".into(),
519            0.1,
520            1000,
521            0,
522            0,
523        );
524        let request = AnalysisRequest {
525            video_files: vec!["video1.mp4".into()],
526            subtitle_files: vec!["subtitle1.srt".into()],
527            content_samples: vec![],
528        };
529
530        let prompt = client.build_analysis_prompt(&request);
531        assert!(prompt.contains("video1.mp4"));
532        assert!(prompt.contains("subtitle1.srt"));
533        assert!(prompt.contains("JSON"));
534
535        let json_response = r#"{ "matches": [], "confidence":0.9, "reasoning":"test reason" }"#;
536        let match_result = client.parse_match_result(json_response).unwrap();
537        assert_eq!(match_result.confidence, 0.9);
538        assert_eq!(match_result.reasoning, "test reason");
539    }
540}