1use crate::cli::display_ai_usage;
2use crate::error::SubXError;
3use crate::services::ai::prompts::{PromptBuilder, ResponseParser};
4use crate::services::ai::retry::HttpRetryClient;
5use crate::services::ai::{
6 AIProvider, AnalysisRequest, ConfidenceScore, MatchResult, VerificationRequest,
7};
8use async_trait::async_trait;
9use reqwest::Client;
10use serde_json::{Value, json};
11use std::time::Duration;
12use tokio::time;
13use url::{ParseError, Url};
14
15pub struct AzureOpenAIClient {
17 client: Client,
18 api_key: String,
19 model: String,
20 base_url: String,
21 api_version: String,
22 temperature: f32,
23 max_tokens: u32,
24 retry_attempts: u32,
25 retry_delay_ms: u64,
26 request_timeout_seconds: u64,
27}
28
29impl std::fmt::Debug for AzureOpenAIClient {
30 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31 f.debug_struct("AzureOpenAIClient")
32 .field("client", &self.client)
33 .field("api_key", &"[REDACTED]")
34 .field("model", &self.model)
35 .field("base_url", &self.base_url)
36 .field("api_version", &self.api_version)
37 .field("temperature", &self.temperature)
38 .field("max_tokens", &self.max_tokens)
39 .field("retry_attempts", &self.retry_attempts)
40 .field("retry_delay_ms", &self.retry_delay_ms)
41 .field("request_timeout_seconds", &self.request_timeout_seconds)
42 .finish()
43 }
44}
45
46const DEFAULT_AZURE_API_VERSION: &str = "2025-04-01-preview";
47
48impl AzureOpenAIClient {
49 #[allow(clippy::too_many_arguments)]
51 pub fn new_with_all(
52 api_key: String,
53 model: String,
54 base_url: String,
55 api_version: String,
56 temperature: f32,
57 max_tokens: u32,
58 retry_attempts: u32,
59 retry_delay_ms: u64,
60 request_timeout_seconds: u64,
61 ) -> Self {
62 let client = Client::builder()
63 .timeout(Duration::from_secs(request_timeout_seconds))
64 .build()
65 .expect("Failed to create HTTP client");
66 AzureOpenAIClient {
67 client,
68 api_key,
69 model,
70 base_url: base_url.trim_end_matches('/').to_string(),
71 api_version,
72 temperature,
73 max_tokens,
74 retry_attempts,
75 retry_delay_ms,
76 request_timeout_seconds,
77 }
78 }
79
80 pub fn from_config(config: &crate::config::AIConfig) -> crate::Result<Self> {
82 let api_key = config
83 .api_key
84 .as_ref()
85 .filter(|key| !key.trim().is_empty())
86 .ok_or_else(|| SubXError::config("Missing Azure OpenAI API Key".to_string()))?
87 .clone();
88 let deployment_name = config.model.clone();
90 if deployment_name.trim().is_empty() {
91 return Err(SubXError::config(
92 "Missing Azure OpenAI deployment name in model field".to_string(),
93 ));
94 }
95 let api_version = config
96 .api_version
97 .clone()
98 .unwrap_or_else(|| DEFAULT_AZURE_API_VERSION.to_string());
99
100 let parsed = match Url::parse(&config.base_url) {
102 Ok(u) => u,
103 Err(ParseError::EmptyHost) => {
104 return Err(SubXError::config(
105 "Azure OpenAI endpoint missing host".to_string(),
106 ));
107 }
108 Err(e) => {
109 return Err(SubXError::config(format!(
110 "Invalid Azure OpenAI endpoint: {}",
111 e
112 )));
113 }
114 };
115 if !matches!(parsed.scheme(), "http" | "https") {
116 return Err(SubXError::config(
117 "Azure OpenAI endpoint must use http or https".to_string(),
118 ));
119 }
120 crate::services::ai::security::warn_on_insecure_http(&parsed, &api_key);
121
122 Ok(Self::new_with_all(
123 api_key,
124 config.model.clone(),
125 config.base_url.clone(),
126 api_version,
127 config.temperature,
128 config.max_tokens,
129 config.retry_attempts,
130 config.retry_delay_ms,
131 config.request_timeout_seconds,
132 ))
133 }
134
135 async fn make_request_with_retry(
136 &self,
137 request: reqwest::RequestBuilder,
138 ) -> crate::Result<reqwest::Response> {
139 let mut attempts = 0;
140 loop {
141 let cloned = request.try_clone().ok_or_else(|| {
142 crate::error::SubXError::AiService(
143 "Request body cannot be cloned for retry".to_string(),
144 )
145 })?;
146 match cloned.send().await {
147 Ok(resp) => {
148 if attempts > 0 {
149 log::info!("Request succeeded after {} retry attempts", attempts);
150 }
151 return Ok(resp);
152 }
153 Err(e) if (attempts as u32) < self.retry_attempts => {
154 attempts += 1;
155 log::warn!(
156 "Request attempt {} failed: {}. Retrying in {}ms...",
157 attempts,
158 e,
159 self.retry_delay_ms
160 );
161 if e.is_timeout() {
162 log::warn!(
163 "This appears to be a timeout error. Consider increasing 'ai.request_timeout_seconds' in config."
164 );
165 }
166 time::sleep(Duration::from_millis(self.retry_delay_ms)).await;
167 }
168 Err(e) => {
169 log::error!(
170 "Request failed after {} attempts. Final error: {}",
171 attempts + 1,
172 e
173 );
174 if e.is_timeout() {
175 log::error!(
176 "AI service error: Request timed out after multiple attempts. Try increasing 'ai.request_timeout_seconds' configuration."
177 );
178 } else if e.is_connect() {
179 log::error!(
180 "AI service error: Connection failed. Check network connection and Azure OpenAI endpoint settings."
181 );
182 }
183 return Err(e.into());
184 }
185 }
186 }
187 }
188
189 async fn chat_completion(&self, messages: Vec<Value>) -> crate::Result<String> {
190 let url = format!(
191 "{}/openai/deployments/{}/chat/completions?api-version={}",
192 self.base_url, self.model, self.api_version
193 );
194 let mut req = self
195 .client
196 .post(url)
197 .header("Content-Type", "application/json");
198 if self.api_key.to_lowercase().starts_with("bearer ") {
199 req = req.header("Authorization", self.api_key.clone());
200 } else {
201 req = req.header("api-key", self.api_key.clone());
202 }
203 let body = json!({
204 "messages": messages,
205 "temperature": self.temperature,
206 "max_tokens": self.max_tokens,
207 "stream": false
208 });
209 let request = req.json(&body);
210 let mut response = self.make_request_with_retry(request).await?;
211
212 const MAX_AI_RESPONSE_BYTES: u64 = 10 * 1024 * 1024; if let Some(len) = response.content_length() {
214 if len > MAX_AI_RESPONSE_BYTES {
215 return Err(SubXError::AiService(format!(
216 "AI response too large: {} bytes (limit: {} bytes)",
217 len, MAX_AI_RESPONSE_BYTES
218 )));
219 }
220 }
221
222 if !response.status().is_success() {
223 let status = response.status();
224 let text = response.text().await?;
225 let safe_body = crate::services::ai::error_sanitizer::sanitize_url_in_error(
226 &crate::services::ai::error_sanitizer::truncate_error_body(
227 &text,
228 crate::services::ai::error_sanitizer::DEFAULT_ERROR_BODY_MAX_LEN,
229 ),
230 );
231 return Err(SubXError::AiService(format!(
232 "Azure OpenAI API error {}: {}",
233 status, safe_body
234 )));
235 }
236 let mut body = Vec::new();
239 while let Some(chunk) = response.chunk().await? {
240 body.extend_from_slice(&chunk);
241 if body.len() as u64 > MAX_AI_RESPONSE_BYTES {
242 return Err(SubXError::AiService(format!(
243 "AI response too large: {} bytes read (limit: {} bytes)",
244 body.len(),
245 MAX_AI_RESPONSE_BYTES
246 )));
247 }
248 }
249 let resp_json: Value = serde_json::from_slice(&body)
250 .map_err(|e| SubXError::AiService(format!("Failed to parse AI response: {}", e)))?;
251 if let Some(usage) = resp_json.get("usage") {
252 if let (Some(p), Some(c), Some(t)) = (
253 usage.get("prompt_tokens").and_then(Value::as_u64),
254 usage.get("completion_tokens").and_then(Value::as_u64),
255 usage.get("total_tokens").and_then(Value::as_u64),
256 ) {
257 let model = resp_json
259 .get("model")
260 .and_then(Value::as_str)
261 .unwrap_or(self.model.as_str())
262 .to_string();
263 let stats = crate::services::ai::AiUsageStats {
264 model,
265 prompt_tokens: p as u32,
266 completion_tokens: c as u32,
267 total_tokens: t as u32,
268 };
269 display_ai_usage(&stats);
270 }
271 }
272 let content = resp_json["choices"][0]["message"]["content"]
273 .as_str()
274 .ok_or_else(|| SubXError::AiService("Invalid API response format".to_string()))?;
275 Ok(content.to_string())
276 }
277}
278
279impl PromptBuilder for AzureOpenAIClient {}
280impl ResponseParser for AzureOpenAIClient {}
281impl HttpRetryClient for AzureOpenAIClient {
282 fn retry_attempts(&self) -> u32 {
283 self.retry_attempts
284 }
285
286 fn retry_delay_ms(&self) -> u64 {
287 self.retry_delay_ms
288 }
289}
290
291#[cfg(test)]
292mod tests {
293 use super::*;
294 use crate::config::Config;
295
296 #[test]
297 fn test_azure_openai_from_config_and_url_construction() {
298 let mut config = Config::default();
299 config.ai.provider = "azure-openai".to_string();
300 config.ai.api_key = Some("test-api-key".to_string());
301 config.ai.model = "deployment-name".to_string();
302 config.ai.base_url = "https://example.openai.azure.com".to_string();
303 config.ai.api_version = Some("2025-04-01-preview".to_string());
304
305 let client = AzureOpenAIClient::from_config(&config.ai).unwrap();
306 let url = format!(
307 "{}/openai/deployments/{}/chat/completions?api-version={}",
308 client.base_url, client.model, client.api_version
309 );
310 assert!(url.contains("deployment-name"));
311 }
312
313 #[test]
314 fn test_missing_model_error() {
315 let mut config = Config::default();
316 config.ai.provider = "azure-openai".to_string();
317 config.ai.api_key = Some("test-api-key".to_string());
318 config.ai.model = "".to_string();
319 config.ai.base_url = "https://example.openai.azure.com".to_string();
320
321 let err = AzureOpenAIClient::from_config(&config.ai)
322 .unwrap_err()
323 .to_string();
324 assert!(err.contains("Missing Azure OpenAI deployment name in model field"));
325 }
326
327 #[test]
328 fn test_azure_openai_client_creation_with_defaults() {
329 let mut config = Config::default();
330 config.ai.provider = "azure-openai".to_string();
331 config.ai.api_key = Some("test-api-key".to_string());
332 config.ai.model = "deployment-name".to_string();
333 config.ai.base_url = "https://example.openai.azure.com".to_string();
334 let client = AzureOpenAIClient::from_config(&config.ai).unwrap();
337 assert_eq!(
338 client.api_version,
339 super::DEFAULT_AZURE_API_VERSION.to_string()
340 );
341 }
342
343 #[test]
344 fn test_azure_openai_client_missing_api_key() {
345 let mut config = Config::default();
346 config.ai.provider = "azure-openai".to_string();
347 config.ai.api_key = None;
348 config.ai.model = "deployment-name".to_string();
349 config.ai.base_url = "https://example.openai.azure.com".to_string();
350
351 let result = AzureOpenAIClient::from_config(&config.ai);
352 let err = result.unwrap_err().to_string();
353 assert!(err.contains("Missing Azure OpenAI API Key"));
354 }
355
356 #[test]
357 fn test_azure_openai_client_invalid_base_url() {
358 let mut config = Config::default();
359 config.ai.provider = "azure-openai".to_string();
360 config.ai.api_key = Some("test-api-key".to_string());
361 config.ai.model = "deployment-name".to_string();
362 config.ai.base_url = "invalid-url".to_string();
363
364 let result = AzureOpenAIClient::from_config(&config.ai);
365 let err = result.unwrap_err().to_string();
366 assert!(err.contains("Invalid Azure OpenAI endpoint"));
367 }
368
369 #[test]
370 fn test_azure_openai_client_invalid_url_scheme() {
371 let mut config = Config::default();
372 config.ai.provider = "azure-openai".to_string();
373 config.ai.api_key = Some("test-api-key".to_string());
374 config.ai.model = "deployment-name".to_string();
375 config.ai.base_url = "ftp://example.openai.azure.com".to_string();
376
377 let result = AzureOpenAIClient::from_config(&config.ai);
378 let err = result.unwrap_err().to_string();
379 assert!(err.contains("must use http or https"));
380 }
381
382 #[test]
383 fn test_azure_openai_client_url_without_host() {
384 let mut config = Config::default();
385 config.ai.provider = "azure-openai".to_string();
386 config.ai.api_key = Some("test-api-key".to_string());
387 config.ai.model = "deployment-name".to_string();
388 config.ai.base_url = "https://".to_string();
389
390 let result = AzureOpenAIClient::from_config(&config.ai);
391 let err = result.unwrap_err().to_string();
392 assert!(err.contains("missing host"));
393 }
394
395 #[test]
396 fn test_azure_openai_with_custom_model_and_version() {
397 let mock_model = "custom-model-123";
398 let mock_version = "2023-12-01-preview";
399
400 let mut config = Config::default();
401 config.ai.provider = "azure-openai".to_string();
402 config.ai.api_key = Some("test-api-key".to_string());
403 config.ai.model = mock_model.to_string();
404 config.ai.base_url = "https://custom.openai.azure.com".to_string();
405 config.ai.api_version = Some(mock_version.to_string());
406
407 let client = AzureOpenAIClient::from_config(&config.ai).unwrap();
408 assert_eq!(client.model, mock_model);
409 assert_eq!(client.api_version, mock_version);
410 }
411
412 #[test]
413 fn test_azure_openai_with_trailing_slash_in_url() {
414 let mut config = Config::default();
415 config.ai.provider = "azure-openai".to_string();
416 config.ai.api_key = Some("test-api-key".to_string());
417 config.ai.model = "deployment-name".to_string();
418 config.ai.base_url = "https://example.openai.azure.com/".to_string(); let client = AzureOpenAIClient::from_config(&config.ai).unwrap();
421 assert_eq!(
422 client.base_url,
423 "https://example.openai.azure.com".to_string()
424 );
425 }
426
427 #[test]
428 fn test_azure_openai_with_custom_temperature_and_tokens() {
429 let mut config = Config::default();
430 config.ai.provider = "azure-openai".to_string();
431 config.ai.api_key = Some("test-api-key".to_string());
432 config.ai.model = "deployment-name".to_string();
433 config.ai.base_url = "https://example.openai.azure.com".to_string();
434 config.ai.temperature = 0.8;
435 config.ai.max_tokens = 2000;
436
437 let client = AzureOpenAIClient::from_config(&config.ai).unwrap();
438 assert!((client.temperature - 0.8).abs() < f32::EPSILON);
439 assert_eq!(client.max_tokens, 2000);
440 }
441
442 #[test]
443 fn test_azure_openai_with_custom_retry_and_timeout() {
444 let mut config = Config::default();
445 config.ai.provider = "azure-openai".to_string();
446 config.ai.api_key = Some("test-api-key".to_string());
447 config.ai.model = "deployment-name".to_string();
448 config.ai.base_url = "https://example.openai.azure.com".to_string();
449 config.ai.retry_attempts = 5;
450 config.ai.retry_delay_ms = 2000;
451 config.ai.request_timeout_seconds = 180;
452
453 let client = AzureOpenAIClient::from_config(&config.ai).unwrap();
454 assert_eq!(client.retry_attempts, 5);
455 assert_eq!(client.retry_delay_ms, 2000);
456 assert_eq!(client.request_timeout_seconds, 180);
457 }
458
459 #[test]
460 fn test_azure_openai_new_with_all_parameters() {
461 let client = AzureOpenAIClient::new_with_all(
462 "test-api-key".to_string(),
463 "gpt-test".to_string(),
464 "https://example.openai.azure.com".to_string(),
465 "2025-04-01-preview".to_string(),
466 0.7,
467 4000,
468 3,
469 1000,
470 120,
471 );
472 assert!(format!("{:?}", client).contains("AzureOpenAIClient"));
473 }
474
475 #[test]
476 fn test_azure_openai_error_handling_empty_api_key() {
477 let mut config = Config::default();
478 config.ai.provider = "azure-openai".to_string();
479 config.ai.api_key = Some("".to_string()); config.ai.model = "deployment-name".to_string();
481 config.ai.base_url = "https://example.openai.azure.com".to_string();
482
483 let err = AzureOpenAIClient::from_config(&config.ai)
484 .unwrap_err()
485 .to_string();
486 assert!(err.contains("Missing Azure OpenAI API Key"));
487 }
488}
489
490#[async_trait]
491impl AIProvider for AzureOpenAIClient {
492 async fn analyze_content(&self, request: AnalysisRequest) -> crate::Result<MatchResult> {
493 let prompt = self.build_analysis_prompt(&request);
494 let messages = vec![
495 json!({"role": "system", "content": "You are a professional subtitle matching assistant that can analyze the correspondence between video and subtitle files."}),
496 json!({"role": "user", "content": prompt}),
497 ];
498 let resp = self.chat_completion(messages).await?;
499 self.parse_match_result(&resp)
500 }
501
502 async fn verify_match(
503 &self,
504 verification: VerificationRequest,
505 ) -> crate::Result<ConfidenceScore> {
506 let prompt = self.build_verification_prompt(&verification);
507 let messages = vec![
508 json!({"role": "system", "content": "Please evaluate the confidence level of subtitle matching and provide a score between 0-1."}),
509 json!({"role": "user", "content": prompt}),
510 ];
511 let resp = self.chat_completion(messages).await?;
512 self.parse_confidence_score(&resp)
513 }
514}