1use serde::{Deserialize, Serialize};
2
3use crate::audit::{AuditEntry, AuditEvent};
4use crate::error::{AvError, Result};
5use crate::session::Session;
6
7#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
9#[serde(rename_all = "lowercase")]
10pub enum DbBackend {
11 Sqlite,
12 Postgres,
13}
14
15impl std::fmt::Display for DbBackend {
16 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
17 match self {
18 Self::Sqlite => write!(f, "sqlite"),
19 Self::Postgres => write!(f, "postgres"),
20 }
21 }
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct DbAuditConfig {
27 pub backend: DbBackend,
29 pub connection: String,
33 #[serde(default = "default_max_connections")]
35 pub max_connections: u32,
36 #[serde(default = "default_wal_mode")]
38 pub wal_mode: bool,
39 #[serde(default)]
41 pub retention_days: u32,
42}
43
44fn default_max_connections() -> u32 {
45 5
46}
47
48fn default_wal_mode() -> bool {
49 true
50}
51
52impl Default for DbAuditConfig {
53 fn default() -> Self {
54 let path = dirs::data_local_dir()
55 .unwrap_or_else(|| std::path::PathBuf::from("."))
56 .join("audex")
57 .join("audit.db");
58 Self {
59 backend: DbBackend::Sqlite,
60 connection: path.to_string_lossy().to_string(),
61 max_connections: default_max_connections(),
62 wal_mode: default_wal_mode(),
63 retention_days: 0,
64 }
65 }
66}
67
68pub const SQLITE_SCHEMA: &str = r#"
70CREATE TABLE IF NOT EXISTS audit_entries (
71 id INTEGER PRIMARY KEY AUTOINCREMENT,
72 timestamp TEXT NOT NULL,
73 session_id TEXT NOT NULL,
74 provider TEXT NOT NULL DEFAULT 'aws',
75 event_type TEXT NOT NULL,
76 event_data TEXT NOT NULL,
77 created_at TEXT NOT NULL DEFAULT (datetime('now'))
78);
79
80CREATE INDEX IF NOT EXISTS idx_audit_session_id ON audit_entries(session_id);
81CREATE INDEX IF NOT EXISTS idx_audit_timestamp ON audit_entries(timestamp);
82CREATE INDEX IF NOT EXISTS idx_audit_provider ON audit_entries(provider);
83CREATE INDEX IF NOT EXISTS idx_audit_event_type ON audit_entries(event_type);
84"#;
85
86pub const POSTGRES_SCHEMA: &str = r#"
87CREATE TABLE IF NOT EXISTS audit_entries (
88 id BIGSERIAL PRIMARY KEY,
89 timestamp TIMESTAMPTZ NOT NULL,
90 session_id TEXT NOT NULL,
91 provider TEXT NOT NULL DEFAULT 'aws',
92 event_type TEXT NOT NULL,
93 event_data JSONB NOT NULL,
94 created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
95);
96
97CREATE INDEX IF NOT EXISTS idx_audit_session_id ON audit_entries(session_id);
98CREATE INDEX IF NOT EXISTS idx_audit_timestamp ON audit_entries(timestamp);
99CREATE INDEX IF NOT EXISTS idx_audit_provider ON audit_entries(provider);
100CREATE INDEX IF NOT EXISTS idx_audit_event_type ON audit_entries(event_type);
101"#;
102
103pub fn event_type_str(event: &AuditEvent) -> &'static str {
105 match event {
106 AuditEvent::SessionCreated { .. } => "session_created",
107 AuditEvent::CredentialsIssued { .. } => "credentials_issued",
108 AuditEvent::SessionEnded { .. } => "session_ended",
109 AuditEvent::BudgetWarning { .. } => "budget_warning",
110 AuditEvent::BudgetExceeded { .. } => "budget_exceeded",
111 }
112}
113
114pub struct SqlStatements;
116
117impl SqlStatements {
118 pub fn insert(backend: &DbBackend) -> &'static str {
120 match backend {
121 DbBackend::Sqlite => {
122 "INSERT INTO audit_entries (timestamp, session_id, provider, event_type, event_data) VALUES (?, ?, ?, ?, ?)"
123 }
124 DbBackend::Postgres => {
125 "INSERT INTO audit_entries (timestamp, session_id, provider, event_type, event_data) VALUES ($1, $2, $3, $4, $5::jsonb)"
126 }
127 }
128 }
129
130 pub fn by_session(backend: &DbBackend) -> &'static str {
132 match backend {
133 DbBackend::Sqlite => {
134 "SELECT timestamp, session_id, provider, event_type, event_data FROM audit_entries WHERE session_id = ? ORDER BY timestamp ASC"
135 }
136 DbBackend::Postgres => {
137 "SELECT timestamp, session_id, provider, event_type, event_data FROM audit_entries WHERE session_id = $1 ORDER BY timestamp ASC"
138 }
139 }
140 }
141
142 pub fn recent(backend: &DbBackend) -> &'static str {
144 match backend {
145 DbBackend::Sqlite => {
146 "SELECT timestamp, session_id, provider, event_type, event_data FROM audit_entries ORDER BY timestamp DESC LIMIT ?"
147 }
148 DbBackend::Postgres => {
149 "SELECT timestamp, session_id, provider, event_type, event_data FROM audit_entries ORDER BY timestamp DESC LIMIT $1"
150 }
151 }
152 }
153
154 pub fn by_provider(backend: &DbBackend) -> &'static str {
156 match backend {
157 DbBackend::Sqlite => {
158 "SELECT timestamp, session_id, provider, event_type, event_data FROM audit_entries WHERE provider = ? ORDER BY timestamp DESC LIMIT ?"
159 }
160 DbBackend::Postgres => {
161 "SELECT timestamp, session_id, provider, event_type, event_data FROM audit_entries WHERE provider = $1 ORDER BY timestamp DESC LIMIT $2"
162 }
163 }
164 }
165
166 pub fn count(backend: &DbBackend) -> &'static str {
168 match backend {
169 DbBackend::Sqlite => "SELECT COUNT(*) FROM audit_entries",
170 DbBackend::Postgres => "SELECT COUNT(*) FROM audit_entries",
171 }
172 }
173
174 pub fn count_by_type(backend: &DbBackend) -> &'static str {
176 match backend {
177 DbBackend::Sqlite => {
178 "SELECT event_type, COUNT(*) as count FROM audit_entries GROUP BY event_type ORDER BY count DESC"
179 }
180 DbBackend::Postgres => {
181 "SELECT event_type, COUNT(*) as count FROM audit_entries GROUP BY event_type ORDER BY count DESC"
182 }
183 }
184 }
185
186 pub fn prune(backend: &DbBackend) -> &'static str {
188 match backend {
189 DbBackend::Sqlite => {
190 "DELETE FROM audit_entries WHERE timestamp < datetime('now', '-' || ? || ' days')"
191 }
192 DbBackend::Postgres => {
193 "DELETE FROM audit_entries WHERE timestamp < NOW() - ($1 || ' days')::INTERVAL"
194 }
195 }
196 }
197
198 pub fn search(backend: &DbBackend) -> &'static str {
200 match backend {
201 DbBackend::Sqlite => {
202 "SELECT timestamp, session_id, provider, event_type, event_data FROM audit_entries WHERE event_data LIKE ? ORDER BY timestamp DESC LIMIT ?"
203 }
204 DbBackend::Postgres => {
205 "SELECT timestamp, session_id, provider, event_type, event_data FROM audit_entries WHERE event_data::text ILIKE $1 ORDER BY timestamp DESC LIMIT $2"
206 }
207 }
208 }
209}
210
211pub fn entry_to_row(entry: &AuditEntry) -> Result<DbRow> {
213 let event_data = serde_json::to_string(&entry.event)
214 .map_err(|e| AvError::Sts(format!("Failed to serialize event: {}", e)))?;
215
216 let event_data = crate::leakdetect::redact_secrets(&event_data);
218
219 Ok(DbRow {
220 timestamp: entry.timestamp.to_rfc3339(),
221 session_id: entry.session_id.clone(),
222 provider: entry.provider.clone(),
223 event_type: event_type_str(&entry.event).to_string(),
224 event_data,
225 })
226}
227
228pub fn session_created_entry(session: &Session) -> AuditEntry {
230 use crate::session::CloudProvider;
231
232 let allowed_actions: Vec<String> = match session.provider {
233 CloudProvider::Gcp => session.policy.actions.iter().map(|a| a.to_gcp_permission()).collect(),
234 CloudProvider::Azure => session.policy.actions.iter().map(|a| a.to_azure_permission()).collect(),
235 CloudProvider::Aws => session.policy.actions.iter().map(|a| a.to_iam_action()).collect(),
236 };
237
238 AuditEntry {
239 timestamp: chrono::Utc::now(),
240 session_id: session.id.clone(),
241 provider: session.provider.to_string(),
242 event: AuditEvent::SessionCreated {
243 role_arn: session.role_arn.clone(),
244 ttl_seconds: session.ttl_seconds,
245 budget: session.budget,
246 allowed_actions,
247 command: session.command.clone(),
248 agent_id: session.agent_id.clone(),
249 },
250 }
251}
252
253#[derive(Debug, Clone, Serialize, Deserialize)]
255pub struct DbRow {
256 pub timestamp: String,
257 pub session_id: String,
258 pub provider: String,
259 pub event_type: String,
260 pub event_data: String,
261}
262
263impl DbRow {
264 pub fn to_entry(&self) -> Result<AuditEntry> {
266 let timestamp = chrono::DateTime::parse_from_rfc3339(&self.timestamp)
267 .map(|dt| dt.with_timezone(&chrono::Utc))
268 .map_err(|e| AvError::Sts(format!("Invalid timestamp: {}", e)))?;
269
270 let event: AuditEvent = serde_json::from_str(&self.event_data)
271 .map_err(|e| AvError::Sts(format!("Invalid event data: {}", e)))?;
272
273 Ok(AuditEntry {
274 timestamp,
275 session_id: self.session_id.clone(),
276 provider: self.provider.clone(),
277 event,
278 })
279 }
280}
281
282#[derive(Debug, Clone, Serialize, Deserialize)]
284pub struct DbAuditStats {
285 pub total_entries: u64,
286 pub entries_by_type: Vec<(String, u64)>,
287 pub backend: String,
288 pub connection: String,
289}
290
291pub fn parse_jsonl_for_import(content: &str) -> Vec<AuditEntry> {
293 content
294 .lines()
295 .filter(|l| !l.trim().is_empty())
296 .filter_map(|l| {
297 let (json_content, _) = crate::integrity::parse_line(l);
298 serde_json::from_str(json_content).ok()
299 })
300 .collect()
301}
302
303pub fn generate_import_sql(entries: &[AuditEntry], backend: &DbBackend) -> Result<String> {
305 let mut sql = String::new();
306
307 match backend {
308 DbBackend::Sqlite => {
309 sql.push_str("BEGIN TRANSACTION;\n");
310 }
311 DbBackend::Postgres => {
312 sql.push_str("BEGIN;\n");
313 }
314 }
315
316 for entry in entries {
317 let row = entry_to_row(entry)?;
318 let escaped_data = row.event_data.replace('\'', "''");
319 sql.push_str(&format!(
320 "INSERT INTO audit_entries (timestamp, session_id, provider, event_type, event_data) VALUES ('{}', '{}', '{}', '{}', '{}');\n",
321 row.timestamp, row.session_id, row.provider, row.event_type, escaped_data
322 ));
323 }
324
325 sql.push_str("COMMIT;\n");
326 Ok(sql)
327}
328
329#[cfg(test)]
330mod tests {
331 use super::*;
332 use crate::audit::AuditEvent;
333
334 #[test]
335 fn test_config_default() {
336 let config = DbAuditConfig::default();
337 assert_eq!(config.backend, DbBackend::Sqlite);
338 assert!(config.connection.contains("audit.db"));
339 assert_eq!(config.max_connections, 5);
340 assert!(config.wal_mode);
341 }
342
343 #[test]
344 fn test_config_deserialize_sqlite() {
345 let toml_str = r#"
346backend = "sqlite"
347connection = "/data/audex/audit.db"
348wal_mode = true
349retention_days = 90
350"#;
351 let config: DbAuditConfig = toml::from_str(toml_str).unwrap();
352 assert_eq!(config.backend, DbBackend::Sqlite);
353 assert_eq!(config.connection, "/data/audex/audit.db");
354 assert_eq!(config.retention_days, 90);
355 }
356
357 #[test]
358 fn test_config_deserialize_postgres() {
359 let toml_str = r#"
360backend = "postgres"
361connection = "postgres://audex:password@db.internal:5432/audex"
362max_connections = 10
363"#;
364 let config: DbAuditConfig = toml::from_str(toml_str).unwrap();
365 assert_eq!(config.backend, DbBackend::Postgres);
366 assert_eq!(config.max_connections, 10);
367 }
368
369 #[test]
370 fn test_event_type_str() {
371 assert_eq!(
372 event_type_str(&AuditEvent::SessionCreated {
373 role_arn: "arn".into(),
374 ttl_seconds: 900,
375 budget: None,
376 allowed_actions: vec![],
377 command: vec![],
378 agent_id: None,
379 }),
380 "session_created"
381 );
382 assert_eq!(
383 event_type_str(&AuditEvent::SessionEnded {
384 status: "ok".into(),
385 duration_seconds: 60,
386 exit_code: Some(0),
387 }),
388 "session_ended"
389 );
390 assert_eq!(
391 event_type_str(&AuditEvent::BudgetWarning {
392 current_spend: 1.0,
393 limit: 5.0,
394 }),
395 "budget_warning"
396 );
397 }
398
399 #[test]
400 fn test_entry_to_row() {
401 let entry = AuditEntry {
402 timestamp: chrono::Utc::now(),
403 session_id: "test-123".to_string(),
404 provider: "aws".to_string(),
405 event: AuditEvent::SessionCreated {
406 role_arn: "arn:aws:iam::123:role/Test".to_string(),
407 ttl_seconds: 900,
408 budget: Some(5.0),
409 allowed_actions: vec!["s3:GetObject".to_string()],
410 command: vec!["aws".to_string(), "s3".to_string(), "ls".to_string()],
411 agent_id: Some("claude".to_string()),
412 },
413 };
414 let row = entry_to_row(&entry).unwrap();
415 assert_eq!(row.session_id, "test-123");
416 assert_eq!(row.provider, "aws");
417 assert_eq!(row.event_type, "session_created");
418 assert!(row.event_data.contains("s3:GetObject"));
419 }
420
421 #[test]
422 fn test_row_roundtrip() {
423 let entry = AuditEntry {
424 timestamp: chrono::Utc::now(),
425 session_id: "rt-456".to_string(),
426 provider: "gcp".to_string(),
427 event: AuditEvent::SessionEnded {
428 status: "completed".to_string(),
429 duration_seconds: 120,
430 exit_code: Some(0),
431 },
432 };
433 let row = entry_to_row(&entry).unwrap();
434 let restored = row.to_entry().unwrap();
435 assert_eq!(restored.session_id, "rt-456");
436 assert_eq!(restored.provider, "gcp");
437 }
438
439 #[test]
440 fn test_sql_statements_sqlite() {
441 let b = DbBackend::Sqlite;
442 assert!(SqlStatements::insert(&b).contains("?"));
443 assert!(SqlStatements::by_session(&b).contains("session_id = ?"));
444 assert!(SqlStatements::recent(&b).contains("LIMIT ?"));
445 assert!(SqlStatements::prune(&b).contains("datetime"));
446 }
447
448 #[test]
449 fn test_sql_statements_postgres() {
450 let b = DbBackend::Postgres;
451 assert!(SqlStatements::insert(&b).contains("$1"));
452 assert!(SqlStatements::by_session(&b).contains("session_id = $1"));
453 assert!(SqlStatements::recent(&b).contains("LIMIT $1"));
454 assert!(SqlStatements::prune(&b).contains("INTERVAL"));
455 }
456
457 #[test]
458 fn test_parse_jsonl_for_import() {
459 let jsonl = r#"{"timestamp":"2026-01-01T00:00:00Z","session_id":"s1","provider":"aws","event":{"type":"session_ended","status":"ok","duration_seconds":60,"exit_code":0}}"#;
460 let entries = parse_jsonl_for_import(jsonl);
461 assert_eq!(entries.len(), 1);
462 assert_eq!(entries[0].session_id, "s1");
463 }
464
465 #[test]
466 fn test_generate_import_sql_sqlite() {
467 let entries = vec![AuditEntry {
468 timestamp: chrono::DateTime::parse_from_rfc3339("2026-01-01T00:00:00Z")
469 .unwrap()
470 .with_timezone(&chrono::Utc),
471 session_id: "s1".to_string(),
472 provider: "aws".to_string(),
473 event: AuditEvent::SessionEnded {
474 status: "ok".to_string(),
475 duration_seconds: 60,
476 exit_code: Some(0),
477 },
478 }];
479 let sql = generate_import_sql(&entries, &DbBackend::Sqlite).unwrap();
480 assert!(sql.contains("BEGIN TRANSACTION"));
481 assert!(sql.contains("INSERT INTO audit_entries"));
482 assert!(sql.contains("COMMIT"));
483 }
484
485 #[test]
486 fn test_generate_import_sql_postgres() {
487 let entries = vec![AuditEntry {
488 timestamp: chrono::DateTime::parse_from_rfc3339("2026-01-01T00:00:00Z")
489 .unwrap()
490 .with_timezone(&chrono::Utc),
491 session_id: "s2".to_string(),
492 provider: "gcp".to_string(),
493 event: AuditEvent::CredentialsIssued {
494 access_key_id: "AKID".to_string(),
495 expires_at: chrono::DateTime::parse_from_rfc3339("2026-01-01T00:15:00Z")
496 .unwrap()
497 .with_timezone(&chrono::Utc),
498 },
499 }];
500 let sql = generate_import_sql(&entries, &DbBackend::Postgres).unwrap();
501 assert!(sql.contains("BEGIN;"));
502 assert!(sql.contains("COMMIT;"));
503 }
504
505 #[test]
506 fn test_db_backend_display() {
507 assert_eq!(DbBackend::Sqlite.to_string(), "sqlite");
508 assert_eq!(DbBackend::Postgres.to_string(), "postgres");
509 }
510
511 #[test]
512 fn test_db_audit_stats_serialization() {
513 let stats = DbAuditStats {
514 total_entries: 1000,
515 entries_by_type: vec![
516 ("session_created".to_string(), 400),
517 ("credentials_issued".to_string(), 400),
518 ("session_ended".to_string(), 200),
519 ],
520 backend: "sqlite".to_string(),
521 connection: "/data/audit.db".to_string(),
522 };
523 let json = serde_json::to_string(&stats).unwrap();
524 assert!(json.contains("1000"));
525 assert!(json.contains("session_created"));
526 }
527}