Skip to main content

tryaudex_core/
credentials.rs

1use std::path::PathBuf;
2use std::time::Duration;
3
4use serde::{Deserialize, Serialize};
5
6use crate::error::{AvError, Result};
7use crate::policy::ScopedPolicy;
8use crate::session::Session;
9
10/// Temporary AWS credentials returned by STS.
11#[derive(Clone, Serialize, Deserialize)]
12pub struct TempCredentials {
13    pub access_key_id: String,
14    pub secret_access_key: String,
15    pub session_token: String,
16    pub expires_at: chrono::DateTime<chrono::Utc>,
17}
18
19/// Cached GCP access token for session reuse (R6-H7).
20#[derive(Clone, Serialize, Deserialize)]
21pub struct CachedGcpToken {
22    pub access_token: String,
23    pub expires_at: chrono::DateTime<chrono::Utc>,
24    /// Whether this token was successfully downscoped via Credential Access
25    /// Boundaries. Used so a CAB-enforced session is never reused for a
26    /// request that expected enforcement when the original fell back to
27    /// advisory-only (R6-M19).
28    #[serde(default)]
29    pub scoping_enforced: bool,
30}
31
32impl std::fmt::Debug for CachedGcpToken {
33    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34        f.debug_struct("CachedGcpToken")
35            .field("access_token", &"[REDACTED]")
36            .field("expires_at", &self.expires_at)
37            .field("scoping_enforced", &self.scoping_enforced)
38            .finish()
39    }
40}
41
42/// Cached Azure access token for session reuse (R6-H7).
43#[derive(Clone, Serialize, Deserialize)]
44pub struct CachedAzureToken {
45    pub access_token: String,
46    pub expires_at: chrono::DateTime<chrono::Utc>,
47    /// ARM scope the token was minted for. Reuse only matches on identical
48    /// scope so rotated / reused tokens target the correct audience.
49    #[serde(default)]
50    pub scope: Option<String>,
51}
52
53impl std::fmt::Debug for CachedAzureToken {
54    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55        f.debug_struct("CachedAzureToken")
56            .field("access_token", &"[REDACTED]")
57            .field("expires_at", &self.expires_at)
58            .field("scope", &self.scope)
59            .finish()
60    }
61}
62
63impl std::fmt::Debug for TempCredentials {
64    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
65        f.debug_struct("TempCredentials")
66            .field("access_key_id", &self.access_key_id)
67            .field("secret_access_key", &"[REDACTED]")
68            .field("session_token", &"[REDACTED]")
69            .field("expires_at", &self.expires_at)
70            .finish()
71    }
72}
73
74impl TempCredentials {
75    /// Returns env vars to inject into the subprocess.
76    pub fn as_env_vars(&self) -> Vec<(&str, &str)> {
77        vec![
78            ("AWS_ACCESS_KEY_ID", &self.access_key_id),
79            ("AWS_SECRET_ACCESS_KEY", &self.secret_access_key),
80            ("AWS_SESSION_TOKEN", &self.session_token),
81        ]
82    }
83}
84
85/// Issues scoped, short-lived AWS credentials via STS AssumeRole.
86pub struct CredentialIssuer {
87    sts_client: aws_sdk_sts::Client,
88}
89
90impl CredentialIssuer {
91    pub async fn new() -> Result<Self> {
92        Self::with_region(None).await
93    }
94
95    pub async fn with_region(region: Option<&str>) -> Result<Self> {
96        let mut loader = aws_config::defaults(aws_config::BehaviorVersion::latest());
97        if let Some(region) = region {
98            loader = loader.region(aws_config::Region::new(region.to_string()));
99        }
100        let config = loader.load().await;
101        Ok(Self {
102            sts_client: aws_sdk_sts::Client::new(&config),
103        })
104    }
105
106    /// Assume a role with an inline policy that restricts permissions to the scoped policy.
107    /// Optionally attaches a permissions boundary ARN as an additional ceiling.
108    pub async fn issue(
109        &self,
110        session: &Session,
111        policy: &ScopedPolicy,
112        ttl: Duration,
113    ) -> Result<TempCredentials> {
114        self.issue_with_boundary(session, policy, ttl, None).await
115    }
116
117    /// Assume a role with an inline policy and optional permissions boundary.
118    pub async fn issue_with_boundary(
119        &self,
120        session: &Session,
121        policy: &ScopedPolicy,
122        ttl: Duration,
123        permissions_boundary: Option<&str>,
124    ) -> Result<TempCredentials> {
125        self.issue_full(session, policy, ttl, permissions_boundary, None)
126            .await
127    }
128
129    /// Assume a role with inline policy, optional boundary, and optional network conditions.
130    pub async fn issue_full(
131        &self,
132        session: &Session,
133        policy: &ScopedPolicy,
134        ttl: Duration,
135        permissions_boundary: Option<&str>,
136        network: Option<&crate::policy::NetworkPolicy>,
137    ) -> Result<TempCredentials> {
138        self.issue_full_with_tag_lock(
139            session,
140            policy,
141            ttl,
142            permissions_boundary,
143            network,
144            None,
145            None,
146        )
147        .await
148    }
149
150    /// Assume a role with inline policy plus an optional tag-lock deny that
151    /// prevents the caller from modifying/removing a specific tag key. Used
152    /// by `--tag-session` so agents can't strip the tryaudex-session marker.
153    #[allow(clippy::too_many_arguments)]
154    pub async fn issue_full_with_tag_lock(
155        &self,
156        session: &Session,
157        policy: &ScopedPolicy,
158        ttl: Duration,
159        permissions_boundary: Option<&str>,
160        network: Option<&crate::policy::NetworkPolicy>,
161        tag_lock_key: Option<&str>,
162        external_id: Option<&str>,
163    ) -> Result<TempCredentials> {
164        if let Some(net) = network {
165            net.validate()?;
166        }
167        let policy_json = match tag_lock_key {
168            Some(key) if network.is_none() => policy.to_iam_policy_json_with_tag_lock(key)?,
169            Some(key) => policy.to_iam_policy_json_with_network_and_tag_lock(network, Some(key))?,
170            None => policy.to_iam_policy_json_with_network(network)?,
171        };
172        // AWS STS AssumeRole requires 900s <= DurationSeconds <= 43200s (15 minutes to 12 hours).
173        // Clamp at both ends rather than letting STS reject the call with a cryptic message.
174        let requested = ttl.as_secs();
175        let ttl_secs = requested.clamp(900, 43200) as i32;
176        if requested < 900 {
177            tracing::warn!(
178                requested_secs = requested,
179                clamped_to_secs = ttl_secs,
180                "TTL below AWS STS minimum (900s / 15m); clamping up"
181            );
182        } else if requested > 43200 {
183            tracing::warn!(
184                requested_secs = requested,
185                clamped_to_secs = ttl_secs,
186                "TTL above AWS STS maximum (43200s / 12h); clamping down"
187            );
188        }
189
190        tracing::info!(
191            session_id = %session.id,
192            role_arn = %session.role_arn,
193            ttl_secs = ttl_secs,
194            permissions_boundary = ?permissions_boundary,
195            "Assuming role with scoped policy"
196        );
197
198        let mut request = self
199            .sts_client
200            .assume_role()
201            .role_arn(&session.role_arn)
202            .role_session_name(format!("av-{}", session.id.get(..8).unwrap_or(&session.id)))
203            .policy(&policy_json)
204            .duration_seconds(ttl_secs);
205
206        // Attach user-supplied session tags to the STS call so they flow
207        // into CloudTrail and are available for attribute-based access control.
208        for (key, value) in &session.tags {
209            request = request.tags(
210                aws_sdk_sts::types::Tag::builder()
211                    .key(key)
212                    .value(value)
213                    .build()
214                    .map_err(|e| AvError::Sts(format!("Invalid STS tag: {}", e)))?,
215            );
216        }
217
218        // Attach permissions boundary as a managed policy ARN ceiling
219        if let Some(boundary_arn) = permissions_boundary {
220            request = request.policy_arns(
221                aws_sdk_sts::types::PolicyDescriptorType::builder()
222                    .arn(boundary_arn)
223                    .build(),
224            );
225        }
226
227        // Attach external ID for cross-account role assumption
228        if let Some(ext_id) = external_id {
229            request = request.external_id(ext_id);
230        }
231
232        let result = request.send().await.map_err(|e| {
233            // AWS SDK's Display impl collapses ServiceError variants to "service error".
234            // Extract the actual error code + message so the user can act on it.
235            let detail = match e.as_service_error() {
236                Some(svc) => {
237                    let code = svc.meta().code().unwrap_or("Unknown");
238                    let message = svc.meta().message().unwrap_or("(no message)");
239                    format!("{code}: {message}")
240                }
241                None => e.to_string(),
242            };
243            AvError::Sts(detail)
244        })?;
245
246        let creds = result
247            .credentials()
248            .ok_or_else(|| AvError::Sts("No credentials returned by STS".to_string()))?;
249
250        let exp = creds.expiration();
251        let expires_at = chrono::DateTime::from_timestamp(exp.secs(), exp.subsec_nanos())
252            .unwrap_or_else(|| {
253                tracing::warn!(
254                    secs = exp.secs(),
255                    "STS returned unparseable expiration timestamp; falling back to Utc::now() + TTL"
256                );
257                chrono::Utc::now() + chrono::Duration::seconds(ttl_secs as i64)
258            });
259
260        Ok(TempCredentials {
261            access_key_id: creds.access_key_id().to_string(),
262            secret_access_key: creds.secret_access_key().to_string(),
263            session_token: creds.session_token().to_string(),
264            expires_at,
265        })
266    }
267}
268
269/// Caches credentials on disk, keyed by session ID.
270pub struct CredentialCache {
271    dir: PathBuf,
272}
273
274impl CredentialCache {
275    pub fn new() -> Result<Self> {
276        let base = dirs::data_local_dir().ok_or_else(|| {
277            AvError::InvalidPolicy(
278                "Could not determine local data directory. Set XDG_DATA_HOME or HOME.".to_string(),
279            )
280        })?;
281        let dir = base.join("audex").join("cred_cache");
282        std::fs::create_dir_all(&dir)?;
283        #[cfg(unix)]
284        {
285            use std::os::unix::fs::PermissionsExt;
286            std::fs::set_permissions(&dir, std::fs::Permissions::from_mode(0o700))?;
287        }
288        Ok(Self { dir })
289    }
290
291    /// Save credentials for a session (encrypted at rest).
292    pub fn save(&self, session_id: &str, creds: &TempCredentials) -> Result<()> {
293        crate::validate::session_id(session_id)?;
294        let path = self.dir.join(format!("{}.json", session_id));
295        crate::keystore::encrypt_to_file(&path, creds)
296    }
297
298    /// Load cached credentials for a session. Returns None if expired or missing.
299    /// Handles both encrypted and legacy plaintext files for migration.
300    pub fn load(&self, session_id: &str) -> Result<Option<TempCredentials>> {
301        crate::validate::session_id(session_id)?;
302        let path = self.dir.join(format!("{}.json", session_id));
303        let creds: Option<TempCredentials> = crate::keystore::decrypt_from_file(&path)?;
304        match creds {
305            Some(c) if c.expires_at <= chrono::Utc::now() => {
306                let _ = std::fs::remove_file(&path);
307                Ok(None)
308            }
309            other => Ok(other),
310        }
311    }
312
313    /// Remove cached credentials for a session.
314    pub fn remove(&self, session_id: &str) -> Result<()> {
315        crate::validate::session_id(session_id)?;
316        let path = self.dir.join(format!("{}.json", session_id));
317        let _ = std::fs::remove_file(path);
318        // Also clean up provider-specific caches.
319        let _ = std::fs::remove_file(self.dir.join(format!("{}.gcp.json", session_id)));
320        let _ = std::fs::remove_file(self.dir.join(format!("{}.azure.json", session_id)));
321        Ok(())
322    }
323
324    /// Save a GCP token for a session (encrypted at rest).
325    pub fn save_gcp(&self, session_id: &str, token: &CachedGcpToken) -> Result<()> {
326        crate::validate::session_id(session_id)?;
327        let path = self.dir.join(format!("{}.gcp.json", session_id));
328        crate::keystore::encrypt_to_file(&path, token)
329    }
330
331    /// Load a cached GCP token. Returns None if expired or missing.
332    pub fn load_gcp(&self, session_id: &str) -> Result<Option<CachedGcpToken>> {
333        crate::validate::session_id(session_id)?;
334        let path = self.dir.join(format!("{}.gcp.json", session_id));
335        let token: Option<CachedGcpToken> = crate::keystore::decrypt_from_file(&path)?;
336        match token {
337            Some(t) if t.expires_at <= chrono::Utc::now() => {
338                let _ = std::fs::remove_file(&path);
339                Ok(None)
340            }
341            other => Ok(other),
342        }
343    }
344
345    /// Save an Azure token for a session (encrypted at rest).
346    pub fn save_azure(&self, session_id: &str, token: &CachedAzureToken) -> Result<()> {
347        crate::validate::session_id(session_id)?;
348        let path = self.dir.join(format!("{}.azure.json", session_id));
349        crate::keystore::encrypt_to_file(&path, token)
350    }
351
352    /// Load a cached Azure token. Returns None if expired or missing.
353    pub fn load_azure(&self, session_id: &str) -> Result<Option<CachedAzureToken>> {
354        crate::validate::session_id(session_id)?;
355        let path = self.dir.join(format!("{}.azure.json", session_id));
356        let token: Option<CachedAzureToken> = crate::keystore::decrypt_from_file(&path)?;
357        match token {
358            Some(t) if t.expires_at <= chrono::Utc::now() => {
359                let _ = std::fs::remove_file(&path);
360                Ok(None)
361            }
362            other => Ok(other),
363        }
364    }
365}