Skip to main content

tryaudex_core/
dbaudit.rs

1use serde::{Deserialize, Serialize};
2
3use crate::audit::{AuditEntry, AuditEvent};
4use crate::error::{AvError, Result};
5use crate::session::Session;
6
7/// Database audit backend type.
8#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
9#[serde(rename_all = "lowercase")]
10pub enum DbBackend {
11    Sqlite,
12    Postgres,
13}
14
15impl std::fmt::Display for DbBackend {
16    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
17        match self {
18            Self::Sqlite => write!(f, "sqlite"),
19            Self::Postgres => write!(f, "postgres"),
20        }
21    }
22}
23
24/// Configuration for the database audit backend.
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct DbAuditConfig {
27    /// Backend type: "sqlite" or "postgres"
28    pub backend: DbBackend,
29    /// Connection string.
30    /// - SQLite: file path (e.g. "~/.local/share/audex/audit.db") or ":memory:"
31    /// - Postgres: connection URI (e.g. "postgres://user:pass@localhost/audex")
32    pub connection: String,
33    /// Maximum number of connections in the pool (postgres only, default: 5)
34    #[serde(default = "default_max_connections")]
35    pub max_connections: u32,
36    /// Enable WAL mode for SQLite (default: true)
37    #[serde(default = "default_wal_mode")]
38    pub wal_mode: bool,
39    /// Retention period in days (0 = keep forever, default: 0)
40    #[serde(default)]
41    pub retention_days: u32,
42}
43
44fn default_max_connections() -> u32 {
45    5
46}
47
48fn default_wal_mode() -> bool {
49    true
50}
51
52impl Default for DbAuditConfig {
53    fn default() -> Self {
54        let path = dirs::data_local_dir()
55            .unwrap_or_else(|| std::path::PathBuf::from("."))
56            .join("audex")
57            .join("audit.db");
58        Self {
59            backend: DbBackend::Sqlite,
60            connection: path.to_string_lossy().to_string(),
61            max_connections: default_max_connections(),
62            wal_mode: default_wal_mode(),
63            retention_days: 0,
64        }
65    }
66}
67
68/// SQL schema for audit tables.
69pub const SQLITE_SCHEMA: &str = r#"
70CREATE TABLE IF NOT EXISTS audit_entries (
71    id INTEGER PRIMARY KEY AUTOINCREMENT,
72    timestamp TEXT NOT NULL,
73    session_id TEXT NOT NULL,
74    provider TEXT NOT NULL DEFAULT 'aws',
75    event_type TEXT NOT NULL,
76    event_data TEXT NOT NULL,
77    created_at TEXT NOT NULL DEFAULT (datetime('now'))
78);
79
80CREATE INDEX IF NOT EXISTS idx_audit_session_id ON audit_entries(session_id);
81CREATE INDEX IF NOT EXISTS idx_audit_timestamp ON audit_entries(timestamp);
82CREATE INDEX IF NOT EXISTS idx_audit_provider ON audit_entries(provider);
83CREATE INDEX IF NOT EXISTS idx_audit_event_type ON audit_entries(event_type);
84"#;
85
86pub const POSTGRES_SCHEMA: &str = r#"
87CREATE TABLE IF NOT EXISTS audit_entries (
88    id BIGSERIAL PRIMARY KEY,
89    timestamp TIMESTAMPTZ NOT NULL,
90    session_id TEXT NOT NULL,
91    provider TEXT NOT NULL DEFAULT 'aws',
92    event_type TEXT NOT NULL,
93    event_data JSONB NOT NULL,
94    created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
95);
96
97CREATE INDEX IF NOT EXISTS idx_audit_session_id ON audit_entries(session_id);
98CREATE INDEX IF NOT EXISTS idx_audit_timestamp ON audit_entries(timestamp);
99CREATE INDEX IF NOT EXISTS idx_audit_provider ON audit_entries(provider);
100CREATE INDEX IF NOT EXISTS idx_audit_event_type ON audit_entries(event_type);
101"#;
102
103/// Extract event type string from an AuditEvent.
104pub fn event_type_str(event: &AuditEvent) -> &'static str {
105    match event {
106        AuditEvent::SessionCreated { .. } => "session_created",
107        AuditEvent::CredentialsIssued { .. } => "credentials_issued",
108        AuditEvent::SessionEnded { .. } => "session_ended",
109        AuditEvent::BudgetWarning { .. } => "budget_warning",
110        AuditEvent::BudgetExceeded { .. } => "budget_exceeded",
111        AuditEvent::ResourceCreated { .. } => "resource_created",
112    }
113}
114
115/// SQL statements for common operations.
116pub struct SqlStatements;
117
118impl SqlStatements {
119    /// Insert an audit entry.
120    pub fn insert(backend: &DbBackend) -> &'static str {
121        match backend {
122            DbBackend::Sqlite => {
123                "INSERT INTO audit_entries (timestamp, session_id, provider, event_type, event_data) VALUES (?, ?, ?, ?, ?)"
124            }
125            DbBackend::Postgres => {
126                "INSERT INTO audit_entries (timestamp, session_id, provider, event_type, event_data) VALUES ($1, $2, $3, $4, $5::jsonb)"
127            }
128        }
129    }
130
131    /// Query entries by session ID.
132    pub fn by_session(backend: &DbBackend) -> &'static str {
133        match backend {
134            DbBackend::Sqlite => {
135                "SELECT timestamp, session_id, provider, event_type, event_data FROM audit_entries WHERE session_id = ? ORDER BY timestamp ASC"
136            }
137            DbBackend::Postgres => {
138                "SELECT timestamp, session_id, provider, event_type, event_data FROM audit_entries WHERE session_id = $1 ORDER BY timestamp ASC"
139            }
140        }
141    }
142
143    /// Query recent entries with optional limit.
144    pub fn recent(backend: &DbBackend) -> &'static str {
145        match backend {
146            DbBackend::Sqlite => {
147                "SELECT timestamp, session_id, provider, event_type, event_data FROM audit_entries ORDER BY timestamp DESC LIMIT ?"
148            }
149            DbBackend::Postgres => {
150                "SELECT timestamp, session_id, provider, event_type, event_data FROM audit_entries ORDER BY timestamp DESC LIMIT $1"
151            }
152        }
153    }
154
155    /// Query entries by provider.
156    pub fn by_provider(backend: &DbBackend) -> &'static str {
157        match backend {
158            DbBackend::Sqlite => {
159                "SELECT timestamp, session_id, provider, event_type, event_data FROM audit_entries WHERE provider = ? ORDER BY timestamp DESC LIMIT ?"
160            }
161            DbBackend::Postgres => {
162                "SELECT timestamp, session_id, provider, event_type, event_data FROM audit_entries WHERE provider = $1 ORDER BY timestamp DESC LIMIT $2"
163            }
164        }
165    }
166
167    /// Count entries.
168    pub fn count(backend: &DbBackend) -> &'static str {
169        match backend {
170            DbBackend::Sqlite => "SELECT COUNT(*) FROM audit_entries",
171            DbBackend::Postgres => "SELECT COUNT(*) FROM audit_entries",
172        }
173    }
174
175    /// Count entries by event type.
176    pub fn count_by_type(backend: &DbBackend) -> &'static str {
177        match backend {
178            DbBackend::Sqlite => {
179                "SELECT event_type, COUNT(*) as count FROM audit_entries GROUP BY event_type ORDER BY count DESC"
180            }
181            DbBackend::Postgres => {
182                "SELECT event_type, COUNT(*) as count FROM audit_entries GROUP BY event_type ORDER BY count DESC"
183            }
184        }
185    }
186
187    /// Delete entries older than N days.
188    pub fn prune(backend: &DbBackend) -> &'static str {
189        match backend {
190            DbBackend::Sqlite => {
191                "DELETE FROM audit_entries WHERE timestamp < datetime('now', '-' || ? || ' days')"
192            }
193            DbBackend::Postgres => {
194                "DELETE FROM audit_entries WHERE timestamp < NOW() - ($1 || ' days')::INTERVAL"
195            }
196        }
197    }
198
199    /// Search event data with text matching.
200    pub fn search(backend: &DbBackend) -> &'static str {
201        match backend {
202            DbBackend::Sqlite => {
203                "SELECT timestamp, session_id, provider, event_type, event_data FROM audit_entries WHERE event_data LIKE ? ORDER BY timestamp DESC LIMIT ?"
204            }
205            DbBackend::Postgres => {
206                "SELECT timestamp, session_id, provider, event_type, event_data FROM audit_entries WHERE event_data::text ILIKE $1 ORDER BY timestamp DESC LIMIT $2"
207            }
208        }
209    }
210}
211
212/// Convert an AuditEntry to database row values.
213pub fn entry_to_row(entry: &AuditEntry) -> Result<DbRow> {
214    let event_data = serde_json::to_string(&entry.event)
215        .map_err(|e| AvError::Sts(format!("Failed to serialize event: {}", e)))?;
216
217    // Redact secrets before storing
218    let event_data = crate::leakdetect::redact_secrets(&event_data);
219
220    Ok(DbRow {
221        timestamp: entry.timestamp.to_rfc3339(),
222        session_id: entry.session_id.clone(),
223        provider: entry.provider.clone(),
224        event_type: event_type_str(&entry.event).to_string(),
225        event_data,
226    })
227}
228
229/// Convert a Session to an AuditEntry for session_created events.
230pub fn session_created_entry(session: &Session) -> AuditEntry {
231    use crate::session::CloudProvider;
232
233    let allowed_actions: Vec<String> = match session.provider {
234        CloudProvider::Gcp => session
235            .policy
236            .actions
237            .iter()
238            .map(|a| a.to_gcp_permission())
239            .collect(),
240        CloudProvider::Azure => session
241            .policy
242            .actions
243            .iter()
244            .map(|a| a.to_azure_permission())
245            .collect(),
246        CloudProvider::Aws => session
247            .policy
248            .actions
249            .iter()
250            .map(|a| a.to_iam_action())
251            .collect(),
252    };
253
254    AuditEntry {
255        timestamp: chrono::Utc::now(),
256        session_id: session.id.clone(),
257        provider: session.provider.to_string(),
258        event: AuditEvent::SessionCreated {
259            role_arn: session.role_arn.clone(),
260            ttl_seconds: session.ttl_seconds,
261            budget: session.budget,
262            allowed_actions,
263            command: session.command.clone(),
264            agent_id: session.agent_id.clone(),
265        },
266    }
267}
268
269/// Database row representation.
270#[derive(Debug, Clone, Serialize, Deserialize)]
271pub struct DbRow {
272    pub timestamp: String,
273    pub session_id: String,
274    pub provider: String,
275    pub event_type: String,
276    pub event_data: String,
277}
278
279impl DbRow {
280    /// Convert back to an AuditEntry.
281    pub fn to_entry(&self) -> Result<AuditEntry> {
282        let timestamp = chrono::DateTime::parse_from_rfc3339(&self.timestamp)
283            .map(|dt| dt.with_timezone(&chrono::Utc))
284            .map_err(|e| AvError::Sts(format!("Invalid timestamp: {}", e)))?;
285
286        let event: AuditEvent = serde_json::from_str(&self.event_data)
287            .map_err(|e| AvError::Sts(format!("Invalid event data: {}", e)))?;
288
289        Ok(AuditEntry {
290            timestamp,
291            session_id: self.session_id.clone(),
292            provider: self.provider.clone(),
293            event,
294        })
295    }
296}
297
298/// Statistics about the audit database.
299#[derive(Debug, Clone, Serialize, Deserialize)]
300pub struct DbAuditStats {
301    pub total_entries: u64,
302    pub entries_by_type: Vec<(String, u64)>,
303    pub backend: String,
304    pub connection: String,
305}
306
307/// Migration helper — import JSONL audit log into database.
308pub fn parse_jsonl_for_import(content: &str) -> Vec<AuditEntry> {
309    content
310        .lines()
311        .filter(|l| !l.trim().is_empty())
312        .filter_map(|l| {
313            let (json_content, _) = crate::integrity::parse_line(l);
314            serde_json::from_str(json_content).ok()
315        })
316        .collect()
317}
318
319/// Generate a migration script to import existing JSONL data.
320pub fn generate_import_sql(entries: &[AuditEntry], backend: &DbBackend) -> Result<String> {
321    let mut sql = String::new();
322
323    match backend {
324        DbBackend::Sqlite => {
325            sql.push_str("BEGIN TRANSACTION;\n");
326        }
327        DbBackend::Postgres => {
328            sql.push_str("BEGIN;\n");
329        }
330    }
331
332    for entry in entries {
333        let row = entry_to_row(entry)?;
334        let escaped_data = row.event_data.replace('\'', "''");
335        sql.push_str(&format!(
336            "INSERT INTO audit_entries (timestamp, session_id, provider, event_type, event_data) VALUES ('{}', '{}', '{}', '{}', '{}');\n",
337            row.timestamp, row.session_id, row.provider, row.event_type, escaped_data
338        ));
339    }
340
341    sql.push_str("COMMIT;\n");
342    Ok(sql)
343}
344
345#[cfg(test)]
346mod tests {
347    use super::*;
348    use crate::audit::AuditEvent;
349
350    #[test]
351    fn test_config_default() {
352        let config = DbAuditConfig::default();
353        assert_eq!(config.backend, DbBackend::Sqlite);
354        assert!(config.connection.contains("audit.db"));
355        assert_eq!(config.max_connections, 5);
356        assert!(config.wal_mode);
357    }
358
359    #[test]
360    fn test_config_deserialize_sqlite() {
361        let toml_str = r#"
362backend = "sqlite"
363connection = "/data/audex/audit.db"
364wal_mode = true
365retention_days = 90
366"#;
367        let config: DbAuditConfig = toml::from_str(toml_str).unwrap();
368        assert_eq!(config.backend, DbBackend::Sqlite);
369        assert_eq!(config.connection, "/data/audex/audit.db");
370        assert_eq!(config.retention_days, 90);
371    }
372
373    #[test]
374    fn test_config_deserialize_postgres() {
375        let toml_str = r#"
376backend = "postgres"
377connection = "postgres://audex:password@db.internal:5432/audex"
378max_connections = 10
379"#;
380        let config: DbAuditConfig = toml::from_str(toml_str).unwrap();
381        assert_eq!(config.backend, DbBackend::Postgres);
382        assert_eq!(config.max_connections, 10);
383    }
384
385    #[test]
386    fn test_event_type_str() {
387        assert_eq!(
388            event_type_str(&AuditEvent::SessionCreated {
389                role_arn: "arn".into(),
390                ttl_seconds: 900,
391                budget: None,
392                allowed_actions: vec![],
393                command: vec![],
394                agent_id: None,
395            }),
396            "session_created"
397        );
398        assert_eq!(
399            event_type_str(&AuditEvent::SessionEnded {
400                status: "ok".into(),
401                duration_seconds: 60,
402                exit_code: Some(0),
403            }),
404            "session_ended"
405        );
406        assert_eq!(
407            event_type_str(&AuditEvent::BudgetWarning {
408                current_spend: 1.0,
409                limit: 5.0,
410            }),
411            "budget_warning"
412        );
413    }
414
415    #[test]
416    fn test_entry_to_row() {
417        let entry = AuditEntry {
418            timestamp: chrono::Utc::now(),
419            session_id: "test-123".to_string(),
420            provider: "aws".to_string(),
421            event: AuditEvent::SessionCreated {
422                role_arn: "arn:aws:iam::123:role/Test".to_string(),
423                ttl_seconds: 900,
424                budget: Some(5.0),
425                allowed_actions: vec!["s3:GetObject".to_string()],
426                command: vec!["aws".to_string(), "s3".to_string(), "ls".to_string()],
427                agent_id: Some("claude".to_string()),
428            },
429        };
430        let row = entry_to_row(&entry).unwrap();
431        assert_eq!(row.session_id, "test-123");
432        assert_eq!(row.provider, "aws");
433        assert_eq!(row.event_type, "session_created");
434        assert!(row.event_data.contains("s3:GetObject"));
435    }
436
437    #[test]
438    fn test_row_roundtrip() {
439        let entry = AuditEntry {
440            timestamp: chrono::Utc::now(),
441            session_id: "rt-456".to_string(),
442            provider: "gcp".to_string(),
443            event: AuditEvent::SessionEnded {
444                status: "completed".to_string(),
445                duration_seconds: 120,
446                exit_code: Some(0),
447            },
448        };
449        let row = entry_to_row(&entry).unwrap();
450        let restored = row.to_entry().unwrap();
451        assert_eq!(restored.session_id, "rt-456");
452        assert_eq!(restored.provider, "gcp");
453    }
454
455    #[test]
456    fn test_sql_statements_sqlite() {
457        let b = DbBackend::Sqlite;
458        assert!(SqlStatements::insert(&b).contains("?"));
459        assert!(SqlStatements::by_session(&b).contains("session_id = ?"));
460        assert!(SqlStatements::recent(&b).contains("LIMIT ?"));
461        assert!(SqlStatements::prune(&b).contains("datetime"));
462    }
463
464    #[test]
465    fn test_sql_statements_postgres() {
466        let b = DbBackend::Postgres;
467        assert!(SqlStatements::insert(&b).contains("$1"));
468        assert!(SqlStatements::by_session(&b).contains("session_id = $1"));
469        assert!(SqlStatements::recent(&b).contains("LIMIT $1"));
470        assert!(SqlStatements::prune(&b).contains("INTERVAL"));
471    }
472
473    #[test]
474    fn test_parse_jsonl_for_import() {
475        let jsonl = r#"{"timestamp":"2026-01-01T00:00:00Z","session_id":"s1","provider":"aws","event":{"type":"session_ended","status":"ok","duration_seconds":60,"exit_code":0}}"#;
476        let entries = parse_jsonl_for_import(jsonl);
477        assert_eq!(entries.len(), 1);
478        assert_eq!(entries[0].session_id, "s1");
479    }
480
481    #[test]
482    fn test_generate_import_sql_sqlite() {
483        let entries = vec![AuditEntry {
484            timestamp: chrono::DateTime::parse_from_rfc3339("2026-01-01T00:00:00Z")
485                .unwrap()
486                .with_timezone(&chrono::Utc),
487            session_id: "s1".to_string(),
488            provider: "aws".to_string(),
489            event: AuditEvent::SessionEnded {
490                status: "ok".to_string(),
491                duration_seconds: 60,
492                exit_code: Some(0),
493            },
494        }];
495        let sql = generate_import_sql(&entries, &DbBackend::Sqlite).unwrap();
496        assert!(sql.contains("BEGIN TRANSACTION"));
497        assert!(sql.contains("INSERT INTO audit_entries"));
498        assert!(sql.contains("COMMIT"));
499    }
500
501    #[test]
502    fn test_generate_import_sql_postgres() {
503        let entries = vec![AuditEntry {
504            timestamp: chrono::DateTime::parse_from_rfc3339("2026-01-01T00:00:00Z")
505                .unwrap()
506                .with_timezone(&chrono::Utc),
507            session_id: "s2".to_string(),
508            provider: "gcp".to_string(),
509            event: AuditEvent::CredentialsIssued {
510                access_key_id: "AKID".to_string(),
511                expires_at: chrono::DateTime::parse_from_rfc3339("2026-01-01T00:15:00Z")
512                    .unwrap()
513                    .with_timezone(&chrono::Utc),
514            },
515        }];
516        let sql = generate_import_sql(&entries, &DbBackend::Postgres).unwrap();
517        assert!(sql.contains("BEGIN;"));
518        assert!(sql.contains("COMMIT;"));
519    }
520
521    #[test]
522    fn test_db_backend_display() {
523        assert_eq!(DbBackend::Sqlite.to_string(), "sqlite");
524        assert_eq!(DbBackend::Postgres.to_string(), "postgres");
525    }
526
527    #[test]
528    fn test_db_audit_stats_serialization() {
529        let stats = DbAuditStats {
530            total_entries: 1000,
531            entries_by_type: vec![
532                ("session_created".to_string(), 400),
533                ("credentials_issued".to_string(), 400),
534                ("session_ended".to_string(), 200),
535            ],
536            backend: "sqlite".to_string(),
537            connection: "/data/audit.db".to_string(),
538        };
539        let json = serde_json::to_string(&stats).unwrap();
540        assert!(json.contains("1000"));
541        assert!(json.contains("session_created"));
542    }
543}