1use std::{
8 collections::{HashMap, HashSet},
9 sync::{
10 atomic::{AtomicUsize, Ordering},
11 RwLock,
12 },
13};
14
15use vibesql_storage::Row;
16
17use super::{CacheStats, QuerySignature};
18use crate::schema::CombinedSchema;
19
20#[derive(Clone)]
22struct CachedResult {
23 rows: Vec<Row>,
25 schema: CombinedSchema,
27 tables: HashSet<String>,
29}
30
31pub struct QueryResultCache {
36 cache: RwLock<HashMap<QuerySignature, CachedResult>>,
37 max_size: usize,
38 hits: AtomicUsize,
39 misses: AtomicUsize,
40 evictions: AtomicUsize,
41}
42
43impl QueryResultCache {
44 pub fn new(max_size: usize) -> Self {
46 Self {
47 cache: RwLock::new(HashMap::new()),
48 max_size,
49 hits: AtomicUsize::new(0),
50 misses: AtomicUsize::new(0),
51 evictions: AtomicUsize::new(0),
52 }
53 }
54
55 pub fn get(&self, signature: &QuerySignature) -> Option<(Vec<Row>, CombinedSchema)> {
57 let cache = self.cache.read().unwrap();
58 if let Some(entry) = cache.get(signature) {
59 self.hits.fetch_add(1, Ordering::Relaxed);
60 Some((entry.rows.clone(), entry.schema.clone()))
61 } else {
62 self.misses.fetch_add(1, Ordering::Relaxed);
63 None
64 }
65 }
66
67 pub fn insert(
69 &self,
70 signature: QuerySignature,
71 rows: Vec<Row>,
72 schema: CombinedSchema,
73 tables: HashSet<String>,
74 ) {
75 let entry = CachedResult { rows, schema, tables };
76 let mut cache = self.cache.write().unwrap();
77
78 if cache.len() >= self.max_size {
80 if let Some(key) = cache.keys().next().cloned() {
81 cache.remove(&key);
82 self.evictions.fetch_add(1, Ordering::Relaxed);
83 }
84 }
85
86 cache.insert(signature, entry);
87 }
88
89 pub fn contains(&self, signature: &QuerySignature) -> bool {
91 self.cache.read().unwrap().contains_key(signature)
92 }
93
94 pub fn clear(&self) {
96 self.cache.write().unwrap().clear();
97 }
98
99 pub fn invalidate_table(&self, table: &str) {
103 let mut cache = self.cache.write().unwrap();
104 cache.retain(|_, entry| !entry.tables.iter().any(|t| t.eq_ignore_ascii_case(table)));
105 }
106
107 pub fn invalidate_all(&self) {
109 self.clear();
110 }
111
112 pub fn stats(&self) -> CacheStats {
114 let cache = self.cache.read().unwrap();
115 let hits = self.hits.load(Ordering::Relaxed);
116 let misses = self.misses.load(Ordering::Relaxed);
117 let total = hits + misses;
118 let hit_rate = if total > 0 { hits as f64 / total as f64 } else { 0.0 };
119
120 CacheStats {
121 hits,
122 misses,
123 evictions: self.evictions.load(Ordering::Relaxed),
124 size: cache.len(),
125 hit_rate,
126 }
127 }
128
129 pub fn max_size(&self) -> usize {
131 self.max_size
132 }
133}
134
135#[cfg(test)]
136mod tests {
137 use super::*;
138 use vibesql_types::SqlValue;
139
140 fn make_test_row(values: Vec<SqlValue>) -> Row {
141 Row::new(values)
142 }
143
144 fn make_test_schema() -> CombinedSchema {
145 use vibesql_catalog::{ColumnSchema, TableSchema};
146 use vibesql_types::DataType;
147
148 let columns = vec![
149 ColumnSchema {
150 name: "id".to_string(),
151 data_type: DataType::Integer,
152 nullable: false,
153 default_value: None,
154 },
155 ColumnSchema {
156 name: "name".to_string(),
157 data_type: DataType::Varchar { max_length: Some(255) },
158 nullable: true,
159 default_value: None,
160 },
161 ];
162
163 let schema = TableSchema::new("users".to_string(), columns);
164 CombinedSchema::from_table("users".to_string(), schema)
165 }
166
167 #[test]
168 fn test_cache_hit() {
169 let cache = QueryResultCache::new(10);
170 let sig = QuerySignature::from_sql("SELECT * FROM users");
171 let rows = vec![
172 make_test_row(vec![SqlValue::Integer(1), SqlValue::Varchar("Alice".to_string())]),
173 make_test_row(vec![SqlValue::Integer(2), SqlValue::Varchar("Bob".to_string())]),
174 ];
175 let schema = make_test_schema();
176 let mut tables = HashSet::new();
177 tables.insert("users".to_string());
178
179 cache.insert(sig.clone(), rows.clone(), schema.clone(), tables);
180 let result = cache.get(&sig);
181
182 assert!(result.is_some());
183 let (cached_rows, _cached_schema) = result.unwrap();
184 assert_eq!(cached_rows.len(), 2);
185 assert_eq!(cached_rows[0].values.len(), 2);
186 }
187
188 #[test]
189 fn test_cache_miss() {
190 let cache = QueryResultCache::new(10);
191 let sig = QuerySignature::from_sql("SELECT * FROM users");
192
193 let result = cache.get(&sig);
194 assert!(result.is_none());
195 }
196
197 #[test]
198 fn test_lru_eviction() {
199 let cache = QueryResultCache::new(2);
200 let schema = make_test_schema();
201 let rows = vec![make_test_row(vec![SqlValue::Integer(1)])];
202 let tables = HashSet::new();
203
204 let sig1 = QuerySignature::from_sql("SELECT * FROM users");
205 let sig2 = QuerySignature::from_sql("SELECT * FROM orders");
206 let sig3 = QuerySignature::from_sql("SELECT * FROM products");
207
208 cache.insert(sig1, rows.clone(), schema.clone(), tables.clone());
209 cache.insert(sig2, rows.clone(), schema.clone(), tables.clone());
210 assert_eq!(cache.stats().size, 2);
211
212 cache.insert(sig3, rows, schema, tables);
213 assert_eq!(cache.stats().size, 2);
214 assert_eq!(cache.stats().evictions, 1);
215 }
216
217 #[test]
218 fn test_cache_clear() {
219 let cache = QueryResultCache::new(10);
220 let sig = QuerySignature::from_sql("SELECT * FROM users");
221 let rows = vec![make_test_row(vec![SqlValue::Integer(1)])];
222 let schema = make_test_schema();
223 let tables = HashSet::new();
224
225 cache.insert(sig.clone(), rows, schema, tables);
226 assert!(cache.contains(&sig));
227
228 cache.clear();
229 assert!(!cache.contains(&sig));
230 }
231
232 #[test]
233 fn test_table_invalidation() {
234 let cache = QueryResultCache::new(10);
235 let sig = QuerySignature::from_sql("SELECT * FROM users WHERE id = 1");
236 let rows = vec![make_test_row(vec![SqlValue::Integer(1)])];
237 let schema = make_test_schema();
238 let mut tables = HashSet::new();
239 tables.insert("users".to_string());
240
241 cache.insert(sig.clone(), rows, schema, tables);
242 assert!(cache.contains(&sig));
243
244 cache.invalidate_table("users");
245 assert!(!cache.contains(&sig));
246 }
247
248 #[test]
249 fn test_table_invalidation_case_insensitive() {
250 let cache = QueryResultCache::new(10);
251 let sig = QuerySignature::from_sql("SELECT * FROM users");
252 let rows = vec![make_test_row(vec![SqlValue::Integer(1)])];
253 let schema = make_test_schema();
254 let mut tables = HashSet::new();
255 tables.insert("users".to_string());
256
257 cache.insert(sig.clone(), rows, schema, tables);
258 assert!(cache.contains(&sig));
259
260 cache.invalidate_table("USERS");
262 assert!(!cache.contains(&sig));
263 }
264
265 #[test]
266 fn test_cache_stats() {
267 let cache = QueryResultCache::new(10);
268 let sig = QuerySignature::from_sql("SELECT * FROM users");
269 let rows = vec![make_test_row(vec![SqlValue::Integer(1)])];
270 let schema = make_test_schema();
271 let tables = HashSet::new();
272
273 cache.insert(sig.clone(), rows, schema, tables);
274
275 cache.get(&sig);
277 cache.get(&sig);
278
279 let other_sig = QuerySignature::from_sql("SELECT * FROM orders");
281 cache.get(&other_sig);
282
283 let stats = cache.stats();
284 assert_eq!(stats.hits, 2);
285 assert_eq!(stats.misses, 1);
286 assert!((stats.hit_rate - 2.0 / 3.0).abs() < 0.01);
287 }
288}