1use crate::Result;
2use crate::cli::display_ai_usage;
3use crate::error::SubXError;
4use crate::services::ai::AiUsageStats;
5use crate::services::ai::{
6 AIProvider, AnalysisRequest, ConfidenceScore, MatchResult, VerificationRequest,
7};
8use async_trait::async_trait;
9use reqwest::Client;
10use serde_json::Value;
11use serde_json::json;
12use std::time::Duration;
13
14use crate::services::ai::hosted_hint::{append_local_hint, maybe_attach_local_hint};
15use crate::services::ai::prompts::{PromptBuilder, ResponseParser};
16use crate::services::ai::retry::HttpRetryClient;
17
18pub struct OpenAIClient {
20 client: Client,
21 api_key: String,
22 model: String,
23 temperature: f32,
24 max_tokens: u32,
25 retry_attempts: u32,
26 retry_delay_ms: u64,
27 base_url: String,
28}
29
30impl std::fmt::Debug for OpenAIClient {
31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 f.debug_struct("OpenAIClient")
33 .field("client", &self.client)
34 .field("api_key", &"[REDACTED]")
35 .field("model", &self.model)
36 .field("temperature", &self.temperature)
37 .field("max_tokens", &self.max_tokens)
38 .field("retry_attempts", &self.retry_attempts)
39 .field("retry_delay_ms", &self.retry_delay_ms)
40 .field("base_url", &self.base_url)
41 .finish()
42 }
43}
44
45impl PromptBuilder for OpenAIClient {}
46impl ResponseParser for OpenAIClient {}
47impl HttpRetryClient for OpenAIClient {
48 fn retry_attempts(&self) -> u32 {
49 self.retry_attempts
50 }
51 fn retry_delay_ms(&self) -> u64 {
52 self.retry_delay_ms
53 }
54}
55
56#[cfg(test)]
58mod tests {
59 use super::*;
60 use mockall::{mock, predicate::eq};
61 use serde_json::json;
62 use wiremock::matchers::{header, method, path};
63 use wiremock::{Mock, MockServer, ResponseTemplate};
64
65 mock! {
66 AIClient {}
67
68 #[async_trait]
69 impl AIProvider for AIClient {
70 async fn analyze_content(&self, request: AnalysisRequest) -> crate::Result<MatchResult>;
71 async fn verify_match(&self, verification: VerificationRequest) -> crate::Result<ConfidenceScore>;
72 }
73 }
74
75 #[tokio::test]
76 async fn test_openai_client_creation() {
77 let client = OpenAIClient::new("test-key".into(), "gpt-4.1-mini".into(), 0.5, 1000, 2, 100);
78 assert_eq!(client.api_key, "test-key");
79 assert_eq!(client.model, "gpt-4.1-mini");
80 assert_eq!(client.temperature, 0.5);
81 assert_eq!(client.max_tokens, 1000);
82 assert_eq!(client.retry_attempts, 2);
83 assert_eq!(client.retry_delay_ms, 100);
84 }
85
86 #[tokio::test]
87 async fn test_chat_completion_success() {
88 let server = MockServer::start().await;
89 Mock::given(method("POST"))
90 .and(path("/chat/completions"))
91 .and(header("authorization", "Bearer test-key"))
92 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
93 "choices": [{"message": {"content": "test response content"}}]
94 })))
95 .mount(&server)
96 .await;
97 let mut client =
98 OpenAIClient::new("test-key".into(), "gpt-4.1-mini".into(), 0.3, 1000, 1, 0);
99 client.base_url = server.uri();
100 let messages = vec![json!({"role":"user","content":"test"})];
101 let resp = client.chat_completion(messages).await.unwrap();
102 assert_eq!(resp, "test response content");
103 }
104
105 #[tokio::test]
106 async fn test_chat_completion_error() {
107 let server = MockServer::start().await;
108 Mock::given(method("POST"))
109 .and(path("/chat/completions"))
110 .respond_with(ResponseTemplate::new(400).set_body_json(json!({
111 "error": {"message":"Invalid API key"}
112 })))
113 .mount(&server)
114 .await;
115 let mut client =
116 OpenAIClient::new("bad-key".into(), "gpt-4.1-mini".into(), 0.3, 1000, 1, 0);
117 client.base_url = server.uri();
118 let messages = vec![json!({"role":"user","content":"test"})];
119 let result = client.chat_completion(messages).await;
120 assert!(result.is_err());
121 }
122
123 #[tokio::test]
124 async fn test_analyze_content_with_mock() {
125 let mut mock = MockAIClient::new();
126 let req = AnalysisRequest {
127 video_files: vec!["v.mp4".into()],
128 subtitle_files: vec!["s.srt".into()],
129 content_samples: vec![],
130 };
131 let expected = MatchResult {
132 matches: vec![],
133 confidence: 0.5,
134 reasoning: "OK".into(),
135 };
136 mock.expect_analyze_content()
137 .with(eq(req.clone()))
138 .times(1)
139 .returning(move |_| Ok(expected.clone()));
140 let res = mock.analyze_content(req.clone()).await.unwrap();
141 assert_eq!(res.confidence, 0.5);
142 }
143
144 #[test]
145 fn test_prompt_building_and_parsing() {
146 let client = OpenAIClient::new("k".into(), "m".into(), 0.1, 1000, 0, 0);
147 let request = AnalysisRequest {
148 video_files: vec!["F1.mp4".into()],
149 subtitle_files: vec!["S1.srt".into()],
150 content_samples: vec![],
151 };
152 let prompt = client.build_analysis_prompt(&request);
153 assert!(prompt.contains("F1.mp4"));
154 assert!(prompt.contains("S1.srt"));
155 assert!(prompt.contains("JSON"));
156 let json_resp = r#"{ "matches": [], "confidence":0.9, "reasoning":"r" }"#;
157 let mr = client.parse_match_result(json_resp).unwrap();
158 assert_eq!(mr.confidence, 0.9);
159 }
160
161 #[test]
162 fn test_openai_client_from_config() {
163 let config = crate::config::AIConfig {
164 provider: "openai".to_string(),
165 api_key: Some("test-key".to_string()),
166 model: "gpt-test".to_string(),
167 base_url: "https://custom.openai.com/v1".to_string(),
168 max_sample_length: 500,
169 temperature: 0.7,
170 max_tokens: 2000,
171 retry_attempts: 2,
172 retry_delay_ms: 150,
173 request_timeout_seconds: 60,
174 api_version: None,
175 };
176 let client = OpenAIClient::from_config(&config).unwrap();
177 assert_eq!(client.api_key, "test-key");
178 assert_eq!(client.model, "gpt-test");
179 assert_eq!(client.temperature, 0.7);
180 assert_eq!(client.max_tokens, 2000);
181 }
182
183 #[test]
184 fn test_openai_client_from_config_invalid_base_url() {
185 let config = crate::config::AIConfig {
186 provider: "openai".to_string(),
187 api_key: Some("test-key".to_string()),
188 model: "gpt-test".to_string(),
189 base_url: "ftp://invalid.url".to_string(),
190 max_sample_length: 500,
191 temperature: 0.7,
192 max_tokens: 1000,
193 retry_attempts: 2,
194 retry_delay_ms: 150,
195 request_timeout_seconds: 30,
196 api_version: None,
197 };
198 let err = OpenAIClient::from_config(&config).unwrap_err();
199 assert!(
201 err.to_string()
202 .contains("Base URL must use http or https protocol")
203 );
204 }
205
206 #[tokio::test]
209 async fn test_hosted_hint_connection_refused_loopback() {
210 let port = pick_unused_port().await;
211 let mut client = OpenAIClient::new("k".into(), "gpt-4.1-mini".into(), 0.0, 16, 0, 0);
212 client.base_url = format!("http://127.0.0.1:{}", port);
213 let err = client
214 .chat_completion(vec![json!({"role":"user","content":"x"})])
215 .await
216 .unwrap_err();
217 let msg = err.to_string();
218 assert!(
219 msg.contains("ollama") && msg.contains("ai.provider"),
220 "expected local-provider hint, got: {msg}"
221 );
222 }
223
224 #[tokio::test]
226 async fn test_hosted_hint_connection_refused_rfc1918() {
227 let client = OpenAIClient::new_with_base_url_and_timeout(
232 "k".into(),
233 "gpt-4.1-mini".into(),
234 0.0,
235 16,
236 0,
237 0,
238 "http://192.168.0.1:1".to_string(),
239 1,
240 );
241 let err = client
242 .chat_completion(vec![json!({"role":"user","content":"x"})])
243 .await
244 .unwrap_err();
245 let msg = err.to_string();
246 assert!(
247 msg.contains("ollama") && msg.contains("ai.provider"),
248 "expected local-provider hint, got: {msg}"
249 );
250 }
251
252 #[tokio::test]
255 async fn test_hosted_hint_http_200_non_openai_body() {
256 let server = MockServer::start().await;
257 Mock::given(method("POST"))
258 .and(path("/chat/completions"))
259 .respond_with(ResponseTemplate::new(200).set_body_json(json!({ "hello": "world" })))
260 .mount(&server)
261 .await;
262 let mut client = OpenAIClient::new("k".into(), "gpt-4.1-mini".into(), 0.0, 16, 0, 0);
263 client.base_url = server.uri();
264 let err = client
265 .chat_completion(vec![json!({"role":"user","content":"x"})])
266 .await
267 .unwrap_err();
268 let msg = err.to_string();
269 assert!(
270 msg.contains("Invalid API response format"),
271 "expected base parse-shape message: {msg}"
272 );
273 assert!(
274 msg.contains("ollama") && msg.contains("ai.provider"),
275 "expected local-provider hint: {msg}"
276 );
277 }
278
279 #[tokio::test]
286 async fn test_hosted_hint_not_emitted_for_public_host() {
287 let client = OpenAIClient::new_with_base_url_and_timeout(
288 "k".into(),
289 "gpt-4.1-mini".into(),
290 0.0,
291 16,
292 0,
293 0,
294 "https://192.0.2.1/v1".to_string(),
295 1,
296 );
297 let err = client
298 .chat_completion(vec![json!({"role":"user","content":"x"})])
299 .await
300 .unwrap_err();
301 let msg = err.to_string();
302 assert!(
303 !msg.contains("ollama"),
304 "public-host failure must NOT carry the local-provider hint: {msg}"
305 );
306 }
307
308 async fn pick_unused_port() -> u16 {
313 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
314 let port = listener.local_addr().unwrap().port();
315 drop(listener);
316 port
317 }
318}
319
320impl OpenAIClient {
321 pub fn new(
323 api_key: String,
324 model: String,
325 temperature: f32,
326 max_tokens: u32,
327 retry_attempts: u32,
328 retry_delay_ms: u64,
329 ) -> Self {
330 Self::new_with_base_url(
331 api_key,
332 model,
333 temperature,
334 max_tokens,
335 retry_attempts,
336 retry_delay_ms,
337 "https://api.openai.com/v1".to_string(),
338 )
339 }
340
341 pub fn new_with_base_url(
343 api_key: String,
344 model: String,
345 temperature: f32,
346 max_tokens: u32,
347 retry_attempts: u32,
348 retry_delay_ms: u64,
349 base_url: String,
350 ) -> Self {
351 Self::new_with_base_url_and_timeout(
353 api_key,
354 model,
355 temperature,
356 max_tokens,
357 retry_attempts,
358 retry_delay_ms,
359 base_url,
360 30,
361 )
362 }
363
364 #[allow(clippy::too_many_arguments)]
366 pub fn new_with_base_url_and_timeout(
367 api_key: String,
368 model: String,
369 temperature: f32,
370 max_tokens: u32,
371 retry_attempts: u32,
372 retry_delay_ms: u64,
373 base_url: String,
374 request_timeout_seconds: u64,
375 ) -> Self {
376 let client = Client::builder()
377 .timeout(Duration::from_secs(request_timeout_seconds))
378 .build()
379 .expect("Failed to create HTTP client");
380 Self {
381 client,
382 api_key,
383 model,
384 temperature,
385 max_tokens,
386 retry_attempts,
387 retry_delay_ms,
388 base_url: base_url.trim_end_matches('/').to_string(),
389 }
390 }
391
392 pub fn from_config(config: &crate::config::AIConfig) -> crate::Result<Self> {
394 let api_key = config
395 .api_key
396 .as_ref()
397 .ok_or_else(|| crate::error::SubXError::config("Missing OpenAI API Key"))?;
398
399 Self::validate_base_url(&config.base_url)?;
401 crate::services::ai::security::warn_on_insecure_http_str(&config.base_url, api_key);
402
403 Ok(Self::new_with_base_url_and_timeout(
404 api_key.clone(),
405 config.model.clone(),
406 config.temperature,
407 config.max_tokens,
408 config.retry_attempts,
409 config.retry_delay_ms,
410 config.base_url.clone(),
411 config.request_timeout_seconds,
412 ))
413 }
414
415 fn validate_base_url(url: &str) -> crate::Result<()> {
417 use url::Url;
418 let parsed = Url::parse(url)
419 .map_err(|e| crate::error::SubXError::config(format!("Invalid base URL: {}", e)))?;
420
421 if !matches!(parsed.scheme(), "http" | "https") {
422 return Err(crate::error::SubXError::config(
423 "Base URL must use http or https protocol".to_string(),
424 ));
425 }
426
427 if parsed.host().is_none() {
428 return Err(crate::error::SubXError::config(
429 "Base URL must contain a valid hostname".to_string(),
430 ));
431 }
432
433 Ok(())
434 }
435
436 pub async fn chat_completion(&self, messages: Vec<serde_json::Value>) -> Result<String> {
438 let request_body = json!({
439 "model": self.model,
440 "messages": messages,
441 "temperature": self.temperature,
442 "max_tokens": self.max_tokens,
443 });
444
445 let request = self
446 .client
447 .post(format!("{}/chat/completions", self.base_url))
448 .header("Authorization", format!("Bearer {}", self.api_key))
449 .header("Content-Type", "application/json")
450 .json(&request_body);
451 let mut response = match self.make_request_with_retry(request).await {
452 Ok(r) => r,
453 Err(e) => return Err(maybe_attach_local_hint(e, &self.base_url)),
454 };
455
456 const MAX_AI_RESPONSE_BYTES: u64 = 10 * 1024 * 1024; if let Some(len) = response.content_length() {
458 if len > MAX_AI_RESPONSE_BYTES {
459 return Err(SubXError::AiService(format!(
460 "AI response too large: {} bytes (limit: {} bytes)",
461 len, MAX_AI_RESPONSE_BYTES
462 )));
463 }
464 }
465
466 if !response.status().is_success() {
467 let status = response.status();
468 let error_text = response.text().await?;
469 let safe_body = crate::services::ai::error_sanitizer::sanitize_url_in_error(
470 &crate::services::ai::error_sanitizer::truncate_error_body(
471 &error_text,
472 crate::services::ai::error_sanitizer::DEFAULT_ERROR_BODY_MAX_LEN,
473 ),
474 );
475 return Err(SubXError::AiService(format!(
476 "OpenAI API error {}: {}",
477 status, safe_body
478 )));
479 }
480
481 let mut body = Vec::new();
484 while let Some(chunk) = response.chunk().await? {
485 body.extend_from_slice(&chunk);
486 if body.len() as u64 > MAX_AI_RESPONSE_BYTES {
487 return Err(SubXError::AiService(format!(
488 "AI response too large: {} bytes read (limit: {} bytes)",
489 body.len(),
490 MAX_AI_RESPONSE_BYTES
491 )));
492 }
493 }
494 let response_json: Value = serde_json::from_slice(&body)
495 .map_err(|e| SubXError::AiService(format!("Failed to parse AI response: {}", e)))?;
496 let content = response_json["choices"][0]["message"]["content"]
497 .as_str()
498 .ok_or_else(|| {
499 SubXError::AiService(append_local_hint("Invalid API response format"))
505 })?;
506
507 if let Some(usage_obj) = response_json.get("usage") {
509 if let (Some(p), Some(c), Some(t)) = (
510 usage_obj.get("prompt_tokens").and_then(Value::as_u64),
511 usage_obj.get("completion_tokens").and_then(Value::as_u64),
512 usage_obj.get("total_tokens").and_then(Value::as_u64),
513 ) {
514 let stats = AiUsageStats {
515 model: self.model.clone(),
516 prompt_tokens: p as u32,
517 completion_tokens: c as u32,
518 total_tokens: t as u32,
519 };
520 display_ai_usage(&stats);
521 }
522 }
523
524 Ok(content.to_string())
525 }
526}
527
528#[async_trait]
529impl AIProvider for OpenAIClient {
530 async fn analyze_content(&self, request: AnalysisRequest) -> Result<MatchResult> {
531 let prompt = self.build_analysis_prompt(&request);
532 let messages = vec![
533 json!({"role": "system", "content": Self::get_analysis_system_message()}),
534 json!({"role": "user", "content": prompt}),
535 ];
536 let response = self.chat_completion(messages).await?;
537 self.parse_match_result(&response)
538 }
539
540 async fn verify_match(&self, verification: VerificationRequest) -> Result<ConfidenceScore> {
541 let prompt = self.build_verification_prompt(&verification);
542 let messages = vec![
543 json!({"role": "system", "content": Self::get_verification_system_message()}),
544 json!({"role": "user", "content": prompt}),
545 ];
546 let response = self.chat_completion(messages).await?;
547 self.parse_confidence_score(&response)
548 }
549
550 async fn chat_completion(&self, messages: Vec<serde_json::Value>) -> Result<String> {
551 OpenAIClient::chat_completion(self, messages).await
552 }
553}