vtcode_core/command_safety/
cache.rs1use hashbrown::HashMap;
7use std::sync::Arc;
8use tokio::sync::Mutex;
9
10#[derive(Clone, Debug)]
12pub struct CachedDecision {
13 pub is_safe: bool,
15 pub reason: String,
17 pub access_count: u64,
19}
20
21pub struct SafetyDecisionCache {
23 cache: Arc<Mutex<HashMap<String, CachedDecision>>>,
24 max_size: usize,
25}
26
27impl SafetyDecisionCache {
28 pub fn new(max_size: usize) -> Self {
30 Self {
31 cache: Arc::new(Mutex::new(HashMap::new())),
32 max_size,
33 }
34 }
35
36 pub fn default_cache() -> Self {
38 Self::new(1000)
39 }
40
41 pub async fn get(&self, command: &str) -> Option<CachedDecision> {
43 let mut cache = self.cache.lock().await;
44 if let Some(decision) = cache.get_mut(command) {
45 decision.access_count += 1;
46 return Some(decision.clone());
47 }
48 None
49 }
50
51 pub async fn put(&self, command: String, is_safe: bool, reason: String) {
53 let mut cache = self.cache.lock().await;
54
55 if cache.len() >= self.max_size
56 && !cache.contains_key(&command)
57 && let Some(least_used) = cache
58 .iter()
59 .min_by_key(|(_, decision)| decision.access_count)
60 .map(|(k, _)| k.clone())
61 {
62 cache.remove(&least_used);
63 }
64
65 cache.insert(
66 command,
67 CachedDecision {
68 is_safe,
69 reason,
70 access_count: 1,
71 },
72 );
73 }
74
75 pub async fn clear(&self) {
77 let mut cache = self.cache.lock().await;
78 cache.clear();
79 }
80
81 pub async fn size(&self) -> usize {
83 let cache = self.cache.lock().await;
84 cache.len()
85 }
86
87 pub async fn stats(&self) -> CacheStats {
89 let cache = self.cache.lock().await;
90 let total_accesses: u64 = cache.values().map(|d| d.access_count).sum();
91 let entry_count = cache.len();
92
93 CacheStats {
94 entry_count,
95 total_accesses,
96 avg_access_per_entry: if entry_count > 0 {
97 total_accesses / entry_count as u64
98 } else {
99 0
100 },
101 }
102 }
103}
104
105impl Clone for SafetyDecisionCache {
106 fn clone(&self) -> Self {
107 Self {
108 cache: Arc::clone(&self.cache),
109 max_size: self.max_size,
110 }
111 }
112}
113
114#[derive(Debug, Clone)]
116pub struct CacheStats {
117 pub entry_count: usize,
118 pub total_accesses: u64,
119 pub avg_access_per_entry: u64,
120}
121
122#[cfg(test)]
123mod tests {
124 use super::*;
125
126 #[tokio::test]
127 async fn cache_stores_and_retrieves() {
128 let cache = SafetyDecisionCache::new(10);
129 cache
130 .put(
131 "git status".to_string(),
132 true,
133 "git status allowed".to_string(),
134 )
135 .await;
136
137 let decision = cache.get("git status").await;
138 assert!(decision.is_some());
139 assert!(decision.unwrap().is_safe);
140 }
141
142 #[tokio::test]
143 async fn cache_returns_none_for_missing_key() {
144 let cache = SafetyDecisionCache::new(10);
145 let decision = cache.get("missing").await;
146 assert!(decision.is_none());
147 }
148
149 #[tokio::test]
150 async fn cache_tracks_access_count() {
151 let cache = SafetyDecisionCache::new(10);
152 cache
153 .put("cmd".to_string(), true, "allowed".to_string())
154 .await;
155
156 let d1 = cache.get("cmd").await.unwrap();
157 assert_eq!(d1.access_count, 2);
158
159 let d2 = cache.get("cmd").await.unwrap();
160 assert_eq!(d2.access_count, 3);
161 }
162
163 #[tokio::test]
164 async fn cache_respects_max_size() {
165 let cache = SafetyDecisionCache::new(3);
166
167 cache
168 .put("cmd1".to_string(), true, "allowed".to_string())
169 .await;
170 cache
171 .put("cmd2".to_string(), true, "allowed".to_string())
172 .await;
173 cache
174 .put("cmd3".to_string(), true, "allowed".to_string())
175 .await;
176
177 assert_eq!(cache.size().await, 3);
178
179 cache
181 .put("cmd4".to_string(), true, "allowed".to_string())
182 .await;
183 assert_eq!(cache.size().await, 3);
184 }
185
186 #[tokio::test]
187 async fn cache_clears() {
188 let cache = SafetyDecisionCache::new(10);
189 cache
190 .put("cmd".to_string(), true, "allowed".to_string())
191 .await;
192 assert_eq!(cache.size().await, 1);
193
194 cache.clear().await;
195 assert_eq!(cache.size().await, 0);
196 }
197
198 #[tokio::test]
199 async fn cache_stats() {
200 let cache = SafetyDecisionCache::new(10);
201 cache
202 .put("cmd1".to_string(), true, "allowed".to_string())
203 .await;
204 cache
205 .put("cmd2".to_string(), true, "allowed".to_string())
206 .await;
207
208 let _d1 = cache.get("cmd1").await;
209 let _d2 = cache.get("cmd2").await;
210 let _d3 = cache.get("cmd2").await;
211
212 let stats = cache.stats().await;
213 assert_eq!(stats.entry_count, 2);
214 assert_eq!(stats.total_accesses, 5); }
216}