subx_cli/services/ai/
openrouter.rs

1use crate::Result;
2use crate::cli::display_ai_usage;
3use crate::error::SubXError;
4use crate::services::ai::AiUsageStats;
5use crate::services::ai::{
6    AIProvider, AnalysisRequest, ConfidenceScore, MatchResult, VerificationRequest,
7};
8use async_trait::async_trait;
9use reqwest::Client;
10use serde_json::{Value, json};
11use std::time::Duration;
12use tokio::time;
13
14/// OpenRouter client implementation
15#[derive(Debug)]
16pub struct OpenRouterClient {
17    client: Client,
18    api_key: String,
19    model: String,
20    temperature: f32,
21    max_tokens: u32,
22    retry_attempts: u32,
23    retry_delay_ms: u64,
24    base_url: String,
25}
26
27impl OpenRouterClient {
28    /// Create new OpenRouterClient with default configuration
29    pub fn new(
30        api_key: String,
31        model: String,
32        temperature: f32,
33        max_tokens: u32,
34        retry_attempts: u32,
35        retry_delay_ms: u64,
36    ) -> Self {
37        Self::new_with_base_url_and_timeout(
38            api_key,
39            model,
40            temperature,
41            max_tokens,
42            retry_attempts,
43            retry_delay_ms,
44            "https://openrouter.ai/api/v1".to_string(),
45            120,
46        )
47    }
48
49    /// Create new OpenRouterClient with custom base URL and timeout
50    #[allow(clippy::too_many_arguments)]
51    pub fn new_with_base_url_and_timeout(
52        api_key: String,
53        model: String,
54        temperature: f32,
55        max_tokens: u32,
56        retry_attempts: u32,
57        retry_delay_ms: u64,
58        base_url: String,
59        request_timeout_seconds: u64,
60    ) -> Self {
61        let client = Client::builder()
62            .timeout(Duration::from_secs(request_timeout_seconds))
63            .build()
64            .expect("Failed to create HTTP client");
65
66        Self {
67            client,
68            api_key,
69            model,
70            temperature,
71            max_tokens,
72            retry_attempts,
73            retry_delay_ms,
74            base_url: base_url.trim_end_matches('/').to_string(),
75        }
76    }
77
78    /// Create client from unified configuration
79    pub fn from_config(config: &crate::config::AIConfig) -> crate::Result<Self> {
80        let api_key = config
81            .api_key
82            .as_ref()
83            .ok_or_else(|| SubXError::config("Missing OpenRouter API Key"))?;
84
85        // Validate base URL format
86        Self::validate_base_url(&config.base_url)?;
87
88        Ok(Self::new_with_base_url_and_timeout(
89            api_key.clone(),
90            config.model.clone(),
91            config.temperature,
92            config.max_tokens,
93            config.retry_attempts,
94            config.retry_delay_ms,
95            config.base_url.clone(),
96            config.request_timeout_seconds,
97        ))
98    }
99
100    /// Validate base URL format
101    fn validate_base_url(url: &str) -> crate::Result<()> {
102        use url::Url;
103        let parsed =
104            Url::parse(url).map_err(|e| SubXError::config(format!("Invalid base URL: {}", e)))?;
105
106        if !matches!(parsed.scheme(), "http" | "https") {
107            return Err(SubXError::config(
108                "Base URL must use http or https protocol".to_string(),
109            ));
110        }
111
112        if parsed.host().is_none() {
113            return Err(SubXError::config(
114                "Base URL must contain a valid hostname".to_string(),
115            ));
116        }
117
118        Ok(())
119    }
120
121    async fn chat_completion(&self, messages: Vec<Value>) -> Result<String> {
122        let request_body = json!({
123            "model": self.model,
124            "messages": messages,
125            "temperature": self.temperature,
126            "max_tokens": self.max_tokens,
127        });
128
129        let request = self
130            .client
131            .post(format!("{}/chat/completions", self.base_url))
132            .header("Authorization", format!("Bearer {}", self.api_key))
133            .header("Content-Type", "application/json")
134            .header("HTTP-Referer", "https://github.com/jim60105/subx-cli")
135            .header("X-Title", "Subx")
136            .json(&request_body);
137
138        let response = self.make_request_with_retry(request).await?;
139
140        if !response.status().is_success() {
141            let status = response.status();
142            let error_text = response.text().await?;
143            return Err(SubXError::AiService(format!(
144                "OpenRouter API error {}: {}",
145                status, error_text
146            )));
147        }
148
149        let response_json: Value = response.json().await?;
150        let content = response_json["choices"][0]["message"]["content"]
151            .as_str()
152            .ok_or_else(|| SubXError::AiService("Invalid API response format".to_string()))?;
153
154        // Parse usage statistics and display
155        if let Some(usage_obj) = response_json.get("usage") {
156            if let (Some(p), Some(c), Some(t)) = (
157                usage_obj.get("prompt_tokens").and_then(Value::as_u64),
158                usage_obj.get("completion_tokens").and_then(Value::as_u64),
159                usage_obj.get("total_tokens").and_then(Value::as_u64),
160            ) {
161                let stats = AiUsageStats {
162                    model: self.model.clone(),
163                    prompt_tokens: p as u32,
164                    completion_tokens: c as u32,
165                    total_tokens: t as u32,
166                };
167                display_ai_usage(&stats);
168            }
169        }
170
171        Ok(content.to_string())
172    }
173
174    async fn make_request_with_retry(
175        &self,
176        request: reqwest::RequestBuilder,
177    ) -> reqwest::Result<reqwest::Response> {
178        let mut attempts = 0;
179        loop {
180            match request.try_clone().unwrap().send().await {
181                Ok(resp) => {
182                    // Retry on server error statuses (5xx) if attempts remain
183                    if resp.status().is_server_error() && (attempts as u32) < self.retry_attempts {
184                        attempts += 1;
185                        log::warn!(
186                            "Request attempt {} failed with status {}. Retrying in {}ms...",
187                            attempts,
188                            resp.status(),
189                            self.retry_delay_ms
190                        );
191                        time::sleep(Duration::from_millis(self.retry_delay_ms)).await;
192                        continue;
193                    }
194                    if attempts > 0 {
195                        log::info!("Request succeeded after {} retry attempts", attempts);
196                    }
197                    return Ok(resp);
198                }
199                Err(e) if (attempts as u32) < self.retry_attempts => {
200                    attempts += 1;
201                    log::warn!(
202                        "Request attempt {} failed: {}. Retrying in {}ms...",
203                        attempts,
204                        e,
205                        self.retry_delay_ms
206                    );
207
208                    if e.is_timeout() {
209                        log::warn!(
210                            "This appears to be a timeout error. If this persists, consider increasing 'ai.request_timeout_seconds' in your configuration."
211                        );
212                    }
213
214                    time::sleep(Duration::from_millis(self.retry_delay_ms)).await;
215                    continue;
216                }
217                Err(e) => {
218                    log::error!(
219                        "Request failed after {} attempts. Final error: {}",
220                        attempts + 1,
221                        e
222                    );
223
224                    if e.is_timeout() {
225                        log::error!(
226                            "AI service error: Request timed out after multiple attempts. \
227                        This usually indicates network connectivity issues or server overload. \
228                        Try increasing 'ai.request_timeout_seconds' configuration. \
229                        Hint: check network connection and API service status"
230                        );
231                    } else if e.is_connect() {
232                        log::error!(
233                            "AI service error: Connection failed. \
234                        Hint: check network connection and API base URL settings"
235                        );
236                    }
237
238                    return Err(e);
239                }
240            }
241        }
242    }
243}
244
245#[async_trait]
246impl AIProvider for OpenRouterClient {
247    async fn analyze_content(&self, request: AnalysisRequest) -> Result<MatchResult> {
248        let prompt = self.build_analysis_prompt(&request);
249        let messages = vec![
250            json!({"role": "system", "content": "You are a professional subtitle matching assistant that can analyze the correspondence between video and subtitle files."}),
251            json!({"role": "user", "content": prompt}),
252        ];
253        let response = self.chat_completion(messages).await?;
254        self.parse_match_result(&response)
255    }
256
257    async fn verify_match(&self, verification: VerificationRequest) -> Result<ConfidenceScore> {
258        let prompt = self.build_verification_prompt(&verification);
259        let messages = vec![
260            json!({"role": "system", "content": "Please evaluate the confidence level of subtitle matching and provide a score between 0-1."}),
261            json!({"role": "user", "content": prompt}),
262        ];
263        let response = self.chat_completion(messages).await?;
264        self.parse_confidence_score(&response)
265    }
266}
267
268// Prompt building and response parsing methods (copied from OpenAIClient)
269impl OpenRouterClient {
270    /// Build content analysis prompt
271    pub fn build_analysis_prompt(&self, request: &AnalysisRequest) -> String {
272        let mut prompt = String::new();
273        prompt.push_str("Please analyze the matching relationship between the following video and subtitle files. Each file has a unique ID that you must use in your response.\n\n");
274
275        prompt.push_str("Video files:\n");
276        for video in &request.video_files {
277            prompt.push_str(&format!("- {}\n", video));
278        }
279
280        prompt.push_str("\nSubtitle files:\n");
281        for subtitle in &request.subtitle_files {
282            prompt.push_str(&format!("- {}\n", subtitle));
283        }
284
285        if !request.content_samples.is_empty() {
286            prompt.push_str("\nSubtitle content preview:\n");
287            for sample in &request.content_samples {
288                prompt.push_str(&format!("File: {}\n", sample.filename));
289                prompt.push_str(&format!("Content: {}\n\n", sample.content_preview));
290            }
291        }
292
293        prompt.push_str(
294            "Please provide matching suggestions based on filename patterns, content similarity, and other factors.\n\
295            Response format must be JSON using the file IDs:\n\
296            {\n\
297              \"matches\": [\n\
298                {\n\
299                  \"video_file_id\": \"file_abc123456789abcd\",\n\
300                  \"subtitle_file_id\": \"file_def456789abcdef0\",\n\
301                  \"confidence\": 0.95,\n\
302                  \"match_factors\": [\"filename_similarity\", \"content_correlation\"]\n\
303                }\n\
304              ],\n\
305              \"confidence\": 0.9,\n\
306              \"reasoning\": \"Explanation for the matching decisions\"\n\
307            }",
308        );
309
310        prompt
311    }
312
313    /// Parse matching results from AI response
314    pub fn parse_match_result(&self, response: &str) -> Result<MatchResult> {
315        let json_start = response.find('{').unwrap_or(0);
316        let json_end = response.rfind('}').map(|i| i + 1).unwrap_or(response.len());
317        let json_str = &response[json_start..json_end];
318
319        serde_json::from_str(json_str)
320            .map_err(|e| SubXError::AiService(format!("AI response parsing failed: {}", e)))
321    }
322
323    /// Build verification prompt
324    pub fn build_verification_prompt(&self, request: &VerificationRequest) -> String {
325        let mut prompt = String::new();
326        prompt.push_str(
327            "Please evaluate the confidence level based on the following matching information:\n",
328        );
329        prompt.push_str(&format!("Video file: {}\n", request.video_file));
330        prompt.push_str(&format!("Subtitle file: {}\n", request.subtitle_file));
331        prompt.push_str("Matching factors:\n");
332        for factor in &request.match_factors {
333            prompt.push_str(&format!("- {}\n", factor));
334        }
335        prompt.push_str(
336            "\nPlease respond in JSON format as follows:\n{\n  \"score\": 0.9,\n  \"factors\": [\"...\"]\n}",
337        );
338        prompt
339    }
340
341    /// Parse confidence score from AI response
342    pub fn parse_confidence_score(&self, response: &str) -> Result<ConfidenceScore> {
343        let json_start = response.find('{').unwrap_or(0);
344        let json_end = response.rfind('}').map(|i| i + 1).unwrap_or(response.len());
345        let json_str = &response[json_start..json_end];
346
347        serde_json::from_str(json_str)
348            .map_err(|e| SubXError::AiService(format!("AI confidence parsing failed: {}", e)))
349    }
350}
351#[cfg(test)]
352mod tests {
353    use super::*;
354    use mockall::mock;
355    use serde_json::json;
356    use wiremock::matchers::{header, method, path};
357    use wiremock::{Mock, MockServer, ResponseTemplate};
358
359    mock! {
360        AIClient {}
361
362        #[async_trait]
363        impl AIProvider for AIClient {
364            async fn analyze_content(&self, request: AnalysisRequest) -> crate::Result<MatchResult>;
365            async fn verify_match(&self, verification: VerificationRequest) -> crate::Result<ConfidenceScore>;
366        }
367    }
368
369    #[tokio::test]
370    async fn test_openrouter_client_creation() {
371        let client = OpenRouterClient::new(
372            "test-key".into(),
373            "deepseek/deepseek-r1-0528:free".into(),
374            0.5,
375            1000,
376            2,
377            100,
378        );
379        assert_eq!(client.api_key, "test-key");
380        assert_eq!(client.model, "deepseek/deepseek-r1-0528:free");
381        assert_eq!(client.temperature, 0.5);
382        assert_eq!(client.max_tokens, 1000);
383        assert_eq!(client.retry_attempts, 2);
384        assert_eq!(client.retry_delay_ms, 100);
385        assert_eq!(client.base_url, "https://openrouter.ai/api/v1");
386    }
387
388    #[tokio::test]
389    async fn test_openrouter_client_creation_with_custom_base_url() {
390        let client = OpenRouterClient::new_with_base_url_and_timeout(
391            "test-key".into(),
392            "deepseek/deepseek-r1-0528:free".into(),
393            0.3,
394            2000,
395            3,
396            200,
397            "https://custom-openrouter.ai/api/v1".into(),
398            60,
399        );
400        assert_eq!(client.base_url, "https://custom-openrouter.ai/api/v1");
401    }
402
403    #[tokio::test]
404    async fn test_chat_completion_success() {
405        let server = MockServer::start().await;
406        Mock::given(method("POST"))
407            .and(path("/chat/completions"))
408            .and(header("authorization", "Bearer test-key"))
409            .and(header(
410                "HTTP-Referer",
411                "https://github.com/jim60105/subx-cli",
412            ))
413            .and(header("X-Title", "Subx"))
414            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
415                "choices": [{"message": {"content": "test response content"}}],
416                "usage": { "prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15 }
417            })))
418            .mount(&server)
419            .await;
420
421        let mut client = OpenRouterClient::new(
422            "test-key".into(),
423            "deepseek/deepseek-r1-0528:free".into(),
424            0.3,
425            1000,
426            1,
427            0,
428        );
429        client.base_url = server.uri();
430
431        let messages = vec![json!({"role":"user","content":"test"})];
432        let resp = client.chat_completion(messages).await.unwrap();
433        assert_eq!(resp, "test response content");
434    }
435
436    #[tokio::test]
437    async fn test_chat_completion_error_handling() {
438        let server = MockServer::start().await;
439        Mock::given(method("POST"))
440            .and(path("/chat/completions"))
441            .respond_with(ResponseTemplate::new(401).set_body_json(json!({
442                "error": {"message":"Invalid API key"}
443            })))
444            .mount(&server)
445            .await;
446
447        let mut client = OpenRouterClient::new(
448            "bad-key".into(),
449            "deepseek/deepseek-r1-0528:free".into(),
450            0.3,
451            1000,
452            1,
453            0,
454        );
455        client.base_url = server.uri();
456
457        let messages = vec![json!({"role":"user","content":"test"})];
458        let result = client.chat_completion(messages).await;
459        assert!(result.is_err());
460        assert!(
461            result
462                .err()
463                .unwrap()
464                .to_string()
465                .contains("OpenRouter API error 401")
466        );
467    }
468
469    #[tokio::test]
470    async fn test_retry_mechanism() {
471        let server = MockServer::start().await;
472
473        // First request fails, second succeeds
474        Mock::given(method("POST"))
475            .and(path("/chat/completions"))
476            .respond_with(ResponseTemplate::new(500))
477            .up_to_n_times(1)
478            .mount(&server)
479            .await;
480
481        Mock::given(method("POST"))
482            .and(path("/chat/completions"))
483            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
484                "choices": [{"message": {"content": "success after retry"}}]
485            })))
486            .mount(&server)
487            .await;
488
489        let mut client = OpenRouterClient::new(
490            "test-key".into(),
491            "deepseek/deepseek-r1-0528:free".into(),
492            0.3,
493            1000,
494            2,  // Allow 2 retries
495            50, // Short delay for testing
496        );
497        client.base_url = server.uri();
498
499        let messages = vec![json!({"role":"user","content":"test"})];
500        let result = client.chat_completion(messages).await.unwrap();
501        assert_eq!(result, "success after retry");
502    }
503
504    #[test]
505    fn test_openrouter_client_from_config() {
506        let config = crate::config::AIConfig {
507            provider: "openrouter".to_string(),
508            api_key: Some("test-key".to_string()),
509            model: "deepseek/deepseek-r1-0528:free".to_string(),
510            base_url: "https://openrouter.ai/api/v1".to_string(),
511            max_sample_length: 500,
512            temperature: 0.7,
513            max_tokens: 2000,
514            retry_attempts: 3,
515            retry_delay_ms: 150,
516            request_timeout_seconds: 120,
517            api_version: None,
518        };
519
520        let client = OpenRouterClient::from_config(&config).unwrap();
521        assert_eq!(client.api_key, "test-key");
522        assert_eq!(client.model, "deepseek/deepseek-r1-0528:free");
523        assert_eq!(client.temperature, 0.7);
524        assert_eq!(client.max_tokens, 2000);
525        assert_eq!(client.retry_attempts, 3);
526        assert_eq!(client.retry_delay_ms, 150);
527    }
528
529    #[test]
530    fn test_openrouter_client_from_config_missing_api_key() {
531        let config = crate::config::AIConfig {
532            provider: "openrouter".to_string(),
533            api_key: None,
534            model: "deepseek/deepseek-r1-0528:free".to_string(),
535            base_url: "https://openrouter.ai/api/v1".to_string(),
536            max_sample_length: 500,
537            temperature: 0.3,
538            max_tokens: 1000,
539            retry_attempts: 2,
540            retry_delay_ms: 100,
541            request_timeout_seconds: 30,
542            api_version: None,
543        };
544
545        let result = OpenRouterClient::from_config(&config);
546        assert!(result.is_err());
547        assert!(
548            result
549                .err()
550                .unwrap()
551                .to_string()
552                .contains("Missing OpenRouter API Key")
553        );
554    }
555
556    #[test]
557    fn test_openrouter_client_from_config_invalid_base_url() {
558        let config = crate::config::AIConfig {
559            provider: "openrouter".to_string(),
560            api_key: Some("test-key".to_string()),
561            model: "deepseek/deepseek-r1-0528:free".to_string(),
562            base_url: "ftp://invalid.url".to_string(),
563            max_sample_length: 500,
564            temperature: 0.3,
565            max_tokens: 1000,
566            retry_attempts: 2,
567            retry_delay_ms: 100,
568            request_timeout_seconds: 30,
569            api_version: None,
570        };
571
572        let result = OpenRouterClient::from_config(&config);
573        assert!(result.is_err());
574        assert!(
575            result
576                .err()
577                .unwrap()
578                .to_string()
579                .contains("must use http or https protocol")
580        );
581    }
582
583    #[test]
584    fn test_prompt_building_and_parsing() {
585        let client = OpenRouterClient::new(
586            "test-key".into(),
587            "deepseek/deepseek-r1-0528:free".into(),
588            0.1,
589            1000,
590            0,
591            0,
592        );
593        let request = AnalysisRequest {
594            video_files: vec!["video1.mp4".into()],
595            subtitle_files: vec!["subtitle1.srt".into()],
596            content_samples: vec![],
597        };
598
599        let prompt = client.build_analysis_prompt(&request);
600        assert!(prompt.contains("video1.mp4"));
601        assert!(prompt.contains("subtitle1.srt"));
602        assert!(prompt.contains("JSON"));
603
604        let json_response = r#"{ "matches": [], "confidence":0.9, "reasoning":"test reason" }"#;
605        let match_result = client.parse_match_result(json_response).unwrap();
606        assert_eq!(match_result.confidence, 0.9);
607        assert_eq!(match_result.reasoning, "test reason");
608    }
609}