subx_cli/services/ai/
azure_openai.rs

1use crate::cli::display_ai_usage;
2use crate::error::SubXError;
3use crate::services::ai::prompts::{
4    build_analysis_prompt_base, build_verification_prompt_base, parse_confidence_score_base,
5    parse_match_result_base,
6};
7use crate::services::ai::{
8    AIProvider, AnalysisRequest, ConfidenceScore, MatchResult, VerificationRequest,
9};
10use async_trait::async_trait;
11use reqwest::Client;
12use serde_json::{Value, json};
13use std::time::Duration;
14use tokio::time;
15use url::{ParseError, Url};
16
17/// Azure OpenAI client implementation
18#[derive(Debug)]
19pub struct AzureOpenAIClient {
20    client: Client,
21    api_key: String,
22    model: String,
23    base_url: String,
24    api_version: String,
25    temperature: f32,
26    max_tokens: u32,
27    retry_attempts: u32,
28    retry_delay_ms: u64,
29    request_timeout_seconds: u64,
30}
31
32const DEFAULT_AZURE_API_VERSION: &str = "2025-04-01-preview";
33
34impl AzureOpenAIClient {
35    /// Create a new AzureOpenAIClient with full configuration
36    #[allow(clippy::too_many_arguments)]
37    pub fn new_with_all(
38        api_key: String,
39        model: String,
40        base_url: String,
41        api_version: String,
42        temperature: f32,
43        max_tokens: u32,
44        retry_attempts: u32,
45        retry_delay_ms: u64,
46        request_timeout_seconds: u64,
47    ) -> Self {
48        let client = Client::builder()
49            .timeout(Duration::from_secs(request_timeout_seconds))
50            .build()
51            .expect("Failed to create HTTP client");
52        AzureOpenAIClient {
53            client,
54            api_key,
55            model,
56            base_url: base_url.trim_end_matches('/').to_string(),
57            api_version,
58            temperature,
59            max_tokens,
60            retry_attempts,
61            retry_delay_ms,
62            request_timeout_seconds,
63        }
64    }
65
66    /// Create client from AIConfig
67    pub fn from_config(config: &crate::config::AIConfig) -> crate::Result<Self> {
68        let api_key = config
69            .api_key
70            .as_ref()
71            .filter(|key| !key.trim().is_empty())
72            .ok_or_else(|| SubXError::config("Missing Azure OpenAI API Key".to_string()))?
73            .clone();
74        // Use the model value as the deployment identifier; ensure it's provided
75        let deployment_name = config.model.clone();
76        if deployment_name.trim().is_empty() {
77            return Err(SubXError::config(
78                "Missing Azure OpenAI deployment name in model field".to_string(),
79            ));
80        }
81        let api_version = config
82            .api_version
83            .clone()
84            .unwrap_or_else(|| DEFAULT_AZURE_API_VERSION.to_string());
85
86        // Validate base URL format, handle missing host specially
87        let parsed = match Url::parse(&config.base_url) {
88            Ok(u) => u,
89            Err(ParseError::EmptyHost) => {
90                return Err(SubXError::config(
91                    "Azure OpenAI endpoint missing host".to_string(),
92                ));
93            }
94            Err(e) => {
95                return Err(SubXError::config(format!(
96                    "Invalid Azure OpenAI endpoint: {}",
97                    e
98                )));
99            }
100        };
101        if !matches!(parsed.scheme(), "http" | "https") {
102            return Err(SubXError::config(
103                "Azure OpenAI endpoint must use http or https".to_string(),
104            ));
105        }
106
107        Ok(Self::new_with_all(
108            api_key,
109            config.model.clone(),
110            config.base_url.clone(),
111            api_version,
112            config.temperature,
113            config.max_tokens,
114            config.retry_attempts,
115            config.retry_delay_ms,
116            config.request_timeout_seconds,
117        ))
118    }
119
120    async fn make_request_with_retry(
121        &self,
122        request: reqwest::RequestBuilder,
123    ) -> reqwest::Result<reqwest::Response> {
124        let mut attempts = 0;
125        loop {
126            match request.try_clone().unwrap().send().await {
127                Ok(resp) => {
128                    if attempts > 0 {
129                        log::info!("Request succeeded after {} retry attempts", attempts);
130                    }
131                    return Ok(resp);
132                }
133                Err(e) if (attempts as u32) < self.retry_attempts => {
134                    attempts += 1;
135                    log::warn!(
136                        "Request attempt {} failed: {}. Retrying in {}ms...",
137                        attempts,
138                        e,
139                        self.retry_delay_ms
140                    );
141                    if e.is_timeout() {
142                        log::warn!(
143                            "This appears to be a timeout error. Consider increasing 'ai.request_timeout_seconds' in config."
144                        );
145                    }
146                    time::sleep(Duration::from_millis(self.retry_delay_ms)).await;
147                }
148                Err(e) => {
149                    log::error!(
150                        "Request failed after {} attempts. Final error: {}",
151                        attempts + 1,
152                        e
153                    );
154                    if e.is_timeout() {
155                        log::error!(
156                            "AI service error: Request timed out after multiple attempts. Try increasing 'ai.request_timeout_seconds' configuration."
157                        );
158                    } else if e.is_connect() {
159                        log::error!(
160                            "AI service error: Connection failed. Check network connection and Azure OpenAI endpoint settings."
161                        );
162                    }
163                    return Err(e);
164                }
165            }
166        }
167    }
168
169    async fn chat_completion(&self, messages: Vec<Value>) -> crate::Result<String> {
170        let url = format!(
171            "{}/openai/deployments/{}/chat/completions?api-version={}",
172            self.base_url, self.model, self.api_version
173        );
174        let mut req = self
175            .client
176            .post(url)
177            .header("Content-Type", "application/json");
178        if self.api_key.to_lowercase().starts_with("bearer ") {
179            req = req.header("Authorization", self.api_key.clone());
180        } else {
181            req = req.header("api-key", self.api_key.clone());
182        }
183        let body = json!({
184            "messages": messages,
185            "temperature": self.temperature,
186            "max_tokens": self.max_tokens,
187            "stream": false
188        });
189        let request = req.json(&body);
190        let response = self.make_request_with_retry(request).await?;
191        if !response.status().is_success() {
192            let status = response.status();
193            let text = response.text().await?;
194            return Err(SubXError::AiService(format!(
195                "Azure OpenAI API error {}: {}",
196                status, text
197            )));
198        }
199        let resp_json: Value = response.json().await?;
200        if let Some(usage) = resp_json.get("usage") {
201            if let (Some(p), Some(c), Some(t)) = (
202                usage.get("prompt_tokens").and_then(Value::as_u64),
203                usage.get("completion_tokens").and_then(Value::as_u64),
204                usage.get("total_tokens").and_then(Value::as_u64),
205            ) {
206                // Get model from response JSON, fallback to self.model if missing
207                let model = resp_json
208                    .get("model")
209                    .and_then(Value::as_str)
210                    .unwrap_or(self.model.as_str())
211                    .to_string();
212                let stats = crate::services::ai::AiUsageStats {
213                    model,
214                    prompt_tokens: p as u32,
215                    completion_tokens: c as u32,
216                    total_tokens: t as u32,
217                };
218                display_ai_usage(&stats);
219            }
220        }
221        let content = resp_json["choices"][0]["message"]["content"]
222            .as_str()
223            .ok_or_else(|| SubXError::AiService("Invalid API response format".to_string()))?;
224        Ok(content.to_string())
225    }
226
227    fn build_analysis_prompt(&self, request: &AnalysisRequest) -> String {
228        build_analysis_prompt_base(request)
229    }
230
231    fn build_verification_prompt(&self, request: &VerificationRequest) -> String {
232        build_verification_prompt_base(request)
233    }
234
235    fn parse_match_result(&self, response: &str) -> crate::Result<MatchResult> {
236        parse_match_result_base(response)
237    }
238
239    fn parse_confidence_score(&self, response: &str) -> crate::Result<ConfidenceScore> {
240        parse_confidence_score_base(response)
241    }
242}
243
244#[cfg(test)]
245mod tests {
246    use super::*;
247    use crate::config::Config;
248
249    #[test]
250    fn test_azure_openai_from_config_and_url_construction() {
251        let mut config = Config::default();
252        config.ai.provider = "azure-openai".to_string();
253        config.ai.api_key = Some("test-api-key".to_string());
254        config.ai.model = "deployment-name".to_string();
255        config.ai.base_url = "https://example.openai.azure.com".to_string();
256        config.ai.api_version = Some("2025-04-01-preview".to_string());
257
258        let client = AzureOpenAIClient::from_config(&config.ai).unwrap();
259        let url = format!(
260            "{}/openai/deployments/{}/chat/completions?api-version={}",
261            client.base_url, client.model, client.api_version
262        );
263        assert!(url.contains("deployment-name"));
264    }
265
266    #[test]
267    fn test_missing_model_error() {
268        let mut config = Config::default();
269        config.ai.provider = "azure-openai".to_string();
270        config.ai.api_key = Some("test-api-key".to_string());
271        config.ai.model = "".to_string();
272        config.ai.base_url = "https://example.openai.azure.com".to_string();
273
274        let err = AzureOpenAIClient::from_config(&config.ai)
275            .unwrap_err()
276            .to_string();
277        assert!(err.contains("Missing Azure OpenAI deployment name in model field"));
278    }
279
280    #[test]
281    fn test_azure_openai_client_creation_with_defaults() {
282        let mut config = Config::default();
283        config.ai.provider = "azure-openai".to_string();
284        config.ai.api_key = Some("test-api-key".to_string());
285        config.ai.model = "deployment-name".to_string();
286        config.ai.base_url = "https://example.openai.azure.com".to_string();
287        // api_version defaults to DEFAULT_AZURE_API_VERSION
288
289        let client = AzureOpenAIClient::from_config(&config.ai).unwrap();
290        assert_eq!(
291            client.api_version,
292            super::DEFAULT_AZURE_API_VERSION.to_string()
293        );
294    }
295
296    #[test]
297    fn test_azure_openai_client_missing_api_key() {
298        let mut config = Config::default();
299        config.ai.provider = "azure-openai".to_string();
300        config.ai.api_key = None;
301        config.ai.model = "deployment-name".to_string();
302        config.ai.base_url = "https://example.openai.azure.com".to_string();
303
304        let result = AzureOpenAIClient::from_config(&config.ai);
305        let err = result.unwrap_err().to_string();
306        assert!(err.contains("Missing Azure OpenAI API Key"));
307    }
308
309    #[test]
310    fn test_azure_openai_client_invalid_base_url() {
311        let mut config = Config::default();
312        config.ai.provider = "azure-openai".to_string();
313        config.ai.api_key = Some("test-api-key".to_string());
314        config.ai.model = "deployment-name".to_string();
315        config.ai.base_url = "invalid-url".to_string();
316
317        let result = AzureOpenAIClient::from_config(&config.ai);
318        let err = result.unwrap_err().to_string();
319        assert!(err.contains("Invalid Azure OpenAI endpoint"));
320    }
321
322    #[test]
323    fn test_azure_openai_client_invalid_url_scheme() {
324        let mut config = Config::default();
325        config.ai.provider = "azure-openai".to_string();
326        config.ai.api_key = Some("test-api-key".to_string());
327        config.ai.model = "deployment-name".to_string();
328        config.ai.base_url = "ftp://example.openai.azure.com".to_string();
329
330        let result = AzureOpenAIClient::from_config(&config.ai);
331        let err = result.unwrap_err().to_string();
332        assert!(err.contains("must use http or https"));
333    }
334
335    #[test]
336    fn test_azure_openai_client_url_without_host() {
337        let mut config = Config::default();
338        config.ai.provider = "azure-openai".to_string();
339        config.ai.api_key = Some("test-api-key".to_string());
340        config.ai.model = "deployment-name".to_string();
341        config.ai.base_url = "https://".to_string();
342
343        let result = AzureOpenAIClient::from_config(&config.ai);
344        let err = result.unwrap_err().to_string();
345        assert!(err.contains("missing host"));
346    }
347
348    #[test]
349    fn test_azure_openai_with_custom_model_and_version() {
350        let mock_model = "custom-model-123";
351        let mock_version = "2023-12-01-preview";
352
353        let mut config = Config::default();
354        config.ai.provider = "azure-openai".to_string();
355        config.ai.api_key = Some("test-api-key".to_string());
356        config.ai.model = mock_model.to_string();
357        config.ai.base_url = "https://custom.openai.azure.com".to_string();
358        config.ai.api_version = Some(mock_version.to_string());
359
360        let client = AzureOpenAIClient::from_config(&config.ai).unwrap();
361        assert_eq!(client.model, mock_model);
362        assert_eq!(client.api_version, mock_version);
363    }
364
365    #[test]
366    fn test_azure_openai_with_trailing_slash_in_url() {
367        let mut config = Config::default();
368        config.ai.provider = "azure-openai".to_string();
369        config.ai.api_key = Some("test-api-key".to_string());
370        config.ai.model = "deployment-name".to_string();
371        config.ai.base_url = "https://example.openai.azure.com/".to_string(); // Trailing slash
372
373        let client = AzureOpenAIClient::from_config(&config.ai).unwrap();
374        assert_eq!(
375            client.base_url,
376            "https://example.openai.azure.com".to_string()
377        );
378    }
379
380    #[test]
381    fn test_azure_openai_with_custom_temperature_and_tokens() {
382        let mut config = Config::default();
383        config.ai.provider = "azure-openai".to_string();
384        config.ai.api_key = Some("test-api-key".to_string());
385        config.ai.model = "deployment-name".to_string();
386        config.ai.base_url = "https://example.openai.azure.com".to_string();
387        config.ai.temperature = 0.8;
388        config.ai.max_tokens = 2000;
389
390        let client = AzureOpenAIClient::from_config(&config.ai).unwrap();
391        assert!((client.temperature - 0.8).abs() < f32::EPSILON);
392        assert_eq!(client.max_tokens, 2000);
393    }
394
395    #[test]
396    fn test_azure_openai_with_custom_retry_and_timeout() {
397        let mut config = Config::default();
398        config.ai.provider = "azure-openai".to_string();
399        config.ai.api_key = Some("test-api-key".to_string());
400        config.ai.model = "deployment-name".to_string();
401        config.ai.base_url = "https://example.openai.azure.com".to_string();
402        config.ai.retry_attempts = 5;
403        config.ai.retry_delay_ms = 2000;
404        config.ai.request_timeout_seconds = 180;
405
406        let client = AzureOpenAIClient::from_config(&config.ai).unwrap();
407        assert_eq!(client.retry_attempts, 5);
408        assert_eq!(client.retry_delay_ms, 2000);
409        assert_eq!(client.request_timeout_seconds, 180);
410    }
411
412    #[test]
413    fn test_azure_openai_new_with_all_parameters() {
414        let client = AzureOpenAIClient::new_with_all(
415            "test-api-key".to_string(),
416            "gpt-test".to_string(),
417            "https://example.openai.azure.com".to_string(),
418            "2025-04-01-preview".to_string(),
419            0.7,
420            4000,
421            3,
422            1000,
423            120,
424        );
425        assert!(format!("{:?}", client).contains("AzureOpenAIClient"));
426    }
427
428    #[test]
429    fn test_azure_openai_error_handling_empty_api_key() {
430        let mut config = Config::default();
431        config.ai.provider = "azure-openai".to_string();
432        config.ai.api_key = Some("".to_string()); // Empty string
433        config.ai.model = "deployment-name".to_string();
434        config.ai.base_url = "https://example.openai.azure.com".to_string();
435
436        let err = AzureOpenAIClient::from_config(&config.ai)
437            .unwrap_err()
438            .to_string();
439        assert!(err.contains("Missing Azure OpenAI API Key"));
440    }
441}
442
443#[async_trait]
444impl AIProvider for AzureOpenAIClient {
445    async fn analyze_content(&self, request: AnalysisRequest) -> crate::Result<MatchResult> {
446        let prompt = self.build_analysis_prompt(&request);
447        let messages = vec![
448            json!({"role": "system", "content": "You are a professional subtitle matching assistant that can analyze the correspondence between video and subtitle files."}),
449            json!({"role": "user", "content": prompt}),
450        ];
451        let resp = self.chat_completion(messages).await?;
452        self.parse_match_result(&resp)
453    }
454
455    async fn verify_match(
456        &self,
457        verification: VerificationRequest,
458    ) -> crate::Result<ConfidenceScore> {
459        let prompt = self.build_verification_prompt(&verification);
460        let messages = vec![
461            json!({"role": "system", "content": "Please evaluate the confidence level of subtitle matching and provide a score between 0-1."}),
462            json!({"role": "user", "content": prompt}),
463        ];
464        let resp = self.chat_completion(messages).await?;
465        self.parse_confidence_score(&resp)
466    }
467}