1use crate::cli::display_ai_usage;
2use crate::error::SubXError;
3use crate::services::ai::prompts::{PromptBuilder, ResponseParser};
4use crate::services::ai::retry::HttpRetryClient;
5use crate::services::ai::{
6 AIProvider, AnalysisRequest, ConfidenceScore, MatchResult, VerificationRequest,
7};
8use async_trait::async_trait;
9use reqwest::Client;
10use serde_json::{Value, json};
11use std::time::Duration;
12use tokio::time;
13use url::{ParseError, Url};
14
15#[derive(Debug)]
17pub struct AzureOpenAIClient {
18 client: Client,
19 api_key: String,
20 model: String,
21 base_url: String,
22 api_version: String,
23 temperature: f32,
24 max_tokens: u32,
25 retry_attempts: u32,
26 retry_delay_ms: u64,
27 request_timeout_seconds: u64,
28}
29
30const DEFAULT_AZURE_API_VERSION: &str = "2025-04-01-preview";
31
32impl AzureOpenAIClient {
33 #[allow(clippy::too_many_arguments)]
35 pub fn new_with_all(
36 api_key: String,
37 model: String,
38 base_url: String,
39 api_version: String,
40 temperature: f32,
41 max_tokens: u32,
42 retry_attempts: u32,
43 retry_delay_ms: u64,
44 request_timeout_seconds: u64,
45 ) -> Self {
46 let client = Client::builder()
47 .timeout(Duration::from_secs(request_timeout_seconds))
48 .build()
49 .expect("Failed to create HTTP client");
50 AzureOpenAIClient {
51 client,
52 api_key,
53 model,
54 base_url: base_url.trim_end_matches('/').to_string(),
55 api_version,
56 temperature,
57 max_tokens,
58 retry_attempts,
59 retry_delay_ms,
60 request_timeout_seconds,
61 }
62 }
63
64 pub fn from_config(config: &crate::config::AIConfig) -> crate::Result<Self> {
66 let api_key = config
67 .api_key
68 .as_ref()
69 .filter(|key| !key.trim().is_empty())
70 .ok_or_else(|| SubXError::config("Missing Azure OpenAI API Key".to_string()))?
71 .clone();
72 let deployment_name = config.model.clone();
74 if deployment_name.trim().is_empty() {
75 return Err(SubXError::config(
76 "Missing Azure OpenAI deployment name in model field".to_string(),
77 ));
78 }
79 let api_version = config
80 .api_version
81 .clone()
82 .unwrap_or_else(|| DEFAULT_AZURE_API_VERSION.to_string());
83
84 let parsed = match Url::parse(&config.base_url) {
86 Ok(u) => u,
87 Err(ParseError::EmptyHost) => {
88 return Err(SubXError::config(
89 "Azure OpenAI endpoint missing host".to_string(),
90 ));
91 }
92 Err(e) => {
93 return Err(SubXError::config(format!(
94 "Invalid Azure OpenAI endpoint: {}",
95 e
96 )));
97 }
98 };
99 if !matches!(parsed.scheme(), "http" | "https") {
100 return Err(SubXError::config(
101 "Azure OpenAI endpoint must use http or https".to_string(),
102 ));
103 }
104
105 Ok(Self::new_with_all(
106 api_key,
107 config.model.clone(),
108 config.base_url.clone(),
109 api_version,
110 config.temperature,
111 config.max_tokens,
112 config.retry_attempts,
113 config.retry_delay_ms,
114 config.request_timeout_seconds,
115 ))
116 }
117
118 async fn make_request_with_retry(
119 &self,
120 request: reqwest::RequestBuilder,
121 ) -> reqwest::Result<reqwest::Response> {
122 let mut attempts = 0;
123 loop {
124 match request.try_clone().unwrap().send().await {
125 Ok(resp) => {
126 if attempts > 0 {
127 log::info!("Request succeeded after {} retry attempts", attempts);
128 }
129 return Ok(resp);
130 }
131 Err(e) if (attempts as u32) < self.retry_attempts => {
132 attempts += 1;
133 log::warn!(
134 "Request attempt {} failed: {}. Retrying in {}ms...",
135 attempts,
136 e,
137 self.retry_delay_ms
138 );
139 if e.is_timeout() {
140 log::warn!(
141 "This appears to be a timeout error. Consider increasing 'ai.request_timeout_seconds' in config."
142 );
143 }
144 time::sleep(Duration::from_millis(self.retry_delay_ms)).await;
145 }
146 Err(e) => {
147 log::error!(
148 "Request failed after {} attempts. Final error: {}",
149 attempts + 1,
150 e
151 );
152 if e.is_timeout() {
153 log::error!(
154 "AI service error: Request timed out after multiple attempts. Try increasing 'ai.request_timeout_seconds' configuration."
155 );
156 } else if e.is_connect() {
157 log::error!(
158 "AI service error: Connection failed. Check network connection and Azure OpenAI endpoint settings."
159 );
160 }
161 return Err(e);
162 }
163 }
164 }
165 }
166
167 async fn chat_completion(&self, messages: Vec<Value>) -> crate::Result<String> {
168 let url = format!(
169 "{}/openai/deployments/{}/chat/completions?api-version={}",
170 self.base_url, self.model, self.api_version
171 );
172 let mut req = self
173 .client
174 .post(url)
175 .header("Content-Type", "application/json");
176 if self.api_key.to_lowercase().starts_with("bearer ") {
177 req = req.header("Authorization", self.api_key.clone());
178 } else {
179 req = req.header("api-key", self.api_key.clone());
180 }
181 let body = json!({
182 "messages": messages,
183 "temperature": self.temperature,
184 "max_tokens": self.max_tokens,
185 "stream": false
186 });
187 let request = req.json(&body);
188 let response = self.make_request_with_retry(request).await?;
189 if !response.status().is_success() {
190 let status = response.status();
191 let text = response.text().await?;
192 return Err(SubXError::AiService(format!(
193 "Azure OpenAI API error {}: {}",
194 status, text
195 )));
196 }
197 let resp_json: Value = response.json().await?;
198 if let Some(usage) = resp_json.get("usage") {
199 if let (Some(p), Some(c), Some(t)) = (
200 usage.get("prompt_tokens").and_then(Value::as_u64),
201 usage.get("completion_tokens").and_then(Value::as_u64),
202 usage.get("total_tokens").and_then(Value::as_u64),
203 ) {
204 let model = resp_json
206 .get("model")
207 .and_then(Value::as_str)
208 .unwrap_or(self.model.as_str())
209 .to_string();
210 let stats = crate::services::ai::AiUsageStats {
211 model,
212 prompt_tokens: p as u32,
213 completion_tokens: c as u32,
214 total_tokens: t as u32,
215 };
216 display_ai_usage(&stats);
217 }
218 }
219 let content = resp_json["choices"][0]["message"]["content"]
220 .as_str()
221 .ok_or_else(|| SubXError::AiService("Invalid API response format".to_string()))?;
222 Ok(content.to_string())
223 }
224}
225
226impl PromptBuilder for AzureOpenAIClient {}
227impl ResponseParser for AzureOpenAIClient {}
228impl HttpRetryClient for AzureOpenAIClient {
229 fn retry_attempts(&self) -> u32 {
230 self.retry_attempts
231 }
232
233 fn retry_delay_ms(&self) -> u64 {
234 self.retry_delay_ms
235 }
236}
237
238#[cfg(test)]
239mod tests {
240 use super::*;
241 use crate::config::Config;
242
243 #[test]
244 fn test_azure_openai_from_config_and_url_construction() {
245 let mut config = Config::default();
246 config.ai.provider = "azure-openai".to_string();
247 config.ai.api_key = Some("test-api-key".to_string());
248 config.ai.model = "deployment-name".to_string();
249 config.ai.base_url = "https://example.openai.azure.com".to_string();
250 config.ai.api_version = Some("2025-04-01-preview".to_string());
251
252 let client = AzureOpenAIClient::from_config(&config.ai).unwrap();
253 let url = format!(
254 "{}/openai/deployments/{}/chat/completions?api-version={}",
255 client.base_url, client.model, client.api_version
256 );
257 assert!(url.contains("deployment-name"));
258 }
259
260 #[test]
261 fn test_missing_model_error() {
262 let mut config = Config::default();
263 config.ai.provider = "azure-openai".to_string();
264 config.ai.api_key = Some("test-api-key".to_string());
265 config.ai.model = "".to_string();
266 config.ai.base_url = "https://example.openai.azure.com".to_string();
267
268 let err = AzureOpenAIClient::from_config(&config.ai)
269 .unwrap_err()
270 .to_string();
271 assert!(err.contains("Missing Azure OpenAI deployment name in model field"));
272 }
273
274 #[test]
275 fn test_azure_openai_client_creation_with_defaults() {
276 let mut config = Config::default();
277 config.ai.provider = "azure-openai".to_string();
278 config.ai.api_key = Some("test-api-key".to_string());
279 config.ai.model = "deployment-name".to_string();
280 config.ai.base_url = "https://example.openai.azure.com".to_string();
281 let client = AzureOpenAIClient::from_config(&config.ai).unwrap();
284 assert_eq!(
285 client.api_version,
286 super::DEFAULT_AZURE_API_VERSION.to_string()
287 );
288 }
289
290 #[test]
291 fn test_azure_openai_client_missing_api_key() {
292 let mut config = Config::default();
293 config.ai.provider = "azure-openai".to_string();
294 config.ai.api_key = None;
295 config.ai.model = "deployment-name".to_string();
296 config.ai.base_url = "https://example.openai.azure.com".to_string();
297
298 let result = AzureOpenAIClient::from_config(&config.ai);
299 let err = result.unwrap_err().to_string();
300 assert!(err.contains("Missing Azure OpenAI API Key"));
301 }
302
303 #[test]
304 fn test_azure_openai_client_invalid_base_url() {
305 let mut config = Config::default();
306 config.ai.provider = "azure-openai".to_string();
307 config.ai.api_key = Some("test-api-key".to_string());
308 config.ai.model = "deployment-name".to_string();
309 config.ai.base_url = "invalid-url".to_string();
310
311 let result = AzureOpenAIClient::from_config(&config.ai);
312 let err = result.unwrap_err().to_string();
313 assert!(err.contains("Invalid Azure OpenAI endpoint"));
314 }
315
316 #[test]
317 fn test_azure_openai_client_invalid_url_scheme() {
318 let mut config = Config::default();
319 config.ai.provider = "azure-openai".to_string();
320 config.ai.api_key = Some("test-api-key".to_string());
321 config.ai.model = "deployment-name".to_string();
322 config.ai.base_url = "ftp://example.openai.azure.com".to_string();
323
324 let result = AzureOpenAIClient::from_config(&config.ai);
325 let err = result.unwrap_err().to_string();
326 assert!(err.contains("must use http or https"));
327 }
328
329 #[test]
330 fn test_azure_openai_client_url_without_host() {
331 let mut config = Config::default();
332 config.ai.provider = "azure-openai".to_string();
333 config.ai.api_key = Some("test-api-key".to_string());
334 config.ai.model = "deployment-name".to_string();
335 config.ai.base_url = "https://".to_string();
336
337 let result = AzureOpenAIClient::from_config(&config.ai);
338 let err = result.unwrap_err().to_string();
339 assert!(err.contains("missing host"));
340 }
341
342 #[test]
343 fn test_azure_openai_with_custom_model_and_version() {
344 let mock_model = "custom-model-123";
345 let mock_version = "2023-12-01-preview";
346
347 let mut config = Config::default();
348 config.ai.provider = "azure-openai".to_string();
349 config.ai.api_key = Some("test-api-key".to_string());
350 config.ai.model = mock_model.to_string();
351 config.ai.base_url = "https://custom.openai.azure.com".to_string();
352 config.ai.api_version = Some(mock_version.to_string());
353
354 let client = AzureOpenAIClient::from_config(&config.ai).unwrap();
355 assert_eq!(client.model, mock_model);
356 assert_eq!(client.api_version, mock_version);
357 }
358
359 #[test]
360 fn test_azure_openai_with_trailing_slash_in_url() {
361 let mut config = Config::default();
362 config.ai.provider = "azure-openai".to_string();
363 config.ai.api_key = Some("test-api-key".to_string());
364 config.ai.model = "deployment-name".to_string();
365 config.ai.base_url = "https://example.openai.azure.com/".to_string(); let client = AzureOpenAIClient::from_config(&config.ai).unwrap();
368 assert_eq!(
369 client.base_url,
370 "https://example.openai.azure.com".to_string()
371 );
372 }
373
374 #[test]
375 fn test_azure_openai_with_custom_temperature_and_tokens() {
376 let mut config = Config::default();
377 config.ai.provider = "azure-openai".to_string();
378 config.ai.api_key = Some("test-api-key".to_string());
379 config.ai.model = "deployment-name".to_string();
380 config.ai.base_url = "https://example.openai.azure.com".to_string();
381 config.ai.temperature = 0.8;
382 config.ai.max_tokens = 2000;
383
384 let client = AzureOpenAIClient::from_config(&config.ai).unwrap();
385 assert!((client.temperature - 0.8).abs() < f32::EPSILON);
386 assert_eq!(client.max_tokens, 2000);
387 }
388
389 #[test]
390 fn test_azure_openai_with_custom_retry_and_timeout() {
391 let mut config = Config::default();
392 config.ai.provider = "azure-openai".to_string();
393 config.ai.api_key = Some("test-api-key".to_string());
394 config.ai.model = "deployment-name".to_string();
395 config.ai.base_url = "https://example.openai.azure.com".to_string();
396 config.ai.retry_attempts = 5;
397 config.ai.retry_delay_ms = 2000;
398 config.ai.request_timeout_seconds = 180;
399
400 let client = AzureOpenAIClient::from_config(&config.ai).unwrap();
401 assert_eq!(client.retry_attempts, 5);
402 assert_eq!(client.retry_delay_ms, 2000);
403 assert_eq!(client.request_timeout_seconds, 180);
404 }
405
406 #[test]
407 fn test_azure_openai_new_with_all_parameters() {
408 let client = AzureOpenAIClient::new_with_all(
409 "test-api-key".to_string(),
410 "gpt-test".to_string(),
411 "https://example.openai.azure.com".to_string(),
412 "2025-04-01-preview".to_string(),
413 0.7,
414 4000,
415 3,
416 1000,
417 120,
418 );
419 assert!(format!("{:?}", client).contains("AzureOpenAIClient"));
420 }
421
422 #[test]
423 fn test_azure_openai_error_handling_empty_api_key() {
424 let mut config = Config::default();
425 config.ai.provider = "azure-openai".to_string();
426 config.ai.api_key = Some("".to_string()); config.ai.model = "deployment-name".to_string();
428 config.ai.base_url = "https://example.openai.azure.com".to_string();
429
430 let err = AzureOpenAIClient::from_config(&config.ai)
431 .unwrap_err()
432 .to_string();
433 assert!(err.contains("Missing Azure OpenAI API Key"));
434 }
435}
436
437#[async_trait]
438impl AIProvider for AzureOpenAIClient {
439 async fn analyze_content(&self, request: AnalysisRequest) -> crate::Result<MatchResult> {
440 let prompt = self.build_analysis_prompt(&request);
441 let messages = vec![
442 json!({"role": "system", "content": "You are a professional subtitle matching assistant that can analyze the correspondence between video and subtitle files."}),
443 json!({"role": "user", "content": prompt}),
444 ];
445 let resp = self.chat_completion(messages).await?;
446 self.parse_match_result(&resp)
447 }
448
449 async fn verify_match(
450 &self,
451 verification: VerificationRequest,
452 ) -> crate::Result<ConfidenceScore> {
453 let prompt = self.build_verification_prompt(&verification);
454 let messages = vec![
455 json!({"role": "system", "content": "Please evaluate the confidence level of subtitle matching and provide a score between 0-1."}),
456 json!({"role": "user", "content": prompt}),
457 ];
458 let resp = self.chat_completion(messages).await?;
459 self.parse_confidence_score(&resp)
460 }
461}