1use crate::Result;
2use crate::cli::display_ai_usage;
3use crate::error::SubXError;
4use crate::services::ai::AiUsageStats;
5use crate::services::ai::{
6 AIProvider, AnalysisRequest, ConfidenceScore, MatchResult, VerificationRequest,
7};
8use async_trait::async_trait;
9use reqwest::Client;
10use serde_json::Value;
11use serde_json::json;
12use std::time::Duration;
13
14use crate::services::ai::prompts::{PromptBuilder, ResponseParser};
15use crate::services::ai::retry::HttpRetryClient;
16
17#[derive(Debug)]
19pub struct OpenAIClient {
20 client: Client,
21 api_key: String,
22 model: String,
23 temperature: f32,
24 max_tokens: u32,
25 retry_attempts: u32,
26 retry_delay_ms: u64,
27 base_url: String,
28}
29
30impl PromptBuilder for OpenAIClient {}
31impl ResponseParser for OpenAIClient {}
32impl HttpRetryClient for OpenAIClient {
33 fn retry_attempts(&self) -> u32 {
34 self.retry_attempts
35 }
36 fn retry_delay_ms(&self) -> u64 {
37 self.retry_delay_ms
38 }
39}
40
41#[cfg(test)]
43mod tests {
44 use super::*;
45 use mockall::{mock, predicate::eq};
46 use serde_json::json;
47 use wiremock::matchers::{header, method, path};
48 use wiremock::{Mock, MockServer, ResponseTemplate};
49
50 mock! {
51 AIClient {}
52
53 #[async_trait]
54 impl AIProvider for AIClient {
55 async fn analyze_content(&self, request: AnalysisRequest) -> crate::Result<MatchResult>;
56 async fn verify_match(&self, verification: VerificationRequest) -> crate::Result<ConfidenceScore>;
57 }
58 }
59
60 #[tokio::test]
61 async fn test_openai_client_creation() {
62 let client = OpenAIClient::new("test-key".into(), "gpt-4.1-mini".into(), 0.5, 1000, 2, 100);
63 assert_eq!(client.api_key, "test-key");
64 assert_eq!(client.model, "gpt-4.1-mini");
65 assert_eq!(client.temperature, 0.5);
66 assert_eq!(client.max_tokens, 1000);
67 assert_eq!(client.retry_attempts, 2);
68 assert_eq!(client.retry_delay_ms, 100);
69 }
70
71 #[tokio::test]
72 async fn test_chat_completion_success() {
73 let server = MockServer::start().await;
74 Mock::given(method("POST"))
75 .and(path("/chat/completions"))
76 .and(header("authorization", "Bearer test-key"))
77 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
78 "choices": [{"message": {"content": "test response content"}}]
79 })))
80 .mount(&server)
81 .await;
82 let mut client =
83 OpenAIClient::new("test-key".into(), "gpt-4.1-mini".into(), 0.3, 1000, 1, 0);
84 client.base_url = server.uri();
85 let messages = vec![json!({"role":"user","content":"test"})];
86 let resp = client.chat_completion(messages).await.unwrap();
87 assert_eq!(resp, "test response content");
88 }
89
90 #[tokio::test]
91 async fn test_chat_completion_error() {
92 let server = MockServer::start().await;
93 Mock::given(method("POST"))
94 .and(path("/chat/completions"))
95 .respond_with(ResponseTemplate::new(400).set_body_json(json!({
96 "error": {"message":"Invalid API key"}
97 })))
98 .mount(&server)
99 .await;
100 let mut client =
101 OpenAIClient::new("bad-key".into(), "gpt-4.1-mini".into(), 0.3, 1000, 1, 0);
102 client.base_url = server.uri();
103 let messages = vec![json!({"role":"user","content":"test"})];
104 let result = client.chat_completion(messages).await;
105 assert!(result.is_err());
106 }
107
108 #[tokio::test]
109 async fn test_analyze_content_with_mock() {
110 let mut mock = MockAIClient::new();
111 let req = AnalysisRequest {
112 video_files: vec!["v.mp4".into()],
113 subtitle_files: vec!["s.srt".into()],
114 content_samples: vec![],
115 };
116 let expected = MatchResult {
117 matches: vec![],
118 confidence: 0.5,
119 reasoning: "OK".into(),
120 };
121 mock.expect_analyze_content()
122 .with(eq(req.clone()))
123 .times(1)
124 .returning(move |_| Ok(expected.clone()));
125 let res = mock.analyze_content(req.clone()).await.unwrap();
126 assert_eq!(res.confidence, 0.5);
127 }
128
129 #[test]
130 fn test_prompt_building_and_parsing() {
131 let client = OpenAIClient::new("k".into(), "m".into(), 0.1, 1000, 0, 0);
132 let request = AnalysisRequest {
133 video_files: vec!["F1.mp4".into()],
134 subtitle_files: vec!["S1.srt".into()],
135 content_samples: vec![],
136 };
137 let prompt = client.build_analysis_prompt(&request);
138 assert!(prompt.contains("F1.mp4"));
139 assert!(prompt.contains("S1.srt"));
140 assert!(prompt.contains("JSON"));
141 let json_resp = r#"{ "matches": [], "confidence":0.9, "reasoning":"r" }"#;
142 let mr = client.parse_match_result(json_resp).unwrap();
143 assert_eq!(mr.confidence, 0.9);
144 }
145
146 #[test]
147 fn test_openai_client_from_config() {
148 let config = crate::config::AIConfig {
149 provider: "openai".to_string(),
150 api_key: Some("test-key".to_string()),
151 model: "gpt-test".to_string(),
152 base_url: "https://custom.openai.com/v1".to_string(),
153 max_sample_length: 500,
154 temperature: 0.7,
155 max_tokens: 2000,
156 retry_attempts: 2,
157 retry_delay_ms: 150,
158 request_timeout_seconds: 60,
159 api_version: None,
160 };
161 let client = OpenAIClient::from_config(&config).unwrap();
162 assert_eq!(client.api_key, "test-key");
163 assert_eq!(client.model, "gpt-test");
164 assert_eq!(client.temperature, 0.7);
165 assert_eq!(client.max_tokens, 2000);
166 }
167
168 #[test]
169 fn test_openai_client_from_config_invalid_base_url() {
170 let config = crate::config::AIConfig {
171 provider: "openai".to_string(),
172 api_key: Some("test-key".to_string()),
173 model: "gpt-test".to_string(),
174 base_url: "ftp://invalid.url".to_string(),
175 max_sample_length: 500,
176 temperature: 0.7,
177 max_tokens: 1000,
178 retry_attempts: 2,
179 retry_delay_ms: 150,
180 request_timeout_seconds: 30,
181 api_version: None,
182 };
183 let err = OpenAIClient::from_config(&config).unwrap_err();
184 assert!(
186 err.to_string()
187 .contains("Base URL must use http or https protocol")
188 );
189 }
190}
191
192impl OpenAIClient {
193 pub fn new(
195 api_key: String,
196 model: String,
197 temperature: f32,
198 max_tokens: u32,
199 retry_attempts: u32,
200 retry_delay_ms: u64,
201 ) -> Self {
202 Self::new_with_base_url(
203 api_key,
204 model,
205 temperature,
206 max_tokens,
207 retry_attempts,
208 retry_delay_ms,
209 "https://api.openai.com/v1".to_string(),
210 )
211 }
212
213 pub fn new_with_base_url(
215 api_key: String,
216 model: String,
217 temperature: f32,
218 max_tokens: u32,
219 retry_attempts: u32,
220 retry_delay_ms: u64,
221 base_url: String,
222 ) -> Self {
223 Self::new_with_base_url_and_timeout(
225 api_key,
226 model,
227 temperature,
228 max_tokens,
229 retry_attempts,
230 retry_delay_ms,
231 base_url,
232 30,
233 )
234 }
235
236 #[allow(clippy::too_many_arguments)]
238 pub fn new_with_base_url_and_timeout(
239 api_key: String,
240 model: String,
241 temperature: f32,
242 max_tokens: u32,
243 retry_attempts: u32,
244 retry_delay_ms: u64,
245 base_url: String,
246 request_timeout_seconds: u64,
247 ) -> Self {
248 let client = Client::builder()
249 .timeout(Duration::from_secs(request_timeout_seconds))
250 .build()
251 .expect("Failed to create HTTP client");
252 Self {
253 client,
254 api_key,
255 model,
256 temperature,
257 max_tokens,
258 retry_attempts,
259 retry_delay_ms,
260 base_url: base_url.trim_end_matches('/').to_string(),
261 }
262 }
263
264 pub fn from_config(config: &crate::config::AIConfig) -> crate::Result<Self> {
266 let api_key = config
267 .api_key
268 .as_ref()
269 .ok_or_else(|| crate::error::SubXError::config("Missing OpenAI API Key"))?;
270
271 Self::validate_base_url(&config.base_url)?;
273
274 Ok(Self::new_with_base_url_and_timeout(
275 api_key.clone(),
276 config.model.clone(),
277 config.temperature,
278 config.max_tokens,
279 config.retry_attempts,
280 config.retry_delay_ms,
281 config.base_url.clone(),
282 config.request_timeout_seconds,
283 ))
284 }
285
286 fn validate_base_url(url: &str) -> crate::Result<()> {
288 use url::Url;
289 let parsed = Url::parse(url)
290 .map_err(|e| crate::error::SubXError::config(format!("Invalid base URL: {}", e)))?;
291
292 if !matches!(parsed.scheme(), "http" | "https") {
293 return Err(crate::error::SubXError::config(
294 "Base URL must use http or https protocol".to_string(),
295 ));
296 }
297
298 if parsed.host().is_none() {
299 return Err(crate::error::SubXError::config(
300 "Base URL must contain a valid hostname".to_string(),
301 ));
302 }
303
304 Ok(())
305 }
306
307 async fn chat_completion(&self, messages: Vec<serde_json::Value>) -> Result<String> {
308 let request_body = json!({
309 "model": self.model,
310 "messages": messages,
311 "temperature": self.temperature,
312 "max_tokens": self.max_tokens,
313 });
314
315 let request = self
316 .client
317 .post(format!("{}/chat/completions", self.base_url))
318 .header("Authorization", format!("Bearer {}", self.api_key))
319 .header("Content-Type", "application/json")
320 .json(&request_body);
321 let response = self.make_request_with_retry(request).await?;
322
323 if !response.status().is_success() {
324 let status = response.status();
325 let error_text = response.text().await?;
326 return Err(SubXError::AiService(format!(
327 "OpenAI API error {}: {}",
328 status, error_text
329 )));
330 }
331
332 let response_json: Value = response.json().await?;
333 let content = response_json["choices"][0]["message"]["content"]
334 .as_str()
335 .ok_or_else(|| SubXError::AiService("Invalid API response format".to_string()))?;
336
337 if let Some(usage_obj) = response_json.get("usage") {
339 if let (Some(p), Some(c), Some(t)) = (
340 usage_obj.get("prompt_tokens").and_then(Value::as_u64),
341 usage_obj.get("completion_tokens").and_then(Value::as_u64),
342 usage_obj.get("total_tokens").and_then(Value::as_u64),
343 ) {
344 let stats = AiUsageStats {
345 model: self.model.clone(),
346 prompt_tokens: p as u32,
347 completion_tokens: c as u32,
348 total_tokens: t as u32,
349 };
350 display_ai_usage(&stats);
351 }
352 }
353
354 Ok(content.to_string())
355 }
356}
357
358#[async_trait]
359impl AIProvider for OpenAIClient {
360 async fn analyze_content(&self, request: AnalysisRequest) -> Result<MatchResult> {
361 let prompt = self.build_analysis_prompt(&request);
362 let messages = vec![
363 json!({"role": "system", "content": Self::get_analysis_system_message()}),
364 json!({"role": "user", "content": prompt}),
365 ];
366 let response = self.chat_completion(messages).await?;
367 self.parse_match_result(&response)
368 }
369
370 async fn verify_match(&self, verification: VerificationRequest) -> Result<ConfidenceScore> {
371 let prompt = self.build_verification_prompt(&verification);
372 let messages = vec![
373 json!({"role": "system", "content": Self::get_verification_system_message()}),
374 json!({"role": "user", "content": prompt}),
375 ];
376 let response = self.chat_completion(messages).await?;
377 self.parse_confidence_score(&response)
378 }
379}