1use serde::{Deserialize, Serialize};
2
3use crate::audit::{AuditEntry, AuditEvent};
4use crate::error::{AvError, Result};
5use crate::session::Session;
6
7#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
9#[serde(rename_all = "lowercase")]
10pub enum DbBackend {
11 Sqlite,
12 Postgres,
13}
14
15impl std::fmt::Display for DbBackend {
16 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
17 match self {
18 Self::Sqlite => write!(f, "sqlite"),
19 Self::Postgres => write!(f, "postgres"),
20 }
21 }
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct DbAuditConfig {
27 pub backend: DbBackend,
29 pub connection: String,
33 #[serde(default = "default_max_connections")]
35 pub max_connections: u32,
36 #[serde(default = "default_wal_mode")]
38 pub wal_mode: bool,
39 #[serde(default)]
41 pub retention_days: u32,
42}
43
44fn default_max_connections() -> u32 {
45 5
46}
47
48fn default_wal_mode() -> bool {
49 true
50}
51
52impl Default for DbAuditConfig {
53 fn default() -> Self {
54 let path = dirs::data_local_dir()
55 .unwrap_or_else(|| std::path::PathBuf::from("."))
56 .join("audex")
57 .join("audit.db");
58 Self {
59 backend: DbBackend::Sqlite,
60 connection: path.to_string_lossy().to_string(),
61 max_connections: default_max_connections(),
62 wal_mode: default_wal_mode(),
63 retention_days: 0,
64 }
65 }
66}
67
68pub const SQLITE_SCHEMA: &str = r#"
70CREATE TABLE IF NOT EXISTS audit_entries (
71 id INTEGER PRIMARY KEY AUTOINCREMENT,
72 timestamp TEXT NOT NULL,
73 session_id TEXT NOT NULL,
74 provider TEXT NOT NULL DEFAULT 'aws',
75 event_type TEXT NOT NULL,
76 event_data TEXT NOT NULL,
77 created_at TEXT NOT NULL DEFAULT (datetime('now'))
78);
79
80CREATE INDEX IF NOT EXISTS idx_audit_session_id ON audit_entries(session_id);
81CREATE INDEX IF NOT EXISTS idx_audit_timestamp ON audit_entries(timestamp);
82CREATE INDEX IF NOT EXISTS idx_audit_provider ON audit_entries(provider);
83CREATE INDEX IF NOT EXISTS idx_audit_event_type ON audit_entries(event_type);
84"#;
85
86pub const POSTGRES_SCHEMA: &str = r#"
87CREATE TABLE IF NOT EXISTS audit_entries (
88 id BIGSERIAL PRIMARY KEY,
89 timestamp TIMESTAMPTZ NOT NULL,
90 session_id TEXT NOT NULL,
91 provider TEXT NOT NULL DEFAULT 'aws',
92 event_type TEXT NOT NULL,
93 event_data JSONB NOT NULL,
94 created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
95);
96
97CREATE INDEX IF NOT EXISTS idx_audit_session_id ON audit_entries(session_id);
98CREATE INDEX IF NOT EXISTS idx_audit_timestamp ON audit_entries(timestamp);
99CREATE INDEX IF NOT EXISTS idx_audit_provider ON audit_entries(provider);
100CREATE INDEX IF NOT EXISTS idx_audit_event_type ON audit_entries(event_type);
101"#;
102
103pub fn event_type_str(event: &AuditEvent) -> &'static str {
105 match event {
106 AuditEvent::SessionCreated { .. } => "session_created",
107 AuditEvent::CredentialsIssued { .. } => "credentials_issued",
108 AuditEvent::SessionEnded { .. } => "session_ended",
109 AuditEvent::BudgetWarning { .. } => "budget_warning",
110 AuditEvent::BudgetExceeded { .. } => "budget_exceeded",
111 AuditEvent::ResourceCreated { .. } => "resource_created",
112 }
113}
114
115pub struct SqlStatements;
117
118impl SqlStatements {
119 pub fn insert(backend: &DbBackend) -> &'static str {
121 match backend {
122 DbBackend::Sqlite => {
123 "INSERT INTO audit_entries (timestamp, session_id, provider, event_type, event_data) VALUES (?, ?, ?, ?, ?)"
124 }
125 DbBackend::Postgres => {
126 "INSERT INTO audit_entries (timestamp, session_id, provider, event_type, event_data) VALUES ($1, $2, $3, $4, $5::jsonb)"
127 }
128 }
129 }
130
131 pub fn by_session(backend: &DbBackend) -> &'static str {
133 match backend {
134 DbBackend::Sqlite => {
135 "SELECT timestamp, session_id, provider, event_type, event_data FROM audit_entries WHERE session_id = ? ORDER BY timestamp ASC"
136 }
137 DbBackend::Postgres => {
138 "SELECT timestamp, session_id, provider, event_type, event_data FROM audit_entries WHERE session_id = $1 ORDER BY timestamp ASC"
139 }
140 }
141 }
142
143 pub fn recent(backend: &DbBackend) -> &'static str {
145 match backend {
146 DbBackend::Sqlite => {
147 "SELECT timestamp, session_id, provider, event_type, event_data FROM audit_entries ORDER BY timestamp DESC LIMIT ?"
148 }
149 DbBackend::Postgres => {
150 "SELECT timestamp, session_id, provider, event_type, event_data FROM audit_entries ORDER BY timestamp DESC LIMIT $1"
151 }
152 }
153 }
154
155 pub fn by_provider(backend: &DbBackend) -> &'static str {
157 match backend {
158 DbBackend::Sqlite => {
159 "SELECT timestamp, session_id, provider, event_type, event_data FROM audit_entries WHERE provider = ? ORDER BY timestamp DESC LIMIT ?"
160 }
161 DbBackend::Postgres => {
162 "SELECT timestamp, session_id, provider, event_type, event_data FROM audit_entries WHERE provider = $1 ORDER BY timestamp DESC LIMIT $2"
163 }
164 }
165 }
166
167 pub fn count(backend: &DbBackend) -> &'static str {
169 match backend {
170 DbBackend::Sqlite => "SELECT COUNT(*) FROM audit_entries",
171 DbBackend::Postgres => "SELECT COUNT(*) FROM audit_entries",
172 }
173 }
174
175 pub fn count_by_type(backend: &DbBackend) -> &'static str {
177 match backend {
178 DbBackend::Sqlite => {
179 "SELECT event_type, COUNT(*) as count FROM audit_entries GROUP BY event_type ORDER BY count DESC"
180 }
181 DbBackend::Postgres => {
182 "SELECT event_type, COUNT(*) as count FROM audit_entries GROUP BY event_type ORDER BY count DESC"
183 }
184 }
185 }
186
187 pub fn prune(backend: &DbBackend) -> &'static str {
189 match backend {
190 DbBackend::Sqlite => {
191 "DELETE FROM audit_entries WHERE timestamp < datetime('now', '-' || ? || ' days')"
192 }
193 DbBackend::Postgres => {
194 "DELETE FROM audit_entries WHERE timestamp < NOW() - ($1 || ' days')::INTERVAL"
195 }
196 }
197 }
198
199 pub fn search(backend: &DbBackend) -> &'static str {
201 match backend {
202 DbBackend::Sqlite => {
203 "SELECT timestamp, session_id, provider, event_type, event_data FROM audit_entries WHERE event_data LIKE ? ORDER BY timestamp DESC LIMIT ?"
204 }
205 DbBackend::Postgres => {
206 "SELECT timestamp, session_id, provider, event_type, event_data FROM audit_entries WHERE event_data::text ILIKE $1 ORDER BY timestamp DESC LIMIT $2"
207 }
208 }
209 }
210}
211
212pub fn entry_to_row(entry: &AuditEntry) -> Result<DbRow> {
214 let event_data = serde_json::to_string(&entry.event)
215 .map_err(|e| AvError::Sts(format!("Failed to serialize event: {}", e)))?;
216
217 let event_data = crate::leakdetect::redact_secrets(&event_data);
219
220 Ok(DbRow {
221 timestamp: entry.timestamp.to_rfc3339(),
222 session_id: entry.session_id.clone(),
223 provider: entry.provider.clone(),
224 event_type: event_type_str(&entry.event).to_string(),
225 event_data,
226 })
227}
228
229pub fn session_created_entry(session: &Session) -> AuditEntry {
231 use crate::session::CloudProvider;
232
233 let allowed_actions: Vec<String> = match session.provider {
234 CloudProvider::Gcp => session
235 .policy
236 .actions
237 .iter()
238 .map(|a| a.to_gcp_permission())
239 .collect(),
240 CloudProvider::Azure => session
241 .policy
242 .actions
243 .iter()
244 .map(|a| a.to_azure_permission())
245 .collect(),
246 CloudProvider::Aws => session
247 .policy
248 .actions
249 .iter()
250 .map(|a| a.to_iam_action())
251 .collect(),
252 };
253
254 AuditEntry {
255 timestamp: chrono::Utc::now(),
256 session_id: session.id.clone(),
257 provider: session.provider.to_string(),
258 event: AuditEvent::SessionCreated {
259 role_arn: session.role_arn.clone(),
260 ttl_seconds: session.ttl_seconds,
261 budget: session.budget,
262 allowed_actions,
263 command: session.command.clone(),
264 agent_id: session.agent_id.clone(),
265 },
266 }
267}
268
269#[derive(Debug, Clone, Serialize, Deserialize)]
271pub struct DbRow {
272 pub timestamp: String,
273 pub session_id: String,
274 pub provider: String,
275 pub event_type: String,
276 pub event_data: String,
277}
278
279impl DbRow {
280 pub fn to_entry(&self) -> Result<AuditEntry> {
282 let timestamp = chrono::DateTime::parse_from_rfc3339(&self.timestamp)
283 .map(|dt| dt.with_timezone(&chrono::Utc))
284 .map_err(|e| AvError::Sts(format!("Invalid timestamp: {}", e)))?;
285
286 let event: AuditEvent = serde_json::from_str(&self.event_data)
287 .map_err(|e| AvError::Sts(format!("Invalid event data: {}", e)))?;
288
289 Ok(AuditEntry {
290 timestamp,
291 session_id: self.session_id.clone(),
292 provider: self.provider.clone(),
293 event,
294 })
295 }
296}
297
298#[derive(Debug, Clone, Serialize, Deserialize)]
300pub struct DbAuditStats {
301 pub total_entries: u64,
302 pub entries_by_type: Vec<(String, u64)>,
303 pub backend: String,
304 pub connection: String,
305}
306
307pub fn parse_jsonl_for_import(content: &str) -> Vec<AuditEntry> {
309 content
310 .lines()
311 .filter(|l| !l.trim().is_empty())
312 .filter_map(|l| {
313 let (json_content, _) = crate::integrity::parse_line(l);
314 serde_json::from_str(json_content).ok()
315 })
316 .collect()
317}
318
319pub fn generate_import_sql(entries: &[AuditEntry], backend: &DbBackend) -> Result<String> {
321 let mut sql = String::new();
322
323 match backend {
324 DbBackend::Sqlite => {
325 sql.push_str("BEGIN TRANSACTION;\n");
326 }
327 DbBackend::Postgres => {
328 sql.push_str("BEGIN;\n");
329 }
330 }
331
332 for entry in entries {
333 let row = entry_to_row(entry)?;
334 let escaped_data = row.event_data.replace('\'', "''");
335 sql.push_str(&format!(
336 "INSERT INTO audit_entries (timestamp, session_id, provider, event_type, event_data) VALUES ('{}', '{}', '{}', '{}', '{}');\n",
337 row.timestamp, row.session_id, row.provider, row.event_type, escaped_data
338 ));
339 }
340
341 sql.push_str("COMMIT;\n");
342 Ok(sql)
343}
344
345#[cfg(test)]
346mod tests {
347 use super::*;
348 use crate::audit::AuditEvent;
349
350 #[test]
351 fn test_config_default() {
352 let config = DbAuditConfig::default();
353 assert_eq!(config.backend, DbBackend::Sqlite);
354 assert!(config.connection.contains("audit.db"));
355 assert_eq!(config.max_connections, 5);
356 assert!(config.wal_mode);
357 }
358
359 #[test]
360 fn test_config_deserialize_sqlite() {
361 let toml_str = r#"
362backend = "sqlite"
363connection = "/data/audex/audit.db"
364wal_mode = true
365retention_days = 90
366"#;
367 let config: DbAuditConfig = toml::from_str(toml_str).unwrap();
368 assert_eq!(config.backend, DbBackend::Sqlite);
369 assert_eq!(config.connection, "/data/audex/audit.db");
370 assert_eq!(config.retention_days, 90);
371 }
372
373 #[test]
374 fn test_config_deserialize_postgres() {
375 let toml_str = r#"
376backend = "postgres"
377connection = "postgres://audex:password@db.internal:5432/audex"
378max_connections = 10
379"#;
380 let config: DbAuditConfig = toml::from_str(toml_str).unwrap();
381 assert_eq!(config.backend, DbBackend::Postgres);
382 assert_eq!(config.max_connections, 10);
383 }
384
385 #[test]
386 fn test_event_type_str() {
387 assert_eq!(
388 event_type_str(&AuditEvent::SessionCreated {
389 role_arn: "arn".into(),
390 ttl_seconds: 900,
391 budget: None,
392 allowed_actions: vec![],
393 command: vec![],
394 agent_id: None,
395 }),
396 "session_created"
397 );
398 assert_eq!(
399 event_type_str(&AuditEvent::SessionEnded {
400 status: "ok".into(),
401 duration_seconds: 60,
402 exit_code: Some(0),
403 }),
404 "session_ended"
405 );
406 assert_eq!(
407 event_type_str(&AuditEvent::BudgetWarning {
408 current_spend: 1.0,
409 limit: 5.0,
410 }),
411 "budget_warning"
412 );
413 }
414
415 #[test]
416 fn test_entry_to_row() {
417 let entry = AuditEntry {
418 timestamp: chrono::Utc::now(),
419 session_id: "test-123".to_string(),
420 provider: "aws".to_string(),
421 event: AuditEvent::SessionCreated {
422 role_arn: "arn:aws:iam::123:role/Test".to_string(),
423 ttl_seconds: 900,
424 budget: Some(5.0),
425 allowed_actions: vec!["s3:GetObject".to_string()],
426 command: vec!["aws".to_string(), "s3".to_string(), "ls".to_string()],
427 agent_id: Some("claude".to_string()),
428 },
429 };
430 let row = entry_to_row(&entry).unwrap();
431 assert_eq!(row.session_id, "test-123");
432 assert_eq!(row.provider, "aws");
433 assert_eq!(row.event_type, "session_created");
434 assert!(row.event_data.contains("s3:GetObject"));
435 }
436
437 #[test]
438 fn test_row_roundtrip() {
439 let entry = AuditEntry {
440 timestamp: chrono::Utc::now(),
441 session_id: "rt-456".to_string(),
442 provider: "gcp".to_string(),
443 event: AuditEvent::SessionEnded {
444 status: "completed".to_string(),
445 duration_seconds: 120,
446 exit_code: Some(0),
447 },
448 };
449 let row = entry_to_row(&entry).unwrap();
450 let restored = row.to_entry().unwrap();
451 assert_eq!(restored.session_id, "rt-456");
452 assert_eq!(restored.provider, "gcp");
453 }
454
455 #[test]
456 fn test_sql_statements_sqlite() {
457 let b = DbBackend::Sqlite;
458 assert!(SqlStatements::insert(&b).contains("?"));
459 assert!(SqlStatements::by_session(&b).contains("session_id = ?"));
460 assert!(SqlStatements::recent(&b).contains("LIMIT ?"));
461 assert!(SqlStatements::prune(&b).contains("datetime"));
462 }
463
464 #[test]
465 fn test_sql_statements_postgres() {
466 let b = DbBackend::Postgres;
467 assert!(SqlStatements::insert(&b).contains("$1"));
468 assert!(SqlStatements::by_session(&b).contains("session_id = $1"));
469 assert!(SqlStatements::recent(&b).contains("LIMIT $1"));
470 assert!(SqlStatements::prune(&b).contains("INTERVAL"));
471 }
472
473 #[test]
474 fn test_parse_jsonl_for_import() {
475 let jsonl = r#"{"timestamp":"2026-01-01T00:00:00Z","session_id":"s1","provider":"aws","event":{"type":"session_ended","status":"ok","duration_seconds":60,"exit_code":0}}"#;
476 let entries = parse_jsonl_for_import(jsonl);
477 assert_eq!(entries.len(), 1);
478 assert_eq!(entries[0].session_id, "s1");
479 }
480
481 #[test]
482 fn test_generate_import_sql_sqlite() {
483 let entries = vec![AuditEntry {
484 timestamp: chrono::DateTime::parse_from_rfc3339("2026-01-01T00:00:00Z")
485 .unwrap()
486 .with_timezone(&chrono::Utc),
487 session_id: "s1".to_string(),
488 provider: "aws".to_string(),
489 event: AuditEvent::SessionEnded {
490 status: "ok".to_string(),
491 duration_seconds: 60,
492 exit_code: Some(0),
493 },
494 }];
495 let sql = generate_import_sql(&entries, &DbBackend::Sqlite).unwrap();
496 assert!(sql.contains("BEGIN TRANSACTION"));
497 assert!(sql.contains("INSERT INTO audit_entries"));
498 assert!(sql.contains("COMMIT"));
499 }
500
501 #[test]
502 fn test_generate_import_sql_postgres() {
503 let entries = vec![AuditEntry {
504 timestamp: chrono::DateTime::parse_from_rfc3339("2026-01-01T00:00:00Z")
505 .unwrap()
506 .with_timezone(&chrono::Utc),
507 session_id: "s2".to_string(),
508 provider: "gcp".to_string(),
509 event: AuditEvent::CredentialsIssued {
510 access_key_id: "AKID".to_string(),
511 expires_at: chrono::DateTime::parse_from_rfc3339("2026-01-01T00:15:00Z")
512 .unwrap()
513 .with_timezone(&chrono::Utc),
514 },
515 }];
516 let sql = generate_import_sql(&entries, &DbBackend::Postgres).unwrap();
517 assert!(sql.contains("BEGIN;"));
518 assert!(sql.contains("COMMIT;"));
519 }
520
521 #[test]
522 fn test_db_backend_display() {
523 assert_eq!(DbBackend::Sqlite.to_string(), "sqlite");
524 assert_eq!(DbBackend::Postgres.to_string(), "postgres");
525 }
526
527 #[test]
528 fn test_db_audit_stats_serialization() {
529 let stats = DbAuditStats {
530 total_entries: 1000,
531 entries_by_type: vec![
532 ("session_created".to_string(), 400),
533 ("credentials_issued".to_string(), 400),
534 ("session_ended".to_string(), 200),
535 ],
536 backend: "sqlite".to_string(),
537 connection: "/data/audit.db".to_string(),
538 };
539 let json = serde_json::to_string(&stats).unwrap();
540 assert!(json.contains("1000"));
541 assert!(json.contains("session_created"));
542 }
543}