Skip to main content

vtcode_core/command_safety/
cache.rs

1//! Caching layer for command safety decisions.
2//!
3//! Caches safety decisions to avoid re-evaluating the same commands.
4//! Implements LRU eviction when cache exceeds size limit.
5
6use hashbrown::HashMap;
7use std::sync::Arc;
8use tokio::sync::Mutex;
9
10/// Cached safety decision
11#[derive(Clone, Debug)]
12pub struct CachedDecision {
13    /// True if command is safe
14    pub is_safe: bool,
15    /// Reason for decision
16    pub reason: String,
17    /// Access count (for LRU)
18    pub access_count: u64,
19}
20
21/// Thread-safe cache for command safety decisions
22pub struct SafetyDecisionCache {
23    cache: Arc<Mutex<HashMap<String, CachedDecision>>>,
24    max_size: usize,
25}
26
27impl SafetyDecisionCache {
28    /// Creates a new cache with given max size
29    pub fn new(max_size: usize) -> Self {
30        Self {
31            cache: Arc::new(Mutex::new(HashMap::new())),
32            max_size,
33        }
34    }
35
36    /// Creates a default cache (1000 entries)
37    pub fn default_cache() -> Self {
38        Self::new(1000)
39    }
40
41    /// Gets a cached decision
42    pub async fn get(&self, command: &str) -> Option<CachedDecision> {
43        let mut cache = self.cache.lock().await;
44        if let Some(decision) = cache.get_mut(command) {
45            decision.access_count += 1;
46            return Some(decision.clone());
47        }
48        None
49    }
50
51    /// Sets a cached decision
52    pub async fn put(&self, command: String, is_safe: bool, reason: String) {
53        let mut cache = self.cache.lock().await;
54
55        if cache.len() >= self.max_size
56            && !cache.contains_key(&command)
57            && let Some(least_used) = cache
58                .iter()
59                .min_by_key(|(_, decision)| decision.access_count)
60                .map(|(k, _)| k.clone())
61        {
62            cache.remove(&least_used);
63        }
64
65        cache.insert(
66            command,
67            CachedDecision {
68                is_safe,
69                reason,
70                access_count: 1,
71            },
72        );
73    }
74
75    /// Clears all cached entries
76    pub async fn clear(&self) {
77        let mut cache = self.cache.lock().await;
78        cache.clear();
79    }
80
81    /// Returns current cache size
82    pub async fn size(&self) -> usize {
83        let cache = self.cache.lock().await;
84        cache.len()
85    }
86
87    /// Returns cache hit rate statistics
88    pub async fn stats(&self) -> CacheStats {
89        let cache = self.cache.lock().await;
90        let total_accesses: u64 = cache.values().map(|d| d.access_count).sum();
91        let entry_count = cache.len();
92
93        CacheStats {
94            entry_count,
95            total_accesses,
96            avg_access_per_entry: if entry_count > 0 {
97                total_accesses / entry_count as u64
98            } else {
99                0
100            },
101        }
102    }
103}
104
105impl Clone for SafetyDecisionCache {
106    fn clone(&self) -> Self {
107        Self {
108            cache: Arc::clone(&self.cache),
109            max_size: self.max_size,
110        }
111    }
112}
113
114/// Cache statistics
115#[derive(Debug, Clone)]
116pub struct CacheStats {
117    pub entry_count: usize,
118    pub total_accesses: u64,
119    pub avg_access_per_entry: u64,
120}
121
122#[cfg(test)]
123mod tests {
124    use super::*;
125
126    #[tokio::test]
127    async fn cache_stores_and_retrieves() {
128        let cache = SafetyDecisionCache::new(10);
129        cache
130            .put(
131                "git status".to_string(),
132                true,
133                "git status allowed".to_string(),
134            )
135            .await;
136
137        let decision = cache.get("git status").await;
138        assert!(decision.is_some());
139        assert!(decision.unwrap().is_safe);
140    }
141
142    #[tokio::test]
143    async fn cache_returns_none_for_missing_key() {
144        let cache = SafetyDecisionCache::new(10);
145        let decision = cache.get("missing").await;
146        assert!(decision.is_none());
147    }
148
149    #[tokio::test]
150    async fn cache_tracks_access_count() {
151        let cache = SafetyDecisionCache::new(10);
152        cache
153            .put("cmd".to_string(), true, "allowed".to_string())
154            .await;
155
156        let d1 = cache.get("cmd").await.unwrap();
157        assert_eq!(d1.access_count, 2);
158
159        let d2 = cache.get("cmd").await.unwrap();
160        assert_eq!(d2.access_count, 3);
161    }
162
163    #[tokio::test]
164    async fn cache_respects_max_size() {
165        let cache = SafetyDecisionCache::new(3);
166
167        cache
168            .put("cmd1".to_string(), true, "allowed".to_string())
169            .await;
170        cache
171            .put("cmd2".to_string(), true, "allowed".to_string())
172            .await;
173        cache
174            .put("cmd3".to_string(), true, "allowed".to_string())
175            .await;
176
177        assert_eq!(cache.size().await, 3);
178
179        // Adding a 4th entry should evict the least-used
180        cache
181            .put("cmd4".to_string(), true, "allowed".to_string())
182            .await;
183        assert_eq!(cache.size().await, 3);
184    }
185
186    #[tokio::test]
187    async fn cache_clears() {
188        let cache = SafetyDecisionCache::new(10);
189        cache
190            .put("cmd".to_string(), true, "allowed".to_string())
191            .await;
192        assert_eq!(cache.size().await, 1);
193
194        cache.clear().await;
195        assert_eq!(cache.size().await, 0);
196    }
197
198    #[tokio::test]
199    async fn cache_stats() {
200        let cache = SafetyDecisionCache::new(10);
201        cache
202            .put("cmd1".to_string(), true, "allowed".to_string())
203            .await;
204        cache
205            .put("cmd2".to_string(), true, "allowed".to_string())
206            .await;
207
208        let _d1 = cache.get("cmd1").await;
209        let _d2 = cache.get("cmd2").await;
210        let _d3 = cache.get("cmd2").await;
211
212        let stats = cache.stats().await;
213        assert_eq!(stats.entry_count, 2);
214        assert_eq!(stats.total_accesses, 5); // 1+1 initial puts + 1+2 gets
215    }
216}