1use serde::{Deserialize, Serialize};
2
3use crate::audit::{AuditEntry, AuditEvent};
4use crate::error::{AvError, Result};
5use crate::session::Session;
6
7#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
9#[serde(rename_all = "lowercase")]
10pub enum DbBackend {
11 Sqlite,
12 Postgres,
13}
14
15impl std::fmt::Display for DbBackend {
16 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
17 match self {
18 Self::Sqlite => write!(f, "sqlite"),
19 Self::Postgres => write!(f, "postgres"),
20 }
21 }
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct DbAuditConfig {
27 pub backend: DbBackend,
29 pub connection: String,
36 #[serde(default = "default_max_connections")]
38 pub max_connections: u32,
39 #[serde(default = "default_wal_mode")]
41 pub wal_mode: bool,
42 #[serde(default)]
44 pub retention_days: u32,
45}
46
47fn default_max_connections() -> u32 {
48 5
49}
50
51fn default_wal_mode() -> bool {
52 true
53}
54
55impl Default for DbAuditConfig {
56 fn default() -> Self {
57 let path = dirs::data_local_dir()
58 .unwrap_or_else(|| std::path::PathBuf::from("."))
59 .join("audex")
60 .join("audit.db");
61 Self {
62 backend: DbBackend::Sqlite,
63 connection: path.to_string_lossy().to_string(),
64 max_connections: default_max_connections(),
65 wal_mode: default_wal_mode(),
66 retention_days: 0,
67 }
68 }
69}
70
71impl DbAuditConfig {
72 pub fn resolve_connection(&self) -> String {
76 std::env::var("AUDEX_DB_CONNECTION").unwrap_or_else(|_| self.connection.clone())
77 }
78}
79
80pub const SQLITE_SCHEMA: &str = r#"
82CREATE TABLE IF NOT EXISTS audit_entries (
83 id INTEGER PRIMARY KEY AUTOINCREMENT,
84 timestamp TEXT NOT NULL,
85 session_id TEXT NOT NULL,
86 provider TEXT NOT NULL DEFAULT 'aws',
87 event_type TEXT NOT NULL,
88 event_data TEXT NOT NULL,
89 created_at TEXT NOT NULL DEFAULT (datetime('now'))
90);
91
92CREATE INDEX IF NOT EXISTS idx_audit_session_id ON audit_entries(session_id);
93CREATE INDEX IF NOT EXISTS idx_audit_timestamp ON audit_entries(timestamp);
94CREATE INDEX IF NOT EXISTS idx_audit_provider ON audit_entries(provider);
95CREATE INDEX IF NOT EXISTS idx_audit_event_type ON audit_entries(event_type);
96"#;
97
98pub const POSTGRES_SCHEMA: &str = r#"
99CREATE TABLE IF NOT EXISTS audit_entries (
100 id BIGSERIAL PRIMARY KEY,
101 timestamp TIMESTAMPTZ NOT NULL,
102 session_id TEXT NOT NULL,
103 provider TEXT NOT NULL DEFAULT 'aws',
104 event_type TEXT NOT NULL,
105 event_data JSONB NOT NULL,
106 created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
107);
108
109CREATE INDEX IF NOT EXISTS idx_audit_session_id ON audit_entries(session_id);
110CREATE INDEX IF NOT EXISTS idx_audit_timestamp ON audit_entries(timestamp);
111CREATE INDEX IF NOT EXISTS idx_audit_provider ON audit_entries(provider);
112CREATE INDEX IF NOT EXISTS idx_audit_event_type ON audit_entries(event_type);
113"#;
114
115pub fn event_type_str(event: &AuditEvent) -> &'static str {
117 match event {
118 AuditEvent::SessionCreated { .. } => "session_created",
119 AuditEvent::CredentialsIssued { .. } => "credentials_issued",
120 AuditEvent::SessionEnded { .. } => "session_ended",
121 AuditEvent::BudgetWarning { .. } => "budget_warning",
122 AuditEvent::BudgetExceeded { .. } => "budget_exceeded",
123 AuditEvent::ResourceCreated { .. } => "resource_created",
124 AuditEvent::PolicyAdvisoryOnly { .. } => "policy_advisory_only",
125 }
126}
127
128pub struct SqlStatements;
130
131impl SqlStatements {
132 pub fn insert(backend: &DbBackend) -> &'static str {
134 match backend {
135 DbBackend::Sqlite => {
136 "INSERT INTO audit_entries (timestamp, session_id, provider, event_type, event_data) VALUES (?, ?, ?, ?, ?)"
137 }
138 DbBackend::Postgres => {
139 "INSERT INTO audit_entries (timestamp, session_id, provider, event_type, event_data) VALUES ($1, $2, $3, $4, $5::jsonb)"
140 }
141 }
142 }
143
144 pub fn by_session(backend: &DbBackend) -> &'static str {
146 match backend {
147 DbBackend::Sqlite => {
148 "SELECT timestamp, session_id, provider, event_type, event_data FROM audit_entries WHERE session_id = ? ORDER BY timestamp ASC"
149 }
150 DbBackend::Postgres => {
151 "SELECT timestamp, session_id, provider, event_type, event_data FROM audit_entries WHERE session_id = $1 ORDER BY timestamp ASC"
152 }
153 }
154 }
155
156 pub fn recent(backend: &DbBackend) -> &'static str {
158 match backend {
159 DbBackend::Sqlite => {
160 "SELECT timestamp, session_id, provider, event_type, event_data FROM audit_entries ORDER BY timestamp DESC LIMIT ?"
161 }
162 DbBackend::Postgres => {
163 "SELECT timestamp, session_id, provider, event_type, event_data FROM audit_entries ORDER BY timestamp DESC LIMIT $1"
164 }
165 }
166 }
167
168 pub fn by_provider(backend: &DbBackend) -> &'static str {
170 match backend {
171 DbBackend::Sqlite => {
172 "SELECT timestamp, session_id, provider, event_type, event_data FROM audit_entries WHERE provider = ? ORDER BY timestamp DESC LIMIT ?"
173 }
174 DbBackend::Postgres => {
175 "SELECT timestamp, session_id, provider, event_type, event_data FROM audit_entries WHERE provider = $1 ORDER BY timestamp DESC LIMIT $2"
176 }
177 }
178 }
179
180 pub fn count(backend: &DbBackend) -> &'static str {
182 match backend {
183 DbBackend::Sqlite => "SELECT COUNT(*) FROM audit_entries",
184 DbBackend::Postgres => "SELECT COUNT(*) FROM audit_entries",
185 }
186 }
187
188 pub fn count_by_type(backend: &DbBackend) -> &'static str {
190 match backend {
191 DbBackend::Sqlite => {
192 "SELECT event_type, COUNT(*) as count FROM audit_entries GROUP BY event_type ORDER BY count DESC"
193 }
194 DbBackend::Postgres => {
195 "SELECT event_type, COUNT(*) as count FROM audit_entries GROUP BY event_type ORDER BY count DESC"
196 }
197 }
198 }
199
200 pub fn prune(backend: &DbBackend) -> &'static str {
202 match backend {
203 DbBackend::Sqlite => {
204 "DELETE FROM audit_entries WHERE timestamp < datetime('now', '-' || ? || ' days')"
205 }
206 DbBackend::Postgres => {
207 "DELETE FROM audit_entries WHERE timestamp < NOW() - ($1 || ' days')::INTERVAL"
208 }
209 }
210 }
211
212 pub fn search(backend: &DbBackend) -> &'static str {
214 match backend {
215 DbBackend::Sqlite => {
216 "SELECT timestamp, session_id, provider, event_type, event_data FROM audit_entries WHERE event_data LIKE ? ORDER BY timestamp DESC LIMIT ?"
217 }
218 DbBackend::Postgres => {
219 "SELECT timestamp, session_id, provider, event_type, event_data FROM audit_entries WHERE event_data::text ILIKE $1 ORDER BY timestamp DESC LIMIT $2"
220 }
221 }
222 }
223}
224
225pub fn entry_to_row(entry: &AuditEntry) -> Result<DbRow> {
227 let event_data = serde_json::to_string(&entry.event)
228 .map_err(|e| AvError::Sts(format!("Failed to serialize event: {}", e)))?;
229
230 let event_data = crate::leakdetect::redact_secrets(&event_data);
232
233 Ok(DbRow {
234 timestamp: entry.timestamp.to_rfc3339(),
235 session_id: entry.session_id.clone(),
236 provider: entry.provider.clone(),
237 event_type: event_type_str(&entry.event).to_string(),
238 event_data,
239 })
240}
241
242pub fn session_created_entry(session: &Session) -> AuditEntry {
244 use crate::session::CloudProvider;
245
246 let allowed_actions: Vec<String> = match session.provider {
247 CloudProvider::Gcp => session
248 .policy
249 .actions
250 .iter()
251 .map(|a| a.to_gcp_permission())
252 .collect(),
253 CloudProvider::Azure => session
254 .policy
255 .actions
256 .iter()
257 .map(|a| a.to_azure_permission())
258 .collect(),
259 CloudProvider::Aws => session
260 .policy
261 .actions
262 .iter()
263 .map(|a| a.to_iam_action())
264 .collect(),
265 };
266
267 AuditEntry {
268 timestamp: chrono::Utc::now(),
269 session_id: session.id.clone(),
270 provider: session.provider.to_string(),
271 event: AuditEvent::SessionCreated {
272 role_arn: session.role_arn.clone(),
273 ttl_seconds: session.ttl_seconds,
274 budget: session.budget,
275 allowed_actions,
276 command: session.command.clone(),
277 agent_id: session.agent_id.clone(),
278 },
279 }
280}
281
282#[derive(Debug, Clone, Serialize, Deserialize)]
284pub struct DbRow {
285 pub timestamp: String,
286 pub session_id: String,
287 pub provider: String,
288 pub event_type: String,
289 pub event_data: String,
290}
291
292impl DbRow {
293 pub fn to_entry(&self) -> Result<AuditEntry> {
295 let timestamp = chrono::DateTime::parse_from_rfc3339(&self.timestamp)
296 .map(|dt| dt.with_timezone(&chrono::Utc))
297 .map_err(|e| AvError::Sts(format!("Invalid timestamp: {}", e)))?;
298
299 let event: AuditEvent = serde_json::from_str(&self.event_data)
300 .map_err(|e| AvError::Sts(format!("Invalid event data: {}", e)))?;
301
302 Ok(AuditEntry {
303 timestamp,
304 session_id: self.session_id.clone(),
305 provider: self.provider.clone(),
306 event,
307 })
308 }
309}
310
311#[derive(Debug, Clone, Serialize, Deserialize)]
313pub struct DbAuditStats {
314 pub total_entries: u64,
315 pub entries_by_type: Vec<(String, u64)>,
316 pub backend: String,
317 pub connection: String,
318}
319
320pub fn parse_jsonl_for_import(content: &str) -> Vec<AuditEntry> {
322 content
323 .lines()
324 .filter(|l| !l.trim().is_empty())
325 .filter_map(|l| {
326 let (json_content, _) = crate::integrity::parse_line(l);
327 serde_json::from_str(json_content).ok()
328 })
329 .collect()
330}
331
332fn sql_escape(s: &str) -> String {
337 s.replace('\\', "\\\\").replace('\'', "''")
338}
339
340pub fn generate_import_sql(entries: &[AuditEntry], backend: &DbBackend) -> Result<String> {
342 let mut sql = String::new();
343
344 match backend {
345 DbBackend::Sqlite => {
346 sql.push_str("BEGIN TRANSACTION;\n");
347 }
348 DbBackend::Postgres => {
349 sql.push_str("BEGIN;\n");
350 }
351 }
352
353 for entry in entries {
354 let row = entry_to_row(entry)?;
355 sql.push_str(&format!(
356 "INSERT INTO audit_entries (timestamp, session_id, provider, event_type, event_data) VALUES ('{}', '{}', '{}', '{}', '{}');\n",
357 sql_escape(&row.timestamp),
358 sql_escape(&row.session_id),
359 sql_escape(&row.provider),
360 sql_escape(&row.event_type),
361 sql_escape(&row.event_data),
362 ));
363 }
364
365 sql.push_str("COMMIT;\n");
366 Ok(sql)
367}
368
369#[cfg(test)]
370mod tests {
371 use super::*;
372 use crate::audit::AuditEvent;
373
374 #[test]
375 fn test_config_default() {
376 let config = DbAuditConfig::default();
377 assert_eq!(config.backend, DbBackend::Sqlite);
378 assert!(config.connection.contains("audit.db"));
379 assert_eq!(config.max_connections, 5);
380 assert!(config.wal_mode);
381 }
382
383 #[test]
384 fn test_config_deserialize_sqlite() {
385 let toml_str = r#"
386backend = "sqlite"
387connection = "/data/audex/audit.db"
388wal_mode = true
389retention_days = 90
390"#;
391 let config: DbAuditConfig = toml::from_str(toml_str).unwrap();
392 assert_eq!(config.backend, DbBackend::Sqlite);
393 assert_eq!(config.connection, "/data/audex/audit.db");
394 assert_eq!(config.retention_days, 90);
395 }
396
397 #[test]
398 fn test_config_deserialize_postgres() {
399 let toml_str = r#"
400backend = "postgres"
401connection = "postgres://audex:password@db.internal:5432/audex"
402max_connections = 10
403"#;
404 let config: DbAuditConfig = toml::from_str(toml_str).unwrap();
405 assert_eq!(config.backend, DbBackend::Postgres);
406 assert_eq!(config.max_connections, 10);
407 }
408
409 #[test]
410 fn test_event_type_str() {
411 assert_eq!(
412 event_type_str(&AuditEvent::SessionCreated {
413 role_arn: "arn".into(),
414 ttl_seconds: 900,
415 budget: None,
416 allowed_actions: vec![],
417 command: vec![],
418 agent_id: None,
419 }),
420 "session_created"
421 );
422 assert_eq!(
423 event_type_str(&AuditEvent::SessionEnded {
424 status: "ok".into(),
425 duration_seconds: 60,
426 exit_code: Some(0),
427 }),
428 "session_ended"
429 );
430 assert_eq!(
431 event_type_str(&AuditEvent::BudgetWarning {
432 current_spend: 1.0,
433 limit: 5.0,
434 }),
435 "budget_warning"
436 );
437 }
438
439 #[test]
440 fn test_entry_to_row() {
441 let entry = AuditEntry {
442 timestamp: chrono::Utc::now(),
443 session_id: "test-123".to_string(),
444 provider: "aws".to_string(),
445 event: AuditEvent::SessionCreated {
446 role_arn: "arn:aws:iam::123:role/Test".to_string(),
447 ttl_seconds: 900,
448 budget: Some(5.0),
449 allowed_actions: vec!["s3:GetObject".to_string()],
450 command: vec!["aws".to_string(), "s3".to_string(), "ls".to_string()],
451 agent_id: Some("claude".to_string()),
452 },
453 };
454 let row = entry_to_row(&entry).unwrap();
455 assert_eq!(row.session_id, "test-123");
456 assert_eq!(row.provider, "aws");
457 assert_eq!(row.event_type, "session_created");
458 assert!(row.event_data.contains("s3:GetObject"));
459 }
460
461 #[test]
462 fn test_row_roundtrip() {
463 let entry = AuditEntry {
464 timestamp: chrono::Utc::now(),
465 session_id: "rt-456".to_string(),
466 provider: "gcp".to_string(),
467 event: AuditEvent::SessionEnded {
468 status: "completed".to_string(),
469 duration_seconds: 120,
470 exit_code: Some(0),
471 },
472 };
473 let row = entry_to_row(&entry).unwrap();
474 let restored = row.to_entry().unwrap();
475 assert_eq!(restored.session_id, "rt-456");
476 assert_eq!(restored.provider, "gcp");
477 }
478
479 #[test]
480 fn test_sql_statements_sqlite() {
481 let b = DbBackend::Sqlite;
482 assert!(SqlStatements::insert(&b).contains("?"));
483 assert!(SqlStatements::by_session(&b).contains("session_id = ?"));
484 assert!(SqlStatements::recent(&b).contains("LIMIT ?"));
485 assert!(SqlStatements::prune(&b).contains("datetime"));
486 }
487
488 #[test]
489 fn test_sql_statements_postgres() {
490 let b = DbBackend::Postgres;
491 assert!(SqlStatements::insert(&b).contains("$1"));
492 assert!(SqlStatements::by_session(&b).contains("session_id = $1"));
493 assert!(SqlStatements::recent(&b).contains("LIMIT $1"));
494 assert!(SqlStatements::prune(&b).contains("INTERVAL"));
495 }
496
497 #[test]
498 fn test_parse_jsonl_for_import() {
499 let jsonl = r#"{"timestamp":"2026-01-01T00:00:00Z","session_id":"s1","provider":"aws","event":{"type":"session_ended","status":"ok","duration_seconds":60,"exit_code":0}}"#;
500 let entries = parse_jsonl_for_import(jsonl);
501 assert_eq!(entries.len(), 1);
502 assert_eq!(entries[0].session_id, "s1");
503 }
504
505 #[test]
506 fn test_generate_import_sql_sqlite() {
507 let entries = vec![AuditEntry {
508 timestamp: chrono::DateTime::parse_from_rfc3339("2026-01-01T00:00:00Z")
509 .unwrap()
510 .with_timezone(&chrono::Utc),
511 session_id: "s1".to_string(),
512 provider: "aws".to_string(),
513 event: AuditEvent::SessionEnded {
514 status: "ok".to_string(),
515 duration_seconds: 60,
516 exit_code: Some(0),
517 },
518 }];
519 let sql = generate_import_sql(&entries, &DbBackend::Sqlite).unwrap();
520 assert!(sql.contains("BEGIN TRANSACTION"));
521 assert!(sql.contains("INSERT INTO audit_entries"));
522 assert!(sql.contains("COMMIT"));
523 }
524
525 #[test]
526 fn test_generate_import_sql_postgres() {
527 let entries = vec![AuditEntry {
528 timestamp: chrono::DateTime::parse_from_rfc3339("2026-01-01T00:00:00Z")
529 .unwrap()
530 .with_timezone(&chrono::Utc),
531 session_id: "s2".to_string(),
532 provider: "gcp".to_string(),
533 event: AuditEvent::CredentialsIssued {
534 access_key_id: "AKID".to_string(),
535 expires_at: chrono::DateTime::parse_from_rfc3339("2026-01-01T00:15:00Z")
536 .unwrap()
537 .with_timezone(&chrono::Utc),
538 },
539 }];
540 let sql = generate_import_sql(&entries, &DbBackend::Postgres).unwrap();
541 assert!(sql.contains("BEGIN;"));
542 assert!(sql.contains("COMMIT;"));
543 }
544
545 #[test]
546 fn test_db_backend_display() {
547 assert_eq!(DbBackend::Sqlite.to_string(), "sqlite");
548 assert_eq!(DbBackend::Postgres.to_string(), "postgres");
549 }
550
551 #[test]
552 fn test_db_audit_stats_serialization() {
553 let stats = DbAuditStats {
554 total_entries: 1000,
555 entries_by_type: vec![
556 ("session_created".to_string(), 400),
557 ("credentials_issued".to_string(), 400),
558 ("session_ended".to_string(), 200),
559 ],
560 backend: "sqlite".to_string(),
561 connection: "/data/audit.db".to_string(),
562 };
563 let json = serde_json::to_string(&stats).unwrap();
564 assert!(json.contains("1000"));
565 assert!(json.contains("session_created"));
566 }
567
568 #[test]
569 fn test_sql_escape_prevents_injection() {
570 assert_eq!(sql_escape("normal"), "normal");
571 assert_eq!(sql_escape("it's"), "it''s");
572 assert_eq!(
573 sql_escape("'); DROP TABLE audit_entries; --"),
574 "''); DROP TABLE audit_entries; --"
575 );
576 }
577
578 #[test]
579 fn test_generate_import_sql_escapes_all_fields() {
580 let entries = vec![AuditEntry {
581 timestamp: chrono::DateTime::parse_from_rfc3339("2026-01-01T00:00:00Z")
582 .unwrap()
583 .with_timezone(&chrono::Utc),
584 session_id: "it's-a-trap".to_string(),
585 provider: "aw's".to_string(),
586 event: AuditEvent::SessionEnded {
587 status: "ok".to_string(),
588 duration_seconds: 60,
589 exit_code: Some(0),
590 },
591 }];
592 let sql = generate_import_sql(&entries, &DbBackend::Sqlite).unwrap();
593 assert!(sql.contains("it''s-a-trap"));
595 assert!(sql.contains("aw''s"));
596 assert!(!sql.contains("it's-a-trap',"));
598 }
599}