Skip to main content

tryaudex_core/
dbaudit.rs

1use serde::{Deserialize, Serialize};
2
3use crate::audit::{AuditEntry, AuditEvent};
4use crate::error::{AvError, Result};
5use crate::session::Session;
6
7/// Database audit backend type.
8#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
9#[serde(rename_all = "lowercase")]
10pub enum DbBackend {
11    Sqlite,
12    Postgres,
13}
14
15impl std::fmt::Display for DbBackend {
16    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
17        match self {
18            Self::Sqlite => write!(f, "sqlite"),
19            Self::Postgres => write!(f, "postgres"),
20        }
21    }
22}
23
24/// Configuration for the database audit backend.
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct DbAuditConfig {
27    /// Backend type: "sqlite" or "postgres"
28    pub backend: DbBackend,
29    /// Connection string.
30    /// - SQLite: file path (e.g. "~/.local/share/audex/audit.db") or ":memory:"
31    /// - Postgres: connection URI (e.g. "postgres://user:pass@localhost/audex")
32    ///
33    /// The env var `AUDEX_DB_CONNECTION` takes precedence if set, avoiding
34    /// plaintext passwords in the config file.
35    pub connection: String,
36    /// Maximum number of connections in the pool (postgres only, default: 5)
37    #[serde(default = "default_max_connections")]
38    pub max_connections: u32,
39    /// Enable WAL mode for SQLite (default: true)
40    #[serde(default = "default_wal_mode")]
41    pub wal_mode: bool,
42    /// Retention period in days (0 = keep forever, default: 0)
43    #[serde(default)]
44    pub retention_days: u32,
45}
46
47fn default_max_connections() -> u32 {
48    5
49}
50
51fn default_wal_mode() -> bool {
52    true
53}
54
55impl Default for DbAuditConfig {
56    fn default() -> Self {
57        let path = dirs::data_local_dir()
58            .unwrap_or_else(|| std::path::PathBuf::from("."))
59            .join("audex")
60            .join("audit.db");
61        Self {
62            backend: DbBackend::Sqlite,
63            connection: path.to_string_lossy().to_string(),
64            max_connections: default_max_connections(),
65            wal_mode: default_wal_mode(),
66            retention_days: 0,
67        }
68    }
69}
70
71impl DbAuditConfig {
72    /// Return the effective connection string, preferring `AUDEX_DB_CONNECTION`
73    /// env var over the config-file value so passwords need not be stored in
74    /// plaintext on disk.
75    pub fn resolve_connection(&self) -> String {
76        std::env::var("AUDEX_DB_CONNECTION").unwrap_or_else(|_| self.connection.clone())
77    }
78}
79
80/// SQL schema for audit tables.
81pub const SQLITE_SCHEMA: &str = r#"
82CREATE TABLE IF NOT EXISTS audit_entries (
83    id INTEGER PRIMARY KEY AUTOINCREMENT,
84    timestamp TEXT NOT NULL,
85    session_id TEXT NOT NULL,
86    provider TEXT NOT NULL DEFAULT 'aws',
87    event_type TEXT NOT NULL,
88    event_data TEXT NOT NULL,
89    created_at TEXT NOT NULL DEFAULT (datetime('now'))
90);
91
92CREATE INDEX IF NOT EXISTS idx_audit_session_id ON audit_entries(session_id);
93CREATE INDEX IF NOT EXISTS idx_audit_timestamp ON audit_entries(timestamp);
94CREATE INDEX IF NOT EXISTS idx_audit_provider ON audit_entries(provider);
95CREATE INDEX IF NOT EXISTS idx_audit_event_type ON audit_entries(event_type);
96"#;
97
98pub const POSTGRES_SCHEMA: &str = r#"
99CREATE TABLE IF NOT EXISTS audit_entries (
100    id BIGSERIAL PRIMARY KEY,
101    timestamp TIMESTAMPTZ NOT NULL,
102    session_id TEXT NOT NULL,
103    provider TEXT NOT NULL DEFAULT 'aws',
104    event_type TEXT NOT NULL,
105    event_data JSONB NOT NULL,
106    created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
107);
108
109CREATE INDEX IF NOT EXISTS idx_audit_session_id ON audit_entries(session_id);
110CREATE INDEX IF NOT EXISTS idx_audit_timestamp ON audit_entries(timestamp);
111CREATE INDEX IF NOT EXISTS idx_audit_provider ON audit_entries(provider);
112CREATE INDEX IF NOT EXISTS idx_audit_event_type ON audit_entries(event_type);
113"#;
114
115/// Extract event type string from an AuditEvent.
116pub fn event_type_str(event: &AuditEvent) -> &'static str {
117    match event {
118        AuditEvent::SessionCreated { .. } => "session_created",
119        AuditEvent::CredentialsIssued { .. } => "credentials_issued",
120        AuditEvent::SessionEnded { .. } => "session_ended",
121        AuditEvent::BudgetWarning { .. } => "budget_warning",
122        AuditEvent::BudgetExceeded { .. } => "budget_exceeded",
123        AuditEvent::ResourceCreated { .. } => "resource_created",
124        AuditEvent::PolicyAdvisoryOnly { .. } => "policy_advisory_only",
125    }
126}
127
128/// SQL statements for common operations.
129pub struct SqlStatements;
130
131impl SqlStatements {
132    /// Insert an audit entry.
133    pub fn insert(backend: &DbBackend) -> &'static str {
134        match backend {
135            DbBackend::Sqlite => {
136                "INSERT INTO audit_entries (timestamp, session_id, provider, event_type, event_data) VALUES (?, ?, ?, ?, ?)"
137            }
138            DbBackend::Postgres => {
139                "INSERT INTO audit_entries (timestamp, session_id, provider, event_type, event_data) VALUES ($1, $2, $3, $4, $5::jsonb)"
140            }
141        }
142    }
143
144    /// Query entries by session ID.
145    pub fn by_session(backend: &DbBackend) -> &'static str {
146        match backend {
147            DbBackend::Sqlite => {
148                "SELECT timestamp, session_id, provider, event_type, event_data FROM audit_entries WHERE session_id = ? ORDER BY timestamp ASC"
149            }
150            DbBackend::Postgres => {
151                "SELECT timestamp, session_id, provider, event_type, event_data FROM audit_entries WHERE session_id = $1 ORDER BY timestamp ASC"
152            }
153        }
154    }
155
156    /// Query recent entries with optional limit.
157    pub fn recent(backend: &DbBackend) -> &'static str {
158        match backend {
159            DbBackend::Sqlite => {
160                "SELECT timestamp, session_id, provider, event_type, event_data FROM audit_entries ORDER BY timestamp DESC LIMIT ?"
161            }
162            DbBackend::Postgres => {
163                "SELECT timestamp, session_id, provider, event_type, event_data FROM audit_entries ORDER BY timestamp DESC LIMIT $1"
164            }
165        }
166    }
167
168    /// Query entries by provider.
169    pub fn by_provider(backend: &DbBackend) -> &'static str {
170        match backend {
171            DbBackend::Sqlite => {
172                "SELECT timestamp, session_id, provider, event_type, event_data FROM audit_entries WHERE provider = ? ORDER BY timestamp DESC LIMIT ?"
173            }
174            DbBackend::Postgres => {
175                "SELECT timestamp, session_id, provider, event_type, event_data FROM audit_entries WHERE provider = $1 ORDER BY timestamp DESC LIMIT $2"
176            }
177        }
178    }
179
180    /// Count entries.
181    pub fn count(backend: &DbBackend) -> &'static str {
182        match backend {
183            DbBackend::Sqlite => "SELECT COUNT(*) FROM audit_entries",
184            DbBackend::Postgres => "SELECT COUNT(*) FROM audit_entries",
185        }
186    }
187
188    /// Count entries by event type.
189    pub fn count_by_type(backend: &DbBackend) -> &'static str {
190        match backend {
191            DbBackend::Sqlite => {
192                "SELECT event_type, COUNT(*) as count FROM audit_entries GROUP BY event_type ORDER BY count DESC"
193            }
194            DbBackend::Postgres => {
195                "SELECT event_type, COUNT(*) as count FROM audit_entries GROUP BY event_type ORDER BY count DESC"
196            }
197        }
198    }
199
200    /// Delete entries older than N days.
201    pub fn prune(backend: &DbBackend) -> &'static str {
202        match backend {
203            DbBackend::Sqlite => {
204                "DELETE FROM audit_entries WHERE timestamp < datetime('now', '-' || ? || ' days')"
205            }
206            DbBackend::Postgres => {
207                "DELETE FROM audit_entries WHERE timestamp < NOW() - ($1 || ' days')::INTERVAL"
208            }
209        }
210    }
211
212    /// Search event data with text matching.
213    pub fn search(backend: &DbBackend) -> &'static str {
214        match backend {
215            DbBackend::Sqlite => {
216                "SELECT timestamp, session_id, provider, event_type, event_data FROM audit_entries WHERE event_data LIKE ? ORDER BY timestamp DESC LIMIT ?"
217            }
218            DbBackend::Postgres => {
219                "SELECT timestamp, session_id, provider, event_type, event_data FROM audit_entries WHERE event_data::text ILIKE $1 ORDER BY timestamp DESC LIMIT $2"
220            }
221        }
222    }
223}
224
225/// Convert an AuditEntry to database row values.
226pub fn entry_to_row(entry: &AuditEntry) -> Result<DbRow> {
227    let event_data = serde_json::to_string(&entry.event)
228        .map_err(|e| AvError::Sts(format!("Failed to serialize event: {}", e)))?;
229
230    // Redact secrets before storing
231    let event_data = crate::leakdetect::redact_secrets(&event_data);
232
233    Ok(DbRow {
234        timestamp: entry.timestamp.to_rfc3339(),
235        session_id: entry.session_id.clone(),
236        provider: entry.provider.clone(),
237        event_type: event_type_str(&entry.event).to_string(),
238        event_data,
239    })
240}
241
242/// Convert a Session to an AuditEntry for session_created events.
243pub fn session_created_entry(session: &Session) -> AuditEntry {
244    use crate::session::CloudProvider;
245
246    let allowed_actions: Vec<String> = match session.provider {
247        CloudProvider::Gcp => session
248            .policy
249            .actions
250            .iter()
251            .map(|a| a.to_gcp_permission())
252            .collect(),
253        CloudProvider::Azure => session
254            .policy
255            .actions
256            .iter()
257            .map(|a| a.to_azure_permission())
258            .collect(),
259        CloudProvider::Aws => session
260            .policy
261            .actions
262            .iter()
263            .map(|a| a.to_iam_action())
264            .collect(),
265    };
266
267    AuditEntry {
268        timestamp: chrono::Utc::now(),
269        session_id: session.id.clone(),
270        provider: session.provider.to_string(),
271        event: AuditEvent::SessionCreated {
272            role_arn: session.role_arn.clone(),
273            ttl_seconds: session.ttl_seconds,
274            budget: session.budget,
275            allowed_actions,
276            command: session.command.clone(),
277            agent_id: session.agent_id.clone(),
278        },
279    }
280}
281
282/// Database row representation.
283#[derive(Debug, Clone, Serialize, Deserialize)]
284pub struct DbRow {
285    pub timestamp: String,
286    pub session_id: String,
287    pub provider: String,
288    pub event_type: String,
289    pub event_data: String,
290}
291
292impl DbRow {
293    /// Convert back to an AuditEntry.
294    pub fn to_entry(&self) -> Result<AuditEntry> {
295        let timestamp = chrono::DateTime::parse_from_rfc3339(&self.timestamp)
296            .map(|dt| dt.with_timezone(&chrono::Utc))
297            .map_err(|e| AvError::Sts(format!("Invalid timestamp: {}", e)))?;
298
299        let event: AuditEvent = serde_json::from_str(&self.event_data)
300            .map_err(|e| AvError::Sts(format!("Invalid event data: {}", e)))?;
301
302        Ok(AuditEntry {
303            timestamp,
304            session_id: self.session_id.clone(),
305            provider: self.provider.clone(),
306            event,
307        })
308    }
309}
310
311/// Statistics about the audit database.
312#[derive(Debug, Clone, Serialize, Deserialize)]
313pub struct DbAuditStats {
314    pub total_entries: u64,
315    pub entries_by_type: Vec<(String, u64)>,
316    pub backend: String,
317    pub connection: String,
318}
319
320/// Migration helper — import JSONL audit log into database.
321pub fn parse_jsonl_for_import(content: &str) -> Vec<AuditEntry> {
322    content
323        .lines()
324        .filter(|l| !l.trim().is_empty())
325        .filter_map(|l| {
326            let (json_content, _) = crate::integrity::parse_line(l);
327            serde_json::from_str(json_content).ok()
328        })
329        .collect()
330}
331
332/// Escape a string for safe inclusion in a SQL literal.
333///
334/// Doubles single quotes (standard SQL) and escapes backslashes for
335/// PostgreSQL configurations with `standard_conforming_strings = off`.
336fn sql_escape(s: &str) -> String {
337    s.replace('\\', "\\\\").replace('\'', "''")
338}
339
340/// Generate a migration script to import existing JSONL data.
341pub fn generate_import_sql(entries: &[AuditEntry], backend: &DbBackend) -> Result<String> {
342    let mut sql = String::new();
343
344    match backend {
345        DbBackend::Sqlite => {
346            sql.push_str("BEGIN TRANSACTION;\n");
347        }
348        DbBackend::Postgres => {
349            sql.push_str("BEGIN;\n");
350        }
351    }
352
353    for entry in entries {
354        let row = entry_to_row(entry)?;
355        sql.push_str(&format!(
356            "INSERT INTO audit_entries (timestamp, session_id, provider, event_type, event_data) VALUES ('{}', '{}', '{}', '{}', '{}');\n",
357            sql_escape(&row.timestamp),
358            sql_escape(&row.session_id),
359            sql_escape(&row.provider),
360            sql_escape(&row.event_type),
361            sql_escape(&row.event_data),
362        ));
363    }
364
365    sql.push_str("COMMIT;\n");
366    Ok(sql)
367}
368
369#[cfg(test)]
370mod tests {
371    use super::*;
372    use crate::audit::AuditEvent;
373
374    #[test]
375    fn test_config_default() {
376        let config = DbAuditConfig::default();
377        assert_eq!(config.backend, DbBackend::Sqlite);
378        assert!(config.connection.contains("audit.db"));
379        assert_eq!(config.max_connections, 5);
380        assert!(config.wal_mode);
381    }
382
383    #[test]
384    fn test_config_deserialize_sqlite() {
385        let toml_str = r#"
386backend = "sqlite"
387connection = "/data/audex/audit.db"
388wal_mode = true
389retention_days = 90
390"#;
391        let config: DbAuditConfig = toml::from_str(toml_str).unwrap();
392        assert_eq!(config.backend, DbBackend::Sqlite);
393        assert_eq!(config.connection, "/data/audex/audit.db");
394        assert_eq!(config.retention_days, 90);
395    }
396
397    #[test]
398    fn test_config_deserialize_postgres() {
399        let toml_str = r#"
400backend = "postgres"
401connection = "postgres://audex:password@db.internal:5432/audex"
402max_connections = 10
403"#;
404        let config: DbAuditConfig = toml::from_str(toml_str).unwrap();
405        assert_eq!(config.backend, DbBackend::Postgres);
406        assert_eq!(config.max_connections, 10);
407    }
408
409    #[test]
410    fn test_event_type_str() {
411        assert_eq!(
412            event_type_str(&AuditEvent::SessionCreated {
413                role_arn: "arn".into(),
414                ttl_seconds: 900,
415                budget: None,
416                allowed_actions: vec![],
417                command: vec![],
418                agent_id: None,
419            }),
420            "session_created"
421        );
422        assert_eq!(
423            event_type_str(&AuditEvent::SessionEnded {
424                status: "ok".into(),
425                duration_seconds: 60,
426                exit_code: Some(0),
427            }),
428            "session_ended"
429        );
430        assert_eq!(
431            event_type_str(&AuditEvent::BudgetWarning {
432                current_spend: 1.0,
433                limit: 5.0,
434            }),
435            "budget_warning"
436        );
437    }
438
439    #[test]
440    fn test_entry_to_row() {
441        let entry = AuditEntry {
442            timestamp: chrono::Utc::now(),
443            session_id: "test-123".to_string(),
444            provider: "aws".to_string(),
445            event: AuditEvent::SessionCreated {
446                role_arn: "arn:aws:iam::123:role/Test".to_string(),
447                ttl_seconds: 900,
448                budget: Some(5.0),
449                allowed_actions: vec!["s3:GetObject".to_string()],
450                command: vec!["aws".to_string(), "s3".to_string(), "ls".to_string()],
451                agent_id: Some("claude".to_string()),
452            },
453        };
454        let row = entry_to_row(&entry).unwrap();
455        assert_eq!(row.session_id, "test-123");
456        assert_eq!(row.provider, "aws");
457        assert_eq!(row.event_type, "session_created");
458        assert!(row.event_data.contains("s3:GetObject"));
459    }
460
461    #[test]
462    fn test_row_roundtrip() {
463        let entry = AuditEntry {
464            timestamp: chrono::Utc::now(),
465            session_id: "rt-456".to_string(),
466            provider: "gcp".to_string(),
467            event: AuditEvent::SessionEnded {
468                status: "completed".to_string(),
469                duration_seconds: 120,
470                exit_code: Some(0),
471            },
472        };
473        let row = entry_to_row(&entry).unwrap();
474        let restored = row.to_entry().unwrap();
475        assert_eq!(restored.session_id, "rt-456");
476        assert_eq!(restored.provider, "gcp");
477    }
478
479    #[test]
480    fn test_sql_statements_sqlite() {
481        let b = DbBackend::Sqlite;
482        assert!(SqlStatements::insert(&b).contains("?"));
483        assert!(SqlStatements::by_session(&b).contains("session_id = ?"));
484        assert!(SqlStatements::recent(&b).contains("LIMIT ?"));
485        assert!(SqlStatements::prune(&b).contains("datetime"));
486    }
487
488    #[test]
489    fn test_sql_statements_postgres() {
490        let b = DbBackend::Postgres;
491        assert!(SqlStatements::insert(&b).contains("$1"));
492        assert!(SqlStatements::by_session(&b).contains("session_id = $1"));
493        assert!(SqlStatements::recent(&b).contains("LIMIT $1"));
494        assert!(SqlStatements::prune(&b).contains("INTERVAL"));
495    }
496
497    #[test]
498    fn test_parse_jsonl_for_import() {
499        let jsonl = r#"{"timestamp":"2026-01-01T00:00:00Z","session_id":"s1","provider":"aws","event":{"type":"session_ended","status":"ok","duration_seconds":60,"exit_code":0}}"#;
500        let entries = parse_jsonl_for_import(jsonl);
501        assert_eq!(entries.len(), 1);
502        assert_eq!(entries[0].session_id, "s1");
503    }
504
505    #[test]
506    fn test_generate_import_sql_sqlite() {
507        let entries = vec![AuditEntry {
508            timestamp: chrono::DateTime::parse_from_rfc3339("2026-01-01T00:00:00Z")
509                .unwrap()
510                .with_timezone(&chrono::Utc),
511            session_id: "s1".to_string(),
512            provider: "aws".to_string(),
513            event: AuditEvent::SessionEnded {
514                status: "ok".to_string(),
515                duration_seconds: 60,
516                exit_code: Some(0),
517            },
518        }];
519        let sql = generate_import_sql(&entries, &DbBackend::Sqlite).unwrap();
520        assert!(sql.contains("BEGIN TRANSACTION"));
521        assert!(sql.contains("INSERT INTO audit_entries"));
522        assert!(sql.contains("COMMIT"));
523    }
524
525    #[test]
526    fn test_generate_import_sql_postgres() {
527        let entries = vec![AuditEntry {
528            timestamp: chrono::DateTime::parse_from_rfc3339("2026-01-01T00:00:00Z")
529                .unwrap()
530                .with_timezone(&chrono::Utc),
531            session_id: "s2".to_string(),
532            provider: "gcp".to_string(),
533            event: AuditEvent::CredentialsIssued {
534                access_key_id: "AKID".to_string(),
535                expires_at: chrono::DateTime::parse_from_rfc3339("2026-01-01T00:15:00Z")
536                    .unwrap()
537                    .with_timezone(&chrono::Utc),
538            },
539        }];
540        let sql = generate_import_sql(&entries, &DbBackend::Postgres).unwrap();
541        assert!(sql.contains("BEGIN;"));
542        assert!(sql.contains("COMMIT;"));
543    }
544
545    #[test]
546    fn test_db_backend_display() {
547        assert_eq!(DbBackend::Sqlite.to_string(), "sqlite");
548        assert_eq!(DbBackend::Postgres.to_string(), "postgres");
549    }
550
551    #[test]
552    fn test_db_audit_stats_serialization() {
553        let stats = DbAuditStats {
554            total_entries: 1000,
555            entries_by_type: vec![
556                ("session_created".to_string(), 400),
557                ("credentials_issued".to_string(), 400),
558                ("session_ended".to_string(), 200),
559            ],
560            backend: "sqlite".to_string(),
561            connection: "/data/audit.db".to_string(),
562        };
563        let json = serde_json::to_string(&stats).unwrap();
564        assert!(json.contains("1000"));
565        assert!(json.contains("session_created"));
566    }
567
568    #[test]
569    fn test_sql_escape_prevents_injection() {
570        assert_eq!(sql_escape("normal"), "normal");
571        assert_eq!(sql_escape("it's"), "it''s");
572        assert_eq!(
573            sql_escape("'); DROP TABLE audit_entries; --"),
574            "''); DROP TABLE audit_entries; --"
575        );
576    }
577
578    #[test]
579    fn test_generate_import_sql_escapes_all_fields() {
580        let entries = vec![AuditEntry {
581            timestamp: chrono::DateTime::parse_from_rfc3339("2026-01-01T00:00:00Z")
582                .unwrap()
583                .with_timezone(&chrono::Utc),
584            session_id: "it's-a-trap".to_string(),
585            provider: "aw's".to_string(),
586            event: AuditEvent::SessionEnded {
587                status: "ok".to_string(),
588                duration_seconds: 60,
589                exit_code: Some(0),
590            },
591        }];
592        let sql = generate_import_sql(&entries, &DbBackend::Sqlite).unwrap();
593        // Single quotes in session_id and provider must be doubled
594        assert!(sql.contains("it''s-a-trap"));
595        assert!(sql.contains("aw''s"));
596        // Must not contain unescaped injection
597        assert!(!sql.contains("it's-a-trap',"));
598    }
599}