Skip to main content

subx_cli/services/ai/
azure_openai.rs

1use crate::cli::display_ai_usage;
2use crate::error::SubXError;
3use crate::services::ai::prompts::{PromptBuilder, ResponseParser};
4use crate::services::ai::retry::HttpRetryClient;
5use crate::services::ai::{
6    AIProvider, AnalysisRequest, ConfidenceScore, MatchResult, VerificationRequest,
7};
8use async_trait::async_trait;
9use reqwest::Client;
10use serde_json::{Value, json};
11use std::time::Duration;
12use tokio::time;
13use url::{ParseError, Url};
14
15/// Azure OpenAI client implementation
16pub struct AzureOpenAIClient {
17    client: Client,
18    api_key: String,
19    model: String,
20    base_url: String,
21    api_version: String,
22    temperature: f32,
23    max_tokens: u32,
24    retry_attempts: u32,
25    retry_delay_ms: u64,
26    request_timeout_seconds: u64,
27}
28
29impl std::fmt::Debug for AzureOpenAIClient {
30    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31        f.debug_struct("AzureOpenAIClient")
32            .field("client", &self.client)
33            .field("api_key", &"[REDACTED]")
34            .field("model", &self.model)
35            .field("base_url", &self.base_url)
36            .field("api_version", &self.api_version)
37            .field("temperature", &self.temperature)
38            .field("max_tokens", &self.max_tokens)
39            .field("retry_attempts", &self.retry_attempts)
40            .field("retry_delay_ms", &self.retry_delay_ms)
41            .field("request_timeout_seconds", &self.request_timeout_seconds)
42            .finish()
43    }
44}
45
46const DEFAULT_AZURE_API_VERSION: &str = "2025-04-01-preview";
47
48impl AzureOpenAIClient {
49    /// Create a new AzureOpenAIClient with full configuration
50    #[allow(clippy::too_many_arguments)]
51    pub fn new_with_all(
52        api_key: String,
53        model: String,
54        base_url: String,
55        api_version: String,
56        temperature: f32,
57        max_tokens: u32,
58        retry_attempts: u32,
59        retry_delay_ms: u64,
60        request_timeout_seconds: u64,
61    ) -> Self {
62        let client = Client::builder()
63            .timeout(Duration::from_secs(request_timeout_seconds))
64            .build()
65            .expect("Failed to create HTTP client");
66        AzureOpenAIClient {
67            client,
68            api_key,
69            model,
70            base_url: base_url.trim_end_matches('/').to_string(),
71            api_version,
72            temperature,
73            max_tokens,
74            retry_attempts,
75            retry_delay_ms,
76            request_timeout_seconds,
77        }
78    }
79
80    /// Create client from AIConfig
81    pub fn from_config(config: &crate::config::AIConfig) -> crate::Result<Self> {
82        let api_key = config
83            .api_key
84            .as_ref()
85            .filter(|key| !key.trim().is_empty())
86            .ok_or_else(|| SubXError::config("Missing Azure OpenAI API Key".to_string()))?
87            .clone();
88        // Use the model value as the deployment identifier; ensure it's provided
89        let deployment_name = config.model.clone();
90        if deployment_name.trim().is_empty() {
91            return Err(SubXError::config(
92                "Missing Azure OpenAI deployment name in model field".to_string(),
93            ));
94        }
95        let api_version = config
96            .api_version
97            .clone()
98            .unwrap_or_else(|| DEFAULT_AZURE_API_VERSION.to_string());
99
100        // Validate base URL format, handle missing host specially
101        let parsed = match Url::parse(&config.base_url) {
102            Ok(u) => u,
103            Err(ParseError::EmptyHost) => {
104                return Err(SubXError::config(
105                    "Azure OpenAI endpoint missing host".to_string(),
106                ));
107            }
108            Err(e) => {
109                return Err(SubXError::config(format!(
110                    "Invalid Azure OpenAI endpoint: {}",
111                    e
112                )));
113            }
114        };
115        if !matches!(parsed.scheme(), "http" | "https") {
116            return Err(SubXError::config(
117                "Azure OpenAI endpoint must use http or https".to_string(),
118            ));
119        }
120        crate::services::ai::security::warn_on_insecure_http(&parsed, &api_key);
121
122        Ok(Self::new_with_all(
123            api_key,
124            config.model.clone(),
125            config.base_url.clone(),
126            api_version,
127            config.temperature,
128            config.max_tokens,
129            config.retry_attempts,
130            config.retry_delay_ms,
131            config.request_timeout_seconds,
132        ))
133    }
134
135    async fn make_request_with_retry(
136        &self,
137        request: reqwest::RequestBuilder,
138    ) -> crate::Result<reqwest::Response> {
139        let mut attempts = 0;
140        loop {
141            let cloned = request.try_clone().ok_or_else(|| {
142                crate::error::SubXError::AiService(
143                    "Request body cannot be cloned for retry".to_string(),
144                )
145            })?;
146            match cloned.send().await {
147                Ok(resp) => {
148                    if attempts > 0 {
149                        log::info!("Request succeeded after {} retry attempts", attempts);
150                    }
151                    return Ok(resp);
152                }
153                Err(e) if (attempts as u32) < self.retry_attempts => {
154                    attempts += 1;
155                    log::warn!(
156                        "Request attempt {} failed: {}. Retrying in {}ms...",
157                        attempts,
158                        e,
159                        self.retry_delay_ms
160                    );
161                    if e.is_timeout() {
162                        log::warn!(
163                            "This appears to be a timeout error. Consider increasing 'ai.request_timeout_seconds' in config."
164                        );
165                    }
166                    time::sleep(Duration::from_millis(self.retry_delay_ms)).await;
167                }
168                Err(e) => {
169                    log::error!(
170                        "Request failed after {} attempts. Final error: {}",
171                        attempts + 1,
172                        e
173                    );
174                    if e.is_timeout() {
175                        log::error!(
176                            "AI service error: Request timed out after multiple attempts. Try increasing 'ai.request_timeout_seconds' configuration."
177                        );
178                    } else if e.is_connect() {
179                        log::error!(
180                            "AI service error: Connection failed. Check network connection and Azure OpenAI endpoint settings."
181                        );
182                    }
183                    return Err(e.into());
184                }
185            }
186        }
187    }
188
189    async fn chat_completion(&self, messages: Vec<Value>) -> crate::Result<String> {
190        let url = format!(
191            "{}/openai/deployments/{}/chat/completions?api-version={}",
192            self.base_url, self.model, self.api_version
193        );
194        let mut req = self
195            .client
196            .post(url)
197            .header("Content-Type", "application/json");
198        if self.api_key.to_lowercase().starts_with("bearer ") {
199            req = req.header("Authorization", self.api_key.clone());
200        } else {
201            req = req.header("api-key", self.api_key.clone());
202        }
203        let body = json!({
204            "messages": messages,
205            "temperature": self.temperature,
206            "max_tokens": self.max_tokens,
207            "stream": false
208        });
209        let request = req.json(&body);
210        let mut response = self.make_request_with_retry(request).await?;
211
212        const MAX_AI_RESPONSE_BYTES: u64 = 10 * 1024 * 1024; // 10 MiB
213        if let Some(len) = response.content_length() {
214            if len > MAX_AI_RESPONSE_BYTES {
215                return Err(SubXError::AiService(format!(
216                    "AI response too large: {} bytes (limit: {} bytes)",
217                    len, MAX_AI_RESPONSE_BYTES
218                )));
219            }
220        }
221
222        if !response.status().is_success() {
223            let status = response.status();
224            let text = response.text().await?;
225            let safe_body = crate::services::ai::error_sanitizer::sanitize_url_in_error(
226                &crate::services::ai::error_sanitizer::truncate_error_body(
227                    &text,
228                    crate::services::ai::error_sanitizer::DEFAULT_ERROR_BODY_MAX_LEN,
229                ),
230            );
231            return Err(SubXError::AiService(format!(
232                "Azure OpenAI API error {}: {}",
233                status, safe_body
234            )));
235        }
236        // Bounded chunked read to guard against oversized responses when
237        // content_length() is not reported by the server.
238        let mut body = Vec::new();
239        while let Some(chunk) = response.chunk().await? {
240            body.extend_from_slice(&chunk);
241            if body.len() as u64 > MAX_AI_RESPONSE_BYTES {
242                return Err(SubXError::AiService(format!(
243                    "AI response too large: {} bytes read (limit: {} bytes)",
244                    body.len(),
245                    MAX_AI_RESPONSE_BYTES
246                )));
247            }
248        }
249        let resp_json: Value = serde_json::from_slice(&body)
250            .map_err(|e| SubXError::AiService(format!("Failed to parse AI response: {}", e)))?;
251        if let Some(usage) = resp_json.get("usage") {
252            if let (Some(p), Some(c), Some(t)) = (
253                usage.get("prompt_tokens").and_then(Value::as_u64),
254                usage.get("completion_tokens").and_then(Value::as_u64),
255                usage.get("total_tokens").and_then(Value::as_u64),
256            ) {
257                // Get model from response JSON, fallback to self.model if missing
258                let model = resp_json
259                    .get("model")
260                    .and_then(Value::as_str)
261                    .unwrap_or(self.model.as_str())
262                    .to_string();
263                let stats = crate::services::ai::AiUsageStats {
264                    model,
265                    prompt_tokens: p as u32,
266                    completion_tokens: c as u32,
267                    total_tokens: t as u32,
268                };
269                display_ai_usage(&stats);
270            }
271        }
272        let content = resp_json["choices"][0]["message"]["content"]
273            .as_str()
274            .ok_or_else(|| SubXError::AiService("Invalid API response format".to_string()))?;
275        Ok(content.to_string())
276    }
277}
278
279impl PromptBuilder for AzureOpenAIClient {}
280impl ResponseParser for AzureOpenAIClient {}
281impl HttpRetryClient for AzureOpenAIClient {
282    fn retry_attempts(&self) -> u32 {
283        self.retry_attempts
284    }
285
286    fn retry_delay_ms(&self) -> u64 {
287        self.retry_delay_ms
288    }
289}
290
291#[cfg(test)]
292mod tests {
293    use super::*;
294    use crate::config::Config;
295
296    #[test]
297    fn test_azure_openai_from_config_and_url_construction() {
298        let mut config = Config::default();
299        config.ai.provider = "azure-openai".to_string();
300        config.ai.api_key = Some("test-api-key".to_string());
301        config.ai.model = "deployment-name".to_string();
302        config.ai.base_url = "https://example.openai.azure.com".to_string();
303        config.ai.api_version = Some("2025-04-01-preview".to_string());
304
305        let client = AzureOpenAIClient::from_config(&config.ai).unwrap();
306        let url = format!(
307            "{}/openai/deployments/{}/chat/completions?api-version={}",
308            client.base_url, client.model, client.api_version
309        );
310        assert!(url.contains("deployment-name"));
311    }
312
313    #[test]
314    fn test_missing_model_error() {
315        let mut config = Config::default();
316        config.ai.provider = "azure-openai".to_string();
317        config.ai.api_key = Some("test-api-key".to_string());
318        config.ai.model = "".to_string();
319        config.ai.base_url = "https://example.openai.azure.com".to_string();
320
321        let err = AzureOpenAIClient::from_config(&config.ai)
322            .unwrap_err()
323            .to_string();
324        assert!(err.contains("Missing Azure OpenAI deployment name in model field"));
325    }
326
327    #[test]
328    fn test_azure_openai_client_creation_with_defaults() {
329        let mut config = Config::default();
330        config.ai.provider = "azure-openai".to_string();
331        config.ai.api_key = Some("test-api-key".to_string());
332        config.ai.model = "deployment-name".to_string();
333        config.ai.base_url = "https://example.openai.azure.com".to_string();
334        // api_version defaults to DEFAULT_AZURE_API_VERSION
335
336        let client = AzureOpenAIClient::from_config(&config.ai).unwrap();
337        assert_eq!(
338            client.api_version,
339            super::DEFAULT_AZURE_API_VERSION.to_string()
340        );
341    }
342
343    #[test]
344    fn test_azure_openai_client_missing_api_key() {
345        let mut config = Config::default();
346        config.ai.provider = "azure-openai".to_string();
347        config.ai.api_key = None;
348        config.ai.model = "deployment-name".to_string();
349        config.ai.base_url = "https://example.openai.azure.com".to_string();
350
351        let result = AzureOpenAIClient::from_config(&config.ai);
352        let err = result.unwrap_err().to_string();
353        assert!(err.contains("Missing Azure OpenAI API Key"));
354    }
355
356    #[test]
357    fn test_azure_openai_client_invalid_base_url() {
358        let mut config = Config::default();
359        config.ai.provider = "azure-openai".to_string();
360        config.ai.api_key = Some("test-api-key".to_string());
361        config.ai.model = "deployment-name".to_string();
362        config.ai.base_url = "invalid-url".to_string();
363
364        let result = AzureOpenAIClient::from_config(&config.ai);
365        let err = result.unwrap_err().to_string();
366        assert!(err.contains("Invalid Azure OpenAI endpoint"));
367    }
368
369    #[test]
370    fn test_azure_openai_client_invalid_url_scheme() {
371        let mut config = Config::default();
372        config.ai.provider = "azure-openai".to_string();
373        config.ai.api_key = Some("test-api-key".to_string());
374        config.ai.model = "deployment-name".to_string();
375        config.ai.base_url = "ftp://example.openai.azure.com".to_string();
376
377        let result = AzureOpenAIClient::from_config(&config.ai);
378        let err = result.unwrap_err().to_string();
379        assert!(err.contains("must use http or https"));
380    }
381
382    #[test]
383    fn test_azure_openai_client_url_without_host() {
384        let mut config = Config::default();
385        config.ai.provider = "azure-openai".to_string();
386        config.ai.api_key = Some("test-api-key".to_string());
387        config.ai.model = "deployment-name".to_string();
388        config.ai.base_url = "https://".to_string();
389
390        let result = AzureOpenAIClient::from_config(&config.ai);
391        let err = result.unwrap_err().to_string();
392        assert!(err.contains("missing host"));
393    }
394
395    #[test]
396    fn test_azure_openai_with_custom_model_and_version() {
397        let mock_model = "custom-model-123";
398        let mock_version = "2023-12-01-preview";
399
400        let mut config = Config::default();
401        config.ai.provider = "azure-openai".to_string();
402        config.ai.api_key = Some("test-api-key".to_string());
403        config.ai.model = mock_model.to_string();
404        config.ai.base_url = "https://custom.openai.azure.com".to_string();
405        config.ai.api_version = Some(mock_version.to_string());
406
407        let client = AzureOpenAIClient::from_config(&config.ai).unwrap();
408        assert_eq!(client.model, mock_model);
409        assert_eq!(client.api_version, mock_version);
410    }
411
412    #[test]
413    fn test_azure_openai_with_trailing_slash_in_url() {
414        let mut config = Config::default();
415        config.ai.provider = "azure-openai".to_string();
416        config.ai.api_key = Some("test-api-key".to_string());
417        config.ai.model = "deployment-name".to_string();
418        config.ai.base_url = "https://example.openai.azure.com/".to_string(); // Trailing slash
419
420        let client = AzureOpenAIClient::from_config(&config.ai).unwrap();
421        assert_eq!(
422            client.base_url,
423            "https://example.openai.azure.com".to_string()
424        );
425    }
426
427    #[test]
428    fn test_azure_openai_with_custom_temperature_and_tokens() {
429        let mut config = Config::default();
430        config.ai.provider = "azure-openai".to_string();
431        config.ai.api_key = Some("test-api-key".to_string());
432        config.ai.model = "deployment-name".to_string();
433        config.ai.base_url = "https://example.openai.azure.com".to_string();
434        config.ai.temperature = 0.8;
435        config.ai.max_tokens = 2000;
436
437        let client = AzureOpenAIClient::from_config(&config.ai).unwrap();
438        assert!((client.temperature - 0.8).abs() < f32::EPSILON);
439        assert_eq!(client.max_tokens, 2000);
440    }
441
442    #[test]
443    fn test_azure_openai_with_custom_retry_and_timeout() {
444        let mut config = Config::default();
445        config.ai.provider = "azure-openai".to_string();
446        config.ai.api_key = Some("test-api-key".to_string());
447        config.ai.model = "deployment-name".to_string();
448        config.ai.base_url = "https://example.openai.azure.com".to_string();
449        config.ai.retry_attempts = 5;
450        config.ai.retry_delay_ms = 2000;
451        config.ai.request_timeout_seconds = 180;
452
453        let client = AzureOpenAIClient::from_config(&config.ai).unwrap();
454        assert_eq!(client.retry_attempts, 5);
455        assert_eq!(client.retry_delay_ms, 2000);
456        assert_eq!(client.request_timeout_seconds, 180);
457    }
458
459    #[test]
460    fn test_azure_openai_new_with_all_parameters() {
461        let client = AzureOpenAIClient::new_with_all(
462            "test-api-key".to_string(),
463            "gpt-test".to_string(),
464            "https://example.openai.azure.com".to_string(),
465            "2025-04-01-preview".to_string(),
466            0.7,
467            4000,
468            3,
469            1000,
470            120,
471        );
472        assert!(format!("{:?}", client).contains("AzureOpenAIClient"));
473    }
474
475    #[test]
476    fn test_azure_openai_error_handling_empty_api_key() {
477        let mut config = Config::default();
478        config.ai.provider = "azure-openai".to_string();
479        config.ai.api_key = Some("".to_string()); // Empty string
480        config.ai.model = "deployment-name".to_string();
481        config.ai.base_url = "https://example.openai.azure.com".to_string();
482
483        let err = AzureOpenAIClient::from_config(&config.ai)
484            .unwrap_err()
485            .to_string();
486        assert!(err.contains("Missing Azure OpenAI API Key"));
487    }
488}
489
490#[async_trait]
491impl AIProvider for AzureOpenAIClient {
492    async fn analyze_content(&self, request: AnalysisRequest) -> crate::Result<MatchResult> {
493        let prompt = self.build_analysis_prompt(&request);
494        let messages = vec![
495            json!({"role": "system", "content": "You are a professional subtitle matching assistant that can analyze the correspondence between video and subtitle files."}),
496            json!({"role": "user", "content": prompt}),
497        ];
498        let resp = self.chat_completion(messages).await?;
499        self.parse_match_result(&resp)
500    }
501
502    async fn verify_match(
503        &self,
504        verification: VerificationRequest,
505    ) -> crate::Result<ConfidenceScore> {
506        let prompt = self.build_verification_prompt(&verification);
507        let messages = vec![
508            json!({"role": "system", "content": "Please evaluate the confidence level of subtitle matching and provide a score between 0-1."}),
509            json!({"role": "user", "content": prompt}),
510        ];
511        let resp = self.chat_completion(messages).await?;
512        self.parse_confidence_score(&resp)
513    }
514}