1use std::path::PathBuf;
2use std::time::Duration;
3
4use chrono::{DateTime, Utc};
5use serde::{Deserialize, Serialize};
6use uuid::Uuid;
7
8use crate::error::{AvError, Result};
9use crate::policy::ScopedPolicy;
10
11#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
13#[serde(rename_all = "lowercase")]
14pub enum CloudProvider {
15 #[default]
16 Aws,
17 Gcp,
18 Azure,
19}
20
21impl std::fmt::Display for CloudProvider {
22 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23 match self {
24 Self::Aws => write!(f, "aws"),
25 Self::Gcp => write!(f, "gcp"),
26 Self::Azure => write!(f, "azure"),
27 }
28 }
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct Session {
34 pub id: String,
35 pub created_at: DateTime<Utc>,
36 pub expires_at: DateTime<Utc>,
37 pub ttl_seconds: u64,
38 pub budget: Option<f64>,
39 pub policy: ScopedPolicy,
40 pub status: SessionStatus,
41 #[serde(alias = "role_principal")]
53 pub role_arn: String,
54 pub command: Vec<String>,
55 pub access_key_id: Option<String>,
57 #[serde(default)]
59 pub provider: CloudProvider,
60 #[serde(default, skip_serializing_if = "Option::is_none")]
62 pub agent_id: Option<String>,
63 #[serde(default, skip_serializing_if = "Option::is_none")]
66 pub principal_id: Option<String>,
67 #[serde(default, skip_serializing_if = "std::collections::HashMap::is_empty")]
69 pub tags: std::collections::HashMap<String, String>,
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
73#[serde(rename_all = "snake_case")]
74pub enum SessionStatus {
75 Active,
76 Completed,
77 Expired,
78 Revoked,
79 BudgetExceeded,
80 Failed,
81}
82
83impl std::fmt::Display for SessionStatus {
84 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85 match self {
86 Self::Active => write!(f, "active"),
87 Self::Completed => write!(f, "completed"),
88 Self::Expired => write!(f, "expired"),
89 Self::Revoked => write!(f, "revoked"),
90 Self::BudgetExceeded => write!(f, "budget_exceeded"),
91 Self::Failed => write!(f, "failed"),
92 }
93 }
94}
95
96impl Session {
97 pub fn new(
98 ttl: Duration,
99 budget: Option<f64>,
100 policy: ScopedPolicy,
101 role_arn: String,
102 command: Vec<String>,
103 ) -> Self {
104 let now = Utc::now();
105 let expires_at = now + chrono::Duration::seconds(ttl.as_secs() as i64);
106 Self {
107 id: Uuid::new_v4().to_string(),
108 created_at: now,
109 expires_at,
110 ttl_seconds: ttl.as_secs(),
111 budget,
112 policy,
113 status: SessionStatus::Active,
114 role_arn,
115 command,
116 access_key_id: None,
117 provider: CloudProvider::default(),
118 agent_id: None,
119 principal_id: None,
120 tags: std::collections::HashMap::new(),
121 }
122 }
123
124 pub fn new_gcp(
125 ttl: Duration,
126 budget: Option<f64>,
127 policy: ScopedPolicy,
128 service_account: String,
129 command: Vec<String>,
130 ) -> Self {
131 let now = Utc::now();
132 let expires_at = now + chrono::Duration::seconds(ttl.as_secs() as i64);
133 Self {
134 id: Uuid::new_v4().to_string(),
135 created_at: now,
136 expires_at,
137 ttl_seconds: ttl.as_secs(),
138 budget,
139 policy,
140 status: SessionStatus::Active,
141 role_arn: service_account,
142 command,
143 access_key_id: None,
144 provider: CloudProvider::Gcp,
145 agent_id: None,
146 principal_id: None,
147 tags: std::collections::HashMap::new(),
148 }
149 }
150
151 pub fn new_azure(
152 ttl: Duration,
153 budget: Option<f64>,
154 policy: ScopedPolicy,
155 subscription_id: String,
156 command: Vec<String>,
157 ) -> Self {
158 let now = Utc::now();
159 let expires_at = now + chrono::Duration::seconds(ttl.as_secs() as i64);
160 Self {
161 id: Uuid::new_v4().to_string(),
162 created_at: now,
163 expires_at,
164 ttl_seconds: ttl.as_secs(),
165 budget,
166 policy,
167 status: SessionStatus::Active,
168 role_arn: subscription_id,
169 command,
170 access_key_id: None,
171 provider: CloudProvider::Azure,
172 agent_id: None,
173 principal_id: None,
174 tags: std::collections::HashMap::new(),
175 }
176 }
177
178 pub fn short_id(&self) -> &str {
182 self.id.get(..8).unwrap_or(&self.id)
183 }
184
185 pub fn is_expired(&self) -> bool {
186 Utc::now() > self.expires_at
187 }
188
189 pub fn remaining_seconds(&self) -> i64 {
190 (self.expires_at - Utc::now()).num_seconds().max(0)
191 }
192
193 pub fn complete(&mut self) {
194 self.status = SessionStatus::Completed;
195 }
196
197 pub fn expire(&mut self) {
198 self.status = SessionStatus::Expired;
199 }
200
201 pub fn revoke(&mut self) {
202 self.status = SessionStatus::Revoked;
203 }
204
205 pub fn fail(&mut self) {
206 self.status = SessionStatus::Failed;
207 }
208
209 pub fn principal_id(&self) -> &str {
218 self.principal_id.as_deref().unwrap_or(&self.role_arn)
219 }
220}
221
222pub struct SessionStore {
224 dir: PathBuf,
225}
226
227impl SessionStore {
228 pub fn new() -> Result<Self> {
229 let base = dirs::data_local_dir().ok_or_else(|| {
230 AvError::InvalidPolicy(
231 "Could not determine local data directory. Set XDG_DATA_HOME or HOME.".to_string(),
232 )
233 })?;
234 let dir = base.join("audex").join("sessions");
235 std::fs::create_dir_all(&dir)?;
236 Ok(Self { dir })
237 }
238
239 const SESSION_MAX: usize = 1000;
244
245 fn write_file(&self, session: &Session) -> Result<()> {
246 let path = self.dir.join(format!("{}.json", session.id));
247 let json = serde_json::to_string_pretty(session)?;
248 #[cfg(unix)]
249 {
250 use std::os::unix::fs::OpenOptionsExt;
251 let mut file = std::fs::OpenOptions::new()
252 .write(true)
253 .create(true)
254 .truncate(true)
255 .mode(0o600)
256 .open(&path)?;
257 std::io::Write::write_all(&mut file, json.as_bytes())?;
258 }
259 #[cfg(not(unix))]
260 {
261 std::fs::write(&path, json)?;
262 }
263 Ok(())
264 }
265
266 pub fn save(&self, session: &Session) -> Result<()> {
267 let path = self.dir.join(format!("{}.json", session.id));
281 let is_new = !path.exists();
282
283 if is_new {
284 let entries: Result<Vec<_>> = std::fs::read_dir(&self.dir)?
285 .filter_map(|e| e.ok())
286 .filter(|e| {
287 e.path()
288 .extension()
289 .and_then(|x| x.to_str())
290 .map(|x| x == "json")
291 .unwrap_or(false)
292 })
293 .map(|e| {
294 let mtime = e
295 .metadata()
296 .and_then(|m| m.modified())
297 .unwrap_or(std::time::SystemTime::UNIX_EPOCH);
298 Ok((mtime, e.path()))
299 })
300 .collect();
301 let mut entries = entries?;
302 if entries.len() + 1 > Self::SESSION_MAX {
304 entries.sort_unstable_by_key(|(mtime, _)| *mtime);
305 let to_delete = entries.len() + 1 - Self::SESSION_MAX;
306 for (_, p) in entries.iter().take(to_delete) {
307 let _ = std::fs::remove_file(p);
308 }
309 }
310 }
311
312 self.write_file(session)
313 }
314
315 pub fn load(&self, id: &str) -> Result<Session> {
316 crate::validate::session_id(id)?;
317 let path = self.dir.join(format!("{}.json", id));
318 if !path.exists() {
319 return Err(AvError::SessionNotFound { id: id.to_string() });
320 }
321 let json = std::fs::read_to_string(path)?;
322 let session: Session = serde_json::from_str(&json)?;
323 Ok(session)
324 }
325
326 pub fn list(&self) -> Result<Vec<Session>> {
327 let mut sessions = Vec::new();
328 for entry in std::fs::read_dir(&self.dir)? {
329 let entry = entry?;
330 let path = entry.path();
331 if path.extension().is_some_and(|ext| ext == "json") {
332 let json = std::fs::read_to_string(&path)?;
333 if let Ok(session) = serde_json::from_str::<Session>(&json) {
334 sessions.push(session);
335 }
336 }
337 }
338 sessions.sort_by(|a, b| b.created_at.cmp(&a.created_at));
339 Ok(sessions)
340 }
341
342 pub fn update(&self, session: &Session) -> Result<()> {
343 self.save(session)
344 }
345
346 pub fn find_reusable(
349 &self,
350 allow_str: &str,
351 role_arn: &str,
352 min_remaining_secs: i64,
353 ) -> Result<Option<Session>> {
354 self.find_reusable_for_provider(
355 allow_str,
356 role_arn,
357 min_remaining_secs,
358 CloudProvider::Aws,
359 &[],
360 )
361 }
362
363 pub fn find_reusable_for_provider(
368 &self,
369 allow_str: &str,
370 role_arn: &str,
371 min_remaining_secs: i64,
372 provider: CloudProvider,
373 resources: &[String],
374 ) -> Result<Option<Session>> {
375 let sessions = self.list()?;
376 let mut allow_sorted: Vec<&str> = allow_str.split(',').map(|s| s.trim()).collect();
379 allow_sorted.sort();
380 allow_sorted.dedup();
381
382 let mut requested_resources: Vec<&str> = resources.iter().map(|s| s.as_str()).collect();
383 requested_resources.sort();
384 requested_resources.dedup();
385
386 for session in sessions {
387 if session.status != SessionStatus::Active {
388 continue;
389 }
390 if session.is_expired() {
391 continue;
392 }
393 if session.remaining_seconds() < min_remaining_secs {
394 continue;
395 }
396 if session.role_arn != role_arn {
397 continue;
398 }
399 if provider == CloudProvider::Azure {
405 if let Some(ref stored_pid) = session.principal_id {
406 if let Ok(current_pid) = std::env::var("AZURE_CLIENT_ID") {
407 if current_pid != *stored_pid {
408 continue;
409 }
410 }
411 }
412 let current_tenant = std::env::var("AZURE_TENANT_ID").ok();
417 let stored_tenant = session.tags.get("azure:tenant_id").cloned();
418 if stored_tenant != current_tenant {
419 continue;
420 }
421 }
422
423 let mut session_actions: Vec<String> = session
425 .policy
426 .actions
427 .iter()
428 .map(|a| match provider {
429 CloudProvider::Gcp => a.to_gcp_permission(),
430 CloudProvider::Azure => a.to_azure_permission(),
431 CloudProvider::Aws => a.to_iam_action(),
432 })
433 .collect();
434 session_actions.sort();
435 session_actions.dedup();
436 let session_actions_str: Vec<&str> =
437 session_actions.iter().map(|s| s.as_str()).collect();
438
439 if session_actions_str != allow_sorted {
440 continue;
441 }
442
443 let mut session_resources: Vec<&str> = session
447 .policy
448 .resources
449 .iter()
450 .map(|s| s.as_str())
451 .collect();
452 session_resources.sort();
453 session_resources.dedup();
454 if session_resources != requested_resources {
455 continue;
456 }
457
458 return Ok(Some(session));
459 }
460 Ok(None)
461 }
462}