vibesql_executor/cache/
query_result_cache.rs

1//! Thread-safe query result cache with LRU eviction
2//!
3//! Caches actual query results (rows + schema) to avoid re-executing
4//! identical read-only queries. This is particularly effective for
5//! SQLLogicTest workloads with repeated query patterns.
6
7use std::{
8    collections::{HashMap, HashSet},
9    sync::{
10        atomic::{AtomicUsize, Ordering},
11        RwLock,
12    },
13};
14
15use vibesql_storage::Row;
16
17use super::{CacheStats, QuerySignature};
18use crate::schema::CombinedSchema;
19
20/// Cached query result with metadata
21#[derive(Clone)]
22struct CachedResult {
23    /// Query result rows
24    rows: Vec<Row>,
25    /// Schema for the result
26    schema: CombinedSchema,
27    /// Tables accessed by this query (for invalidation)
28    tables: HashSet<String>,
29}
30
31/// Thread-safe cache for query results
32///
33/// Caches the actual result rows and schema from SELECT queries.
34/// Results are invalidated when any referenced table is modified.
35pub struct QueryResultCache {
36    cache: RwLock<HashMap<QuerySignature, CachedResult>>,
37    max_size: usize,
38    hits: AtomicUsize,
39    misses: AtomicUsize,
40    evictions: AtomicUsize,
41}
42
43impl QueryResultCache {
44    /// Create a new result cache with specified max size
45    pub fn new(max_size: usize) -> Self {
46        Self {
47            cache: RwLock::new(HashMap::new()),
48            max_size,
49            hits: AtomicUsize::new(0),
50            misses: AtomicUsize::new(0),
51            evictions: AtomicUsize::new(0),
52        }
53    }
54
55    /// Try to get cached result for a query
56    pub fn get(&self, signature: &QuerySignature) -> Option<(Vec<Row>, CombinedSchema)> {
57        let cache = self.cache.read().unwrap();
58        if let Some(entry) = cache.get(signature) {
59            self.hits.fetch_add(1, Ordering::Relaxed);
60            Some((entry.rows.clone(), entry.schema.clone()))
61        } else {
62            self.misses.fetch_add(1, Ordering::Relaxed);
63            None
64        }
65    }
66
67    /// Insert query result into cache with table dependencies
68    pub fn insert(
69        &self,
70        signature: QuerySignature,
71        rows: Vec<Row>,
72        schema: CombinedSchema,
73        tables: HashSet<String>,
74    ) {
75        let entry = CachedResult { rows, schema, tables };
76        let mut cache = self.cache.write().unwrap();
77
78        // Simple LRU: evict first entry if at capacity
79        if cache.len() >= self.max_size {
80            if let Some(key) = cache.keys().next().cloned() {
81                cache.remove(&key);
82                self.evictions.fetch_add(1, Ordering::Relaxed);
83            }
84        }
85
86        cache.insert(signature, entry);
87    }
88
89    /// Check if signature is cached
90    pub fn contains(&self, signature: &QuerySignature) -> bool {
91        self.cache.read().unwrap().contains_key(signature)
92    }
93
94    /// Clear all cached results
95    pub fn clear(&self) {
96        self.cache.write().unwrap().clear();
97    }
98
99    /// Invalidate all queries touching a specific table
100    ///
101    /// This should be called when a table is modified (INSERT/UPDATE/DELETE)
102    pub fn invalidate_table(&self, table: &str) {
103        let mut cache = self.cache.write().unwrap();
104        cache.retain(|_, entry| !entry.tables.iter().any(|t| t.eq_ignore_ascii_case(table)));
105    }
106
107    /// Invalidate entire cache (e.g., on any write when not tracking dependencies)
108    pub fn invalidate_all(&self) {
109        self.clear();
110    }
111
112    /// Get cache statistics
113    pub fn stats(&self) -> CacheStats {
114        let cache = self.cache.read().unwrap();
115        let hits = self.hits.load(Ordering::Relaxed);
116        let misses = self.misses.load(Ordering::Relaxed);
117        let total = hits + misses;
118        let hit_rate = if total > 0 { hits as f64 / total as f64 } else { 0.0 };
119
120        CacheStats {
121            hits,
122            misses,
123            evictions: self.evictions.load(Ordering::Relaxed),
124            size: cache.len(),
125            hit_rate,
126        }
127    }
128
129    /// Get maximum cache size
130    pub fn max_size(&self) -> usize {
131        self.max_size
132    }
133}
134
135#[cfg(test)]
136mod tests {
137    use super::*;
138    use vibesql_types::SqlValue;
139
140    fn make_test_row(values: Vec<SqlValue>) -> Row {
141        Row::new(values)
142    }
143
144    fn make_test_schema() -> CombinedSchema {
145        use vibesql_catalog::{ColumnSchema, TableSchema};
146        use vibesql_types::DataType;
147
148        let columns = vec![
149            ColumnSchema {
150                name: "id".to_string(),
151                data_type: DataType::Integer,
152                nullable: false,
153                default_value: None,
154            },
155            ColumnSchema {
156                name: "name".to_string(),
157                data_type: DataType::Varchar { max_length: Some(255) },
158                nullable: true,
159                default_value: None,
160            },
161        ];
162
163        let schema = TableSchema::new("users".to_string(), columns);
164        CombinedSchema::from_table("users".to_string(), schema)
165    }
166
167    #[test]
168    fn test_cache_hit() {
169        let cache = QueryResultCache::new(10);
170        let sig = QuerySignature::from_sql("SELECT * FROM users");
171        let rows = vec![
172            make_test_row(vec![SqlValue::Integer(1), SqlValue::Varchar(arcstr::ArcStr::from("Alice"))]),
173            make_test_row(vec![SqlValue::Integer(2), SqlValue::Varchar(arcstr::ArcStr::from("Bob"))]),
174        ];
175        let schema = make_test_schema();
176        let mut tables = HashSet::new();
177        tables.insert("users".to_string());
178
179        cache.insert(sig.clone(), rows.clone(), schema.clone(), tables);
180        let result = cache.get(&sig);
181
182        assert!(result.is_some());
183        let (cached_rows, _cached_schema) = result.unwrap();
184        assert_eq!(cached_rows.len(), 2);
185        assert_eq!(cached_rows[0].values.len(), 2);
186    }
187
188    #[test]
189    fn test_cache_miss() {
190        let cache = QueryResultCache::new(10);
191        let sig = QuerySignature::from_sql("SELECT * FROM users");
192
193        let result = cache.get(&sig);
194        assert!(result.is_none());
195    }
196
197    #[test]
198    fn test_lru_eviction() {
199        let cache = QueryResultCache::new(2);
200        let schema = make_test_schema();
201        let rows = vec![make_test_row(vec![SqlValue::Integer(1)])];
202        let tables = HashSet::new();
203
204        let sig1 = QuerySignature::from_sql("SELECT * FROM users");
205        let sig2 = QuerySignature::from_sql("SELECT * FROM orders");
206        let sig3 = QuerySignature::from_sql("SELECT * FROM products");
207
208        cache.insert(sig1, rows.clone(), schema.clone(), tables.clone());
209        cache.insert(sig2, rows.clone(), schema.clone(), tables.clone());
210        assert_eq!(cache.stats().size, 2);
211
212        cache.insert(sig3, rows, schema, tables);
213        assert_eq!(cache.stats().size, 2);
214        assert_eq!(cache.stats().evictions, 1);
215    }
216
217    #[test]
218    fn test_cache_clear() {
219        let cache = QueryResultCache::new(10);
220        let sig = QuerySignature::from_sql("SELECT * FROM users");
221        let rows = vec![make_test_row(vec![SqlValue::Integer(1)])];
222        let schema = make_test_schema();
223        let tables = HashSet::new();
224
225        cache.insert(sig.clone(), rows, schema, tables);
226        assert!(cache.contains(&sig));
227
228        cache.clear();
229        assert!(!cache.contains(&sig));
230    }
231
232    #[test]
233    fn test_table_invalidation() {
234        let cache = QueryResultCache::new(10);
235        let sig = QuerySignature::from_sql("SELECT * FROM users WHERE id = 1");
236        let rows = vec![make_test_row(vec![SqlValue::Integer(1)])];
237        let schema = make_test_schema();
238        let mut tables = HashSet::new();
239        tables.insert("users".to_string());
240
241        cache.insert(sig.clone(), rows, schema, tables);
242        assert!(cache.contains(&sig));
243
244        cache.invalidate_table("users");
245        assert!(!cache.contains(&sig));
246    }
247
248    #[test]
249    fn test_table_invalidation_case_insensitive() {
250        let cache = QueryResultCache::new(10);
251        let sig = QuerySignature::from_sql("SELECT * FROM users");
252        let rows = vec![make_test_row(vec![SqlValue::Integer(1)])];
253        let schema = make_test_schema();
254        let mut tables = HashSet::new();
255        tables.insert("users".to_string());
256
257        cache.insert(sig.clone(), rows, schema, tables);
258        assert!(cache.contains(&sig));
259
260        // Invalidate with different case
261        cache.invalidate_table("USERS");
262        assert!(!cache.contains(&sig));
263    }
264
265    #[test]
266    fn test_cache_stats() {
267        let cache = QueryResultCache::new(10);
268        let sig = QuerySignature::from_sql("SELECT * FROM users");
269        let rows = vec![make_test_row(vec![SqlValue::Integer(1)])];
270        let schema = make_test_schema();
271        let tables = HashSet::new();
272
273        cache.insert(sig.clone(), rows, schema, tables);
274
275        // Generate hits
276        cache.get(&sig);
277        cache.get(&sig);
278
279        // Generate miss
280        let other_sig = QuerySignature::from_sql("SELECT * FROM orders");
281        cache.get(&other_sig);
282
283        let stats = cache.stats();
284        assert_eq!(stats.hits, 2);
285        assert_eq!(stats.misses, 1);
286        assert!((stats.hit_rate - 2.0 / 3.0).abs() < 0.01);
287    }
288}