Skip to main content

swink_agent_adapters/
azure.rs

1//! Azure `OpenAI` / Azure AI Foundry adapter.
2//!
3//! This adapter targets Azure's OpenAI-v1-compatible chat completions surface.
4//! It reuses the shared OAI-compatible SSE parsing from [`openai_compat`] and
5//! the shared transport pipeline from [`oai_transport`].
6
7use std::pin::Pin;
8use std::time::{Duration, Instant};
9
10use futures::stream::{self, Stream, StreamExt as _};
11use serde::Deserialize;
12use tokio_util::sync::CancellationToken;
13use tracing::debug;
14
15use swink_agent::{AgentContext, AssistantMessageEvent, ModelSpec, StreamFn, StreamOptions};
16use swink_agent_auth::{ExpiringValue, SingleFlightTokenSource};
17
18use crate::classify::{HttpErrorKind, classify_with_overrides};
19use crate::oai_transport::{OaiAdapterShell, oai_send_and_parse, prepare_oai_request};
20
21/// Authentication method for Azure `OpenAI` deployments.
22#[derive(Clone)]
23pub enum AzureAuth {
24    /// API key authentication via the `api-key` header.
25    ApiKey(String),
26    /// Azure AD / Entra ID `OAuth2` client credentials flow.
27    EntraId {
28        tenant_id: String,
29        client_id: String,
30        client_secret: String,
31    },
32}
33
34impl std::fmt::Debug for AzureAuth {
35    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36        match self {
37            Self::ApiKey(_) => f.debug_tuple("ApiKey").field(&"[REDACTED]").finish(),
38            Self::EntraId { .. } => f
39                .debug_struct("EntraId")
40                .field("tenant_id", &"[REDACTED]")
41                .field("client_id", &"[REDACTED]")
42                .field("client_secret", &"[REDACTED]")
43                .finish(),
44        }
45    }
46}
47
48/// Refresh tokens proactively 5 minutes before expiry.
49const REFRESH_MARGIN: Duration = Duration::from_secs(300);
50
51/// Response from the Microsoft identity platform token endpoint.
52#[derive(Deserialize)]
53struct TokenResponse {
54    access_token: String,
55    expires_in: u64,
56}
57
58#[derive(Clone)]
59enum TokenAcquireError {
60    Auth(String),
61    Throttled(String),
62    Network(String),
63    Other(String),
64}
65
66pub struct AzureStreamFn {
67    shell: OaiAdapterShell,
68    auth: AzureAuth,
69    token_source: SingleFlightTokenSource<String, TokenAcquireError>,
70    /// Override token endpoint URL (for testing). `None` = use Microsoft default.
71    token_endpoint_override: Option<String>,
72}
73
74impl AzureStreamFn {
75    #[must_use]
76    pub fn new(base_url: impl Into<String>, auth: AzureAuth) -> Self {
77        let shell_api_key = match &auth {
78            AzureAuth::ApiKey(key) => key.clone(),
79            AzureAuth::EntraId { .. } => String::new(),
80        };
81
82        Self {
83            shell: OaiAdapterShell::new_with_path(
84                "Azure",
85                base_url,
86                shell_api_key,
87                "/chat/completions",
88            ),
89            auth,
90            token_source: SingleFlightTokenSource::new(REFRESH_MARGIN),
91            token_endpoint_override: None,
92        }
93    }
94
95    /// Set a custom token endpoint URL (for testing with wiremock).
96    #[must_use]
97    pub fn with_token_endpoint(mut self, url: impl Into<String>) -> Self {
98        self.token_endpoint_override = Some(url.into());
99        self
100    }
101}
102
103impl AzureStreamFn {
104    /// Acquire a fresh token from the Microsoft identity platform.
105    async fn acquire_token(
106        client: reqwest::Client,
107        token_url: String,
108        client_id: String,
109        client_secret: String,
110    ) -> Result<ExpiringValue<String>, TokenAcquireError> {
111        let params = [
112            ("grant_type", "client_credentials".to_string()),
113            ("client_id", client_id),
114            ("client_secret", client_secret),
115            (
116                "scope",
117                "https://cognitiveservices.azure.com/.default".to_string(),
118            ),
119        ];
120
121        let resp = client
122            .post(&token_url)
123            .form(&params)
124            .send()
125            .await
126            .map_err(|e| TokenAcquireError::Network(format!("token request failed: {e}")))?;
127
128        if !resp.status().is_success() {
129            let status = resp.status().as_u16();
130            let body = resp.text().await.unwrap_or_default();
131            return Err(match classify_token_endpoint_status(status) {
132                Some(HttpErrorKind::Auth) => TokenAcquireError::Auth(format!(
133                    "token endpoint auth error (HTTP {status}): {body}"
134                )),
135                Some(HttpErrorKind::Throttled) => TokenAcquireError::Throttled(format!(
136                    "token endpoint rate limit (HTTP {status}): {body}"
137                )),
138                Some(HttpErrorKind::Network) => TokenAcquireError::Network(format!(
139                    "token endpoint server error (HTTP {status}): {body}"
140                )),
141                None => TokenAcquireError::Other(format!(
142                    "token endpoint returned error (HTTP {status}): {body}"
143                )),
144            });
145        }
146
147        let token_resp: TokenResponse = resp.json().await.map_err(|e| {
148            TokenAcquireError::Other(format!("failed to parse token response: {e}"))
149        })?;
150
151        Ok(ExpiringValue::new(
152            token_resp.access_token,
153            Instant::now() + Duration::from_secs(token_resp.expires_in),
154        ))
155    }
156
157    /// Get a valid token, refreshing if necessary.
158    async fn get_or_refresh_token(
159        &self,
160        tenant_id: &str,
161        client_id: &str,
162        client_secret: &str,
163    ) -> Result<String, TokenAcquireError> {
164        let client = self.shell.client().clone();
165        let token_url = self.token_url(tenant_id);
166        let client_id = client_id.to_string();
167        let client_secret = client_secret.to_string();
168
169        self.token_source
170            .get_or_refresh(move || {
171                Self::acquire_token(client, token_url, client_id, client_secret)
172            })
173            .await
174    }
175
176    /// Build the token endpoint URL. Uses override if set, otherwise Microsoft default.
177    fn token_url(&self, tenant_id: &str) -> String {
178        self.token_endpoint_override.as_ref().map_or_else(
179            || format!("https://login.microsoftonline.com/{tenant_id}/oauth2/v2.0/token"),
180            Clone::clone,
181        )
182    }
183
184    /// Apply Azure-specific auth headers to the request builder.
185    async fn apply_auth(
186        &self,
187        request: reqwest::RequestBuilder,
188        options: &StreamOptions,
189    ) -> Result<reqwest::RequestBuilder, AssistantMessageEvent> {
190        match &self.auth {
191            AzureAuth::ApiKey(key) => {
192                let api_key = options.api_key.as_deref().unwrap_or(key);
193                Ok(request.header("api-key", api_key))
194            }
195            AzureAuth::EntraId {
196                tenant_id,
197                client_id,
198                client_secret,
199            } => {
200                let token = self
201                    .get_or_refresh_token(tenant_id, client_id, client_secret)
202                    .await
203                    .map_err(|e| match e {
204                        TokenAcquireError::Auth(message) => AssistantMessageEvent::error_auth(
205                            format!("Azure token error: {message}"),
206                        ),
207                        TokenAcquireError::Throttled(message) => {
208                            AssistantMessageEvent::error_throttled(format!(
209                                "Azure token error: {message}"
210                            ))
211                        }
212                        TokenAcquireError::Network(message) => {
213                            AssistantMessageEvent::error_network(format!(
214                                "Azure token error: {message}"
215                            ))
216                        }
217                        TokenAcquireError::Other(message) => {
218                            AssistantMessageEvent::error(format!("Azure token error: {message}"))
219                        }
220                    })?;
221                Ok(request.header("Authorization", format!("Bearer {token}")))
222            }
223        }
224    }
225}
226
227fn classify_token_endpoint_status(status: u16) -> Option<HttpErrorKind> {
228    match status {
229        400..=499 if status != 408 && status != 429 => Some(HttpErrorKind::Auth),
230        _ => classify_with_overrides(status, &[]),
231    }
232}
233
234impl std::fmt::Debug for AzureStreamFn {
235    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
236        f.debug_struct("AzureStreamFn")
237            .field("base_url", &self.shell.base_url())
238            .field("auth", &self.auth)
239            .finish_non_exhaustive()
240    }
241}
242
243impl StreamFn for AzureStreamFn {
244    fn stream<'a>(
245        &'a self,
246        model: &'a ModelSpec,
247        context: &'a AgentContext,
248        options: &'a StreamOptions,
249        cancellation_token: CancellationToken,
250    ) -> Pin<Box<dyn Stream<Item = AssistantMessageEvent> + Send + 'a>> {
251        Box::pin(azure_stream(
252            self,
253            model,
254            context,
255            options,
256            cancellation_token,
257        ))
258    }
259}
260
261fn azure_stream<'a>(
262    azure: &'a AzureStreamFn,
263    model: &'a ModelSpec,
264    context: &'a AgentContext,
265    options: &'a StreamOptions,
266    cancellation_token: CancellationToken,
267) -> impl Stream<Item = AssistantMessageEvent> + Send + 'a {
268    stream::once(async move {
269        let url = azure.shell.chat_completions_url();
270        debug!(
271            %url,
272            model = %model.model_id,
273            messages = context.messages.len(),
274            "sending Azure request"
275        );
276
277        let request = prepare_oai_request(azure.shell.client(), &url, model, context, options);
278        let request = match crate::base::race_pre_stream_cancellation(
279            &cancellation_token,
280            "Azure request cancelled",
281            azure.apply_auth(request, options),
282        )
283        .await
284        {
285            Ok(r) => r,
286            Err(event) => return stream::iter(crate::base::pre_stream_error(event)).left_stream(),
287        };
288
289        oai_send_and_parse(
290            request,
291            azure.shell.provider(),
292            cancellation_token,
293            options.on_raw_payload.clone(),
294            |status, body| {
295                if is_content_filter_error(body) {
296                    Some(AssistantMessageEvent::error_content_filtered(format!(
297                        "Azure content filter blocked request (HTTP {status})"
298                    )))
299                } else {
300                    None
301                }
302            },
303        )
304        .right_stream()
305    })
306    .flatten()
307}
308
309/// Check if an HTTP error body contains an Azure content filter violation.
310///
311/// Azure returns `error.code: "ContentFilterBlocked"` when the request
312/// or response triggers content safety filters.
313fn is_content_filter_error(body: &str) -> bool {
314    serde_json::from_str::<serde_json::Value>(body)
315        .ok()
316        .and_then(|v| v.get("error")?.get("code")?.as_str().map(String::from))
317        .is_some_and(|code| code == "ContentFilterBlocked")
318}
319
320const _: () = {
321    const fn assert_send_sync<T: Send + Sync>() {}
322    assert_send_sync::<AzureStreamFn>();
323};