subx_cli/services/ai/
openai.rs

1use crate::Result;
2use crate::cli::display_ai_usage;
3use crate::error::SubXError;
4use crate::services::ai::AiUsageStats;
5use crate::services::ai::{
6    AIProvider, AnalysisRequest, ConfidenceScore, MatchResult, VerificationRequest,
7};
8use async_trait::async_trait;
9use reqwest::Client;
10use serde_json::Value;
11use serde_json::json;
12use std::time::Duration;
13use tokio::time;
14
15/// OpenAI 客戶端實作
16/// OpenAI 客戶端實作
17#[derive(Debug)]
18pub struct OpenAIClient {
19    client: Client,
20    api_key: String,
21    model: String,
22    temperature: f32,
23    retry_attempts: u32,
24    retry_delay_ms: u64,
25    base_url: String,
26}
27
28// 模擬測試: OpenAIClient 與 AIProvider 介面
29#[cfg(test)]
30mod tests {
31    use super::*;
32    use mockall::{mock, predicate::eq};
33    use serde_json::json;
34    use wiremock::matchers::{header, method, path};
35    use wiremock::{Mock, MockServer, ResponseTemplate};
36
37    mock! {
38        AIClient {}
39
40        #[async_trait]
41        impl AIProvider for AIClient {
42            async fn analyze_content(&self, request: AnalysisRequest) -> crate::Result<MatchResult>;
43            async fn verify_match(&self, verification: VerificationRequest) -> crate::Result<ConfidenceScore>;
44        }
45    }
46
47    #[tokio::test]
48    async fn test_openai_client_creation() {
49        let client = OpenAIClient::new("test-key".into(), "gpt-4o-mini".into(), 0.5, 2, 100);
50        assert_eq!(client.api_key, "test-key");
51        assert_eq!(client.model, "gpt-4o-mini");
52        assert_eq!(client.temperature, 0.5);
53        assert_eq!(client.retry_attempts, 2);
54        assert_eq!(client.retry_delay_ms, 100);
55    }
56
57    #[tokio::test]
58    async fn test_chat_completion_success() {
59        let server = MockServer::start().await;
60        Mock::given(method("POST"))
61            .and(path("/chat/completions"))
62            .and(header("authorization", "Bearer test-key"))
63            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
64                "choices": [{"message": {"content": "測試回應內容"}}]
65            })))
66            .mount(&server)
67            .await;
68        let mut client = OpenAIClient::new("test-key".into(), "gpt-4o-mini".into(), 0.3, 1, 0);
69        client.base_url = server.uri();
70        let messages = vec![json!({"role":"user","content":"測試"})];
71        let resp = client.chat_completion(messages).await.unwrap();
72        assert_eq!(resp, "測試回應內容");
73    }
74
75    #[tokio::test]
76    async fn test_chat_completion_error() {
77        let server = MockServer::start().await;
78        Mock::given(method("POST"))
79            .and(path("/chat/completions"))
80            .respond_with(ResponseTemplate::new(400).set_body_json(json!({
81                "error": {"message":"Invalid API key"}
82            })))
83            .mount(&server)
84            .await;
85        let mut client = OpenAIClient::new("bad-key".into(), "gpt-4o-mini".into(), 0.3, 1, 0);
86        client.base_url = server.uri();
87        let messages = vec![json!({"role":"user","content":"測試"})];
88        let result = client.chat_completion(messages).await;
89        assert!(result.is_err());
90    }
91
92    #[tokio::test]
93    async fn test_analyze_content_with_mock() {
94        let mut mock = MockAIClient::new();
95        let req = AnalysisRequest {
96            video_files: vec!["v.mp4".into()],
97            subtitle_files: vec!["s.srt".into()],
98            content_samples: vec![],
99        };
100        let expected = MatchResult {
101            matches: vec![],
102            confidence: 0.5,
103            reasoning: "OK".into(),
104        };
105        mock.expect_analyze_content()
106            .with(eq(req.clone()))
107            .times(1)
108            .returning(move |_| Ok(expected.clone()));
109        let res = mock.analyze_content(req.clone()).await.unwrap();
110        assert_eq!(res.confidence, 0.5);
111    }
112
113    #[test]
114    fn test_prompt_building_and_parsing() {
115        let client = OpenAIClient::new("k".into(), "m".into(), 0.1, 0, 0);
116        let request = AnalysisRequest {
117            video_files: vec!["F1.mp4".into()],
118            subtitle_files: vec!["S1.srt".into()],
119            content_samples: vec![],
120        };
121        let prompt = client.build_analysis_prompt(&request);
122        assert!(prompt.contains("F1.mp4"));
123        assert!(prompt.contains("S1.srt"));
124        assert!(prompt.contains("JSON"));
125        let json_resp = r#"{ "matches": [], "confidence":0.9, "reasoning":"r" }"#;
126        let mr = client.parse_match_result(json_resp).unwrap();
127        assert_eq!(mr.confidence, 0.9);
128    }
129
130    #[test]
131    fn test_openai_client_from_config() {
132        let config = crate::config::AIConfig {
133            provider: "openai".to_string(),
134            api_key: Some("test-key".to_string()),
135            model: "gpt-test".to_string(),
136            base_url: "https://custom.openai.com/v1".to_string(),
137            temperature: 0.7,
138            retry_attempts: 2,
139            retry_delay_ms: 150,
140            max_sample_length: 500,
141        };
142        let client = OpenAIClient::from_config(&config).unwrap();
143        assert_eq!(client.api_key, "test-key");
144        assert_eq!(client.model, "gpt-test");
145        assert_eq!(client.temperature, 0.7);
146        assert_eq!(client.base_url, "https://custom.openai.com/v1");
147    }
148
149    #[test]
150    fn test_openai_client_from_config_invalid_base_url() {
151        let config = crate::config::AIConfig {
152            provider: "openai".to_string(),
153            api_key: Some("test-key".to_string()),
154            model: "gpt-test".to_string(),
155            base_url: "ftp://invalid.url".to_string(),
156            temperature: 0.7,
157            retry_attempts: 2,
158            retry_delay_ms: 150,
159            max_sample_length: 500,
160        };
161        let err = OpenAIClient::from_config(&config).unwrap_err();
162        // 非 http/https 協定應返回協定錯誤訊息
163        assert!(
164            err.to_string()
165                .contains("base URL 必須使用 http 或 https 協定")
166        );
167    }
168}
169
170impl OpenAIClient {
171    /// 建立新的 OpenAIClient (使用預設 base_url)
172    pub fn new(
173        api_key: String,
174        model: String,
175        temperature: f32,
176        retry_attempts: u32,
177        retry_delay_ms: u64,
178    ) -> Self {
179        Self::new_with_base_url(
180            api_key,
181            model,
182            temperature,
183            retry_attempts,
184            retry_delay_ms,
185            "https://api.openai.com/v1".to_string(),
186        )
187    }
188
189    /// 建立新的 OpenAIClient,支援自訂 base_url
190    pub fn new_with_base_url(
191        api_key: String,
192        model: String,
193        temperature: f32,
194        retry_attempts: u32,
195        retry_delay_ms: u64,
196        base_url: String,
197    ) -> Self {
198        let client = Client::builder()
199            .timeout(Duration::from_secs(30))
200            .build()
201            .expect("建立 HTTP 客戶端失敗");
202        Self {
203            client,
204            api_key,
205            model,
206            temperature,
207            retry_attempts,
208            retry_delay_ms,
209            base_url: base_url.trim_end_matches('/').to_string(),
210        }
211    }
212
213    /// 從統一配置建立客戶端
214    pub fn from_config(config: &crate::config::AIConfig) -> crate::Result<Self> {
215        let api_key = config
216            .api_key
217            .as_ref()
218            .ok_or_else(|| crate::error::SubXError::config("缺少 OpenAI API Key"))?;
219
220        // 驗證 base URL 格式
221        Self::validate_base_url(&config.base_url)?;
222
223        Ok(Self::new_with_base_url(
224            api_key.clone(),
225            config.model.clone(),
226            config.temperature,
227            config.retry_attempts,
228            config.retry_delay_ms,
229            config.base_url.clone(),
230        ))
231    }
232
233    /// 驗證 base URL 格式
234    fn validate_base_url(url: &str) -> crate::Result<()> {
235        use url::Url;
236        let parsed = Url::parse(url)
237            .map_err(|e| crate::error::SubXError::config(format!("無效的 base URL: {}", e)))?;
238
239        if !matches!(parsed.scheme(), "http" | "https") {
240            return Err(crate::error::SubXError::config(
241                "base URL 必須使用 http 或 https 協定".to_string(),
242            ));
243        }
244
245        if parsed.host().is_none() {
246            return Err(crate::error::SubXError::config(
247                "base URL 必須包含有效的主機名稱".to_string(),
248            ));
249        }
250
251        Ok(())
252    }
253
254    async fn chat_completion(&self, messages: Vec<serde_json::Value>) -> Result<String> {
255        let request_body = json!({
256            "model": self.model,
257            "messages": messages,
258            "temperature": self.temperature,
259            "max_tokens": 1000,
260        });
261
262        let request = self
263            .client
264            .post(format!("{}/chat/completions", self.base_url))
265            .header("Authorization", format!("Bearer {}", self.api_key))
266            .header("Content-Type", "application/json")
267            .json(&request_body);
268        let response = self.make_request_with_retry(request).await?;
269
270        if !response.status().is_success() {
271            let status = response.status();
272            let error_text = response.text().await?;
273            return Err(SubXError::AiService(format!(
274                "OpenAI API 錯誤 {}: {}",
275                status, error_text
276            )));
277        }
278
279        let response_json: Value = response.json().await?;
280        let content = response_json["choices"][0]["message"]["content"]
281            .as_str()
282            .ok_or_else(|| SubXError::AiService("無效的 API 回應格式".to_string()))?;
283
284        // 解析使用統計並顯示
285        if let Some(usage_obj) = response_json.get("usage") {
286            if let (Some(p), Some(c), Some(t)) = (
287                usage_obj.get("prompt_tokens").and_then(Value::as_u64),
288                usage_obj.get("completion_tokens").and_then(Value::as_u64),
289                usage_obj.get("total_tokens").and_then(Value::as_u64),
290            ) {
291                let stats = AiUsageStats {
292                    model: self.model.clone(),
293                    prompt_tokens: p as u32,
294                    completion_tokens: c as u32,
295                    total_tokens: t as u32,
296                };
297                display_ai_usage(&stats);
298            }
299        }
300
301        Ok(content.to_string())
302    }
303}
304
305#[async_trait]
306impl AIProvider for OpenAIClient {
307    async fn analyze_content(&self, request: AnalysisRequest) -> Result<MatchResult> {
308        let prompt = self.build_analysis_prompt(&request);
309        let messages = vec![
310            json!({"role": "system", "content": "你是一個專業的字幕匹配助手,能夠分析影片和字幕檔案的對應關係。"}),
311            json!({"role": "user", "content": prompt}),
312        ];
313        let response = self.chat_completion(messages).await?;
314        self.parse_match_result(&response)
315    }
316
317    async fn verify_match(&self, verification: VerificationRequest) -> Result<ConfidenceScore> {
318        let prompt = self.build_verification_prompt(&verification);
319        let messages = vec![
320            json!({"role": "system", "content": "請評估字幕匹配的信心度,提供 0-1 之間的分數。"}),
321            json!({"role": "user", "content": prompt}),
322        ];
323        let response = self.chat_completion(messages).await?;
324        self.parse_confidence_score(&response)
325    }
326}
327
328impl OpenAIClient {
329    async fn make_request_with_retry(
330        &self,
331        request: reqwest::RequestBuilder,
332    ) -> reqwest::Result<reqwest::Response> {
333        let mut attempts = 0;
334        loop {
335            match request.try_clone().unwrap().send().await {
336                Ok(resp) => return Ok(resp),
337                Err(_e) if (attempts as u32) < self.retry_attempts => {
338                    attempts += 1;
339                    time::sleep(Duration::from_millis(self.retry_delay_ms)).await;
340                    continue;
341                }
342                Err(e) => return Err(e),
343            }
344        }
345    }
346}