vibesql_executor/cache/
query_plan_cache.rs1use std::{
8 collections::{HashMap, HashSet},
9 sync::{
10 atomic::{AtomicUsize, Ordering},
11 RwLock,
12 },
13};
14
15use super::QuerySignature;
16
17#[derive(Clone, Debug)]
19pub struct CacheStats {
20 pub hits: usize,
21 pub misses: usize,
22 pub evictions: usize,
23 pub size: usize,
24 pub hit_rate: f64,
25}
26
27#[derive(Clone)]
29struct CachedEntry {
30 sql: String,
31 tables: HashSet<String>,
32}
33
34pub struct QueryPlanCache {
37 cache: RwLock<HashMap<QuerySignature, CachedEntry>>,
38 max_size: usize,
39 hits: AtomicUsize,
40 misses: AtomicUsize,
41 evictions: AtomicUsize,
42}
43
44impl QueryPlanCache {
45 pub fn new(max_size: usize) -> Self {
47 Self {
48 cache: RwLock::new(HashMap::new()),
49 max_size,
50 hits: AtomicUsize::new(0),
51 misses: AtomicUsize::new(0),
52 evictions: AtomicUsize::new(0),
53 }
54 }
55
56 pub fn get(&self, signature: &QuerySignature) -> Option<String> {
58 let cache = self.cache.read().unwrap();
59 if let Some(entry) = cache.get(signature) {
60 self.hits.fetch_add(1, Ordering::Relaxed);
61 Some(entry.sql.clone())
62 } else {
63 self.misses.fetch_add(1, Ordering::Relaxed);
64 None
65 }
66 }
67
68 pub fn insert(&self, signature: QuerySignature, sql: String) {
71 self.insert_with_tables(signature, sql, HashSet::new());
72 }
73
74 pub fn insert_with_tables(
76 &self,
77 signature: QuerySignature,
78 sql: String,
79 tables: HashSet<String>,
80 ) {
81 let entry = CachedEntry { sql, tables };
82 let mut cache = self.cache.write().unwrap();
83
84 if cache.len() >= self.max_size {
85 if let Some(key) = cache.keys().next().cloned() {
87 cache.remove(&key);
88 self.evictions.fetch_add(1, Ordering::Relaxed);
89 }
90 }
91
92 cache.insert(signature, entry);
93 }
94
95 pub fn contains(&self, signature: &QuerySignature) -> bool {
97 self.cache.read().unwrap().contains_key(signature)
98 }
99
100 pub fn clear(&self) {
102 self.cache.write().unwrap().clear();
103 }
104
105 pub fn invalidate_table(&self, table: &str) {
107 let mut cache = self.cache.write().unwrap();
108 cache.retain(|_, entry| !entry.tables.iter().any(|t| t.eq_ignore_ascii_case(table)));
109 }
110
111 pub fn stats(&self) -> CacheStats {
113 let cache = self.cache.read().unwrap();
114 let hits = self.hits.load(Ordering::Relaxed);
115 let misses = self.misses.load(Ordering::Relaxed);
116 let total = hits + misses;
117 let hit_rate = if total > 0 { hits as f64 / total as f64 } else { 0.0 };
118
119 CacheStats {
120 hits,
121 misses,
122 evictions: self.evictions.load(Ordering::Relaxed),
123 size: cache.len(),
124 hit_rate,
125 }
126 }
127
128 pub fn max_size(&self) -> usize {
130 self.max_size
131 }
132}
133
134#[cfg(test)]
135mod tests {
136 use super::*;
137
138 #[test]
139 fn test_cache_hit() {
140 let cache = QueryPlanCache::new(10);
141 let sig = QuerySignature::from_sql("SELECT * FROM users");
142 let sql = "select * from users".to_string();
143
144 cache.insert(sig.clone(), sql.clone());
145 let result = cache.get(&sig);
146
147 assert!(result.is_some());
148 assert_eq!(result.unwrap(), sql);
149 }
150
151 #[test]
152 fn test_cache_miss() {
153 let cache = QueryPlanCache::new(10);
154 let sig = QuerySignature::from_sql("SELECT * FROM users");
155
156 let result = cache.get(&sig);
157 assert!(result.is_none());
158 }
159
160 #[test]
161 fn test_lru_eviction() {
162 let cache = QueryPlanCache::new(2);
163
164 let sig1 = QuerySignature::from_sql("SELECT * FROM users");
165 let sig2 = QuerySignature::from_sql("SELECT * FROM orders");
166 let sig3 = QuerySignature::from_sql("SELECT * FROM products");
167
168 cache.insert(sig1, "select * from users".to_string());
169 cache.insert(sig2, "select * from orders".to_string());
170 assert_eq!(cache.stats().size, 2);
171
172 cache.insert(sig3, "select * from products".to_string());
173 assert_eq!(cache.stats().size, 2);
174 assert_eq!(cache.stats().evictions, 1);
175 }
176
177 #[test]
178 fn test_cache_clear() {
179 let cache = QueryPlanCache::new(10);
180 let sig = QuerySignature::from_sql("SELECT * FROM users");
181
182 cache.insert(sig.clone(), "select * from users".to_string());
183 assert!(cache.contains(&sig));
184
185 cache.clear();
186 assert!(!cache.contains(&sig));
187 }
188
189 #[test]
190 fn test_table_invalidation() {
191 let cache = QueryPlanCache::new(10);
192 let sig = QuerySignature::from_sql("SELECT * FROM users WHERE id = 1");
193 let mut tables = std::collections::HashSet::new();
194 tables.insert("users".to_string());
195
196 cache.insert_with_tables(
197 sig.clone(),
198 "select * from users where id = 1".to_string(),
199 tables,
200 );
201 assert!(cache.contains(&sig));
202
203 cache.invalidate_table("users");
204 assert!(!cache.contains(&sig));
205 }
206
207 #[test]
208 fn test_cache_stats() {
209 let cache = QueryPlanCache::new(10);
210 let sig = QuerySignature::from_sql("SELECT * FROM users");
211
212 cache.insert(sig.clone(), "select * from users".to_string());
213
214 cache.get(&sig);
216 cache.get(&sig);
217
218 let other_sig = QuerySignature::from_sql("SELECT * FROM orders");
220 cache.get(&other_sig);
221
222 let stats = cache.stats();
223 assert_eq!(stats.hits, 2);
224 assert_eq!(stats.misses, 1);
225 assert!((stats.hit_rate - 2.0 / 3.0).abs() < 0.01);
226 }
227}