Skip to main content

tryaudex_core/
session.rs

1use std::path::PathBuf;
2use std::time::Duration;
3
4use chrono::{DateTime, Utc};
5use serde::{Deserialize, Serialize};
6use uuid::Uuid;
7
8use crate::error::{AvError, Result};
9use crate::policy::ScopedPolicy;
10
11/// Cloud provider for a session.
12#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
13#[serde(rename_all = "lowercase")]
14pub enum CloudProvider {
15    #[default]
16    Aws,
17    Gcp,
18    Azure,
19}
20
21impl std::fmt::Display for CloudProvider {
22    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23        match self {
24            Self::Aws => write!(f, "aws"),
25            Self::Gcp => write!(f, "gcp"),
26            Self::Azure => write!(f, "azure"),
27        }
28    }
29}
30
31/// Represents an active agent credential session.
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct Session {
34    pub id: String,
35    pub created_at: DateTime<Utc>,
36    pub expires_at: DateTime<Utc>,
37    pub ttl_seconds: u64,
38    pub budget: Option<f64>,
39    pub policy: ScopedPolicy,
40    pub status: SessionStatus,
41    /// AWS: IAM role ARN. GCP: service account email. Azure: subscription ID or principal.
42    pub role_arn: String,
43    pub command: Vec<String>,
44    /// AWS access key for this session (populated after credential creation).
45    pub access_key_id: Option<String>,
46    /// Cloud provider (defaults to AWS for backwards compat).
47    #[serde(default)]
48    pub provider: CloudProvider,
49    /// Agent identity (from AUDEX_AGENT_ID env var) — tracks which AI agent/model issued the request.
50    #[serde(default, skip_serializing_if = "Option::is_none")]
51    pub agent_id: Option<String>,
52    /// User-supplied key:value tags for audit filtering and STS session tags.
53    #[serde(default, skip_serializing_if = "std::collections::HashMap::is_empty")]
54    pub tags: std::collections::HashMap<String, String>,
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
58#[serde(rename_all = "snake_case")]
59pub enum SessionStatus {
60    Active,
61    Completed,
62    Expired,
63    Revoked,
64    BudgetExceeded,
65    Failed,
66}
67
68impl std::fmt::Display for SessionStatus {
69    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
70        match self {
71            Self::Active => write!(f, "active"),
72            Self::Completed => write!(f, "completed"),
73            Self::Expired => write!(f, "expired"),
74            Self::Revoked => write!(f, "revoked"),
75            Self::BudgetExceeded => write!(f, "budget_exceeded"),
76            Self::Failed => write!(f, "failed"),
77        }
78    }
79}
80
81impl Session {
82    pub fn new(
83        ttl: Duration,
84        budget: Option<f64>,
85        policy: ScopedPolicy,
86        role_arn: String,
87        command: Vec<String>,
88    ) -> Self {
89        let now = Utc::now();
90        let expires_at = now + chrono::Duration::seconds(ttl.as_secs() as i64);
91        Self {
92            id: Uuid::new_v4().to_string(),
93            created_at: now,
94            expires_at,
95            ttl_seconds: ttl.as_secs(),
96            budget,
97            policy,
98            status: SessionStatus::Active,
99            role_arn,
100            command,
101            access_key_id: None,
102            provider: CloudProvider::default(),
103            agent_id: None,
104            tags: std::collections::HashMap::new(),
105        }
106    }
107
108    pub fn new_gcp(
109        ttl: Duration,
110        budget: Option<f64>,
111        policy: ScopedPolicy,
112        service_account: String,
113        command: Vec<String>,
114    ) -> Self {
115        let now = Utc::now();
116        let expires_at = now + chrono::Duration::seconds(ttl.as_secs() as i64);
117        Self {
118            id: Uuid::new_v4().to_string(),
119            created_at: now,
120            expires_at,
121            ttl_seconds: ttl.as_secs(),
122            budget,
123            policy,
124            status: SessionStatus::Active,
125            role_arn: service_account,
126            command,
127            access_key_id: None,
128            provider: CloudProvider::Gcp,
129            agent_id: None,
130            tags: std::collections::HashMap::new(),
131        }
132    }
133
134    pub fn new_azure(
135        ttl: Duration,
136        budget: Option<f64>,
137        policy: ScopedPolicy,
138        subscription_id: String,
139        command: Vec<String>,
140    ) -> Self {
141        let now = Utc::now();
142        let expires_at = now + chrono::Duration::seconds(ttl.as_secs() as i64);
143        Self {
144            id: Uuid::new_v4().to_string(),
145            created_at: now,
146            expires_at,
147            ttl_seconds: ttl.as_secs(),
148            budget,
149            policy,
150            status: SessionStatus::Active,
151            role_arn: subscription_id,
152            command,
153            access_key_id: None,
154            provider: CloudProvider::Azure,
155            agent_id: None,
156            tags: std::collections::HashMap::new(),
157        }
158    }
159
160    pub fn is_expired(&self) -> bool {
161        Utc::now() > self.expires_at
162    }
163
164    pub fn remaining_seconds(&self) -> i64 {
165        (self.expires_at - Utc::now()).num_seconds().max(0)
166    }
167
168    pub fn complete(&mut self) {
169        self.status = SessionStatus::Completed;
170    }
171
172    pub fn expire(&mut self) {
173        self.status = SessionStatus::Expired;
174    }
175
176    pub fn revoke(&mut self) {
177        self.status = SessionStatus::Revoked;
178    }
179
180    pub fn fail(&mut self) {
181        self.status = SessionStatus::Failed;
182    }
183}
184
185/// Manages session persistence on disk.
186pub struct SessionStore {
187    dir: PathBuf,
188}
189
190impl SessionStore {
191    pub fn new() -> Result<Self> {
192        let dir = dirs::data_local_dir()
193            .unwrap_or_else(|| PathBuf::from("."))
194            .join("audex")
195            .join("sessions");
196        std::fs::create_dir_all(&dir)?;
197        Ok(Self { dir })
198    }
199
200    pub fn save(&self, session: &Session) -> Result<()> {
201        let path = self.dir.join(format!("{}.json", session.id));
202        let json = serde_json::to_string_pretty(session)?;
203        std::fs::write(path, json)?;
204        Ok(())
205    }
206
207    pub fn load(&self, id: &str) -> Result<Session> {
208        let path = self.dir.join(format!("{}.json", id));
209        if !path.exists() {
210            return Err(AvError::SessionNotFound { id: id.to_string() });
211        }
212        let json = std::fs::read_to_string(path)?;
213        let session: Session = serde_json::from_str(&json)?;
214        Ok(session)
215    }
216
217    pub fn list(&self) -> Result<Vec<Session>> {
218        let mut sessions = Vec::new();
219        for entry in std::fs::read_dir(&self.dir)? {
220            let entry = entry?;
221            let path = entry.path();
222            if path.extension().is_some_and(|ext| ext == "json") {
223                let json = std::fs::read_to_string(&path)?;
224                if let Ok(session) = serde_json::from_str::<Session>(&json) {
225                    sessions.push(session);
226                }
227            }
228        }
229        sessions.sort_by(|a, b| b.created_at.cmp(&a.created_at));
230        Ok(sessions)
231    }
232
233    pub fn update(&self, session: &Session) -> Result<()> {
234        self.save(session)
235    }
236
237    /// Find an active session with matching allow actions, role ARN, and enough remaining TTL.
238    /// Used for credential caching — reuse existing sessions instead of issuing new ones.
239    pub fn find_reusable(
240        &self,
241        allow_str: &str,
242        role_arn: &str,
243        min_remaining_secs: i64,
244    ) -> Result<Option<Session>> {
245        let sessions = self.list()?;
246        let mut allow_sorted: Vec<&str> = allow_str.split(',').map(|s| s.trim()).collect();
247        allow_sorted.sort();
248
249        for session in sessions {
250            if session.status != SessionStatus::Active {
251                continue;
252            }
253            if session.is_expired() {
254                continue;
255            }
256            if session.remaining_seconds() < min_remaining_secs {
257                continue;
258            }
259            if session.role_arn != role_arn {
260                continue;
261            }
262
263            // Compare actions
264            let mut session_actions: Vec<String> = session
265                .policy
266                .actions
267                .iter()
268                .map(|a| a.to_iam_action())
269                .collect();
270            session_actions.sort();
271            let session_actions_str: Vec<&str> =
272                session_actions.iter().map(|s| s.as_str()).collect();
273
274            if session_actions_str == allow_sorted {
275                return Ok(Some(session));
276            }
277        }
278        Ok(None)
279    }
280}