Skip to main content

tryaudex_core/
session.rs

1use std::path::PathBuf;
2use std::time::Duration;
3
4use chrono::{DateTime, Utc};
5use serde::{Deserialize, Serialize};
6use uuid::Uuid;
7
8use crate::error::{AvError, Result};
9use crate::policy::ScopedPolicy;
10
11/// Cloud provider for a session.
12#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
13#[serde(rename_all = "lowercase")]
14pub enum CloudProvider {
15    #[default]
16    Aws,
17    Gcp,
18    Azure,
19}
20
21impl std::fmt::Display for CloudProvider {
22    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23        match self {
24            Self::Aws => write!(f, "aws"),
25            Self::Gcp => write!(f, "gcp"),
26            Self::Azure => write!(f, "azure"),
27        }
28    }
29}
30
31/// Represents an active agent credential session.
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct Session {
34    pub id: String,
35    pub created_at: DateTime<Utc>,
36    pub expires_at: DateTime<Utc>,
37    pub ttl_seconds: u64,
38    pub budget: Option<f64>,
39    pub policy: ScopedPolicy,
40    pub status: SessionStatus,
41    /// Provider-agnostic principal identifier despite the AWS-centric field name.
42    ///
43    /// - **AWS**: IAM role ARN (e.g. `arn:aws:iam::123456789012:role/MyRole`)
44    /// - **GCP**: Service account email (e.g. `my-sa@project.iam.gserviceaccount.com`)
45    /// - **Azure**: Subscription ID or principal (e.g. `00000000-0000-0000-0000-000000000000`)
46    ///
47    /// The field is named `role_arn` for historical reasons and backwards
48    /// serialization compatibility. New code should use the [`Session::principal_id`]
49    /// helper method instead of accessing this field directly.
50    /// JSON deserialization also accepts the alias `"role_principal"` to ease
51    /// migration in tooling that wants provider-neutral field names.
52    #[serde(alias = "role_principal")]
53    pub role_arn: String,
54    pub command: Vec<String>,
55    /// AWS access key for this session (populated after credential creation).
56    pub access_key_id: Option<String>,
57    /// Cloud provider (defaults to AWS for backwards compat).
58    #[serde(default)]
59    pub provider: CloudProvider,
60    /// Agent identity (from AUDEX_AGENT_ID env var) — tracks which AI agent/model issued the request.
61    #[serde(default, skip_serializing_if = "Option::is_none")]
62    pub agent_id: Option<String>,
63    /// Azure principal ID (client ID or user object ID) to disambiguate sessions
64    /// on the same subscription but different identities.
65    #[serde(default, skip_serializing_if = "Option::is_none")]
66    pub principal_id: Option<String>,
67    /// User-supplied key:value tags for audit filtering and STS session tags.
68    #[serde(default, skip_serializing_if = "std::collections::HashMap::is_empty")]
69    pub tags: std::collections::HashMap<String, String>,
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
73#[serde(rename_all = "snake_case")]
74pub enum SessionStatus {
75    Active,
76    Completed,
77    Expired,
78    Revoked,
79    BudgetExceeded,
80    Failed,
81}
82
83impl std::fmt::Display for SessionStatus {
84    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85        match self {
86            Self::Active => write!(f, "active"),
87            Self::Completed => write!(f, "completed"),
88            Self::Expired => write!(f, "expired"),
89            Self::Revoked => write!(f, "revoked"),
90            Self::BudgetExceeded => write!(f, "budget_exceeded"),
91            Self::Failed => write!(f, "failed"),
92        }
93    }
94}
95
96impl Session {
97    pub fn new(
98        ttl: Duration,
99        budget: Option<f64>,
100        policy: ScopedPolicy,
101        role_arn: String,
102        command: Vec<String>,
103    ) -> Self {
104        let now = Utc::now();
105        let expires_at = now + chrono::Duration::seconds(ttl.as_secs() as i64);
106        Self {
107            id: Uuid::new_v4().to_string(),
108            created_at: now,
109            expires_at,
110            ttl_seconds: ttl.as_secs(),
111            budget,
112            policy,
113            status: SessionStatus::Active,
114            role_arn,
115            command,
116            access_key_id: None,
117            provider: CloudProvider::default(),
118            agent_id: None,
119            principal_id: None,
120            tags: std::collections::HashMap::new(),
121        }
122    }
123
124    pub fn new_gcp(
125        ttl: Duration,
126        budget: Option<f64>,
127        policy: ScopedPolicy,
128        service_account: String,
129        command: Vec<String>,
130    ) -> Self {
131        let now = Utc::now();
132        let expires_at = now + chrono::Duration::seconds(ttl.as_secs() as i64);
133        Self {
134            id: Uuid::new_v4().to_string(),
135            created_at: now,
136            expires_at,
137            ttl_seconds: ttl.as_secs(),
138            budget,
139            policy,
140            status: SessionStatus::Active,
141            role_arn: service_account,
142            command,
143            access_key_id: None,
144            provider: CloudProvider::Gcp,
145            agent_id: None,
146            principal_id: None,
147            tags: std::collections::HashMap::new(),
148        }
149    }
150
151    pub fn new_azure(
152        ttl: Duration,
153        budget: Option<f64>,
154        policy: ScopedPolicy,
155        subscription_id: String,
156        command: Vec<String>,
157    ) -> Self {
158        let now = Utc::now();
159        let expires_at = now + chrono::Duration::seconds(ttl.as_secs() as i64);
160        Self {
161            id: Uuid::new_v4().to_string(),
162            created_at: now,
163            expires_at,
164            ttl_seconds: ttl.as_secs(),
165            budget,
166            policy,
167            status: SessionStatus::Active,
168            role_arn: subscription_id,
169            command,
170            access_key_id: None,
171            provider: CloudProvider::Azure,
172            agent_id: None,
173            principal_id: None,
174            tags: std::collections::HashMap::new(),
175        }
176    }
177
178    /// Return the first 8 characters of the session ID for display, or the
179    /// full ID if it is shorter than 8 characters. Avoids panics from direct
180    /// `session.id[..8]` slicing on corrupt/hand-edited session files.
181    pub fn short_id(&self) -> &str {
182        self.id.get(..8).unwrap_or(&self.id)
183    }
184
185    pub fn is_expired(&self) -> bool {
186        Utc::now() > self.expires_at
187    }
188
189    pub fn remaining_seconds(&self) -> i64 {
190        (self.expires_at - Utc::now()).num_seconds().max(0)
191    }
192
193    pub fn complete(&mut self) {
194        self.status = SessionStatus::Completed;
195    }
196
197    pub fn expire(&mut self) {
198        self.status = SessionStatus::Expired;
199    }
200
201    pub fn revoke(&mut self) {
202        self.status = SessionStatus::Revoked;
203    }
204
205    pub fn fail(&mut self) {
206        self.status = SessionStatus::Failed;
207    }
208
209    /// Returns the provider-agnostic principal identifier for this session.
210    ///
211    /// This is a migration-path accessor for the `role_arn` field, which stores
212    /// different things depending on the cloud provider (IAM role ARN for AWS,
213    /// service account email for GCP, subscription/principal ID for Azure).
214    /// When the optional `principal_id` field is populated (e.g. for Azure
215    /// service-principal disambiguation), that value is returned instead.
216    /// Prefer this method over accessing `role_arn` directly in new code.
217    pub fn principal_id(&self) -> &str {
218        self.principal_id.as_deref().unwrap_or(&self.role_arn)
219    }
220}
221
222/// Manages session persistence on disk.
223pub struct SessionStore {
224    dir: PathBuf,
225}
226
227impl SessionStore {
228    pub fn new() -> Result<Self> {
229        let base = dirs::data_local_dir().ok_or_else(|| {
230            AvError::InvalidPolicy(
231                "Could not determine local data directory. Set XDG_DATA_HOME or HOME.".to_string(),
232            )
233        })?;
234        let dir = base.join("audex").join("sessions");
235        std::fs::create_dir_all(&dir)?;
236        Ok(Self { dir })
237    }
238
239    /// Maximum number of session files retained on disk. When a new save
240    /// would push the directory past this cap, the oldest files (by mtime)
241    /// are evicted one-at-a-time to stay exactly at `SESSION_MAX` and
242    /// prevent unbounded disk growth (R3-L12, R6-H22).
243    const SESSION_MAX: usize = 1000;
244
245    fn write_file(&self, session: &Session) -> Result<()> {
246        let path = self.dir.join(format!("{}.json", session.id));
247        let json = serde_json::to_string_pretty(session)?;
248        #[cfg(unix)]
249        {
250            use std::os::unix::fs::OpenOptionsExt;
251            let mut file = std::fs::OpenOptions::new()
252                .write(true)
253                .create(true)
254                .truncate(true)
255                .mode(0o600)
256                .open(&path)?;
257            std::io::Write::write_all(&mut file, json.as_bytes())?;
258        }
259        #[cfg(not(unix))]
260        {
261            std::fs::write(&path, json)?;
262        }
263        Ok(())
264    }
265
266    pub fn save(&self, session: &Session) -> Result<()> {
267        // R6-H22: previously the prune check fired at `entries.len() >=
268        // SESSION_MAX` and then blindly deleted back to `PRUNE_TARGET - 1`
269        // (899), which meant every save above 1000 evicted 101 files —
270        // even in the common case where we were only one over the limit.
271        // Combined with `update()` aliasing `save()`, a single session
272        // update could churn 101 unrelated sessions out of existence.
273        //
274        // Only prune when the new file would actually push us over the
275        // cap, and only delete the excess down to SESSION_MAX so steady-
276        // state writes evict exactly one file per new session.  Also
277        // skip eviction when the session file already exists (update
278        // path): updating an existing record must never delete other
279        // sessions.
280        let path = self.dir.join(format!("{}.json", session.id));
281        let is_new = !path.exists();
282
283        if is_new {
284            let entries: Result<Vec<_>> = std::fs::read_dir(&self.dir)?
285                .filter_map(|e| e.ok())
286                .filter(|e| {
287                    e.path()
288                        .extension()
289                        .and_then(|x| x.to_str())
290                        .map(|x| x == "json")
291                        .unwrap_or(false)
292                })
293                .map(|e| {
294                    let mtime = e
295                        .metadata()
296                        .and_then(|m| m.modified())
297                        .unwrap_or(std::time::SystemTime::UNIX_EPOCH);
298                    Ok((mtime, e.path()))
299                })
300                .collect();
301            let mut entries = entries?;
302            // Evict just enough to fit the new record under the cap.
303            if entries.len() + 1 > Self::SESSION_MAX {
304                entries.sort_unstable_by_key(|(mtime, _)| *mtime);
305                let to_delete = entries.len() + 1 - Self::SESSION_MAX;
306                for (_, p) in entries.iter().take(to_delete) {
307                    let _ = std::fs::remove_file(p);
308                }
309            }
310        }
311
312        self.write_file(session)
313    }
314
315    pub fn load(&self, id: &str) -> Result<Session> {
316        crate::validate::session_id(id)?;
317        let path = self.dir.join(format!("{}.json", id));
318        if !path.exists() {
319            return Err(AvError::SessionNotFound { id: id.to_string() });
320        }
321        let json = std::fs::read_to_string(path)?;
322        let session: Session = serde_json::from_str(&json)?;
323        Ok(session)
324    }
325
326    pub fn list(&self) -> Result<Vec<Session>> {
327        let mut sessions = Vec::new();
328        for entry in std::fs::read_dir(&self.dir)? {
329            let entry = entry?;
330            let path = entry.path();
331            if path.extension().is_some_and(|ext| ext == "json") {
332                let json = std::fs::read_to_string(&path)?;
333                if let Ok(session) = serde_json::from_str::<Session>(&json) {
334                    sessions.push(session);
335                }
336            }
337        }
338        sessions.sort_by(|a, b| b.created_at.cmp(&a.created_at));
339        Ok(sessions)
340    }
341
342    pub fn update(&self, session: &Session) -> Result<()> {
343        self.save(session)
344    }
345
346    /// Find an active session with matching allow actions, role ARN, and enough remaining TTL.
347    /// Used for credential caching — reuse existing sessions instead of issuing new ones.
348    pub fn find_reusable(
349        &self,
350        allow_str: &str,
351        role_arn: &str,
352        min_remaining_secs: i64,
353    ) -> Result<Option<Session>> {
354        self.find_reusable_for_provider(
355            allow_str,
356            role_arn,
357            min_remaining_secs,
358            CloudProvider::Aws,
359            &[],
360        )
361    }
362
363    /// Provider-aware session reuse lookup. Formats stored actions using the
364    /// correct delimiter for the provider so GCP/Azure sessions can be matched.
365    /// For Azure, `principal_id` disambiguates sessions on the same subscription
366    /// but different identities (service principals, users).
367    pub fn find_reusable_for_provider(
368        &self,
369        allow_str: &str,
370        role_arn: &str,
371        min_remaining_secs: i64,
372        provider: CloudProvider,
373        resources: &[String],
374    ) -> Result<Option<Session>> {
375        let sessions = self.list()?;
376        // R6-H5: sort + dedup so semantically equivalent inputs
377        // (reordered or repeated entries) still match the same session.
378        let mut allow_sorted: Vec<&str> = allow_str.split(',').map(|s| s.trim()).collect();
379        allow_sorted.sort();
380        allow_sorted.dedup();
381
382        let mut requested_resources: Vec<&str> = resources.iter().map(|s| s.as_str()).collect();
383        requested_resources.sort();
384        requested_resources.dedup();
385
386        for session in sessions {
387            if session.status != SessionStatus::Active {
388                continue;
389            }
390            if session.is_expired() {
391                continue;
392            }
393            if session.remaining_seconds() < min_remaining_secs {
394                continue;
395            }
396            if session.role_arn != role_arn {
397                continue;
398            }
399            // For Azure, also require matching principal to avoid reusing
400            // credentials from a different identity on the same subscription.
401            // Only enforced when AZURE_CLIENT_ID is set (service principal flow).
402            // Users via `az login` don't have this env var, so we skip the
403            // check and rely on subscription (role_arn) matching alone.
404            if provider == CloudProvider::Azure {
405                if let Some(ref stored_pid) = session.principal_id {
406                    if let Ok(current_pid) = std::env::var("AZURE_CLIENT_ID") {
407                        if current_pid != *stored_pid {
408                            continue;
409                        }
410                    }
411                }
412                // R6-H11: match on tenant ID when available. A subscription
413                // migrated between tenants (e.g. CSP transfer) keeps the same
414                // subscription GUID but changes the token audience; reusing
415                // the old token against the new tenant leaks directory data.
416                let current_tenant = std::env::var("AZURE_TENANT_ID").ok();
417                let stored_tenant = session.tags.get("azure:tenant_id").cloned();
418                if stored_tenant != current_tenant {
419                    continue;
420                }
421            }
422
423            // Compare actions using provider-appropriate format
424            let mut session_actions: Vec<String> = session
425                .policy
426                .actions
427                .iter()
428                .map(|a| match provider {
429                    CloudProvider::Gcp => a.to_gcp_permission(),
430                    CloudProvider::Azure => a.to_azure_permission(),
431                    CloudProvider::Aws => a.to_iam_action(),
432                })
433                .collect();
434            session_actions.sort();
435            session_actions.dedup();
436            let session_actions_str: Vec<&str> =
437                session_actions.iter().map(|s| s.as_str()).collect();
438
439            if session_actions_str != allow_sorted {
440                continue;
441            }
442
443            // Also require matching resource restrictions so a session scoped to
444            // a specific ARN is never handed out for a request that expects a
445            // different (or wider) set of resources.
446            let mut session_resources: Vec<&str> = session
447                .policy
448                .resources
449                .iter()
450                .map(|s| s.as_str())
451                .collect();
452            session_resources.sort();
453            session_resources.dedup();
454            if session_resources != requested_resources {
455                continue;
456            }
457
458            return Ok(Some(session));
459        }
460        Ok(None)
461    }
462}