1use crate::cli::display_ai_usage;
2use crate::error::SubXError;
3use crate::services::ai::prompts::{
4 build_analysis_prompt_base, build_verification_prompt_base, parse_confidence_score_base,
5 parse_match_result_base,
6};
7use crate::services::ai::{
8 AIProvider, AnalysisRequest, ConfidenceScore, MatchResult, VerificationRequest,
9};
10use async_trait::async_trait;
11use reqwest::Client;
12use serde_json::{Value, json};
13use std::time::Duration;
14use tokio::time;
15use url::{ParseError, Url};
16
17#[derive(Debug)]
19pub struct AzureOpenAIClient {
20 client: Client,
21 api_key: String,
22 model: String,
23 base_url: String,
24 api_version: String,
25 temperature: f32,
26 max_tokens: u32,
27 retry_attempts: u32,
28 retry_delay_ms: u64,
29 request_timeout_seconds: u64,
30}
31
32const DEFAULT_AZURE_API_VERSION: &str = "2025-04-01-preview";
33
34impl AzureOpenAIClient {
35 #[allow(clippy::too_many_arguments)]
37 pub fn new_with_all(
38 api_key: String,
39 model: String,
40 base_url: String,
41 api_version: String,
42 temperature: f32,
43 max_tokens: u32,
44 retry_attempts: u32,
45 retry_delay_ms: u64,
46 request_timeout_seconds: u64,
47 ) -> Self {
48 let client = Client::builder()
49 .timeout(Duration::from_secs(request_timeout_seconds))
50 .build()
51 .expect("Failed to create HTTP client");
52 AzureOpenAIClient {
53 client,
54 api_key,
55 model,
56 base_url: base_url.trim_end_matches('/').to_string(),
57 api_version,
58 temperature,
59 max_tokens,
60 retry_attempts,
61 retry_delay_ms,
62 request_timeout_seconds,
63 }
64 }
65
66 pub fn from_config(config: &crate::config::AIConfig) -> crate::Result<Self> {
68 let api_key = config
69 .api_key
70 .as_ref()
71 .filter(|key| !key.trim().is_empty())
72 .ok_or_else(|| SubXError::config("Missing Azure OpenAI API Key".to_string()))?
73 .clone();
74 let deployment_name = config.model.clone();
76 if deployment_name.trim().is_empty() {
77 return Err(SubXError::config(
78 "Missing Azure OpenAI deployment name in model field".to_string(),
79 ));
80 }
81 let api_version = config
82 .api_version
83 .clone()
84 .unwrap_or_else(|| DEFAULT_AZURE_API_VERSION.to_string());
85
86 let parsed = match Url::parse(&config.base_url) {
88 Ok(u) => u,
89 Err(ParseError::EmptyHost) => {
90 return Err(SubXError::config(
91 "Azure OpenAI endpoint missing host".to_string(),
92 ));
93 }
94 Err(e) => {
95 return Err(SubXError::config(format!(
96 "Invalid Azure OpenAI endpoint: {}",
97 e
98 )));
99 }
100 };
101 if !matches!(parsed.scheme(), "http" | "https") {
102 return Err(SubXError::config(
103 "Azure OpenAI endpoint must use http or https".to_string(),
104 ));
105 }
106
107 Ok(Self::new_with_all(
108 api_key,
109 config.model.clone(),
110 config.base_url.clone(),
111 api_version,
112 config.temperature,
113 config.max_tokens,
114 config.retry_attempts,
115 config.retry_delay_ms,
116 config.request_timeout_seconds,
117 ))
118 }
119
120 async fn make_request_with_retry(
121 &self,
122 request: reqwest::RequestBuilder,
123 ) -> reqwest::Result<reqwest::Response> {
124 let mut attempts = 0;
125 loop {
126 match request.try_clone().unwrap().send().await {
127 Ok(resp) => {
128 if attempts > 0 {
129 log::info!("Request succeeded after {} retry attempts", attempts);
130 }
131 return Ok(resp);
132 }
133 Err(e) if (attempts as u32) < self.retry_attempts => {
134 attempts += 1;
135 log::warn!(
136 "Request attempt {} failed: {}. Retrying in {}ms...",
137 attempts,
138 e,
139 self.retry_delay_ms
140 );
141 if e.is_timeout() {
142 log::warn!(
143 "This appears to be a timeout error. Consider increasing 'ai.request_timeout_seconds' in config."
144 );
145 }
146 time::sleep(Duration::from_millis(self.retry_delay_ms)).await;
147 }
148 Err(e) => {
149 log::error!(
150 "Request failed after {} attempts. Final error: {}",
151 attempts + 1,
152 e
153 );
154 if e.is_timeout() {
155 log::error!(
156 "AI service error: Request timed out after multiple attempts. Try increasing 'ai.request_timeout_seconds' configuration."
157 );
158 } else if e.is_connect() {
159 log::error!(
160 "AI service error: Connection failed. Check network connection and Azure OpenAI endpoint settings."
161 );
162 }
163 return Err(e);
164 }
165 }
166 }
167 }
168
169 async fn chat_completion(&self, messages: Vec<Value>) -> crate::Result<String> {
170 let url = format!(
171 "{}/openai/deployments/{}/chat/completions?api-version={}",
172 self.base_url, self.model, self.api_version
173 );
174 let mut req = self
175 .client
176 .post(url)
177 .header("Content-Type", "application/json");
178 if self.api_key.to_lowercase().starts_with("bearer ") {
179 req = req.header("Authorization", self.api_key.clone());
180 } else {
181 req = req.header("api-key", self.api_key.clone());
182 }
183 let body = json!({
184 "messages": messages,
185 "temperature": self.temperature,
186 "max_tokens": self.max_tokens,
187 "stream": false
188 });
189 let request = req.json(&body);
190 let response = self.make_request_with_retry(request).await?;
191 if !response.status().is_success() {
192 let status = response.status();
193 let text = response.text().await?;
194 return Err(SubXError::AiService(format!(
195 "Azure OpenAI API error {}: {}",
196 status, text
197 )));
198 }
199 let resp_json: Value = response.json().await?;
200 if let Some(usage) = resp_json.get("usage") {
201 if let (Some(p), Some(c), Some(t)) = (
202 usage.get("prompt_tokens").and_then(Value::as_u64),
203 usage.get("completion_tokens").and_then(Value::as_u64),
204 usage.get("total_tokens").and_then(Value::as_u64),
205 ) {
206 let model = resp_json
208 .get("model")
209 .and_then(Value::as_str)
210 .unwrap_or(self.model.as_str())
211 .to_string();
212 let stats = crate::services::ai::AiUsageStats {
213 model,
214 prompt_tokens: p as u32,
215 completion_tokens: c as u32,
216 total_tokens: t as u32,
217 };
218 display_ai_usage(&stats);
219 }
220 }
221 let content = resp_json["choices"][0]["message"]["content"]
222 .as_str()
223 .ok_or_else(|| SubXError::AiService("Invalid API response format".to_string()))?;
224 Ok(content.to_string())
225 }
226
227 fn build_analysis_prompt(&self, request: &AnalysisRequest) -> String {
228 build_analysis_prompt_base(request)
229 }
230
231 fn build_verification_prompt(&self, request: &VerificationRequest) -> String {
232 build_verification_prompt_base(request)
233 }
234
235 fn parse_match_result(&self, response: &str) -> crate::Result<MatchResult> {
236 parse_match_result_base(response)
237 }
238
239 fn parse_confidence_score(&self, response: &str) -> crate::Result<ConfidenceScore> {
240 parse_confidence_score_base(response)
241 }
242}
243
244#[cfg(test)]
245mod tests {
246 use super::*;
247 use crate::config::Config;
248
249 #[test]
250 fn test_azure_openai_from_config_and_url_construction() {
251 let mut config = Config::default();
252 config.ai.provider = "azure-openai".to_string();
253 config.ai.api_key = Some("test-api-key".to_string());
254 config.ai.model = "deployment-name".to_string();
255 config.ai.base_url = "https://example.openai.azure.com".to_string();
256 config.ai.api_version = Some("2025-04-01-preview".to_string());
257
258 let client = AzureOpenAIClient::from_config(&config.ai).unwrap();
259 let url = format!(
260 "{}/openai/deployments/{}/chat/completions?api-version={}",
261 client.base_url, client.model, client.api_version
262 );
263 assert!(url.contains("deployment-name"));
264 }
265
266 #[test]
267 fn test_missing_model_error() {
268 let mut config = Config::default();
269 config.ai.provider = "azure-openai".to_string();
270 config.ai.api_key = Some("test-api-key".to_string());
271 config.ai.model = "".to_string();
272 config.ai.base_url = "https://example.openai.azure.com".to_string();
273
274 let err = AzureOpenAIClient::from_config(&config.ai)
275 .unwrap_err()
276 .to_string();
277 assert!(err.contains("Missing Azure OpenAI deployment name in model field"));
278 }
279
280 #[test]
281 fn test_azure_openai_client_creation_with_defaults() {
282 let mut config = Config::default();
283 config.ai.provider = "azure-openai".to_string();
284 config.ai.api_key = Some("test-api-key".to_string());
285 config.ai.model = "deployment-name".to_string();
286 config.ai.base_url = "https://example.openai.azure.com".to_string();
287 let client = AzureOpenAIClient::from_config(&config.ai).unwrap();
290 assert_eq!(
291 client.api_version,
292 super::DEFAULT_AZURE_API_VERSION.to_string()
293 );
294 }
295
296 #[test]
297 fn test_azure_openai_client_missing_api_key() {
298 let mut config = Config::default();
299 config.ai.provider = "azure-openai".to_string();
300 config.ai.api_key = None;
301 config.ai.model = "deployment-name".to_string();
302 config.ai.base_url = "https://example.openai.azure.com".to_string();
303
304 let result = AzureOpenAIClient::from_config(&config.ai);
305 let err = result.unwrap_err().to_string();
306 assert!(err.contains("Missing Azure OpenAI API Key"));
307 }
308
309 #[test]
310 fn test_azure_openai_client_invalid_base_url() {
311 let mut config = Config::default();
312 config.ai.provider = "azure-openai".to_string();
313 config.ai.api_key = Some("test-api-key".to_string());
314 config.ai.model = "deployment-name".to_string();
315 config.ai.base_url = "invalid-url".to_string();
316
317 let result = AzureOpenAIClient::from_config(&config.ai);
318 let err = result.unwrap_err().to_string();
319 assert!(err.contains("Invalid Azure OpenAI endpoint"));
320 }
321
322 #[test]
323 fn test_azure_openai_client_invalid_url_scheme() {
324 let mut config = Config::default();
325 config.ai.provider = "azure-openai".to_string();
326 config.ai.api_key = Some("test-api-key".to_string());
327 config.ai.model = "deployment-name".to_string();
328 config.ai.base_url = "ftp://example.openai.azure.com".to_string();
329
330 let result = AzureOpenAIClient::from_config(&config.ai);
331 let err = result.unwrap_err().to_string();
332 assert!(err.contains("must use http or https"));
333 }
334
335 #[test]
336 fn test_azure_openai_client_url_without_host() {
337 let mut config = Config::default();
338 config.ai.provider = "azure-openai".to_string();
339 config.ai.api_key = Some("test-api-key".to_string());
340 config.ai.model = "deployment-name".to_string();
341 config.ai.base_url = "https://".to_string();
342
343 let result = AzureOpenAIClient::from_config(&config.ai);
344 let err = result.unwrap_err().to_string();
345 assert!(err.contains("missing host"));
346 }
347
348 #[test]
349 fn test_azure_openai_with_custom_model_and_version() {
350 let mock_model = "custom-model-123";
351 let mock_version = "2023-12-01-preview";
352
353 let mut config = Config::default();
354 config.ai.provider = "azure-openai".to_string();
355 config.ai.api_key = Some("test-api-key".to_string());
356 config.ai.model = mock_model.to_string();
357 config.ai.base_url = "https://custom.openai.azure.com".to_string();
358 config.ai.api_version = Some(mock_version.to_string());
359
360 let client = AzureOpenAIClient::from_config(&config.ai).unwrap();
361 assert_eq!(client.model, mock_model);
362 assert_eq!(client.api_version, mock_version);
363 }
364
365 #[test]
366 fn test_azure_openai_with_trailing_slash_in_url() {
367 let mut config = Config::default();
368 config.ai.provider = "azure-openai".to_string();
369 config.ai.api_key = Some("test-api-key".to_string());
370 config.ai.model = "deployment-name".to_string();
371 config.ai.base_url = "https://example.openai.azure.com/".to_string(); let client = AzureOpenAIClient::from_config(&config.ai).unwrap();
374 assert_eq!(
375 client.base_url,
376 "https://example.openai.azure.com".to_string()
377 );
378 }
379
380 #[test]
381 fn test_azure_openai_with_custom_temperature_and_tokens() {
382 let mut config = Config::default();
383 config.ai.provider = "azure-openai".to_string();
384 config.ai.api_key = Some("test-api-key".to_string());
385 config.ai.model = "deployment-name".to_string();
386 config.ai.base_url = "https://example.openai.azure.com".to_string();
387 config.ai.temperature = 0.8;
388 config.ai.max_tokens = 2000;
389
390 let client = AzureOpenAIClient::from_config(&config.ai).unwrap();
391 assert!((client.temperature - 0.8).abs() < f32::EPSILON);
392 assert_eq!(client.max_tokens, 2000);
393 }
394
395 #[test]
396 fn test_azure_openai_with_custom_retry_and_timeout() {
397 let mut config = Config::default();
398 config.ai.provider = "azure-openai".to_string();
399 config.ai.api_key = Some("test-api-key".to_string());
400 config.ai.model = "deployment-name".to_string();
401 config.ai.base_url = "https://example.openai.azure.com".to_string();
402 config.ai.retry_attempts = 5;
403 config.ai.retry_delay_ms = 2000;
404 config.ai.request_timeout_seconds = 180;
405
406 let client = AzureOpenAIClient::from_config(&config.ai).unwrap();
407 assert_eq!(client.retry_attempts, 5);
408 assert_eq!(client.retry_delay_ms, 2000);
409 assert_eq!(client.request_timeout_seconds, 180);
410 }
411
412 #[test]
413 fn test_azure_openai_new_with_all_parameters() {
414 let client = AzureOpenAIClient::new_with_all(
415 "test-api-key".to_string(),
416 "gpt-test".to_string(),
417 "https://example.openai.azure.com".to_string(),
418 "2025-04-01-preview".to_string(),
419 0.7,
420 4000,
421 3,
422 1000,
423 120,
424 );
425 assert!(format!("{:?}", client).contains("AzureOpenAIClient"));
426 }
427
428 #[test]
429 fn test_azure_openai_error_handling_empty_api_key() {
430 let mut config = Config::default();
431 config.ai.provider = "azure-openai".to_string();
432 config.ai.api_key = Some("".to_string()); config.ai.model = "deployment-name".to_string();
434 config.ai.base_url = "https://example.openai.azure.com".to_string();
435
436 let err = AzureOpenAIClient::from_config(&config.ai)
437 .unwrap_err()
438 .to_string();
439 assert!(err.contains("Missing Azure OpenAI API Key"));
440 }
441}
442
443#[async_trait]
444impl AIProvider for AzureOpenAIClient {
445 async fn analyze_content(&self, request: AnalysisRequest) -> crate::Result<MatchResult> {
446 let prompt = self.build_analysis_prompt(&request);
447 let messages = vec![
448 json!({"role": "system", "content": "You are a professional subtitle matching assistant that can analyze the correspondence between video and subtitle files."}),
449 json!({"role": "user", "content": prompt}),
450 ];
451 let resp = self.chat_completion(messages).await?;
452 self.parse_match_result(&resp)
453 }
454
455 async fn verify_match(
456 &self,
457 verification: VerificationRequest,
458 ) -> crate::Result<ConfidenceScore> {
459 let prompt = self.build_verification_prompt(&verification);
460 let messages = vec![
461 json!({"role": "system", "content": "Please evaluate the confidence level of subtitle matching and provide a score between 0-1."}),
462 json!({"role": "user", "content": prompt}),
463 ];
464 let resp = self.chat_completion(messages).await?;
465 self.parse_confidence_score(&resp)
466 }
467}