vtcode_core/command_safety/
audit.rs1use serde::{Deserialize, Serialize};
7use std::sync::Arc;
8use tokio::sync::Mutex;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct AuditEntry {
13 pub command: Vec<String>,
15 pub allowed: bool,
17 pub reason: String,
19 pub decision_type: String,
21 pub timestamp: String,
23}
24
25impl AuditEntry {
26 pub fn new(command: Vec<String>, allowed: bool, reason: String, decision_type: String) -> Self {
28 let timestamp = std::time::SystemTime::now()
30 .duration_since(std::time::UNIX_EPOCH)
31 .map(|d| d.as_secs().to_string())
32 .unwrap_or_else(|_| "unknown".to_string());
33
34 Self {
35 command,
36 allowed,
37 reason,
38 decision_type,
39 timestamp,
40 }
41 }
42}
43
44pub struct SafetyAuditLogger {
46 entries: Arc<Mutex<Vec<AuditEntry>>>,
47 enabled: bool,
48}
49
50impl SafetyAuditLogger {
51 pub fn new(enabled: bool) -> Self {
53 Self {
54 entries: Arc::new(Mutex::new(Vec::new())),
55 enabled,
56 }
57 }
58
59 pub async fn log(&self, entry: AuditEntry) {
61 if self.enabled {
62 let mut entries = self.entries.lock().await;
63 entries.push(entry);
64 }
65 }
66
67 pub async fn entries(&self) -> Vec<AuditEntry> {
69 let entries = self.entries.lock().await;
70 entries.clone()
71 }
72
73 pub async fn entries_for_command(&self, cmd: &str) -> Vec<AuditEntry> {
75 let entries = self.entries.lock().await;
76 entries
77 .iter()
78 .filter(|e| e.command.join(" ").contains(cmd))
79 .cloned()
80 .collect()
81 }
82
83 pub async fn denied_entries(&self) -> Vec<AuditEntry> {
85 let entries = self.entries.lock().await;
86 entries.iter().filter(|e| !e.allowed).cloned().collect()
87 }
88
89 pub async fn clear(&self) {
91 let mut entries = self.entries.lock().await;
92 entries.clear();
93 }
94
95 pub async fn count(&self) -> usize {
97 let entries = self.entries.lock().await;
98 entries.len()
99 }
100}
101
102impl Clone for SafetyAuditLogger {
103 fn clone(&self) -> Self {
104 Self {
105 entries: Arc::clone(&self.entries),
106 enabled: self.enabled,
107 }
108 }
109}
110
111#[cfg(test)]
112mod tests {
113 use super::*;
114
115 #[tokio::test]
116 async fn creates_audit_entry() {
117 let cmd = vec!["git".to_string(), "status".to_string()];
118 let entry = AuditEntry::new(
119 cmd,
120 true,
121 "git status allowed".to_string(),
122 "Allow".to_string(),
123 );
124 assert!(entry.allowed);
125 assert!(!entry.timestamp.is_empty());
126 }
127
128 #[tokio::test]
129 async fn logs_entries() {
130 let logger = SafetyAuditLogger::new(true);
131 let cmd = vec!["git".to_string(), "status".to_string()];
132 let entry = AuditEntry::new(
133 cmd,
134 true,
135 "git status allowed".to_string(),
136 "Allow".to_string(),
137 );
138
139 logger.log(entry).await;
140 assert_eq!(logger.count().await, 1);
141 }
142
143 #[tokio::test]
144 async fn filters_denied_entries() {
145 let logger = SafetyAuditLogger::new(true);
146
147 let cmd1 = vec!["git".to_string(), "status".to_string()];
148 logger
149 .log(AuditEntry::new(
150 cmd1,
151 true,
152 "allowed".to_string(),
153 "Allow".to_string(),
154 ))
155 .await;
156
157 let cmd2 = vec!["git".to_string(), "reset".to_string()];
158 logger
159 .log(AuditEntry::new(
160 cmd2,
161 false,
162 "denied".to_string(),
163 "Deny".to_string(),
164 ))
165 .await;
166
167 let denied = logger.denied_entries().await;
168 assert_eq!(denied.len(), 1);
169 assert!(!denied[0].allowed);
170 }
171
172 #[tokio::test]
173 async fn disabled_logger_ignores_entries() {
174 let logger = SafetyAuditLogger::new(false);
175 let cmd = vec!["git".to_string(), "status".to_string()];
176 let entry = AuditEntry::new(cmd, true, "allowed".to_string(), "Allow".to_string());
177
178 logger.log(entry).await;
179 assert_eq!(logger.count().await, 0);
180 }
181
182 #[tokio::test]
183 async fn clones_share_same_entries() {
184 let logger1 = SafetyAuditLogger::new(true);
185 let logger2 = logger1.clone();
186
187 let cmd = vec!["git".to_string(), "status".to_string()];
188 let entry = AuditEntry::new(cmd, true, "allowed".to_string(), "Allow".to_string());
189
190 logger1.log(entry).await;
191
192 assert_eq!(logger1.count().await, 1);
194 assert_eq!(logger2.count().await, 1);
195 }
196}