vibesql_executor/cache/
query_result_cache.rs

1//! Thread-safe query result cache with LRU eviction
2//!
3//! Caches actual query results (rows + schema) to avoid re-executing
4//! identical read-only queries. This is particularly effective for
5//! SQLLogicTest workloads with repeated query patterns.
6
7use std::{
8    collections::{HashMap, HashSet},
9    sync::{
10        atomic::{AtomicUsize, Ordering},
11        RwLock,
12    },
13};
14
15use vibesql_storage::Row;
16
17use super::{CacheStats, QuerySignature};
18use crate::schema::CombinedSchema;
19
20/// Cached query result with metadata
21#[derive(Clone)]
22struct CachedResult {
23    /// Query result rows
24    rows: Vec<Row>,
25    /// Schema for the result
26    schema: CombinedSchema,
27    /// Tables accessed by this query (for invalidation)
28    tables: HashSet<String>,
29}
30
31/// Thread-safe cache for query results
32///
33/// Caches the actual result rows and schema from SELECT queries.
34/// Results are invalidated when any referenced table is modified.
35pub struct QueryResultCache {
36    cache: RwLock<HashMap<QuerySignature, CachedResult>>,
37    max_size: usize,
38    hits: AtomicUsize,
39    misses: AtomicUsize,
40    evictions: AtomicUsize,
41}
42
43impl QueryResultCache {
44    /// Create a new result cache with specified max size
45    pub fn new(max_size: usize) -> Self {
46        Self {
47            cache: RwLock::new(HashMap::new()),
48            max_size,
49            hits: AtomicUsize::new(0),
50            misses: AtomicUsize::new(0),
51            evictions: AtomicUsize::new(0),
52        }
53    }
54
55    /// Try to get cached result for a query
56    pub fn get(&self, signature: &QuerySignature) -> Option<(Vec<Row>, CombinedSchema)> {
57        let cache = self.cache.read().unwrap();
58        if let Some(entry) = cache.get(signature) {
59            self.hits.fetch_add(1, Ordering::Relaxed);
60            Some((entry.rows.clone(), entry.schema.clone()))
61        } else {
62            self.misses.fetch_add(1, Ordering::Relaxed);
63            None
64        }
65    }
66
67    /// Insert query result into cache with table dependencies
68    pub fn insert(
69        &self,
70        signature: QuerySignature,
71        rows: Vec<Row>,
72        schema: CombinedSchema,
73        tables: HashSet<String>,
74    ) {
75        let entry = CachedResult { rows, schema, tables };
76        let mut cache = self.cache.write().unwrap();
77
78        // Simple LRU: evict first entry if at capacity
79        if cache.len() >= self.max_size {
80            if let Some(key) = cache.keys().next().cloned() {
81                cache.remove(&key);
82                self.evictions.fetch_add(1, Ordering::Relaxed);
83            }
84        }
85
86        cache.insert(signature, entry);
87    }
88
89    /// Check if signature is cached
90    pub fn contains(&self, signature: &QuerySignature) -> bool {
91        self.cache.read().unwrap().contains_key(signature)
92    }
93
94    /// Clear all cached results
95    pub fn clear(&self) {
96        self.cache.write().unwrap().clear();
97    }
98
99    /// Invalidate all queries touching a specific table
100    ///
101    /// This should be called when a table is modified (INSERT/UPDATE/DELETE)
102    pub fn invalidate_table(&self, table: &str) {
103        let mut cache = self.cache.write().unwrap();
104        cache.retain(|_, entry| !entry.tables.iter().any(|t| t.eq_ignore_ascii_case(table)));
105    }
106
107    /// Invalidate entire cache (e.g., on any write when not tracking dependencies)
108    pub fn invalidate_all(&self) {
109        self.clear();
110    }
111
112    /// Get cache statistics
113    pub fn stats(&self) -> CacheStats {
114        let cache = self.cache.read().unwrap();
115        let hits = self.hits.load(Ordering::Relaxed);
116        let misses = self.misses.load(Ordering::Relaxed);
117        let total = hits + misses;
118        let hit_rate = if total > 0 { hits as f64 / total as f64 } else { 0.0 };
119
120        CacheStats {
121            hits,
122            misses,
123            evictions: self.evictions.load(Ordering::Relaxed),
124            size: cache.len(),
125            hit_rate,
126        }
127    }
128
129    /// Get maximum cache size
130    pub fn max_size(&self) -> usize {
131        self.max_size
132    }
133}
134
135#[cfg(test)]
136mod tests {
137    use vibesql_types::SqlValue;
138
139    use super::*;
140
141    fn make_test_row(values: Vec<SqlValue>) -> Row {
142        Row::new(values)
143    }
144
145    fn make_test_schema() -> CombinedSchema {
146        use vibesql_catalog::{ColumnSchema, TableSchema};
147        use vibesql_types::DataType;
148
149        let columns = vec![
150            ColumnSchema {
151                name: "id".to_string(),
152                data_type: DataType::Integer,
153                nullable: false,
154                default_value: None,
155                generated_expr: None, is_exact_integer_type: false,
156                collation: None,
157            },
158            ColumnSchema {
159                name: "name".to_string(),
160                data_type: DataType::Varchar { max_length: Some(255) },
161                nullable: true,
162                default_value: None,
163                generated_expr: None, is_exact_integer_type: false,
164                collation: None,
165            },
166        ];
167
168        let schema = TableSchema::new("users".to_string(), columns);
169        CombinedSchema::from_table("users".to_string(), schema)
170    }
171
172    #[test]
173    fn test_cache_hit() {
174        let cache = QueryResultCache::new(10);
175        let sig = QuerySignature::from_sql("SELECT * FROM users");
176        let rows = vec![
177            make_test_row(vec![
178                SqlValue::Integer(1),
179                SqlValue::Varchar(arcstr::ArcStr::from("Alice")),
180            ]),
181            make_test_row(vec![
182                SqlValue::Integer(2),
183                SqlValue::Varchar(arcstr::ArcStr::from("Bob")),
184            ]),
185        ];
186        let schema = make_test_schema();
187        let mut tables = HashSet::new();
188        tables.insert("users".to_string());
189
190        cache.insert(sig.clone(), rows.clone(), schema.clone(), tables);
191        let result = cache.get(&sig);
192
193        assert!(result.is_some());
194        let (cached_rows, _cached_schema) = result.unwrap();
195        assert_eq!(cached_rows.len(), 2);
196        assert_eq!(cached_rows[0].values.len(), 2);
197    }
198
199    #[test]
200    fn test_cache_miss() {
201        let cache = QueryResultCache::new(10);
202        let sig = QuerySignature::from_sql("SELECT * FROM users");
203
204        let result = cache.get(&sig);
205        assert!(result.is_none());
206    }
207
208    #[test]
209    fn test_lru_eviction() {
210        let cache = QueryResultCache::new(2);
211        let schema = make_test_schema();
212        let rows = vec![make_test_row(vec![SqlValue::Integer(1)])];
213        let tables = HashSet::new();
214
215        let sig1 = QuerySignature::from_sql("SELECT * FROM users");
216        let sig2 = QuerySignature::from_sql("SELECT * FROM orders");
217        let sig3 = QuerySignature::from_sql("SELECT * FROM products");
218
219        cache.insert(sig1, rows.clone(), schema.clone(), tables.clone());
220        cache.insert(sig2, rows.clone(), schema.clone(), tables.clone());
221        assert_eq!(cache.stats().size, 2);
222
223        cache.insert(sig3, rows, schema, tables);
224        assert_eq!(cache.stats().size, 2);
225        assert_eq!(cache.stats().evictions, 1);
226    }
227
228    #[test]
229    fn test_cache_clear() {
230        let cache = QueryResultCache::new(10);
231        let sig = QuerySignature::from_sql("SELECT * FROM users");
232        let rows = vec![make_test_row(vec![SqlValue::Integer(1)])];
233        let schema = make_test_schema();
234        let tables = HashSet::new();
235
236        cache.insert(sig.clone(), rows, schema, tables);
237        assert!(cache.contains(&sig));
238
239        cache.clear();
240        assert!(!cache.contains(&sig));
241    }
242
243    #[test]
244    fn test_table_invalidation() {
245        let cache = QueryResultCache::new(10);
246        let sig = QuerySignature::from_sql("SELECT * FROM users WHERE id = 1");
247        let rows = vec![make_test_row(vec![SqlValue::Integer(1)])];
248        let schema = make_test_schema();
249        let mut tables = HashSet::new();
250        tables.insert("users".to_string());
251
252        cache.insert(sig.clone(), rows, schema, tables);
253        assert!(cache.contains(&sig));
254
255        cache.invalidate_table("users");
256        assert!(!cache.contains(&sig));
257    }
258
259    #[test]
260    fn test_table_invalidation_case_insensitive() {
261        let cache = QueryResultCache::new(10);
262        let sig = QuerySignature::from_sql("SELECT * FROM users");
263        let rows = vec![make_test_row(vec![SqlValue::Integer(1)])];
264        let schema = make_test_schema();
265        let mut tables = HashSet::new();
266        tables.insert("users".to_string());
267
268        cache.insert(sig.clone(), rows, schema, tables);
269        assert!(cache.contains(&sig));
270
271        // Invalidate with different case
272        cache.invalidate_table("USERS");
273        assert!(!cache.contains(&sig));
274    }
275
276    #[test]
277    fn test_cache_stats() {
278        let cache = QueryResultCache::new(10);
279        let sig = QuerySignature::from_sql("SELECT * FROM users");
280        let rows = vec![make_test_row(vec![SqlValue::Integer(1)])];
281        let schema = make_test_schema();
282        let tables = HashSet::new();
283
284        cache.insert(sig.clone(), rows, schema, tables);
285
286        // Generate hits
287        cache.get(&sig);
288        cache.get(&sig);
289
290        // Generate miss
291        let other_sig = QuerySignature::from_sql("SELECT * FROM orders");
292        cache.get(&other_sig);
293
294        let stats = cache.stats();
295        assert_eq!(stats.hits, 2);
296        assert_eq!(stats.misses, 1);
297        assert!((stats.hit_rate - 2.0 / 3.0).abs() < 0.01);
298    }
299}