1use crate::Result;
2use crate::cli::display_ai_usage;
3use crate::error::SubXError;
4use crate::services::ai::AiUsageStats;
5use crate::services::ai::{
6 AIProvider, AnalysisRequest, ConfidenceScore, MatchResult, VerificationRequest,
7};
8use async_trait::async_trait;
9use reqwest::Client;
10use serde_json::Value;
11use serde_json::json;
12use std::time::Duration;
13
14use crate::services::ai::prompts::{PromptBuilder, ResponseParser};
15use crate::services::ai::retry::HttpRetryClient;
16
17pub struct OpenAIClient {
19 client: Client,
20 api_key: String,
21 model: String,
22 temperature: f32,
23 max_tokens: u32,
24 retry_attempts: u32,
25 retry_delay_ms: u64,
26 base_url: String,
27}
28
29impl std::fmt::Debug for OpenAIClient {
30 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31 f.debug_struct("OpenAIClient")
32 .field("client", &self.client)
33 .field("api_key", &"[REDACTED]")
34 .field("model", &self.model)
35 .field("temperature", &self.temperature)
36 .field("max_tokens", &self.max_tokens)
37 .field("retry_attempts", &self.retry_attempts)
38 .field("retry_delay_ms", &self.retry_delay_ms)
39 .field("base_url", &self.base_url)
40 .finish()
41 }
42}
43
44impl PromptBuilder for OpenAIClient {}
45impl ResponseParser for OpenAIClient {}
46impl HttpRetryClient for OpenAIClient {
47 fn retry_attempts(&self) -> u32 {
48 self.retry_attempts
49 }
50 fn retry_delay_ms(&self) -> u64 {
51 self.retry_delay_ms
52 }
53}
54
55#[cfg(test)]
57mod tests {
58 use super::*;
59 use mockall::{mock, predicate::eq};
60 use serde_json::json;
61 use wiremock::matchers::{header, method, path};
62 use wiremock::{Mock, MockServer, ResponseTemplate};
63
64 mock! {
65 AIClient {}
66
67 #[async_trait]
68 impl AIProvider for AIClient {
69 async fn analyze_content(&self, request: AnalysisRequest) -> crate::Result<MatchResult>;
70 async fn verify_match(&self, verification: VerificationRequest) -> crate::Result<ConfidenceScore>;
71 }
72 }
73
74 #[tokio::test]
75 async fn test_openai_client_creation() {
76 let client = OpenAIClient::new("test-key".into(), "gpt-4.1-mini".into(), 0.5, 1000, 2, 100);
77 assert_eq!(client.api_key, "test-key");
78 assert_eq!(client.model, "gpt-4.1-mini");
79 assert_eq!(client.temperature, 0.5);
80 assert_eq!(client.max_tokens, 1000);
81 assert_eq!(client.retry_attempts, 2);
82 assert_eq!(client.retry_delay_ms, 100);
83 }
84
85 #[tokio::test]
86 async fn test_chat_completion_success() {
87 let server = MockServer::start().await;
88 Mock::given(method("POST"))
89 .and(path("/chat/completions"))
90 .and(header("authorization", "Bearer test-key"))
91 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
92 "choices": [{"message": {"content": "test response content"}}]
93 })))
94 .mount(&server)
95 .await;
96 let mut client =
97 OpenAIClient::new("test-key".into(), "gpt-4.1-mini".into(), 0.3, 1000, 1, 0);
98 client.base_url = server.uri();
99 let messages = vec![json!({"role":"user","content":"test"})];
100 let resp = client.chat_completion(messages).await.unwrap();
101 assert_eq!(resp, "test response content");
102 }
103
104 #[tokio::test]
105 async fn test_chat_completion_error() {
106 let server = MockServer::start().await;
107 Mock::given(method("POST"))
108 .and(path("/chat/completions"))
109 .respond_with(ResponseTemplate::new(400).set_body_json(json!({
110 "error": {"message":"Invalid API key"}
111 })))
112 .mount(&server)
113 .await;
114 let mut client =
115 OpenAIClient::new("bad-key".into(), "gpt-4.1-mini".into(), 0.3, 1000, 1, 0);
116 client.base_url = server.uri();
117 let messages = vec![json!({"role":"user","content":"test"})];
118 let result = client.chat_completion(messages).await;
119 assert!(result.is_err());
120 }
121
122 #[tokio::test]
123 async fn test_analyze_content_with_mock() {
124 let mut mock = MockAIClient::new();
125 let req = AnalysisRequest {
126 video_files: vec!["v.mp4".into()],
127 subtitle_files: vec!["s.srt".into()],
128 content_samples: vec![],
129 };
130 let expected = MatchResult {
131 matches: vec![],
132 confidence: 0.5,
133 reasoning: "OK".into(),
134 };
135 mock.expect_analyze_content()
136 .with(eq(req.clone()))
137 .times(1)
138 .returning(move |_| Ok(expected.clone()));
139 let res = mock.analyze_content(req.clone()).await.unwrap();
140 assert_eq!(res.confidence, 0.5);
141 }
142
143 #[test]
144 fn test_prompt_building_and_parsing() {
145 let client = OpenAIClient::new("k".into(), "m".into(), 0.1, 1000, 0, 0);
146 let request = AnalysisRequest {
147 video_files: vec!["F1.mp4".into()],
148 subtitle_files: vec!["S1.srt".into()],
149 content_samples: vec![],
150 };
151 let prompt = client.build_analysis_prompt(&request);
152 assert!(prompt.contains("F1.mp4"));
153 assert!(prompt.contains("S1.srt"));
154 assert!(prompt.contains("JSON"));
155 let json_resp = r#"{ "matches": [], "confidence":0.9, "reasoning":"r" }"#;
156 let mr = client.parse_match_result(json_resp).unwrap();
157 assert_eq!(mr.confidence, 0.9);
158 }
159
160 #[test]
161 fn test_openai_client_from_config() {
162 let config = crate::config::AIConfig {
163 provider: "openai".to_string(),
164 api_key: Some("test-key".to_string()),
165 model: "gpt-test".to_string(),
166 base_url: "https://custom.openai.com/v1".to_string(),
167 max_sample_length: 500,
168 temperature: 0.7,
169 max_tokens: 2000,
170 retry_attempts: 2,
171 retry_delay_ms: 150,
172 request_timeout_seconds: 60,
173 api_version: None,
174 };
175 let client = OpenAIClient::from_config(&config).unwrap();
176 assert_eq!(client.api_key, "test-key");
177 assert_eq!(client.model, "gpt-test");
178 assert_eq!(client.temperature, 0.7);
179 assert_eq!(client.max_tokens, 2000);
180 }
181
182 #[test]
183 fn test_openai_client_from_config_invalid_base_url() {
184 let config = crate::config::AIConfig {
185 provider: "openai".to_string(),
186 api_key: Some("test-key".to_string()),
187 model: "gpt-test".to_string(),
188 base_url: "ftp://invalid.url".to_string(),
189 max_sample_length: 500,
190 temperature: 0.7,
191 max_tokens: 1000,
192 retry_attempts: 2,
193 retry_delay_ms: 150,
194 request_timeout_seconds: 30,
195 api_version: None,
196 };
197 let err = OpenAIClient::from_config(&config).unwrap_err();
198 assert!(
200 err.to_string()
201 .contains("Base URL must use http or https protocol")
202 );
203 }
204}
205
206impl OpenAIClient {
207 pub fn new(
209 api_key: String,
210 model: String,
211 temperature: f32,
212 max_tokens: u32,
213 retry_attempts: u32,
214 retry_delay_ms: u64,
215 ) -> Self {
216 Self::new_with_base_url(
217 api_key,
218 model,
219 temperature,
220 max_tokens,
221 retry_attempts,
222 retry_delay_ms,
223 "https://api.openai.com/v1".to_string(),
224 )
225 }
226
227 pub fn new_with_base_url(
229 api_key: String,
230 model: String,
231 temperature: f32,
232 max_tokens: u32,
233 retry_attempts: u32,
234 retry_delay_ms: u64,
235 base_url: String,
236 ) -> Self {
237 Self::new_with_base_url_and_timeout(
239 api_key,
240 model,
241 temperature,
242 max_tokens,
243 retry_attempts,
244 retry_delay_ms,
245 base_url,
246 30,
247 )
248 }
249
250 #[allow(clippy::too_many_arguments)]
252 pub fn new_with_base_url_and_timeout(
253 api_key: String,
254 model: String,
255 temperature: f32,
256 max_tokens: u32,
257 retry_attempts: u32,
258 retry_delay_ms: u64,
259 base_url: String,
260 request_timeout_seconds: u64,
261 ) -> Self {
262 let client = Client::builder()
263 .timeout(Duration::from_secs(request_timeout_seconds))
264 .build()
265 .expect("Failed to create HTTP client");
266 Self {
267 client,
268 api_key,
269 model,
270 temperature,
271 max_tokens,
272 retry_attempts,
273 retry_delay_ms,
274 base_url: base_url.trim_end_matches('/').to_string(),
275 }
276 }
277
278 pub fn from_config(config: &crate::config::AIConfig) -> crate::Result<Self> {
280 let api_key = config
281 .api_key
282 .as_ref()
283 .ok_or_else(|| crate::error::SubXError::config("Missing OpenAI API Key"))?;
284
285 Self::validate_base_url(&config.base_url)?;
287 crate::services::ai::security::warn_on_insecure_http_str(&config.base_url, api_key);
288
289 Ok(Self::new_with_base_url_and_timeout(
290 api_key.clone(),
291 config.model.clone(),
292 config.temperature,
293 config.max_tokens,
294 config.retry_attempts,
295 config.retry_delay_ms,
296 config.base_url.clone(),
297 config.request_timeout_seconds,
298 ))
299 }
300
301 fn validate_base_url(url: &str) -> crate::Result<()> {
303 use url::Url;
304 let parsed = Url::parse(url)
305 .map_err(|e| crate::error::SubXError::config(format!("Invalid base URL: {}", e)))?;
306
307 if !matches!(parsed.scheme(), "http" | "https") {
308 return Err(crate::error::SubXError::config(
309 "Base URL must use http or https protocol".to_string(),
310 ));
311 }
312
313 if parsed.host().is_none() {
314 return Err(crate::error::SubXError::config(
315 "Base URL must contain a valid hostname".to_string(),
316 ));
317 }
318
319 Ok(())
320 }
321
322 async fn chat_completion(&self, messages: Vec<serde_json::Value>) -> Result<String> {
323 let request_body = json!({
324 "model": self.model,
325 "messages": messages,
326 "temperature": self.temperature,
327 "max_tokens": self.max_tokens,
328 });
329
330 let request = self
331 .client
332 .post(format!("{}/chat/completions", self.base_url))
333 .header("Authorization", format!("Bearer {}", self.api_key))
334 .header("Content-Type", "application/json")
335 .json(&request_body);
336 let mut response = self.make_request_with_retry(request).await?;
337
338 const MAX_AI_RESPONSE_BYTES: u64 = 10 * 1024 * 1024; if let Some(len) = response.content_length() {
340 if len > MAX_AI_RESPONSE_BYTES {
341 return Err(SubXError::AiService(format!(
342 "AI response too large: {} bytes (limit: {} bytes)",
343 len, MAX_AI_RESPONSE_BYTES
344 )));
345 }
346 }
347
348 if !response.status().is_success() {
349 let status = response.status();
350 let error_text = response.text().await?;
351 let safe_body = crate::services::ai::error_sanitizer::sanitize_url_in_error(
352 &crate::services::ai::error_sanitizer::truncate_error_body(
353 &error_text,
354 crate::services::ai::error_sanitizer::DEFAULT_ERROR_BODY_MAX_LEN,
355 ),
356 );
357 return Err(SubXError::AiService(format!(
358 "OpenAI API error {}: {}",
359 status, safe_body
360 )));
361 }
362
363 let mut body = Vec::new();
366 while let Some(chunk) = response.chunk().await? {
367 body.extend_from_slice(&chunk);
368 if body.len() as u64 > MAX_AI_RESPONSE_BYTES {
369 return Err(SubXError::AiService(format!(
370 "AI response too large: {} bytes read (limit: {} bytes)",
371 body.len(),
372 MAX_AI_RESPONSE_BYTES
373 )));
374 }
375 }
376 let response_json: Value = serde_json::from_slice(&body)
377 .map_err(|e| SubXError::AiService(format!("Failed to parse AI response: {}", e)))?;
378 let content = response_json["choices"][0]["message"]["content"]
379 .as_str()
380 .ok_or_else(|| SubXError::AiService("Invalid API response format".to_string()))?;
381
382 if let Some(usage_obj) = response_json.get("usage") {
384 if let (Some(p), Some(c), Some(t)) = (
385 usage_obj.get("prompt_tokens").and_then(Value::as_u64),
386 usage_obj.get("completion_tokens").and_then(Value::as_u64),
387 usage_obj.get("total_tokens").and_then(Value::as_u64),
388 ) {
389 let stats = AiUsageStats {
390 model: self.model.clone(),
391 prompt_tokens: p as u32,
392 completion_tokens: c as u32,
393 total_tokens: t as u32,
394 };
395 display_ai_usage(&stats);
396 }
397 }
398
399 Ok(content.to_string())
400 }
401}
402
403#[async_trait]
404impl AIProvider for OpenAIClient {
405 async fn analyze_content(&self, request: AnalysisRequest) -> Result<MatchResult> {
406 let prompt = self.build_analysis_prompt(&request);
407 let messages = vec![
408 json!({"role": "system", "content": Self::get_analysis_system_message()}),
409 json!({"role": "user", "content": prompt}),
410 ];
411 let response = self.chat_completion(messages).await?;
412 self.parse_match_result(&response)
413 }
414
415 async fn verify_match(&self, verification: VerificationRequest) -> Result<ConfidenceScore> {
416 let prompt = self.build_verification_prompt(&verification);
417 let messages = vec![
418 json!({"role": "system", "content": Self::get_verification_system_message()}),
419 json!({"role": "user", "content": prompt}),
420 ];
421 let response = self.chat_completion(messages).await?;
422 self.parse_confidence_score(&response)
423 }
424}