subx_cli/services/ai/
azure_openai.rs

1use crate::cli::display_ai_usage;
2use crate::error::SubXError;
3use crate::services::ai::prompts::{PromptBuilder, ResponseParser};
4use crate::services::ai::retry::HttpRetryClient;
5use crate::services::ai::{
6    AIProvider, AnalysisRequest, ConfidenceScore, MatchResult, VerificationRequest,
7};
8use async_trait::async_trait;
9use reqwest::Client;
10use serde_json::{Value, json};
11use std::time::Duration;
12use tokio::time;
13use url::{ParseError, Url};
14
15/// Azure OpenAI client implementation
16#[derive(Debug)]
17pub struct AzureOpenAIClient {
18    client: Client,
19    api_key: String,
20    model: String,
21    base_url: String,
22    api_version: String,
23    temperature: f32,
24    max_tokens: u32,
25    retry_attempts: u32,
26    retry_delay_ms: u64,
27    request_timeout_seconds: u64,
28}
29
30const DEFAULT_AZURE_API_VERSION: &str = "2025-04-01-preview";
31
32impl AzureOpenAIClient {
33    /// Create a new AzureOpenAIClient with full configuration
34    #[allow(clippy::too_many_arguments)]
35    pub fn new_with_all(
36        api_key: String,
37        model: String,
38        base_url: String,
39        api_version: String,
40        temperature: f32,
41        max_tokens: u32,
42        retry_attempts: u32,
43        retry_delay_ms: u64,
44        request_timeout_seconds: u64,
45    ) -> Self {
46        let client = Client::builder()
47            .timeout(Duration::from_secs(request_timeout_seconds))
48            .build()
49            .expect("Failed to create HTTP client");
50        AzureOpenAIClient {
51            client,
52            api_key,
53            model,
54            base_url: base_url.trim_end_matches('/').to_string(),
55            api_version,
56            temperature,
57            max_tokens,
58            retry_attempts,
59            retry_delay_ms,
60            request_timeout_seconds,
61        }
62    }
63
64    /// Create client from AIConfig
65    pub fn from_config(config: &crate::config::AIConfig) -> crate::Result<Self> {
66        let api_key = config
67            .api_key
68            .as_ref()
69            .filter(|key| !key.trim().is_empty())
70            .ok_or_else(|| SubXError::config("Missing Azure OpenAI API Key".to_string()))?
71            .clone();
72        // Use the model value as the deployment identifier; ensure it's provided
73        let deployment_name = config.model.clone();
74        if deployment_name.trim().is_empty() {
75            return Err(SubXError::config(
76                "Missing Azure OpenAI deployment name in model field".to_string(),
77            ));
78        }
79        let api_version = config
80            .api_version
81            .clone()
82            .unwrap_or_else(|| DEFAULT_AZURE_API_VERSION.to_string());
83
84        // Validate base URL format, handle missing host specially
85        let parsed = match Url::parse(&config.base_url) {
86            Ok(u) => u,
87            Err(ParseError::EmptyHost) => {
88                return Err(SubXError::config(
89                    "Azure OpenAI endpoint missing host".to_string(),
90                ));
91            }
92            Err(e) => {
93                return Err(SubXError::config(format!(
94                    "Invalid Azure OpenAI endpoint: {}",
95                    e
96                )));
97            }
98        };
99        if !matches!(parsed.scheme(), "http" | "https") {
100            return Err(SubXError::config(
101                "Azure OpenAI endpoint must use http or https".to_string(),
102            ));
103        }
104
105        Ok(Self::new_with_all(
106            api_key,
107            config.model.clone(),
108            config.base_url.clone(),
109            api_version,
110            config.temperature,
111            config.max_tokens,
112            config.retry_attempts,
113            config.retry_delay_ms,
114            config.request_timeout_seconds,
115        ))
116    }
117
118    async fn make_request_with_retry(
119        &self,
120        request: reqwest::RequestBuilder,
121    ) -> reqwest::Result<reqwest::Response> {
122        let mut attempts = 0;
123        loop {
124            match request.try_clone().unwrap().send().await {
125                Ok(resp) => {
126                    if attempts > 0 {
127                        log::info!("Request succeeded after {} retry attempts", attempts);
128                    }
129                    return Ok(resp);
130                }
131                Err(e) if (attempts as u32) < self.retry_attempts => {
132                    attempts += 1;
133                    log::warn!(
134                        "Request attempt {} failed: {}. Retrying in {}ms...",
135                        attempts,
136                        e,
137                        self.retry_delay_ms
138                    );
139                    if e.is_timeout() {
140                        log::warn!(
141                            "This appears to be a timeout error. Consider increasing 'ai.request_timeout_seconds' in config."
142                        );
143                    }
144                    time::sleep(Duration::from_millis(self.retry_delay_ms)).await;
145                }
146                Err(e) => {
147                    log::error!(
148                        "Request failed after {} attempts. Final error: {}",
149                        attempts + 1,
150                        e
151                    );
152                    if e.is_timeout() {
153                        log::error!(
154                            "AI service error: Request timed out after multiple attempts. Try increasing 'ai.request_timeout_seconds' configuration."
155                        );
156                    } else if e.is_connect() {
157                        log::error!(
158                            "AI service error: Connection failed. Check network connection and Azure OpenAI endpoint settings."
159                        );
160                    }
161                    return Err(e);
162                }
163            }
164        }
165    }
166
167    async fn chat_completion(&self, messages: Vec<Value>) -> crate::Result<String> {
168        let url = format!(
169            "{}/openai/deployments/{}/chat/completions?api-version={}",
170            self.base_url, self.model, self.api_version
171        );
172        let mut req = self
173            .client
174            .post(url)
175            .header("Content-Type", "application/json");
176        if self.api_key.to_lowercase().starts_with("bearer ") {
177            req = req.header("Authorization", self.api_key.clone());
178        } else {
179            req = req.header("api-key", self.api_key.clone());
180        }
181        let body = json!({
182            "messages": messages,
183            "temperature": self.temperature,
184            "max_tokens": self.max_tokens,
185            "stream": false
186        });
187        let request = req.json(&body);
188        let response = self.make_request_with_retry(request).await?;
189        if !response.status().is_success() {
190            let status = response.status();
191            let text = response.text().await?;
192            return Err(SubXError::AiService(format!(
193                "Azure OpenAI API error {}: {}",
194                status, text
195            )));
196        }
197        let resp_json: Value = response.json().await?;
198        if let Some(usage) = resp_json.get("usage") {
199            if let (Some(p), Some(c), Some(t)) = (
200                usage.get("prompt_tokens").and_then(Value::as_u64),
201                usage.get("completion_tokens").and_then(Value::as_u64),
202                usage.get("total_tokens").and_then(Value::as_u64),
203            ) {
204                // Get model from response JSON, fallback to self.model if missing
205                let model = resp_json
206                    .get("model")
207                    .and_then(Value::as_str)
208                    .unwrap_or(self.model.as_str())
209                    .to_string();
210                let stats = crate::services::ai::AiUsageStats {
211                    model,
212                    prompt_tokens: p as u32,
213                    completion_tokens: c as u32,
214                    total_tokens: t as u32,
215                };
216                display_ai_usage(&stats);
217            }
218        }
219        let content = resp_json["choices"][0]["message"]["content"]
220            .as_str()
221            .ok_or_else(|| SubXError::AiService("Invalid API response format".to_string()))?;
222        Ok(content.to_string())
223    }
224}
225
226impl PromptBuilder for AzureOpenAIClient {}
227impl ResponseParser for AzureOpenAIClient {}
228impl HttpRetryClient for AzureOpenAIClient {
229    fn retry_attempts(&self) -> u32 {
230        self.retry_attempts
231    }
232
233    fn retry_delay_ms(&self) -> u64 {
234        self.retry_delay_ms
235    }
236}
237
238#[cfg(test)]
239mod tests {
240    use super::*;
241    use crate::config::Config;
242
243    #[test]
244    fn test_azure_openai_from_config_and_url_construction() {
245        let mut config = Config::default();
246        config.ai.provider = "azure-openai".to_string();
247        config.ai.api_key = Some("test-api-key".to_string());
248        config.ai.model = "deployment-name".to_string();
249        config.ai.base_url = "https://example.openai.azure.com".to_string();
250        config.ai.api_version = Some("2025-04-01-preview".to_string());
251
252        let client = AzureOpenAIClient::from_config(&config.ai).unwrap();
253        let url = format!(
254            "{}/openai/deployments/{}/chat/completions?api-version={}",
255            client.base_url, client.model, client.api_version
256        );
257        assert!(url.contains("deployment-name"));
258    }
259
260    #[test]
261    fn test_missing_model_error() {
262        let mut config = Config::default();
263        config.ai.provider = "azure-openai".to_string();
264        config.ai.api_key = Some("test-api-key".to_string());
265        config.ai.model = "".to_string();
266        config.ai.base_url = "https://example.openai.azure.com".to_string();
267
268        let err = AzureOpenAIClient::from_config(&config.ai)
269            .unwrap_err()
270            .to_string();
271        assert!(err.contains("Missing Azure OpenAI deployment name in model field"));
272    }
273
274    #[test]
275    fn test_azure_openai_client_creation_with_defaults() {
276        let mut config = Config::default();
277        config.ai.provider = "azure-openai".to_string();
278        config.ai.api_key = Some("test-api-key".to_string());
279        config.ai.model = "deployment-name".to_string();
280        config.ai.base_url = "https://example.openai.azure.com".to_string();
281        // api_version defaults to DEFAULT_AZURE_API_VERSION
282
283        let client = AzureOpenAIClient::from_config(&config.ai).unwrap();
284        assert_eq!(
285            client.api_version,
286            super::DEFAULT_AZURE_API_VERSION.to_string()
287        );
288    }
289
290    #[test]
291    fn test_azure_openai_client_missing_api_key() {
292        let mut config = Config::default();
293        config.ai.provider = "azure-openai".to_string();
294        config.ai.api_key = None;
295        config.ai.model = "deployment-name".to_string();
296        config.ai.base_url = "https://example.openai.azure.com".to_string();
297
298        let result = AzureOpenAIClient::from_config(&config.ai);
299        let err = result.unwrap_err().to_string();
300        assert!(err.contains("Missing Azure OpenAI API Key"));
301    }
302
303    #[test]
304    fn test_azure_openai_client_invalid_base_url() {
305        let mut config = Config::default();
306        config.ai.provider = "azure-openai".to_string();
307        config.ai.api_key = Some("test-api-key".to_string());
308        config.ai.model = "deployment-name".to_string();
309        config.ai.base_url = "invalid-url".to_string();
310
311        let result = AzureOpenAIClient::from_config(&config.ai);
312        let err = result.unwrap_err().to_string();
313        assert!(err.contains("Invalid Azure OpenAI endpoint"));
314    }
315
316    #[test]
317    fn test_azure_openai_client_invalid_url_scheme() {
318        let mut config = Config::default();
319        config.ai.provider = "azure-openai".to_string();
320        config.ai.api_key = Some("test-api-key".to_string());
321        config.ai.model = "deployment-name".to_string();
322        config.ai.base_url = "ftp://example.openai.azure.com".to_string();
323
324        let result = AzureOpenAIClient::from_config(&config.ai);
325        let err = result.unwrap_err().to_string();
326        assert!(err.contains("must use http or https"));
327    }
328
329    #[test]
330    fn test_azure_openai_client_url_without_host() {
331        let mut config = Config::default();
332        config.ai.provider = "azure-openai".to_string();
333        config.ai.api_key = Some("test-api-key".to_string());
334        config.ai.model = "deployment-name".to_string();
335        config.ai.base_url = "https://".to_string();
336
337        let result = AzureOpenAIClient::from_config(&config.ai);
338        let err = result.unwrap_err().to_string();
339        assert!(err.contains("missing host"));
340    }
341
342    #[test]
343    fn test_azure_openai_with_custom_model_and_version() {
344        let mock_model = "custom-model-123";
345        let mock_version = "2023-12-01-preview";
346
347        let mut config = Config::default();
348        config.ai.provider = "azure-openai".to_string();
349        config.ai.api_key = Some("test-api-key".to_string());
350        config.ai.model = mock_model.to_string();
351        config.ai.base_url = "https://custom.openai.azure.com".to_string();
352        config.ai.api_version = Some(mock_version.to_string());
353
354        let client = AzureOpenAIClient::from_config(&config.ai).unwrap();
355        assert_eq!(client.model, mock_model);
356        assert_eq!(client.api_version, mock_version);
357    }
358
359    #[test]
360    fn test_azure_openai_with_trailing_slash_in_url() {
361        let mut config = Config::default();
362        config.ai.provider = "azure-openai".to_string();
363        config.ai.api_key = Some("test-api-key".to_string());
364        config.ai.model = "deployment-name".to_string();
365        config.ai.base_url = "https://example.openai.azure.com/".to_string(); // Trailing slash
366
367        let client = AzureOpenAIClient::from_config(&config.ai).unwrap();
368        assert_eq!(
369            client.base_url,
370            "https://example.openai.azure.com".to_string()
371        );
372    }
373
374    #[test]
375    fn test_azure_openai_with_custom_temperature_and_tokens() {
376        let mut config = Config::default();
377        config.ai.provider = "azure-openai".to_string();
378        config.ai.api_key = Some("test-api-key".to_string());
379        config.ai.model = "deployment-name".to_string();
380        config.ai.base_url = "https://example.openai.azure.com".to_string();
381        config.ai.temperature = 0.8;
382        config.ai.max_tokens = 2000;
383
384        let client = AzureOpenAIClient::from_config(&config.ai).unwrap();
385        assert!((client.temperature - 0.8).abs() < f32::EPSILON);
386        assert_eq!(client.max_tokens, 2000);
387    }
388
389    #[test]
390    fn test_azure_openai_with_custom_retry_and_timeout() {
391        let mut config = Config::default();
392        config.ai.provider = "azure-openai".to_string();
393        config.ai.api_key = Some("test-api-key".to_string());
394        config.ai.model = "deployment-name".to_string();
395        config.ai.base_url = "https://example.openai.azure.com".to_string();
396        config.ai.retry_attempts = 5;
397        config.ai.retry_delay_ms = 2000;
398        config.ai.request_timeout_seconds = 180;
399
400        let client = AzureOpenAIClient::from_config(&config.ai).unwrap();
401        assert_eq!(client.retry_attempts, 5);
402        assert_eq!(client.retry_delay_ms, 2000);
403        assert_eq!(client.request_timeout_seconds, 180);
404    }
405
406    #[test]
407    fn test_azure_openai_new_with_all_parameters() {
408        let client = AzureOpenAIClient::new_with_all(
409            "test-api-key".to_string(),
410            "gpt-test".to_string(),
411            "https://example.openai.azure.com".to_string(),
412            "2025-04-01-preview".to_string(),
413            0.7,
414            4000,
415            3,
416            1000,
417            120,
418        );
419        assert!(format!("{:?}", client).contains("AzureOpenAIClient"));
420    }
421
422    #[test]
423    fn test_azure_openai_error_handling_empty_api_key() {
424        let mut config = Config::default();
425        config.ai.provider = "azure-openai".to_string();
426        config.ai.api_key = Some("".to_string()); // Empty string
427        config.ai.model = "deployment-name".to_string();
428        config.ai.base_url = "https://example.openai.azure.com".to_string();
429
430        let err = AzureOpenAIClient::from_config(&config.ai)
431            .unwrap_err()
432            .to_string();
433        assert!(err.contains("Missing Azure OpenAI API Key"));
434    }
435}
436
437#[async_trait]
438impl AIProvider for AzureOpenAIClient {
439    async fn analyze_content(&self, request: AnalysisRequest) -> crate::Result<MatchResult> {
440        let prompt = self.build_analysis_prompt(&request);
441        let messages = vec![
442            json!({"role": "system", "content": "You are a professional subtitle matching assistant that can analyze the correspondence between video and subtitle files."}),
443            json!({"role": "user", "content": prompt}),
444        ];
445        let resp = self.chat_completion(messages).await?;
446        self.parse_match_result(&resp)
447    }
448
449    async fn verify_match(
450        &self,
451        verification: VerificationRequest,
452    ) -> crate::Result<ConfidenceScore> {
453        let prompt = self.build_verification_prompt(&verification);
454        let messages = vec![
455            json!({"role": "system", "content": "Please evaluate the confidence level of subtitle matching and provide a score between 0-1."}),
456            json!({"role": "user", "content": prompt}),
457        ];
458        let resp = self.chat_completion(messages).await?;
459        self.parse_confidence_score(&resp)
460    }
461}