1use std::sync::Arc;
2use std::time::Duration;
3
4use azure_core::credentials::{AccessToken, TokenCredential};
5use azure_identity::AzureCliCredential;
6use chrono::{DateTime, Utc};
7use serde::{Deserialize, Serialize};
8
9use crate::error::{AvError, Result};
10
11#[derive(Clone, Serialize, Deserialize)]
14pub struct AzureTempCredentials {
15 #[serde(skip_serializing)]
16 pub access_token: String,
17 pub expires_at: DateTime<Utc>,
18}
19
20impl std::fmt::Debug for AzureTempCredentials {
21 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
22 f.debug_struct("AzureTempCredentials")
23 .field("access_token", &"[REDACTED]")
24 .field("expires_at", &self.expires_at)
25 .finish()
26 }
27}
28
29pub struct AzureCredentialIssuer {
35 credential: Arc<AzureCliCredential>,
36}
37
38impl AzureCredentialIssuer {
39 pub fn new() -> Result<Self> {
42 let credential = AzureCliCredential::new(None).map_err(|e| {
43 AvError::Azure(format!(
44 "Failed to create Azure credential. Run `az login` first. Error: {}",
45 e
46 ))
47 })?;
48
49 Ok(Self { credential })
50 }
51
52 pub async fn issue(&self, ttl: Duration) -> Result<AzureTempCredentials> {
66 let arm_scope = arm_scope_for_cloud()?;
67 let scopes = &[arm_scope.as_str()];
68
69 let token: AccessToken = self
70 .credential
71 .get_token(scopes, None)
72 .await
73 .map_err(|e| AvError::Azure(format!("Failed to get Azure access token: {}", e)))?;
74
75 let token_expires_at = DateTime::from_timestamp(token.expires_on.unix_timestamp(), 0)
77 .ok_or_else(|| {
78 AvError::Azure(format!(
79 "Failed to convert Azure token expiry (unix timestamp {})",
80 token.expires_on.unix_timestamp()
81 ))
82 })?;
83
84 let requested_expires_at =
89 Utc::now() + chrono::Duration::from_std(ttl).unwrap_or(chrono::Duration::seconds(3600));
90 let expires_at = token_expires_at.min(requested_expires_at);
91
92 if requested_expires_at < token_expires_at {
93 tracing::warn!(
94 requested_ttl_secs = ttl.as_secs(),
95 token_lifetime_secs = (token_expires_at - Utc::now()).num_seconds(),
96 "Azure token TTL cannot be shortened at issuance; session expires_at \
97 set to requested TTL. The underlying token remains valid until its \
98 natural expiry — revoke it explicitly if needed."
99 );
100 }
101
102 Ok(AzureTempCredentials {
103 access_token: token.token.secret().to_string(),
104 expires_at,
105 })
106 }
107
108 pub async fn issue_with_optional_scope(
112 &self,
113 scope: Option<&str>,
114 ttl: Duration,
115 ) -> Result<AzureTempCredentials> {
116 match scope {
117 Some(s) => self.issue_for_scope(s, ttl).await,
118 None => self.issue(ttl).await,
119 }
120 }
121
122 pub async fn issue_for_scope(
126 &self,
127 scope: &str,
128 ttl: Duration,
129 ) -> Result<AzureTempCredentials> {
130 let parsed = url::Url::parse(scope).map_err(|e| {
140 AvError::Azure(format!(
141 "Invalid Azure scope '{}': {}. Expected an https:// URL \
142 (e.g. 'https://management.azure.com/.default')",
143 scope, e
144 ))
145 })?;
146 if parsed.scheme() != "https" {
147 return Err(AvError::Azure(format!(
148 "Invalid Azure scope '{}': scheme must be https, got '{}'",
149 scope,
150 parsed.scheme()
151 )));
152 }
153 let host = parsed
154 .host_str()
155 .ok_or_else(|| AvError::Azure(format!("Azure scope '{}' has no host", scope)))?
156 .to_ascii_lowercase();
157 let path = parsed.path();
158 if path.split('/').any(|seg| seg == "..") || path.contains('%') {
159 return Err(AvError::Azure(format!(
160 "Azure scope '{}' contains path traversal or percent-encoded \
161 segments; reject to avoid SSRF to unintended hosts",
162 scope
163 )));
164 }
165 if !AZURE_ALLOWED_SCOPE_DOMAINS
166 .iter()
167 .any(|allowed| host == *allowed || host.ends_with(&format!(".{}", allowed)))
168 {
169 return Err(AvError::Azure(format!(
170 "Azure scope '{}' targets an unrecognised domain '{}'. \
171 Allowed domains: {}",
172 scope,
173 host,
174 AZURE_ALLOWED_SCOPE_DOMAINS.join(", ")
175 )));
176 }
177
178 let scopes = &[scope];
179
180 let token: AccessToken = self.credential.get_token(scopes, None).await.map_err(|e| {
181 AvError::Azure(format!(
182 "Failed to get Azure access token for {}: {}",
183 scope, e
184 ))
185 })?;
186
187 let token_expires_at = DateTime::from_timestamp(token.expires_on.unix_timestamp(), 0)
188 .ok_or_else(|| {
189 AvError::Azure(format!(
190 "Failed to convert Azure token expiry (unix timestamp {})",
191 token.expires_on.unix_timestamp()
192 ))
193 })?;
194
195 let requested_expires_at =
196 Utc::now() + chrono::Duration::from_std(ttl).unwrap_or(chrono::Duration::seconds(3600));
197 let expires_at = token_expires_at.min(requested_expires_at);
198
199 Ok(AzureTempCredentials {
200 access_token: token.token.secret().to_string(),
201 expires_at,
202 })
203 }
204}
205
206const AZURE_ALLOWED_SCOPE_DOMAINS: &[&str] = &[
212 "management.azure.com",
213 "management.chinacloudapi.cn",
214 "management.microsoftazure.de",
215 "management.usgovcloudapi.net",
216 "vault.azure.net",
217 "database.windows.net",
218 "graph.microsoft.com",
219 "storage.azure.com",
220 "dev.azure.com",
221];
222
223fn arm_scope_for_cloud() -> Result<String> {
234 let cloud = std::env::var("AZURE_CLOUD")
235 .or_else(|_| std::env::var("AZURE_ENVIRONMENT"))
236 .unwrap_or_else(|_| "AzureCloud".to_string());
237
238 match cloud.as_str() {
239 "AzureUSGovernment" => Ok("https://management.usgovcloudapi.net/.default".to_string()),
240 "AzureChinaCloud" => Ok("https://management.chinacloudapi.cn/.default".to_string()),
241 "AzureGermanCloud" => Err(AvError::Azure(
242 "AzureGermanCloud (Microsoft Cloud Deutschland) was retired on \
243 29 October 2021 and its ARM endpoint no longer issues tokens. \
244 Set AZURE_CLOUD=AzureCloud (or unset the variable) to use the \
245 standard Azure global cloud."
246 .to_string(),
247 )),
248 _ => Ok("https://management.azure.com/.default".to_string()),
249 }
250}