1use crate::cli::display_ai_usage;
2use crate::error::SubXError;
3use crate::services::ai::hosted_hint::{append_local_hint, maybe_attach_local_hint};
4use crate::services::ai::prompts::{PromptBuilder, ResponseParser};
5use crate::services::ai::retry::HttpRetryClient;
6use crate::services::ai::{
7 AIProvider, AnalysisRequest, ConfidenceScore, MatchResult, VerificationRequest,
8};
9use async_trait::async_trait;
10use reqwest::Client;
11use serde_json::{Value, json};
12use std::time::Duration;
13use tokio::time;
14use url::{ParseError, Url};
15
16pub struct AzureOpenAIClient {
18 client: Client,
19 api_key: String,
20 model: String,
21 base_url: String,
22 api_version: String,
23 temperature: f32,
24 max_tokens: u32,
25 retry_attempts: u32,
26 retry_delay_ms: u64,
27 request_timeout_seconds: u64,
28}
29
30impl std::fmt::Debug for AzureOpenAIClient {
31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 f.debug_struct("AzureOpenAIClient")
33 .field("client", &self.client)
34 .field("api_key", &"[REDACTED]")
35 .field("model", &self.model)
36 .field("base_url", &self.base_url)
37 .field("api_version", &self.api_version)
38 .field("temperature", &self.temperature)
39 .field("max_tokens", &self.max_tokens)
40 .field("retry_attempts", &self.retry_attempts)
41 .field("retry_delay_ms", &self.retry_delay_ms)
42 .field("request_timeout_seconds", &self.request_timeout_seconds)
43 .finish()
44 }
45}
46
47const DEFAULT_AZURE_API_VERSION: &str = "2025-04-01-preview";
48
49impl AzureOpenAIClient {
50 #[allow(clippy::too_many_arguments)]
52 pub fn new_with_all(
53 api_key: String,
54 model: String,
55 base_url: String,
56 api_version: String,
57 temperature: f32,
58 max_tokens: u32,
59 retry_attempts: u32,
60 retry_delay_ms: u64,
61 request_timeout_seconds: u64,
62 ) -> Self {
63 let client = Client::builder()
64 .timeout(Duration::from_secs(request_timeout_seconds))
65 .build()
66 .expect("Failed to create HTTP client");
67 AzureOpenAIClient {
68 client,
69 api_key,
70 model,
71 base_url: base_url.trim_end_matches('/').to_string(),
72 api_version,
73 temperature,
74 max_tokens,
75 retry_attempts,
76 retry_delay_ms,
77 request_timeout_seconds,
78 }
79 }
80
81 pub fn from_config(config: &crate::config::AIConfig) -> crate::Result<Self> {
83 let api_key = config
84 .api_key
85 .as_ref()
86 .filter(|key| !key.trim().is_empty())
87 .ok_or_else(|| SubXError::config("Missing Azure OpenAI API Key".to_string()))?
88 .clone();
89 let deployment_name = config.model.clone();
91 if deployment_name.trim().is_empty() {
92 return Err(SubXError::config(
93 "Missing Azure OpenAI deployment name in model field".to_string(),
94 ));
95 }
96 let api_version = config
97 .api_version
98 .clone()
99 .unwrap_or_else(|| DEFAULT_AZURE_API_VERSION.to_string());
100
101 let parsed = match Url::parse(&config.base_url) {
103 Ok(u) => u,
104 Err(ParseError::EmptyHost) => {
105 return Err(SubXError::config(
106 "Azure OpenAI endpoint missing host".to_string(),
107 ));
108 }
109 Err(e) => {
110 return Err(SubXError::config(format!(
111 "Invalid Azure OpenAI endpoint: {}",
112 e
113 )));
114 }
115 };
116 if !matches!(parsed.scheme(), "http" | "https") {
117 return Err(SubXError::config(
118 "Azure OpenAI endpoint must use http or https".to_string(),
119 ));
120 }
121 crate::services::ai::security::warn_on_insecure_http(&parsed, &api_key);
122
123 Ok(Self::new_with_all(
124 api_key,
125 config.model.clone(),
126 config.base_url.clone(),
127 api_version,
128 config.temperature,
129 config.max_tokens,
130 config.retry_attempts,
131 config.retry_delay_ms,
132 config.request_timeout_seconds,
133 ))
134 }
135
136 async fn make_request_with_retry(
137 &self,
138 request: reqwest::RequestBuilder,
139 ) -> crate::Result<reqwest::Response> {
140 let mut attempts = 0;
141 loop {
142 let cloned = request.try_clone().ok_or_else(|| {
143 crate::error::SubXError::AiService(
144 "Request body cannot be cloned for retry".to_string(),
145 )
146 })?;
147 match cloned.send().await {
148 Ok(resp) => {
149 if attempts > 0 {
150 log::info!("Request succeeded after {} retry attempts", attempts);
151 }
152 return Ok(resp);
153 }
154 Err(e) if (attempts as u32) < self.retry_attempts => {
155 attempts += 1;
156 log::warn!(
157 "Request attempt {} failed: {}. Retrying in {}ms...",
158 attempts,
159 e,
160 self.retry_delay_ms
161 );
162 if e.is_timeout() {
163 log::warn!(
164 "This appears to be a timeout error. Consider increasing 'ai.request_timeout_seconds' in config."
165 );
166 }
167 time::sleep(Duration::from_millis(self.retry_delay_ms)).await;
168 }
169 Err(e) => {
170 log::error!(
171 "Request failed after {} attempts. Final error: {}",
172 attempts + 1,
173 e
174 );
175 if e.is_timeout() {
176 log::error!(
177 "AI service error: Request timed out after multiple attempts. Try increasing 'ai.request_timeout_seconds' configuration."
178 );
179 } else if e.is_connect() {
180 log::error!(
181 "AI service error: Connection failed. Check network connection and Azure OpenAI endpoint settings."
182 );
183 }
184 return Err(e.into());
185 }
186 }
187 }
188 }
189
190 pub async fn chat_completion(&self, messages: Vec<Value>) -> crate::Result<String> {
192 let url = format!(
193 "{}/openai/deployments/{}/chat/completions?api-version={}",
194 self.base_url, self.model, self.api_version
195 );
196 let mut req = self
197 .client
198 .post(url)
199 .header("Content-Type", "application/json");
200 if self.api_key.to_lowercase().starts_with("bearer ") {
201 req = req.header("Authorization", self.api_key.clone());
202 } else {
203 req = req.header("api-key", self.api_key.clone());
204 }
205 let body = json!({
206 "messages": messages,
207 "temperature": self.temperature,
208 "max_tokens": self.max_tokens,
209 "stream": false
210 });
211 let request = req.json(&body);
212 let mut response = match self.make_request_with_retry(request).await {
213 Ok(r) => r,
214 Err(e) => return Err(maybe_attach_local_hint(e, &self.base_url)),
215 };
216
217 const MAX_AI_RESPONSE_BYTES: u64 = 10 * 1024 * 1024; if let Some(len) = response.content_length() {
219 if len > MAX_AI_RESPONSE_BYTES {
220 return Err(SubXError::AiService(format!(
221 "AI response too large: {} bytes (limit: {} bytes)",
222 len, MAX_AI_RESPONSE_BYTES
223 )));
224 }
225 }
226
227 if !response.status().is_success() {
228 let status = response.status();
229 let text = response.text().await?;
230 let safe_body = crate::services::ai::error_sanitizer::sanitize_url_in_error(
231 &crate::services::ai::error_sanitizer::truncate_error_body(
232 &text,
233 crate::services::ai::error_sanitizer::DEFAULT_ERROR_BODY_MAX_LEN,
234 ),
235 );
236 return Err(SubXError::AiService(format!(
237 "Azure OpenAI API error {}: {}",
238 status, safe_body
239 )));
240 }
241 let mut body = Vec::new();
244 while let Some(chunk) = response.chunk().await? {
245 body.extend_from_slice(&chunk);
246 if body.len() as u64 > MAX_AI_RESPONSE_BYTES {
247 return Err(SubXError::AiService(format!(
248 "AI response too large: {} bytes read (limit: {} bytes)",
249 body.len(),
250 MAX_AI_RESPONSE_BYTES
251 )));
252 }
253 }
254 let resp_json: Value = serde_json::from_slice(&body)
255 .map_err(|e| SubXError::AiService(format!("Failed to parse AI response: {}", e)))?;
256 if let Some(usage) = resp_json.get("usage") {
257 if let (Some(p), Some(c), Some(t)) = (
258 usage.get("prompt_tokens").and_then(Value::as_u64),
259 usage.get("completion_tokens").and_then(Value::as_u64),
260 usage.get("total_tokens").and_then(Value::as_u64),
261 ) {
262 let model = resp_json
264 .get("model")
265 .and_then(Value::as_str)
266 .unwrap_or(self.model.as_str())
267 .to_string();
268 let stats = crate::services::ai::AiUsageStats {
269 model,
270 prompt_tokens: p as u32,
271 completion_tokens: c as u32,
272 total_tokens: t as u32,
273 };
274 display_ai_usage(&stats);
275 }
276 }
277 let content = resp_json["choices"][0]["message"]["content"]
278 .as_str()
279 .ok_or_else(|| {
280 SubXError::AiService(append_local_hint("Invalid API response format"))
281 })?;
282 Ok(content.to_string())
283 }
284}
285
286impl PromptBuilder for AzureOpenAIClient {}
287impl ResponseParser for AzureOpenAIClient {}
288impl HttpRetryClient for AzureOpenAIClient {
289 fn retry_attempts(&self) -> u32 {
290 self.retry_attempts
291 }
292
293 fn retry_delay_ms(&self) -> u64 {
294 self.retry_delay_ms
295 }
296}
297
298#[cfg(test)]
299mod tests {
300 use super::*;
301 use crate::config::Config;
302
303 #[test]
304 fn test_azure_openai_from_config_and_url_construction() {
305 let mut config = Config::default();
306 config.ai.provider = "azure-openai".to_string();
307 config.ai.api_key = Some("test-api-key".to_string());
308 config.ai.model = "deployment-name".to_string();
309 config.ai.base_url = "https://example.openai.azure.com".to_string();
310 config.ai.api_version = Some("2025-04-01-preview".to_string());
311
312 let client = AzureOpenAIClient::from_config(&config.ai).unwrap();
313 let url = format!(
314 "{}/openai/deployments/{}/chat/completions?api-version={}",
315 client.base_url, client.model, client.api_version
316 );
317 assert!(url.contains("deployment-name"));
318 }
319
320 #[test]
321 fn test_missing_model_error() {
322 let mut config = Config::default();
323 config.ai.provider = "azure-openai".to_string();
324 config.ai.api_key = Some("test-api-key".to_string());
325 config.ai.model = "".to_string();
326 config.ai.base_url = "https://example.openai.azure.com".to_string();
327
328 let err = AzureOpenAIClient::from_config(&config.ai)
329 .unwrap_err()
330 .to_string();
331 assert!(err.contains("Missing Azure OpenAI deployment name in model field"));
332 }
333
334 #[test]
335 fn test_azure_openai_client_creation_with_defaults() {
336 let mut config = Config::default();
337 config.ai.provider = "azure-openai".to_string();
338 config.ai.api_key = Some("test-api-key".to_string());
339 config.ai.model = "deployment-name".to_string();
340 config.ai.base_url = "https://example.openai.azure.com".to_string();
341 let client = AzureOpenAIClient::from_config(&config.ai).unwrap();
344 assert_eq!(
345 client.api_version,
346 super::DEFAULT_AZURE_API_VERSION.to_string()
347 );
348 }
349
350 #[test]
351 fn test_azure_openai_client_missing_api_key() {
352 let mut config = Config::default();
353 config.ai.provider = "azure-openai".to_string();
354 config.ai.api_key = None;
355 config.ai.model = "deployment-name".to_string();
356 config.ai.base_url = "https://example.openai.azure.com".to_string();
357
358 let result = AzureOpenAIClient::from_config(&config.ai);
359 let err = result.unwrap_err().to_string();
360 assert!(err.contains("Missing Azure OpenAI API Key"));
361 }
362
363 #[test]
364 fn test_azure_openai_client_invalid_base_url() {
365 let mut config = Config::default();
366 config.ai.provider = "azure-openai".to_string();
367 config.ai.api_key = Some("test-api-key".to_string());
368 config.ai.model = "deployment-name".to_string();
369 config.ai.base_url = "invalid-url".to_string();
370
371 let result = AzureOpenAIClient::from_config(&config.ai);
372 let err = result.unwrap_err().to_string();
373 assert!(err.contains("Invalid Azure OpenAI endpoint"));
374 }
375
376 #[test]
377 fn test_azure_openai_client_invalid_url_scheme() {
378 let mut config = Config::default();
379 config.ai.provider = "azure-openai".to_string();
380 config.ai.api_key = Some("test-api-key".to_string());
381 config.ai.model = "deployment-name".to_string();
382 config.ai.base_url = "ftp://example.openai.azure.com".to_string();
383
384 let result = AzureOpenAIClient::from_config(&config.ai);
385 let err = result.unwrap_err().to_string();
386 assert!(err.contains("must use http or https"));
387 }
388
389 #[test]
390 fn test_azure_openai_client_url_without_host() {
391 let mut config = Config::default();
392 config.ai.provider = "azure-openai".to_string();
393 config.ai.api_key = Some("test-api-key".to_string());
394 config.ai.model = "deployment-name".to_string();
395 config.ai.base_url = "https://".to_string();
396
397 let result = AzureOpenAIClient::from_config(&config.ai);
398 let err = result.unwrap_err().to_string();
399 assert!(err.contains("missing host"));
400 }
401
402 #[test]
403 fn test_azure_openai_with_custom_model_and_version() {
404 let mock_model = "custom-model-123";
405 let mock_version = "2023-12-01-preview";
406
407 let mut config = Config::default();
408 config.ai.provider = "azure-openai".to_string();
409 config.ai.api_key = Some("test-api-key".to_string());
410 config.ai.model = mock_model.to_string();
411 config.ai.base_url = "https://custom.openai.azure.com".to_string();
412 config.ai.api_version = Some(mock_version.to_string());
413
414 let client = AzureOpenAIClient::from_config(&config.ai).unwrap();
415 assert_eq!(client.model, mock_model);
416 assert_eq!(client.api_version, mock_version);
417 }
418
419 #[test]
420 fn test_azure_openai_with_trailing_slash_in_url() {
421 let mut config = Config::default();
422 config.ai.provider = "azure-openai".to_string();
423 config.ai.api_key = Some("test-api-key".to_string());
424 config.ai.model = "deployment-name".to_string();
425 config.ai.base_url = "https://example.openai.azure.com/".to_string(); let client = AzureOpenAIClient::from_config(&config.ai).unwrap();
428 assert_eq!(
429 client.base_url,
430 "https://example.openai.azure.com".to_string()
431 );
432 }
433
434 #[test]
435 fn test_azure_openai_with_custom_temperature_and_tokens() {
436 let mut config = Config::default();
437 config.ai.provider = "azure-openai".to_string();
438 config.ai.api_key = Some("test-api-key".to_string());
439 config.ai.model = "deployment-name".to_string();
440 config.ai.base_url = "https://example.openai.azure.com".to_string();
441 config.ai.temperature = 0.8;
442 config.ai.max_tokens = 2000;
443
444 let client = AzureOpenAIClient::from_config(&config.ai).unwrap();
445 assert!((client.temperature - 0.8).abs() < f32::EPSILON);
446 assert_eq!(client.max_tokens, 2000);
447 }
448
449 #[test]
450 fn test_azure_openai_with_custom_retry_and_timeout() {
451 let mut config = Config::default();
452 config.ai.provider = "azure-openai".to_string();
453 config.ai.api_key = Some("test-api-key".to_string());
454 config.ai.model = "deployment-name".to_string();
455 config.ai.base_url = "https://example.openai.azure.com".to_string();
456 config.ai.retry_attempts = 5;
457 config.ai.retry_delay_ms = 2000;
458 config.ai.request_timeout_seconds = 180;
459
460 let client = AzureOpenAIClient::from_config(&config.ai).unwrap();
461 assert_eq!(client.retry_attempts, 5);
462 assert_eq!(client.retry_delay_ms, 2000);
463 assert_eq!(client.request_timeout_seconds, 180);
464 }
465
466 #[test]
467 fn test_azure_openai_new_with_all_parameters() {
468 let client = AzureOpenAIClient::new_with_all(
469 "test-api-key".to_string(),
470 "gpt-test".to_string(),
471 "https://example.openai.azure.com".to_string(),
472 "2025-04-01-preview".to_string(),
473 0.7,
474 4000,
475 3,
476 1000,
477 120,
478 );
479 assert!(format!("{:?}", client).contains("AzureOpenAIClient"));
480 }
481
482 #[test]
483 fn test_azure_openai_error_handling_empty_api_key() {
484 let mut config = Config::default();
485 config.ai.provider = "azure-openai".to_string();
486 config.ai.api_key = Some("".to_string()); config.ai.model = "deployment-name".to_string();
488 config.ai.base_url = "https://example.openai.azure.com".to_string();
489
490 let err = AzureOpenAIClient::from_config(&config.ai)
491 .unwrap_err()
492 .to_string();
493 assert!(err.contains("Missing Azure OpenAI API Key"));
494 }
495
496 #[tokio::test]
498 async fn test_hosted_hint_connection_refused_loopback() {
499 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
500 let port = listener.local_addr().unwrap().port();
501 drop(listener);
502 let client = AzureOpenAIClient::new_with_all(
503 "k".into(),
504 "dep".into(),
505 format!("http://127.0.0.1:{}", port),
506 "2025-04-01-preview".into(),
507 0.0,
508 16,
509 0,
510 0,
511 1,
512 );
513 let err = client
514 .chat_completion(vec![json!({"role":"user","content":"x"})])
515 .await
516 .unwrap_err();
517 let msg = err.to_string();
518 assert!(
519 msg.contains("ollama") && msg.contains("ai.provider"),
520 "expected local-provider hint: {msg}"
521 );
522 }
523
524 #[tokio::test]
527 async fn test_hosted_hint_http_200_non_openai_body() {
528 use wiremock::matchers::{method, path};
529 use wiremock::{Mock, MockServer, ResponseTemplate};
530 let server = MockServer::start().await;
531 Mock::given(method("POST"))
532 .and(path("/openai/deployments/dep/chat/completions"))
535 .respond_with(ResponseTemplate::new(200).set_body_json(json!({ "hello": "world" })))
536 .mount(&server)
537 .await;
538 let client = AzureOpenAIClient::new_with_all(
539 "k".into(),
540 "dep".into(),
541 server.uri(),
542 "2025-04-01-preview".into(),
543 0.0,
544 16,
545 0,
546 0,
547 5,
548 );
549 let err = client
550 .chat_completion(vec![json!({"role":"user","content":"x"})])
551 .await
552 .unwrap_err();
553 let msg = err.to_string();
554 assert!(
555 msg.contains("Invalid API response format")
556 && msg.contains("ollama")
557 && msg.contains("ai.provider"),
558 "expected hint-bearing parse-shape error: {msg}"
559 );
560 }
561
562 #[tokio::test]
565 async fn test_hosted_hint_not_emitted_for_public_host() {
566 let client = AzureOpenAIClient::new_with_all(
567 "k".into(),
568 "dep".into(),
569 "https://192.0.2.1".into(),
570 "2025-04-01-preview".into(),
571 0.0,
572 16,
573 0,
574 0,
575 1,
576 );
577 let err = client
578 .chat_completion(vec![json!({"role":"user","content":"x"})])
579 .await
580 .unwrap_err();
581 let msg = err.to_string();
582 assert!(
583 !msg.contains("ollama"),
584 "public-host failure must NOT carry the hint: {msg}"
585 );
586 }
587}
588
589#[async_trait]
590impl AIProvider for AzureOpenAIClient {
591 async fn analyze_content(&self, request: AnalysisRequest) -> crate::Result<MatchResult> {
592 let prompt = self.build_analysis_prompt(&request);
593 let messages = vec![
594 json!({"role": "system", "content": "You are a professional subtitle matching assistant that can analyze the correspondence between video and subtitle files."}),
595 json!({"role": "user", "content": prompt}),
596 ];
597 let resp = self.chat_completion(messages).await?;
598 self.parse_match_result(&resp)
599 }
600
601 async fn verify_match(
602 &self,
603 verification: VerificationRequest,
604 ) -> crate::Result<ConfidenceScore> {
605 let prompt = self.build_verification_prompt(&verification);
606 let messages = vec![
607 json!({"role": "system", "content": "Please evaluate the confidence level of subtitle matching and provide a score between 0-1."}),
608 json!({"role": "user", "content": prompt}),
609 ];
610 let resp = self.chat_completion(messages).await?;
611 self.parse_confidence_score(&resp)
612 }
613
614 async fn chat_completion(&self, messages: Vec<Value>) -> crate::Result<String> {
615 AzureOpenAIClient::chat_completion(self, messages).await
616 }
617}