1use crate::Result;
2use crate::cli::display_ai_usage;
3use crate::error::SubXError;
4use crate::services::ai::AiUsageStats;
5use crate::services::ai::{
6 AIProvider, AnalysisRequest, ConfidenceScore, MatchResult, VerificationRequest,
7};
8use async_trait::async_trait;
9use reqwest::Client;
10use serde_json::Value;
11use serde_json::json;
12use std::time::Duration;
13use tokio::time;
14
15#[derive(Debug)]
18pub struct OpenAIClient {
19 client: Client,
20 api_key: String,
21 model: String,
22 temperature: f32,
23 retry_attempts: u32,
24 retry_delay_ms: u64,
25 base_url: String,
26}
27
28#[cfg(test)]
30mod tests {
31 use super::*;
32 use mockall::{mock, predicate::eq};
33 use serde_json::json;
34 use wiremock::matchers::{header, method, path};
35 use wiremock::{Mock, MockServer, ResponseTemplate};
36
37 mock! {
38 AIClient {}
39
40 #[async_trait]
41 impl AIProvider for AIClient {
42 async fn analyze_content(&self, request: AnalysisRequest) -> crate::Result<MatchResult>;
43 async fn verify_match(&self, verification: VerificationRequest) -> crate::Result<ConfidenceScore>;
44 }
45 }
46
47 #[tokio::test]
48 async fn test_openai_client_creation() {
49 let client = OpenAIClient::new("test-key".into(), "gpt-4o-mini".into(), 0.5, 2, 100);
50 assert_eq!(client.api_key, "test-key");
51 assert_eq!(client.model, "gpt-4o-mini");
52 assert_eq!(client.temperature, 0.5);
53 assert_eq!(client.retry_attempts, 2);
54 assert_eq!(client.retry_delay_ms, 100);
55 }
56
57 #[tokio::test]
58 async fn test_chat_completion_success() {
59 let server = MockServer::start().await;
60 Mock::given(method("POST"))
61 .and(path("/chat/completions"))
62 .and(header("authorization", "Bearer test-key"))
63 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
64 "choices": [{"message": {"content": "測試回應內容"}}]
65 })))
66 .mount(&server)
67 .await;
68 let mut client = OpenAIClient::new("test-key".into(), "gpt-4o-mini".into(), 0.3, 1, 0);
69 client.base_url = server.uri();
70 let messages = vec![json!({"role":"user","content":"測試"})];
71 let resp = client.chat_completion(messages).await.unwrap();
72 assert_eq!(resp, "測試回應內容");
73 }
74
75 #[tokio::test]
76 async fn test_chat_completion_error() {
77 let server = MockServer::start().await;
78 Mock::given(method("POST"))
79 .and(path("/chat/completions"))
80 .respond_with(ResponseTemplate::new(400).set_body_json(json!({
81 "error": {"message":"Invalid API key"}
82 })))
83 .mount(&server)
84 .await;
85 let mut client = OpenAIClient::new("bad-key".into(), "gpt-4o-mini".into(), 0.3, 1, 0);
86 client.base_url = server.uri();
87 let messages = vec![json!({"role":"user","content":"測試"})];
88 let result = client.chat_completion(messages).await;
89 assert!(result.is_err());
90 }
91
92 #[tokio::test]
93 async fn test_analyze_content_with_mock() {
94 let mut mock = MockAIClient::new();
95 let req = AnalysisRequest {
96 video_files: vec!["v.mp4".into()],
97 subtitle_files: vec!["s.srt".into()],
98 content_samples: vec![],
99 };
100 let expected = MatchResult {
101 matches: vec![],
102 confidence: 0.5,
103 reasoning: "OK".into(),
104 };
105 mock.expect_analyze_content()
106 .with(eq(req.clone()))
107 .times(1)
108 .returning(move |_| Ok(expected.clone()));
109 let res = mock.analyze_content(req.clone()).await.unwrap();
110 assert_eq!(res.confidence, 0.5);
111 }
112
113 #[test]
114 fn test_prompt_building_and_parsing() {
115 let client = OpenAIClient::new("k".into(), "m".into(), 0.1, 0, 0);
116 let request = AnalysisRequest {
117 video_files: vec!["F1.mp4".into()],
118 subtitle_files: vec!["S1.srt".into()],
119 content_samples: vec![],
120 };
121 let prompt = client.build_analysis_prompt(&request);
122 assert!(prompt.contains("F1.mp4"));
123 assert!(prompt.contains("S1.srt"));
124 assert!(prompt.contains("JSON"));
125 let json_resp = r#"{ "matches": [], "confidence":0.9, "reasoning":"r" }"#;
126 let mr = client.parse_match_result(json_resp).unwrap();
127 assert_eq!(mr.confidence, 0.9);
128 }
129
130 #[test]
131 fn test_openai_client_from_config() {
132 let config = crate::config::AIConfig {
133 provider: "openai".to_string(),
134 api_key: Some("test-key".to_string()),
135 model: "gpt-test".to_string(),
136 base_url: "https://custom.openai.com/v1".to_string(),
137 temperature: 0.7,
138 retry_attempts: 2,
139 retry_delay_ms: 150,
140 max_sample_length: 500,
141 };
142 let client = OpenAIClient::from_config(&config).unwrap();
143 assert_eq!(client.api_key, "test-key");
144 assert_eq!(client.model, "gpt-test");
145 assert_eq!(client.temperature, 0.7);
146 assert_eq!(client.base_url, "https://custom.openai.com/v1");
147 }
148
149 #[test]
150 fn test_openai_client_from_config_invalid_base_url() {
151 let config = crate::config::AIConfig {
152 provider: "openai".to_string(),
153 api_key: Some("test-key".to_string()),
154 model: "gpt-test".to_string(),
155 base_url: "ftp://invalid.url".to_string(),
156 temperature: 0.7,
157 retry_attempts: 2,
158 retry_delay_ms: 150,
159 max_sample_length: 500,
160 };
161 let err = OpenAIClient::from_config(&config).unwrap_err();
162 assert!(
164 err.to_string()
165 .contains("base URL 必須使用 http 或 https 協定")
166 );
167 }
168}
169
170impl OpenAIClient {
171 pub fn new(
173 api_key: String,
174 model: String,
175 temperature: f32,
176 retry_attempts: u32,
177 retry_delay_ms: u64,
178 ) -> Self {
179 Self::new_with_base_url(
180 api_key,
181 model,
182 temperature,
183 retry_attempts,
184 retry_delay_ms,
185 "https://api.openai.com/v1".to_string(),
186 )
187 }
188
189 pub fn new_with_base_url(
191 api_key: String,
192 model: String,
193 temperature: f32,
194 retry_attempts: u32,
195 retry_delay_ms: u64,
196 base_url: String,
197 ) -> Self {
198 let client = Client::builder()
199 .timeout(Duration::from_secs(30))
200 .build()
201 .expect("建立 HTTP 客戶端失敗");
202 Self {
203 client,
204 api_key,
205 model,
206 temperature,
207 retry_attempts,
208 retry_delay_ms,
209 base_url: base_url.trim_end_matches('/').to_string(),
210 }
211 }
212
213 pub fn from_config(config: &crate::config::AIConfig) -> crate::Result<Self> {
215 let api_key = config
216 .api_key
217 .as_ref()
218 .ok_or_else(|| crate::error::SubXError::config("缺少 OpenAI API Key"))?;
219
220 Self::validate_base_url(&config.base_url)?;
222
223 Ok(Self::new_with_base_url(
224 api_key.clone(),
225 config.model.clone(),
226 config.temperature,
227 config.retry_attempts,
228 config.retry_delay_ms,
229 config.base_url.clone(),
230 ))
231 }
232
233 fn validate_base_url(url: &str) -> crate::Result<()> {
235 use url::Url;
236 let parsed = Url::parse(url)
237 .map_err(|e| crate::error::SubXError::config(format!("無效的 base URL: {}", e)))?;
238
239 if !matches!(parsed.scheme(), "http" | "https") {
240 return Err(crate::error::SubXError::config(
241 "base URL 必須使用 http 或 https 協定".to_string(),
242 ));
243 }
244
245 if parsed.host().is_none() {
246 return Err(crate::error::SubXError::config(
247 "base URL 必須包含有效的主機名稱".to_string(),
248 ));
249 }
250
251 Ok(())
252 }
253
254 async fn chat_completion(&self, messages: Vec<serde_json::Value>) -> Result<String> {
255 let request_body = json!({
256 "model": self.model,
257 "messages": messages,
258 "temperature": self.temperature,
259 "max_tokens": 1000,
260 });
261
262 let request = self
263 .client
264 .post(format!("{}/chat/completions", self.base_url))
265 .header("Authorization", format!("Bearer {}", self.api_key))
266 .header("Content-Type", "application/json")
267 .json(&request_body);
268 let response = self.make_request_with_retry(request).await?;
269
270 if !response.status().is_success() {
271 let status = response.status();
272 let error_text = response.text().await?;
273 return Err(SubXError::AiService(format!(
274 "OpenAI API 錯誤 {}: {}",
275 status, error_text
276 )));
277 }
278
279 let response_json: Value = response.json().await?;
280 let content = response_json["choices"][0]["message"]["content"]
281 .as_str()
282 .ok_or_else(|| SubXError::AiService("無效的 API 回應格式".to_string()))?;
283
284 if let Some(usage_obj) = response_json.get("usage") {
286 if let (Some(p), Some(c), Some(t)) = (
287 usage_obj.get("prompt_tokens").and_then(Value::as_u64),
288 usage_obj.get("completion_tokens").and_then(Value::as_u64),
289 usage_obj.get("total_tokens").and_then(Value::as_u64),
290 ) {
291 let stats = AiUsageStats {
292 model: self.model.clone(),
293 prompt_tokens: p as u32,
294 completion_tokens: c as u32,
295 total_tokens: t as u32,
296 };
297 display_ai_usage(&stats);
298 }
299 }
300
301 Ok(content.to_string())
302 }
303}
304
305#[async_trait]
306impl AIProvider for OpenAIClient {
307 async fn analyze_content(&self, request: AnalysisRequest) -> Result<MatchResult> {
308 let prompt = self.build_analysis_prompt(&request);
309 let messages = vec![
310 json!({"role": "system", "content": "你是一個專業的字幕匹配助手,能夠分析影片和字幕檔案的對應關係。"}),
311 json!({"role": "user", "content": prompt}),
312 ];
313 let response = self.chat_completion(messages).await?;
314 self.parse_match_result(&response)
315 }
316
317 async fn verify_match(&self, verification: VerificationRequest) -> Result<ConfidenceScore> {
318 let prompt = self.build_verification_prompt(&verification);
319 let messages = vec![
320 json!({"role": "system", "content": "請評估字幕匹配的信心度,提供 0-1 之間的分數。"}),
321 json!({"role": "user", "content": prompt}),
322 ];
323 let response = self.chat_completion(messages).await?;
324 self.parse_confidence_score(&response)
325 }
326}
327
328impl OpenAIClient {
329 async fn make_request_with_retry(
330 &self,
331 request: reqwest::RequestBuilder,
332 ) -> reqwest::Result<reqwest::Response> {
333 let mut attempts = 0;
334 loop {
335 match request.try_clone().unwrap().send().await {
336 Ok(resp) => return Ok(resp),
337 Err(_e) if (attempts as u32) < self.retry_attempts => {
338 attempts += 1;
339 time::sleep(Duration::from_millis(self.retry_delay_ms)).await;
340 continue;
341 }
342 Err(e) => return Err(e),
343 }
344 }
345 }
346}