1use std::{
8 collections::{HashMap, HashSet},
9 sync::{
10 atomic::{AtomicUsize, Ordering},
11 RwLock,
12 },
13};
14
15use vibesql_storage::Row;
16
17use super::{CacheStats, QuerySignature};
18use crate::schema::CombinedSchema;
19
20#[derive(Clone)]
22struct CachedResult {
23 rows: Vec<Row>,
25 schema: CombinedSchema,
27 tables: HashSet<String>,
29}
30
31pub struct QueryResultCache {
36 cache: RwLock<HashMap<QuerySignature, CachedResult>>,
37 max_size: usize,
38 hits: AtomicUsize,
39 misses: AtomicUsize,
40 evictions: AtomicUsize,
41}
42
43impl QueryResultCache {
44 pub fn new(max_size: usize) -> Self {
46 Self {
47 cache: RwLock::new(HashMap::new()),
48 max_size,
49 hits: AtomicUsize::new(0),
50 misses: AtomicUsize::new(0),
51 evictions: AtomicUsize::new(0),
52 }
53 }
54
55 pub fn get(&self, signature: &QuerySignature) -> Option<(Vec<Row>, CombinedSchema)> {
57 let cache = self.cache.read().unwrap();
58 if let Some(entry) = cache.get(signature) {
59 self.hits.fetch_add(1, Ordering::Relaxed);
60 Some((entry.rows.clone(), entry.schema.clone()))
61 } else {
62 self.misses.fetch_add(1, Ordering::Relaxed);
63 None
64 }
65 }
66
67 pub fn insert(
69 &self,
70 signature: QuerySignature,
71 rows: Vec<Row>,
72 schema: CombinedSchema,
73 tables: HashSet<String>,
74 ) {
75 let entry = CachedResult { rows, schema, tables };
76 let mut cache = self.cache.write().unwrap();
77
78 if cache.len() >= self.max_size {
80 if let Some(key) = cache.keys().next().cloned() {
81 cache.remove(&key);
82 self.evictions.fetch_add(1, Ordering::Relaxed);
83 }
84 }
85
86 cache.insert(signature, entry);
87 }
88
89 pub fn contains(&self, signature: &QuerySignature) -> bool {
91 self.cache.read().unwrap().contains_key(signature)
92 }
93
94 pub fn clear(&self) {
96 self.cache.write().unwrap().clear();
97 }
98
99 pub fn invalidate_table(&self, table: &str) {
103 let mut cache = self.cache.write().unwrap();
104 cache.retain(|_, entry| !entry.tables.iter().any(|t| t.eq_ignore_ascii_case(table)));
105 }
106
107 pub fn invalidate_all(&self) {
109 self.clear();
110 }
111
112 pub fn stats(&self) -> CacheStats {
114 let cache = self.cache.read().unwrap();
115 let hits = self.hits.load(Ordering::Relaxed);
116 let misses = self.misses.load(Ordering::Relaxed);
117 let total = hits + misses;
118 let hit_rate = if total > 0 { hits as f64 / total as f64 } else { 0.0 };
119
120 CacheStats {
121 hits,
122 misses,
123 evictions: self.evictions.load(Ordering::Relaxed),
124 size: cache.len(),
125 hit_rate,
126 }
127 }
128
129 pub fn max_size(&self) -> usize {
131 self.max_size
132 }
133}
134
135#[cfg(test)]
136mod tests {
137 use vibesql_types::SqlValue;
138
139 use super::*;
140
141 fn make_test_row(values: Vec<SqlValue>) -> Row {
142 Row::new(values)
143 }
144
145 fn make_test_schema() -> CombinedSchema {
146 use vibesql_catalog::{ColumnSchema, TableSchema};
147 use vibesql_types::DataType;
148
149 let columns = vec![
150 ColumnSchema {
151 name: "id".to_string(),
152 data_type: DataType::Integer,
153 nullable: false,
154 default_value: None,
155 generated_expr: None, is_exact_integer_type: false,
156 collation: None,
157 },
158 ColumnSchema {
159 name: "name".to_string(),
160 data_type: DataType::Varchar { max_length: Some(255) },
161 nullable: true,
162 default_value: None,
163 generated_expr: None, is_exact_integer_type: false,
164 collation: None,
165 },
166 ];
167
168 let schema = TableSchema::new("users".to_string(), columns);
169 CombinedSchema::from_table("users".to_string(), schema)
170 }
171
172 #[test]
173 fn test_cache_hit() {
174 let cache = QueryResultCache::new(10);
175 let sig = QuerySignature::from_sql("SELECT * FROM users");
176 let rows = vec![
177 make_test_row(vec![
178 SqlValue::Integer(1),
179 SqlValue::Varchar(arcstr::ArcStr::from("Alice")),
180 ]),
181 make_test_row(vec![
182 SqlValue::Integer(2),
183 SqlValue::Varchar(arcstr::ArcStr::from("Bob")),
184 ]),
185 ];
186 let schema = make_test_schema();
187 let mut tables = HashSet::new();
188 tables.insert("users".to_string());
189
190 cache.insert(sig.clone(), rows.clone(), schema.clone(), tables);
191 let result = cache.get(&sig);
192
193 assert!(result.is_some());
194 let (cached_rows, _cached_schema) = result.unwrap();
195 assert_eq!(cached_rows.len(), 2);
196 assert_eq!(cached_rows[0].values.len(), 2);
197 }
198
199 #[test]
200 fn test_cache_miss() {
201 let cache = QueryResultCache::new(10);
202 let sig = QuerySignature::from_sql("SELECT * FROM users");
203
204 let result = cache.get(&sig);
205 assert!(result.is_none());
206 }
207
208 #[test]
209 fn test_lru_eviction() {
210 let cache = QueryResultCache::new(2);
211 let schema = make_test_schema();
212 let rows = vec![make_test_row(vec![SqlValue::Integer(1)])];
213 let tables = HashSet::new();
214
215 let sig1 = QuerySignature::from_sql("SELECT * FROM users");
216 let sig2 = QuerySignature::from_sql("SELECT * FROM orders");
217 let sig3 = QuerySignature::from_sql("SELECT * FROM products");
218
219 cache.insert(sig1, rows.clone(), schema.clone(), tables.clone());
220 cache.insert(sig2, rows.clone(), schema.clone(), tables.clone());
221 assert_eq!(cache.stats().size, 2);
222
223 cache.insert(sig3, rows, schema, tables);
224 assert_eq!(cache.stats().size, 2);
225 assert_eq!(cache.stats().evictions, 1);
226 }
227
228 #[test]
229 fn test_cache_clear() {
230 let cache = QueryResultCache::new(10);
231 let sig = QuerySignature::from_sql("SELECT * FROM users");
232 let rows = vec![make_test_row(vec![SqlValue::Integer(1)])];
233 let schema = make_test_schema();
234 let tables = HashSet::new();
235
236 cache.insert(sig.clone(), rows, schema, tables);
237 assert!(cache.contains(&sig));
238
239 cache.clear();
240 assert!(!cache.contains(&sig));
241 }
242
243 #[test]
244 fn test_table_invalidation() {
245 let cache = QueryResultCache::new(10);
246 let sig = QuerySignature::from_sql("SELECT * FROM users WHERE id = 1");
247 let rows = vec![make_test_row(vec![SqlValue::Integer(1)])];
248 let schema = make_test_schema();
249 let mut tables = HashSet::new();
250 tables.insert("users".to_string());
251
252 cache.insert(sig.clone(), rows, schema, tables);
253 assert!(cache.contains(&sig));
254
255 cache.invalidate_table("users");
256 assert!(!cache.contains(&sig));
257 }
258
259 #[test]
260 fn test_table_invalidation_case_insensitive() {
261 let cache = QueryResultCache::new(10);
262 let sig = QuerySignature::from_sql("SELECT * FROM users");
263 let rows = vec![make_test_row(vec![SqlValue::Integer(1)])];
264 let schema = make_test_schema();
265 let mut tables = HashSet::new();
266 tables.insert("users".to_string());
267
268 cache.insert(sig.clone(), rows, schema, tables);
269 assert!(cache.contains(&sig));
270
271 cache.invalidate_table("USERS");
273 assert!(!cache.contains(&sig));
274 }
275
276 #[test]
277 fn test_cache_stats() {
278 let cache = QueryResultCache::new(10);
279 let sig = QuerySignature::from_sql("SELECT * FROM users");
280 let rows = vec![make_test_row(vec![SqlValue::Integer(1)])];
281 let schema = make_test_schema();
282 let tables = HashSet::new();
283
284 cache.insert(sig.clone(), rows, schema, tables);
285
286 cache.get(&sig);
288 cache.get(&sig);
289
290 let other_sig = QuerySignature::from_sql("SELECT * FROM orders");
292 cache.get(&other_sig);
293
294 let stats = cache.stats();
295 assert_eq!(stats.hits, 2);
296 assert_eq!(stats.misses, 1);
297 assert!((stats.hit_rate - 2.0 / 3.0).abs() < 0.01);
298 }
299}