1use alloc::collections::{BTreeMap, VecDeque};
20use alloc::string::String;
21use alloc::vec::Vec;
22
23use spg_sql::ast::{Expr, FromClause, FromJoin, SelectItem, SelectStatement, Statement, TableRef};
24use spg_storage::ColumnSchema;
25
26pub(crate) const PLAN_CACHE_MAX_ENTRIES: usize = 256;
31
32#[derive(Debug, Clone)]
36pub struct PreparedPlan {
37 pub stmt: Statement,
38 pub statistics_version: u64,
42 pub source_tables: Vec<String>,
45 pub describe_columns: Vec<ColumnSchema>,
48}
49
50#[derive(Debug, Clone)]
51pub struct PlanCache {
52 entries: BTreeMap<String, PreparedPlan>,
56 lru: VecDeque<String>,
60 max_entries: usize,
65}
66
67impl Default for PlanCache {
68 fn default() -> Self {
69 Self {
70 entries: BTreeMap::new(),
71 lru: VecDeque::new(),
72 max_entries: PLAN_CACHE_MAX_ENTRIES,
73 }
74 }
75}
76
77impl PlanCache {
78 pub fn new() -> Self {
79 Self::default()
80 }
81
82 pub fn set_max_entries(&mut self, n: usize) {
88 self.max_entries = n.max(1).min(PLAN_CACHE_MAX_ENTRIES);
89 }
90
91 pub fn max_entries(&self) -> usize {
92 self.max_entries
93 }
94
95 pub fn len(&self) -> usize {
96 self.entries.len()
97 }
98
99 pub fn is_empty(&self) -> bool {
100 self.entries.is_empty()
101 }
102
103 pub fn get_snapshot(&self, sql: &str) -> Option<&PreparedPlan> {
107 self.entries.get(sql)
108 }
109
110 pub fn get(&mut self, sql: &str) -> Option<&PreparedPlan> {
113 if !self.entries.contains_key(sql) {
114 return None;
115 }
116 if let Some(idx) = self.lru.iter().position(|k| k == sql) {
117 let key = self.lru.remove(idx).expect("idx came from position()");
118 self.lru.push_back(key);
119 }
120 self.entries.get(sql)
121 }
122
123 pub fn insert(&mut self, sql: String, plan: PreparedPlan) {
126 if self.entries.contains_key(&sql) {
127 if let Some(idx) = self.lru.iter().position(|k| k == &sql) {
128 let key = self.lru.remove(idx).expect("idx came from position()");
129 self.lru.push_back(key);
130 }
131 self.entries.insert(sql, plan);
132 return;
133 }
134 if self.entries.len() >= self.max_entries {
135 if let Some(oldest) = self.lru.pop_front() {
136 self.entries.remove(&oldest);
137 }
138 }
139 self.lru.push_back(sql.clone());
140 self.entries.insert(sql, plan);
141 }
142
143 pub fn clear(&mut self) {
144 self.entries.clear();
145 self.lru.clear();
146 }
147
148 pub fn evict(&mut self, sql: &str) -> Option<PreparedPlan> {
151 let plan = self.entries.remove(sql)?;
152 if let Some(idx) = self.lru.iter().position(|k| k == sql) {
153 self.lru.remove(idx);
154 }
155 Some(plan)
156 }
157
158 pub fn evict_referencing(&mut self, table: &str) -> usize {
161 let to_evict: Vec<String> = self
162 .entries
163 .iter()
164 .filter_map(|(k, p)| {
165 if p.source_tables.iter().any(|t| t == table) {
166 Some(k.clone())
167 } else {
168 None
169 }
170 })
171 .collect();
172 let n = to_evict.len();
173 for k in to_evict {
174 self.entries.remove(&k);
175 if let Some(idx) = self.lru.iter().position(|x| x == &k) {
176 self.lru.remove(idx);
177 }
178 }
179 n
180 }
181}
182
183pub fn collect_source_tables(stmt: &Statement) -> Vec<String> {
188 let mut out: Vec<String> = Vec::new();
189 match stmt {
190 Statement::Select(s) => collect_from_select(s, &mut out),
191 Statement::Insert(s) => push_unique(&mut out, &s.table),
192 Statement::Update(s) => {
193 push_unique(&mut out, &s.table);
194 if let Some(w) = &s.where_ {
195 collect_expr(w, &mut out);
196 }
197 }
198 Statement::Delete(s) => {
199 push_unique(&mut out, &s.table);
200 if let Some(w) = &s.where_ {
201 collect_expr(w, &mut out);
202 }
203 }
204 Statement::Explain(inner) => {
205 collect_from_select(&inner.inner, &mut out);
206 }
207 _ => {}
208 }
209 out.sort();
210 out.dedup();
211 out
212}
213
214fn collect_from_select(s: &SelectStatement, out: &mut Vec<String>) {
215 if let Some(from) = &s.from {
216 collect_from_clause(from, out);
217 }
218 if let Some(w) = &s.where_ {
219 collect_expr(w, out);
220 }
221 if let Some(h) = &s.having {
222 collect_expr(h, out);
223 }
224 for item in &s.items {
225 if let SelectItem::Expr { expr, .. } = item {
226 collect_expr(expr, out);
227 }
228 }
229 for (_, peer) in &s.unions {
230 collect_from_select(peer, out);
231 }
232}
233
234fn collect_from_clause(from: &FromClause, out: &mut Vec<String>) {
235 collect_table_ref(&from.primary, out);
236 for j in &from.joins {
237 collect_from_join(j, out);
238 }
239}
240
241fn collect_from_join(j: &FromJoin, out: &mut Vec<String>) {
242 collect_table_ref(&j.table, out);
243 if let Some(on) = &j.on {
244 collect_expr(on, out);
245 }
246}
247
248fn collect_table_ref(t: &TableRef, out: &mut Vec<String>) {
249 push_unique(out, &t.name);
250}
251
252fn collect_expr(e: &Expr, out: &mut Vec<String>) {
253 match e {
254 Expr::AggregateOrdered {
255 call,
256 order_by,
257 filter,
258 ..
259 } => {
260 collect_expr(call, out);
261 for o in order_by {
262 collect_expr(&o.expr, out);
263 }
264 if let Some(f) = filter {
268 collect_expr(f, out);
269 }
270 }
271 Expr::ScalarSubquery(inner) => collect_from_select(inner, out),
272 Expr::Exists { subquery, .. } => collect_from_select(subquery, out),
273 Expr::InSubquery { expr, subquery, .. } => {
274 collect_expr(expr, out);
275 collect_from_select(subquery, out);
276 }
277 Expr::Binary { lhs, rhs, .. } => {
278 collect_expr(lhs, out);
279 collect_expr(rhs, out);
280 }
281 Expr::Unary { expr, .. } => collect_expr(expr, out),
282 Expr::Cast { expr, .. } => collect_expr(expr, out),
283 Expr::IsNull { expr, .. } => collect_expr(expr, out),
284 Expr::Like { expr, pattern, .. } => {
285 collect_expr(expr, out);
286 collect_expr(pattern, out);
287 }
288 Expr::FunctionCall { args, .. } => {
289 for a in args {
290 collect_expr(a, out);
291 }
292 }
293 Expr::WindowFunction {
294 args,
295 partition_by,
296 order_by,
297 ..
298 } => {
299 for a in args {
300 collect_expr(a, out);
301 }
302 for p in partition_by {
303 collect_expr(p, out);
304 }
305 for (o, _, _) in order_by {
306 collect_expr(o, out);
307 }
308 }
309 Expr::Extract { source, .. } => collect_expr(source, out),
310 Expr::Array(items) => {
311 for elem in items {
312 collect_expr(elem, out);
313 }
314 }
315 Expr::ArraySubscript { target, index } => {
316 collect_expr(target, out);
317 collect_expr(index, out);
318 }
319 Expr::AnyAll { expr, array, .. } => {
320 collect_expr(expr, out);
321 collect_expr(array, out);
322 }
323 Expr::InList { expr, list, .. } => {
324 collect_expr(expr, out);
325 for item in list {
326 collect_expr(item, out);
327 }
328 }
329 Expr::Case {
330 operand,
331 branches,
332 else_branch,
333 } => {
334 if let Some(o) = operand {
335 collect_expr(o, out);
336 }
337 for (w, t) in branches {
338 collect_expr(w, out);
339 collect_expr(t, out);
340 }
341 if let Some(e) = else_branch {
342 collect_expr(e, out);
343 }
344 }
345 Expr::Literal(_) | Expr::Column(_) | Expr::Placeholder(_) => {}
346 }
347}
348
349fn push_unique(out: &mut Vec<String>, s: &str) {
350 if !out.iter().any(|x| x == s) {
351 out.push(String::from(s));
352 }
353}
354
355#[cfg(test)]
358mod tests {
359 use super::*;
360 use alloc::string::ToString;
361 use spg_sql::parser::parse_statement;
362
363 fn dummy_plan(version: u64, tables: &[&str]) -> PreparedPlan {
364 let stmt = parse_statement("SELECT 1").expect("trivial SELECT parses");
365 PreparedPlan {
366 stmt,
367 statistics_version: version,
368 source_tables: tables.iter().map(|s| s.to_string()).collect(),
369 describe_columns: Vec::new(),
370 }
371 }
372
373 #[test]
374 fn new_cache_is_empty() {
375 let cache = PlanCache::new();
376 assert!(cache.is_empty());
377 assert_eq!(cache.len(), 0);
378 }
379
380 #[test]
381 fn insert_then_get_returns_the_plan() {
382 let mut cache = PlanCache::new();
383 cache.insert("SELECT 1".into(), dummy_plan(0, &["t"]));
384 assert_eq!(cache.len(), 1);
385 let plan = cache.get("SELECT 1").expect("hit");
386 assert_eq!(plan.source_tables, alloc::vec!["t".to_string()]);
387 }
388
389 #[test]
390 fn miss_returns_none() {
391 let mut cache = PlanCache::new();
392 cache.insert("SELECT 1".into(), dummy_plan(0, &[]));
393 assert!(cache.get("SELECT 2").is_none());
394 }
395
396 #[test]
397 fn replace_overwrites_existing_entry() {
398 let mut cache = PlanCache::new();
399 cache.insert("SELECT 1".into(), dummy_plan(1, &["a"]));
400 cache.insert("SELECT 1".into(), dummy_plan(2, &["b"]));
401 assert_eq!(cache.len(), 1);
402 let plan = cache.get("SELECT 1").expect("hit");
403 assert_eq!(plan.statistics_version, 2);
404 }
405
406 #[test]
407 fn lru_evicts_oldest_at_cap() {
408 let mut cache = PlanCache::new();
409 for i in 0..PLAN_CACHE_MAX_ENTRIES {
410 cache.insert(alloc::format!("SELECT {i}"), dummy_plan(i as u64, &[]));
411 }
412 assert_eq!(cache.len(), PLAN_CACHE_MAX_ENTRIES);
413 cache.insert("SELECT new".into(), dummy_plan(999, &[]));
414 assert_eq!(cache.len(), PLAN_CACHE_MAX_ENTRIES);
415 assert!(cache.get("SELECT 0").is_none());
416 assert!(cache.get("SELECT new").is_some());
417 }
418
419 #[test]
420 fn get_promotes_lru_position() {
421 let mut cache = PlanCache::new();
422 cache.insert("a".into(), dummy_plan(0, &[]));
423 cache.insert("b".into(), dummy_plan(0, &[]));
424 cache.insert("c".into(), dummy_plan(0, &[]));
425 let _ = cache.get("a");
427 for i in 0..(PLAN_CACHE_MAX_ENTRIES - 3) {
430 cache.insert(alloc::format!("filler{i}"), dummy_plan(0, &[]));
431 }
432 cache.insert("trigger".into(), dummy_plan(0, &[]));
433 assert!(
434 cache.get("a").is_some(),
435 "a was MRU after get(); should survive"
436 );
437 assert!(cache.get("b").is_none(), "b should be evicted");
438 }
439
440 #[test]
441 fn clear_drops_everything() {
442 let mut cache = PlanCache::new();
443 cache.insert("a".into(), dummy_plan(0, &[]));
444 cache.insert("b".into(), dummy_plan(0, &[]));
445 cache.clear();
446 assert!(cache.is_empty());
447 assert!(cache.get("a").is_none());
448 }
449
450 #[test]
451 fn evict_referencing_drops_only_matching_plans() {
452 let mut cache = PlanCache::new();
453 cache.insert("a".into(), dummy_plan(0, &["users"]));
454 cache.insert("b".into(), dummy_plan(0, &["orders"]));
455 cache.insert("c".into(), dummy_plan(0, &["users", "orders"]));
456 let n = cache.evict_referencing("users");
457 assert_eq!(n, 2);
458 assert!(cache.get("a").is_none());
459 assert!(cache.get("b").is_some());
460 assert!(cache.get("c").is_none());
461 }
462
463 #[test]
464 fn collect_source_tables_from_simple_select() {
465 let stmt = parse_statement("SELECT a, b FROM t1 WHERE x = 1").expect("parses");
466 let tables = collect_source_tables(&stmt);
467 assert_eq!(tables, alloc::vec!["t1".to_string()]);
468 }
469
470 #[test]
471 fn collect_source_tables_from_join() {
472 let stmt =
473 parse_statement("SELECT * FROM t1 JOIN t2 ON t1.a = t2.b JOIN t3 ON t2.c = t3.d")
474 .expect("parses");
475 let tables = collect_source_tables(&stmt);
476 assert_eq!(
477 tables,
478 alloc::vec!["t1".to_string(), "t2".to_string(), "t3".to_string()]
479 );
480 }
481
482 #[test]
483 fn collect_source_tables_dedupes_self_join() {
484 let stmt = parse_statement("SELECT * FROM t1 a JOIN t1 b ON a.x = b.y").expect("parses");
485 let tables = collect_source_tables(&stmt);
486 assert_eq!(tables, alloc::vec!["t1".to_string()]);
487 }
488}