Skip to main content

tryaudex_core/
session.rs

1use std::path::PathBuf;
2use std::time::Duration;
3
4use chrono::{DateTime, Utc};
5use serde::{Deserialize, Serialize};
6use uuid::Uuid;
7
8use crate::error::{AvError, Result};
9use crate::policy::ScopedPolicy;
10
11/// Cloud provider for a session.
12#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
13#[serde(rename_all = "lowercase")]
14pub enum CloudProvider {
15    #[default]
16    Aws,
17    Gcp,
18    Azure,
19}
20
21impl std::fmt::Display for CloudProvider {
22    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23        match self {
24            Self::Aws => write!(f, "aws"),
25            Self::Gcp => write!(f, "gcp"),
26            Self::Azure => write!(f, "azure"),
27        }
28    }
29}
30
31/// Represents an active agent credential session.
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct Session {
34    pub id: String,
35    pub created_at: DateTime<Utc>,
36    pub expires_at: DateTime<Utc>,
37    pub ttl_seconds: u64,
38    pub budget: Option<f64>,
39    pub policy: ScopedPolicy,
40    pub status: SessionStatus,
41    /// AWS: IAM role ARN. GCP: service account email. Azure: subscription ID or principal.
42    pub role_arn: String,
43    pub command: Vec<String>,
44    /// AWS access key for this session (populated after credential creation).
45    pub access_key_id: Option<String>,
46    /// Cloud provider (defaults to AWS for backwards compat).
47    #[serde(default)]
48    pub provider: CloudProvider,
49    /// Agent identity (from AUDEX_AGENT_ID env var) — tracks which AI agent/model issued the request.
50    #[serde(default, skip_serializing_if = "Option::is_none")]
51    pub agent_id: Option<String>,
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
55#[serde(rename_all = "snake_case")]
56pub enum SessionStatus {
57    Active,
58    Completed,
59    Expired,
60    Revoked,
61    BudgetExceeded,
62    Failed,
63}
64
65impl std::fmt::Display for SessionStatus {
66    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67        match self {
68            Self::Active => write!(f, "active"),
69            Self::Completed => write!(f, "completed"),
70            Self::Expired => write!(f, "expired"),
71            Self::Revoked => write!(f, "revoked"),
72            Self::BudgetExceeded => write!(f, "budget_exceeded"),
73            Self::Failed => write!(f, "failed"),
74        }
75    }
76}
77
78impl Session {
79    pub fn new(
80        ttl: Duration,
81        budget: Option<f64>,
82        policy: ScopedPolicy,
83        role_arn: String,
84        command: Vec<String>,
85    ) -> Self {
86        let now = Utc::now();
87        let expires_at = now + chrono::Duration::seconds(ttl.as_secs() as i64);
88        Self {
89            id: Uuid::new_v4().to_string(),
90            created_at: now,
91            expires_at,
92            ttl_seconds: ttl.as_secs(),
93            budget,
94            policy,
95            status: SessionStatus::Active,
96            role_arn,
97            command,
98            access_key_id: None,
99            provider: CloudProvider::default(),
100            agent_id: None,
101        }
102    }
103
104    pub fn new_gcp(
105        ttl: Duration,
106        budget: Option<f64>,
107        policy: ScopedPolicy,
108        service_account: String,
109        command: Vec<String>,
110    ) -> Self {
111        let now = Utc::now();
112        let expires_at = now + chrono::Duration::seconds(ttl.as_secs() as i64);
113        Self {
114            id: Uuid::new_v4().to_string(),
115            created_at: now,
116            expires_at,
117            ttl_seconds: ttl.as_secs(),
118            budget,
119            policy,
120            status: SessionStatus::Active,
121            role_arn: service_account,
122            command,
123            access_key_id: None,
124            provider: CloudProvider::Gcp,
125            agent_id: None,
126        }
127    }
128
129    pub fn new_azure(
130        ttl: Duration,
131        budget: Option<f64>,
132        policy: ScopedPolicy,
133        subscription_id: String,
134        command: Vec<String>,
135    ) -> Self {
136        let now = Utc::now();
137        let expires_at = now + chrono::Duration::seconds(ttl.as_secs() as i64);
138        Self {
139            id: Uuid::new_v4().to_string(),
140            created_at: now,
141            expires_at,
142            ttl_seconds: ttl.as_secs(),
143            budget,
144            policy,
145            status: SessionStatus::Active,
146            role_arn: subscription_id,
147            command,
148            access_key_id: None,
149            provider: CloudProvider::Azure,
150            agent_id: None,
151        }
152    }
153
154    pub fn is_expired(&self) -> bool {
155        Utc::now() > self.expires_at
156    }
157
158    pub fn remaining_seconds(&self) -> i64 {
159        (self.expires_at - Utc::now()).num_seconds().max(0)
160    }
161
162    pub fn complete(&mut self) {
163        self.status = SessionStatus::Completed;
164    }
165
166    pub fn expire(&mut self) {
167        self.status = SessionStatus::Expired;
168    }
169
170    pub fn revoke(&mut self) {
171        self.status = SessionStatus::Revoked;
172    }
173
174    pub fn fail(&mut self) {
175        self.status = SessionStatus::Failed;
176    }
177}
178
179/// Manages session persistence on disk.
180pub struct SessionStore {
181    dir: PathBuf,
182}
183
184impl SessionStore {
185    pub fn new() -> Result<Self> {
186        let dir = dirs::data_local_dir()
187            .unwrap_or_else(|| PathBuf::from("."))
188            .join("audex")
189            .join("sessions");
190        std::fs::create_dir_all(&dir)?;
191        Ok(Self { dir })
192    }
193
194    pub fn save(&self, session: &Session) -> Result<()> {
195        let path = self.dir.join(format!("{}.json", session.id));
196        let json = serde_json::to_string_pretty(session)?;
197        std::fs::write(path, json)?;
198        Ok(())
199    }
200
201    pub fn load(&self, id: &str) -> Result<Session> {
202        let path = self.dir.join(format!("{}.json", id));
203        if !path.exists() {
204            return Err(AvError::SessionNotFound { id: id.to_string() });
205        }
206        let json = std::fs::read_to_string(path)?;
207        let session: Session = serde_json::from_str(&json)?;
208        Ok(session)
209    }
210
211    pub fn list(&self) -> Result<Vec<Session>> {
212        let mut sessions = Vec::new();
213        for entry in std::fs::read_dir(&self.dir)? {
214            let entry = entry?;
215            let path = entry.path();
216            if path.extension().is_some_and(|ext| ext == "json") {
217                let json = std::fs::read_to_string(&path)?;
218                if let Ok(session) = serde_json::from_str::<Session>(&json) {
219                    sessions.push(session);
220                }
221            }
222        }
223        sessions.sort_by(|a, b| b.created_at.cmp(&a.created_at));
224        Ok(sessions)
225    }
226
227    pub fn update(&self, session: &Session) -> Result<()> {
228        self.save(session)
229    }
230
231    /// Find an active session with matching allow actions, role ARN, and enough remaining TTL.
232    /// Used for credential caching — reuse existing sessions instead of issuing new ones.
233    pub fn find_reusable(
234        &self,
235        allow_str: &str,
236        role_arn: &str,
237        min_remaining_secs: i64,
238    ) -> Result<Option<Session>> {
239        let sessions = self.list()?;
240        let mut allow_sorted: Vec<&str> = allow_str.split(',').map(|s| s.trim()).collect();
241        allow_sorted.sort();
242
243        for session in sessions {
244            if session.status != SessionStatus::Active {
245                continue;
246            }
247            if session.is_expired() {
248                continue;
249            }
250            if session.remaining_seconds() < min_remaining_secs {
251                continue;
252            }
253            if session.role_arn != role_arn {
254                continue;
255            }
256
257            // Compare actions
258            let mut session_actions: Vec<String> = session
259                .policy
260                .actions
261                .iter()
262                .map(|a| a.to_iam_action())
263                .collect();
264            session_actions.sort();
265            let session_actions_str: Vec<&str> =
266                session_actions.iter().map(|s| s.as_str()).collect();
267
268            if session_actions_str == allow_sorted {
269                return Ok(Some(session));
270            }
271        }
272        Ok(None)
273    }
274}