Skip to main content

subx_cli/services/ai/
azure_openai.rs

1use crate::cli::display_ai_usage;
2use crate::error::SubXError;
3use crate::services::ai::hosted_hint::{append_local_hint, maybe_attach_local_hint};
4use crate::services::ai::prompts::{PromptBuilder, ResponseParser};
5use crate::services::ai::retry::HttpRetryClient;
6use crate::services::ai::{
7    AIProvider, AnalysisRequest, ConfidenceScore, MatchResult, VerificationRequest,
8};
9use async_trait::async_trait;
10use reqwest::Client;
11use serde_json::{Value, json};
12use std::time::Duration;
13use tokio::time;
14use url::{ParseError, Url};
15
16/// Azure OpenAI client implementation
17pub struct AzureOpenAIClient {
18    client: Client,
19    api_key: String,
20    model: String,
21    base_url: String,
22    api_version: String,
23    temperature: f32,
24    max_tokens: u32,
25    retry_attempts: u32,
26    retry_delay_ms: u64,
27    request_timeout_seconds: u64,
28}
29
30impl std::fmt::Debug for AzureOpenAIClient {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        f.debug_struct("AzureOpenAIClient")
33            .field("client", &self.client)
34            .field("api_key", &"[REDACTED]")
35            .field("model", &self.model)
36            .field("base_url", &self.base_url)
37            .field("api_version", &self.api_version)
38            .field("temperature", &self.temperature)
39            .field("max_tokens", &self.max_tokens)
40            .field("retry_attempts", &self.retry_attempts)
41            .field("retry_delay_ms", &self.retry_delay_ms)
42            .field("request_timeout_seconds", &self.request_timeout_seconds)
43            .finish()
44    }
45}
46
47const DEFAULT_AZURE_API_VERSION: &str = "2025-04-01-preview";
48
49impl AzureOpenAIClient {
50    /// Create a new AzureOpenAIClient with full configuration
51    #[allow(clippy::too_many_arguments)]
52    pub fn new_with_all(
53        api_key: String,
54        model: String,
55        base_url: String,
56        api_version: String,
57        temperature: f32,
58        max_tokens: u32,
59        retry_attempts: u32,
60        retry_delay_ms: u64,
61        request_timeout_seconds: u64,
62    ) -> Self {
63        let client = Client::builder()
64            .timeout(Duration::from_secs(request_timeout_seconds))
65            .build()
66            .expect("Failed to create HTTP client");
67        AzureOpenAIClient {
68            client,
69            api_key,
70            model,
71            base_url: base_url.trim_end_matches('/').to_string(),
72            api_version,
73            temperature,
74            max_tokens,
75            retry_attempts,
76            retry_delay_ms,
77            request_timeout_seconds,
78        }
79    }
80
81    /// Create client from AIConfig
82    pub fn from_config(config: &crate::config::AIConfig) -> crate::Result<Self> {
83        let api_key = config
84            .api_key
85            .as_ref()
86            .filter(|key| !key.trim().is_empty())
87            .ok_or_else(|| SubXError::config("Missing Azure OpenAI API Key".to_string()))?
88            .clone();
89        // Use the model value as the deployment identifier; ensure it's provided
90        let deployment_name = config.model.clone();
91        if deployment_name.trim().is_empty() {
92            return Err(SubXError::config(
93                "Missing Azure OpenAI deployment name in model field".to_string(),
94            ));
95        }
96        let api_version = config
97            .api_version
98            .clone()
99            .unwrap_or_else(|| DEFAULT_AZURE_API_VERSION.to_string());
100
101        // Validate base URL format, handle missing host specially
102        let parsed = match Url::parse(&config.base_url) {
103            Ok(u) => u,
104            Err(ParseError::EmptyHost) => {
105                return Err(SubXError::config(
106                    "Azure OpenAI endpoint missing host".to_string(),
107                ));
108            }
109            Err(e) => {
110                return Err(SubXError::config(format!(
111                    "Invalid Azure OpenAI endpoint: {}",
112                    e
113                )));
114            }
115        };
116        if !matches!(parsed.scheme(), "http" | "https") {
117            return Err(SubXError::config(
118                "Azure OpenAI endpoint must use http or https".to_string(),
119            ));
120        }
121        crate::services::ai::security::warn_on_insecure_http(&parsed, &api_key);
122
123        Ok(Self::new_with_all(
124            api_key,
125            config.model.clone(),
126            config.base_url.clone(),
127            api_version,
128            config.temperature,
129            config.max_tokens,
130            config.retry_attempts,
131            config.retry_delay_ms,
132            config.request_timeout_seconds,
133        ))
134    }
135
136    async fn make_request_with_retry(
137        &self,
138        request: reqwest::RequestBuilder,
139    ) -> crate::Result<reqwest::Response> {
140        let mut attempts = 0;
141        loop {
142            let cloned = request.try_clone().ok_or_else(|| {
143                crate::error::SubXError::AiService(
144                    "Request body cannot be cloned for retry".to_string(),
145                )
146            })?;
147            match cloned.send().await {
148                Ok(resp) => {
149                    if attempts > 0 {
150                        log::info!("Request succeeded after {} retry attempts", attempts);
151                    }
152                    return Ok(resp);
153                }
154                Err(e) if (attempts as u32) < self.retry_attempts => {
155                    attempts += 1;
156                    log::warn!(
157                        "Request attempt {} failed: {}. Retrying in {}ms...",
158                        attempts,
159                        e,
160                        self.retry_delay_ms
161                    );
162                    if e.is_timeout() {
163                        log::warn!(
164                            "This appears to be a timeout error. Consider increasing 'ai.request_timeout_seconds' in config."
165                        );
166                    }
167                    time::sleep(Duration::from_millis(self.retry_delay_ms)).await;
168                }
169                Err(e) => {
170                    log::error!(
171                        "Request failed after {} attempts. Final error: {}",
172                        attempts + 1,
173                        e
174                    );
175                    if e.is_timeout() {
176                        log::error!(
177                            "AI service error: Request timed out after multiple attempts. Try increasing 'ai.request_timeout_seconds' configuration."
178                        );
179                    } else if e.is_connect() {
180                        log::error!(
181                            "AI service error: Connection failed. Check network connection and Azure OpenAI endpoint settings."
182                        );
183                    }
184                    return Err(e.into());
185                }
186            }
187        }
188    }
189
190    /// Send a raw chat completion request to the Azure OpenAI Chat Completions API.
191    pub async fn chat_completion(&self, messages: Vec<Value>) -> crate::Result<String> {
192        let url = format!(
193            "{}/openai/deployments/{}/chat/completions?api-version={}",
194            self.base_url, self.model, self.api_version
195        );
196        let mut req = self
197            .client
198            .post(url)
199            .header("Content-Type", "application/json");
200        if self.api_key.to_lowercase().starts_with("bearer ") {
201            req = req.header("Authorization", self.api_key.clone());
202        } else {
203            req = req.header("api-key", self.api_key.clone());
204        }
205        let body = json!({
206            "messages": messages,
207            "temperature": self.temperature,
208            "max_tokens": self.max_tokens,
209            "stream": false
210        });
211        let request = req.json(&body);
212        let mut response = match self.make_request_with_retry(request).await {
213            Ok(r) => r,
214            Err(e) => return Err(maybe_attach_local_hint(e, &self.base_url)),
215        };
216
217        const MAX_AI_RESPONSE_BYTES: u64 = 10 * 1024 * 1024; // 10 MiB
218        if let Some(len) = response.content_length() {
219            if len > MAX_AI_RESPONSE_BYTES {
220                return Err(SubXError::AiService(format!(
221                    "AI response too large: {} bytes (limit: {} bytes)",
222                    len, MAX_AI_RESPONSE_BYTES
223                )));
224            }
225        }
226
227        if !response.status().is_success() {
228            let status = response.status();
229            let text = response.text().await?;
230            let safe_body = crate::services::ai::error_sanitizer::sanitize_url_in_error(
231                &crate::services::ai::error_sanitizer::truncate_error_body(
232                    &text,
233                    crate::services::ai::error_sanitizer::DEFAULT_ERROR_BODY_MAX_LEN,
234                ),
235            );
236            return Err(SubXError::AiService(format!(
237                "Azure OpenAI API error {}: {}",
238                status, safe_body
239            )));
240        }
241        // Bounded chunked read to guard against oversized responses when
242        // content_length() is not reported by the server.
243        let mut body = Vec::new();
244        while let Some(chunk) = response.chunk().await? {
245            body.extend_from_slice(&chunk);
246            if body.len() as u64 > MAX_AI_RESPONSE_BYTES {
247                return Err(SubXError::AiService(format!(
248                    "AI response too large: {} bytes read (limit: {} bytes)",
249                    body.len(),
250                    MAX_AI_RESPONSE_BYTES
251                )));
252            }
253        }
254        let resp_json: Value = serde_json::from_slice(&body)
255            .map_err(|e| SubXError::AiService(format!("Failed to parse AI response: {}", e)))?;
256        if let Some(usage) = resp_json.get("usage") {
257            if let (Some(p), Some(c), Some(t)) = (
258                usage.get("prompt_tokens").and_then(Value::as_u64),
259                usage.get("completion_tokens").and_then(Value::as_u64),
260                usage.get("total_tokens").and_then(Value::as_u64),
261            ) {
262                // Get model from response JSON, fallback to self.model if missing
263                let model = resp_json
264                    .get("model")
265                    .and_then(Value::as_str)
266                    .unwrap_or(self.model.as_str())
267                    .to_string();
268                let stats = crate::services::ai::AiUsageStats {
269                    model,
270                    prompt_tokens: p as u32,
271                    completion_tokens: c as u32,
272                    total_tokens: t as u32,
273                };
274                display_ai_usage(&stats);
275            }
276        }
277        let content = resp_json["choices"][0]["message"]["content"]
278            .as_str()
279            .ok_or_else(|| {
280                SubXError::AiService(append_local_hint("Invalid API response format"))
281            })?;
282        Ok(content.to_string())
283    }
284}
285
286impl PromptBuilder for AzureOpenAIClient {}
287impl ResponseParser for AzureOpenAIClient {}
288impl HttpRetryClient for AzureOpenAIClient {
289    fn retry_attempts(&self) -> u32 {
290        self.retry_attempts
291    }
292
293    fn retry_delay_ms(&self) -> u64 {
294        self.retry_delay_ms
295    }
296}
297
298#[cfg(test)]
299mod tests {
300    use super::*;
301    use crate::config::Config;
302
303    #[test]
304    fn test_azure_openai_from_config_and_url_construction() {
305        let mut config = Config::default();
306        config.ai.provider = "azure-openai".to_string();
307        config.ai.api_key = Some("test-api-key".to_string());
308        config.ai.model = "deployment-name".to_string();
309        config.ai.base_url = "https://example.openai.azure.com".to_string();
310        config.ai.api_version = Some("2025-04-01-preview".to_string());
311
312        let client = AzureOpenAIClient::from_config(&config.ai).unwrap();
313        let url = format!(
314            "{}/openai/deployments/{}/chat/completions?api-version={}",
315            client.base_url, client.model, client.api_version
316        );
317        assert!(url.contains("deployment-name"));
318    }
319
320    #[test]
321    fn test_missing_model_error() {
322        let mut config = Config::default();
323        config.ai.provider = "azure-openai".to_string();
324        config.ai.api_key = Some("test-api-key".to_string());
325        config.ai.model = "".to_string();
326        config.ai.base_url = "https://example.openai.azure.com".to_string();
327
328        let err = AzureOpenAIClient::from_config(&config.ai)
329            .unwrap_err()
330            .to_string();
331        assert!(err.contains("Missing Azure OpenAI deployment name in model field"));
332    }
333
334    #[test]
335    fn test_azure_openai_client_creation_with_defaults() {
336        let mut config = Config::default();
337        config.ai.provider = "azure-openai".to_string();
338        config.ai.api_key = Some("test-api-key".to_string());
339        config.ai.model = "deployment-name".to_string();
340        config.ai.base_url = "https://example.openai.azure.com".to_string();
341        // api_version defaults to DEFAULT_AZURE_API_VERSION
342
343        let client = AzureOpenAIClient::from_config(&config.ai).unwrap();
344        assert_eq!(
345            client.api_version,
346            super::DEFAULT_AZURE_API_VERSION.to_string()
347        );
348    }
349
350    #[test]
351    fn test_azure_openai_client_missing_api_key() {
352        let mut config = Config::default();
353        config.ai.provider = "azure-openai".to_string();
354        config.ai.api_key = None;
355        config.ai.model = "deployment-name".to_string();
356        config.ai.base_url = "https://example.openai.azure.com".to_string();
357
358        let result = AzureOpenAIClient::from_config(&config.ai);
359        let err = result.unwrap_err().to_string();
360        assert!(err.contains("Missing Azure OpenAI API Key"));
361    }
362
363    #[test]
364    fn test_azure_openai_client_invalid_base_url() {
365        let mut config = Config::default();
366        config.ai.provider = "azure-openai".to_string();
367        config.ai.api_key = Some("test-api-key".to_string());
368        config.ai.model = "deployment-name".to_string();
369        config.ai.base_url = "invalid-url".to_string();
370
371        let result = AzureOpenAIClient::from_config(&config.ai);
372        let err = result.unwrap_err().to_string();
373        assert!(err.contains("Invalid Azure OpenAI endpoint"));
374    }
375
376    #[test]
377    fn test_azure_openai_client_invalid_url_scheme() {
378        let mut config = Config::default();
379        config.ai.provider = "azure-openai".to_string();
380        config.ai.api_key = Some("test-api-key".to_string());
381        config.ai.model = "deployment-name".to_string();
382        config.ai.base_url = "ftp://example.openai.azure.com".to_string();
383
384        let result = AzureOpenAIClient::from_config(&config.ai);
385        let err = result.unwrap_err().to_string();
386        assert!(err.contains("must use http or https"));
387    }
388
389    #[test]
390    fn test_azure_openai_client_url_without_host() {
391        let mut config = Config::default();
392        config.ai.provider = "azure-openai".to_string();
393        config.ai.api_key = Some("test-api-key".to_string());
394        config.ai.model = "deployment-name".to_string();
395        config.ai.base_url = "https://".to_string();
396
397        let result = AzureOpenAIClient::from_config(&config.ai);
398        let err = result.unwrap_err().to_string();
399        assert!(err.contains("missing host"));
400    }
401
402    #[test]
403    fn test_azure_openai_with_custom_model_and_version() {
404        let mock_model = "custom-model-123";
405        let mock_version = "2023-12-01-preview";
406
407        let mut config = Config::default();
408        config.ai.provider = "azure-openai".to_string();
409        config.ai.api_key = Some("test-api-key".to_string());
410        config.ai.model = mock_model.to_string();
411        config.ai.base_url = "https://custom.openai.azure.com".to_string();
412        config.ai.api_version = Some(mock_version.to_string());
413
414        let client = AzureOpenAIClient::from_config(&config.ai).unwrap();
415        assert_eq!(client.model, mock_model);
416        assert_eq!(client.api_version, mock_version);
417    }
418
419    #[test]
420    fn test_azure_openai_with_trailing_slash_in_url() {
421        let mut config = Config::default();
422        config.ai.provider = "azure-openai".to_string();
423        config.ai.api_key = Some("test-api-key".to_string());
424        config.ai.model = "deployment-name".to_string();
425        config.ai.base_url = "https://example.openai.azure.com/".to_string(); // Trailing slash
426
427        let client = AzureOpenAIClient::from_config(&config.ai).unwrap();
428        assert_eq!(
429            client.base_url,
430            "https://example.openai.azure.com".to_string()
431        );
432    }
433
434    #[test]
435    fn test_azure_openai_with_custom_temperature_and_tokens() {
436        let mut config = Config::default();
437        config.ai.provider = "azure-openai".to_string();
438        config.ai.api_key = Some("test-api-key".to_string());
439        config.ai.model = "deployment-name".to_string();
440        config.ai.base_url = "https://example.openai.azure.com".to_string();
441        config.ai.temperature = 0.8;
442        config.ai.max_tokens = 2000;
443
444        let client = AzureOpenAIClient::from_config(&config.ai).unwrap();
445        assert!((client.temperature - 0.8).abs() < f32::EPSILON);
446        assert_eq!(client.max_tokens, 2000);
447    }
448
449    #[test]
450    fn test_azure_openai_with_custom_retry_and_timeout() {
451        let mut config = Config::default();
452        config.ai.provider = "azure-openai".to_string();
453        config.ai.api_key = Some("test-api-key".to_string());
454        config.ai.model = "deployment-name".to_string();
455        config.ai.base_url = "https://example.openai.azure.com".to_string();
456        config.ai.retry_attempts = 5;
457        config.ai.retry_delay_ms = 2000;
458        config.ai.request_timeout_seconds = 180;
459
460        let client = AzureOpenAIClient::from_config(&config.ai).unwrap();
461        assert_eq!(client.retry_attempts, 5);
462        assert_eq!(client.retry_delay_ms, 2000);
463        assert_eq!(client.request_timeout_seconds, 180);
464    }
465
466    #[test]
467    fn test_azure_openai_new_with_all_parameters() {
468        let client = AzureOpenAIClient::new_with_all(
469            "test-api-key".to_string(),
470            "gpt-test".to_string(),
471            "https://example.openai.azure.com".to_string(),
472            "2025-04-01-preview".to_string(),
473            0.7,
474            4000,
475            3,
476            1000,
477            120,
478        );
479        assert!(format!("{:?}", client).contains("AzureOpenAIClient"));
480    }
481
482    #[test]
483    fn test_azure_openai_error_handling_empty_api_key() {
484        let mut config = Config::default();
485        config.ai.provider = "azure-openai".to_string();
486        config.ai.api_key = Some("".to_string()); // Empty string
487        config.ai.model = "deployment-name".to_string();
488        config.ai.base_url = "https://example.openai.azure.com".to_string();
489
490        let err = AzureOpenAIClient::from_config(&config.ai)
491            .unwrap_err()
492            .to_string();
493        assert!(err.contains("Missing Azure OpenAI API Key"));
494    }
495
496    /// §3.6 — connection refused against `127.0.0.1` MUST surface the hint.
497    #[tokio::test]
498    async fn test_hosted_hint_connection_refused_loopback() {
499        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
500        let port = listener.local_addr().unwrap().port();
501        drop(listener);
502        let client = AzureOpenAIClient::new_with_all(
503            "k".into(),
504            "dep".into(),
505            format!("http://127.0.0.1:{}", port),
506            "2025-04-01-preview".into(),
507            0.0,
508            16,
509            0,
510            0,
511            1,
512        );
513        let err = client
514            .chat_completion(vec![json!({"role":"user","content":"x"})])
515            .await
516            .unwrap_err();
517        let msg = err.to_string();
518        assert!(
519            msg.contains("ollama") && msg.contains("ai.provider"),
520            "expected local-provider hint: {msg}"
521        );
522    }
523
524    /// §3.6 — HTTP 200 with non-OpenAI body must surface the hint via the
525    /// parse-shape branch.
526    #[tokio::test]
527    async fn test_hosted_hint_http_200_non_openai_body() {
528        use wiremock::matchers::{method, path};
529        use wiremock::{Mock, MockServer, ResponseTemplate};
530        let server = MockServer::start().await;
531        Mock::given(method("POST"))
532            // Azure URL shape:
533            // {base}/openai/deployments/{model}/chat/completions
534            .and(path("/openai/deployments/dep/chat/completions"))
535            .respond_with(ResponseTemplate::new(200).set_body_json(json!({ "hello": "world" })))
536            .mount(&server)
537            .await;
538        let client = AzureOpenAIClient::new_with_all(
539            "k".into(),
540            "dep".into(),
541            server.uri(),
542            "2025-04-01-preview".into(),
543            0.0,
544            16,
545            0,
546            0,
547            5,
548        );
549        let err = client
550            .chat_completion(vec![json!({"role":"user","content":"x"})])
551            .await
552            .unwrap_err();
553        let msg = err.to_string();
554        assert!(
555            msg.contains("Invalid API response format")
556                && msg.contains("ollama")
557                && msg.contains("ai.provider"),
558            "expected hint-bearing parse-shape error: {msg}"
559        );
560    }
561
562    /// §3.6 negative — a public host MUST NOT surface the hint. We use
563    /// TEST-NET-1 (RFC 5737) so the test is hermetic.
564    #[tokio::test]
565    async fn test_hosted_hint_not_emitted_for_public_host() {
566        let client = AzureOpenAIClient::new_with_all(
567            "k".into(),
568            "dep".into(),
569            "https://192.0.2.1".into(),
570            "2025-04-01-preview".into(),
571            0.0,
572            16,
573            0,
574            0,
575            1,
576        );
577        let err = client
578            .chat_completion(vec![json!({"role":"user","content":"x"})])
579            .await
580            .unwrap_err();
581        let msg = err.to_string();
582        assert!(
583            !msg.contains("ollama"),
584            "public-host failure must NOT carry the hint: {msg}"
585        );
586    }
587}
588
589#[async_trait]
590impl AIProvider for AzureOpenAIClient {
591    async fn analyze_content(&self, request: AnalysisRequest) -> crate::Result<MatchResult> {
592        let prompt = self.build_analysis_prompt(&request);
593        let messages = vec![
594            json!({"role": "system", "content": "You are a professional subtitle matching assistant that can analyze the correspondence between video and subtitle files."}),
595            json!({"role": "user", "content": prompt}),
596        ];
597        let resp = self.chat_completion(messages).await?;
598        self.parse_match_result(&resp)
599    }
600
601    async fn verify_match(
602        &self,
603        verification: VerificationRequest,
604    ) -> crate::Result<ConfidenceScore> {
605        let prompt = self.build_verification_prompt(&verification);
606        let messages = vec![
607            json!({"role": "system", "content": "Please evaluate the confidence level of subtitle matching and provide a score between 0-1."}),
608            json!({"role": "user", "content": prompt}),
609        ];
610        let resp = self.chat_completion(messages).await?;
611        self.parse_confidence_score(&resp)
612    }
613
614    async fn chat_completion(&self, messages: Vec<Value>) -> crate::Result<String> {
615        AzureOpenAIClient::chat_completion(self, messages).await
616    }
617}