vibesql_executor/cache/prepared_statement/
plan.rs1use std::sync::{Arc, OnceLock};
37
38use vibesql_ast::{
39 DeleteStmt, Expression, FromClause, SelectItem, SelectStmt, Statement, WhereClause,
40};
41
42#[derive(Debug, Clone)]
44pub enum CachedPlan {
45 PkPointLookup(PkPointLookupPlan),
48
49 SimpleFastPath(SimpleFastPathPlan),
53
54 PkDelete(PkDeletePlan),
57
58 Standard,
61}
62
63impl CachedPlan {
64 pub fn is_fast_path(&self) -> bool {
66 !matches!(self, CachedPlan::Standard)
67 }
68}
69
70#[derive(Debug, Clone)]
87pub struct SimpleFastPathPlan {
88 pub table_name: String,
90
91 resolved_columns: Arc<OnceLock<Arc<[String]>>>,
94}
95
96impl SimpleFastPathPlan {
97 pub fn new(table_name: String) -> Self {
99 Self { table_name, resolved_columns: Arc::new(OnceLock::new()) }
100 }
101
102 pub fn get_or_resolve_columns<F>(&self, resolver: F) -> Option<&Arc<[String]>>
107 where
108 F: FnOnce() -> Option<Vec<String>>,
109 {
110 self.resolved_columns.get_or_init(|| {
111 resolver().map(|v| v.into()).unwrap_or_else(|| Arc::from([]))
113 });
114
115 self.resolved_columns.get().filter(|cols| !cols.is_empty())
117 }
118
119 pub fn is_resolved(&self) -> bool {
121 self.resolved_columns.get().is_some()
122 }
123}
124
125#[derive(Debug, Clone)]
127pub struct PkPointLookupPlan {
128 pub table_name: String,
130
131 pub pk_columns: Vec<String>,
133
134 pub param_to_pk_col: Vec<(usize, usize)>,
137
138 pub projection: ProjectionPlan,
140
141 resolved: Arc<OnceLock<ResolvedProjection>>,
144}
145
146#[derive(Debug, Clone)]
150pub struct ResolvedProjection {
151 pub column_indices: Vec<usize>,
153 pub column_names: Arc<[String]>,
155}
156
157impl PkPointLookupPlan {
158 pub fn get_or_resolve<F>(&self, resolver: F) -> Option<&ResolvedProjection>
163 where
164 F: FnOnce(&ProjectionPlan) -> Option<ResolvedProjection>,
165 {
166 self.resolved.get_or_init(|| {
167 resolver(&self.projection).unwrap_or(ResolvedProjection {
171 column_indices: vec![],
172 column_names: Arc::from([]),
173 })
174 });
175
176 self.resolved.get().filter(|r| {
178 !r.column_names.is_empty() || matches!(self.projection, ProjectionPlan::Wildcard)
179 })
180 }
181
182 pub fn is_resolved(&self) -> bool {
184 self.resolved.get().is_some()
185 }
186}
187
188#[derive(Debug, Clone)]
190pub enum ProjectionPlan {
191 Wildcard,
193
194 Columns(Vec<ColumnProjection>),
196}
197
198#[derive(Debug, Clone)]
200pub struct ColumnProjection {
201 pub column_name: String,
203
204 pub alias: Option<String>,
206}
207
208#[derive(Debug, Clone)]
218pub struct PkDeletePlan {
219 pub table_name: String,
221
222 pub pk_columns: Vec<String>,
224
225 pub param_to_pk_col: Vec<(usize, usize)>,
228
229 fast_path_valid: Arc<OnceLock<bool>>,
232}
233
234impl PkDeletePlan {
235 pub fn new(
237 table_name: String,
238 pk_columns: Vec<String>,
239 param_to_pk_col: Vec<(usize, usize)>,
240 ) -> Self {
241 Self { table_name, pk_columns, param_to_pk_col, fast_path_valid: Arc::new(OnceLock::new()) }
242 }
243
244 pub fn build_pk_values(
246 &self,
247 params: &[vibesql_types::SqlValue],
248 ) -> Vec<vibesql_types::SqlValue> {
249 let mut pk_values = vec![vibesql_types::SqlValue::Null; self.pk_columns.len()];
250 for &(param_idx, pk_col_idx) in &self.param_to_pk_col {
251 if param_idx < params.len() && pk_col_idx < pk_values.len() {
252 pk_values[pk_col_idx] = params[param_idx].clone();
253 }
254 }
255 pk_values
256 }
257
258 pub fn is_fast_path_valid(&self) -> Option<bool> {
260 self.fast_path_valid.get().copied()
261 }
262
263 pub fn set_fast_path_valid(&self, valid: bool) -> bool {
266 *self.fast_path_valid.get_or_init(|| valid)
267 }
268}
269
270pub fn analyze_for_plan(stmt: &Statement) -> CachedPlan {
272 match stmt {
273 Statement::Select(select) => analyze_select(select),
274 Statement::Delete(delete) => analyze_delete(delete),
275 _ => CachedPlan::Standard,
276 }
277}
278
279fn analyze_select(stmt: &SelectStmt) -> CachedPlan {
281 if let Some(plan) = try_analyze_pk_lookup(stmt) {
283 return CachedPlan::PkPointLookup(plan);
284 }
285
286 if crate::select::is_simple_point_query(stmt) {
289 if let Some(table_name) = extract_single_table_name(stmt) {
290 return CachedPlan::SimpleFastPath(SimpleFastPathPlan::new(table_name.to_lowercase()));
291 }
292 }
293
294 CachedPlan::Standard
295}
296
297fn analyze_delete(stmt: &DeleteStmt) -> CachedPlan {
299 let where_clause = match &stmt.where_clause {
301 Some(WhereClause::Condition(expr)) => expr,
302 _ => return CachedPlan::Standard,
303 };
304
305 let param_mappings = match extract_pk_param_mappings(where_clause) {
307 Some(mappings) if !mappings.is_empty() => mappings,
308 _ => return CachedPlan::Standard,
309 };
310
311 let pk_columns: Vec<String> = param_mappings.iter().map(|(_, col)| col.clone()).collect();
313 let param_to_pk_col: Vec<(usize, usize)> = param_mappings
314 .iter()
315 .enumerate()
316 .map(|(pk_idx, (param_idx, _))| (*param_idx, pk_idx))
317 .collect();
318
319 CachedPlan::PkDelete(PkDeletePlan::new(
320 stmt.table_name.to_lowercase(),
321 pk_columns,
322 param_to_pk_col,
323 ))
324}
325
326fn try_analyze_pk_lookup(stmt: &SelectStmt) -> Option<PkPointLookupPlan> {
328 let table_name = match &stmt.from {
330 Some(FromClause::Table { name, alias: None, .. }) => name.clone(),
331 _ => return None,
332 };
333
334 if stmt.with_clause.is_some()
336 || stmt.set_operation.is_some()
337 || stmt.group_by.is_some()
338 || stmt.having.is_some()
339 || stmt.distinct
340 || stmt.order_by.is_some()
341 || stmt.limit.is_some()
342 || stmt.offset.is_some()
343 || stmt.into_table.is_some()
344 || stmt.into_variables.is_some()
345 {
346 return None;
347 }
348
349 let projection = analyze_select_list(&stmt.select_list)?;
351
352 let where_clause = stmt.where_clause.as_ref()?;
354
355 let param_mappings = extract_pk_param_mappings(where_clause)?;
357 if param_mappings.is_empty() {
358 return None;
359 }
360
361 let pk_columns: Vec<String> = param_mappings.iter().map(|(_, col)| col.clone()).collect();
365 let param_to_pk_col: Vec<(usize, usize)> = param_mappings
366 .iter()
367 .enumerate()
368 .map(|(pk_idx, (param_idx, _))| (*param_idx, pk_idx))
369 .collect();
370
371 Some(PkPointLookupPlan {
372 table_name: table_name.to_lowercase(),
373 pk_columns,
374 param_to_pk_col,
375 projection,
376 resolved: Arc::new(OnceLock::new()),
377 })
378}
379
380fn extract_single_table_name(stmt: &SelectStmt) -> Option<String> {
382 match &stmt.from {
383 Some(FromClause::Table { name, .. }) => Some(name.clone()),
384 _ => None,
385 }
386}
387
388fn analyze_select_list(select_list: &[SelectItem]) -> Option<ProjectionPlan> {
390 if select_list.len() == 1 {
391 if let SelectItem::Wildcard { .. } = &select_list[0] {
392 return Some(ProjectionPlan::Wildcard);
393 }
394 }
395
396 let mut columns = Vec::with_capacity(select_list.len());
398 for item in select_list {
399 match item {
400 SelectItem::Wildcard { .. } => {
401 return None;
403 }
404 SelectItem::QualifiedWildcard { .. } => {
405 return None;
407 }
408 SelectItem::Expression { expr, alias, .. } => {
409 let column_name = match expr {
410 Expression::ColumnRef(col_id) => col_id.column_canonical().to_string(),
411 _ => {
412 return None;
414 }
415 };
416 columns.push(ColumnProjection { column_name, alias: alias.clone() });
417 }
418 }
419 }
420
421 Some(ProjectionPlan::Columns(columns))
422}
423
424fn extract_pk_param_mappings(expr: &Expression) -> Option<Vec<(usize, String)>> {
429 let mut mappings = Vec::new();
430 collect_pk_param_mappings(expr, &mut mappings)?;
431
432 mappings.sort_by_key(|(idx, _)| *idx);
434
435 Some(mappings)
436}
437
438fn collect_pk_param_mappings(expr: &Expression, mappings: &mut Vec<(usize, String)>) -> Option<()> {
440 match expr {
441 Expression::BinaryOp { left, op, right } => {
442 match op {
443 vibesql_ast::BinaryOperator::And => {
444 collect_pk_param_mappings(left, mappings)?;
446 collect_pk_param_mappings(right, mappings)?;
447 }
448 vibesql_ast::BinaryOperator::Equal => {
449 if let Some(mapping) = extract_column_placeholder_pair(left, right) {
451 mappings.push(mapping);
452 } else {
453 return None;
455 }
456 }
457 _ => {
458 return None;
460 }
461 }
462 }
463 _ => {
464 return None;
466 }
467 }
468 Some(())
469}
470
471fn extract_column_placeholder_pair(
473 left: &Expression,
474 right: &Expression,
475) -> Option<(usize, String)> {
476 if let Expression::ColumnRef(col_id) = left {
478 if let Expression::Placeholder(idx) = right {
479 return Some((*idx, col_id.column_canonical().to_string()));
480 }
481 }
482
483 if let Expression::Placeholder(idx) = left {
485 if let Expression::ColumnRef(col_id) = right {
486 return Some((*idx, col_id.column_canonical().to_string()));
487 }
488 }
489
490 None
491}
492
493#[cfg(test)]
494mod tests {
495 use vibesql_parser::Parser;
496
497 use super::*;
498
499 fn parse_to_plan(sql: &str) -> CachedPlan {
500 let stmt = Parser::parse_sql(sql).unwrap();
501 analyze_for_plan(&stmt)
502 }
503
504 #[test]
505 fn test_simple_pk_lookup() {
506 let plan = parse_to_plan("SELECT * FROM users WHERE id = ?");
507 match plan {
508 CachedPlan::PkPointLookup(p) => {
509 assert_eq!(p.table_name, "users");
510 assert_eq!(p.pk_columns, vec!["id"]);
512 assert_eq!(p.param_to_pk_col, vec![(0, 0)]);
513 assert!(matches!(p.projection, ProjectionPlan::Wildcard));
514 }
515 _ => panic!("Expected PkPointLookup"),
516 }
517 }
518
519 #[test]
520 fn test_composite_pk_lookup() {
521 let plan = parse_to_plan("SELECT * FROM orders WHERE customer_id = ? AND order_id = ?");
522 match plan {
523 CachedPlan::PkPointLookup(p) => {
524 assert_eq!(p.table_name, "orders");
525 assert_eq!(p.pk_columns, vec!["customer_id", "order_id"]);
527 assert_eq!(p.param_to_pk_col.len(), 2);
528 }
529 _ => panic!("Expected PkPointLookup"),
530 }
531 }
532
533 #[test]
534 fn test_projected_columns() {
535 let plan = parse_to_plan("SELECT name, email FROM users WHERE id = ?");
536 match plan {
537 CachedPlan::PkPointLookup(p) => {
538 match p.projection {
539 ProjectionPlan::Columns(cols) => {
540 assert_eq!(cols.len(), 2);
541 assert_eq!(cols[0].column_name, "name");
543 assert_eq!(cols[1].column_name, "email");
544 }
545 _ => panic!("Expected Columns projection"),
546 }
547 }
548 _ => panic!("Expected PkPointLookup"),
549 }
550 }
551
552 #[test]
553 fn test_not_cacheable_join() {
554 let plan =
555 parse_to_plan("SELECT * FROM users u JOIN orders o ON u.id = o.user_id WHERE u.id = ?");
556 assert!(matches!(plan, CachedPlan::Standard));
557 }
558
559 #[test]
560 fn test_not_cacheable_aggregate() {
561 let plan = parse_to_plan("SELECT COUNT(*) FROM users WHERE id = ?");
562 assert!(matches!(plan, CachedPlan::Standard));
563 }
564
565 #[test]
566 fn test_not_cacheable_order_by() {
567 let plan = parse_to_plan("SELECT * FROM users WHERE id = ? ORDER BY name");
568 assert!(matches!(plan, CachedPlan::Standard));
569 }
570
571 #[test]
572 fn test_not_cacheable_or() {
573 let plan = parse_to_plan("SELECT * FROM users WHERE id = ? OR name = ?");
574 assert!(matches!(plan, CachedPlan::Standard));
575 }
576
577 #[test]
578 fn test_literal_gets_simple_fast_path() {
579 let plan = parse_to_plan("SELECT * FROM users WHERE id = 1");
582 assert!(matches!(plan, CachedPlan::SimpleFastPath(_)));
583 }
584
585 #[test]
586 fn test_delete_pk_lookup() {
587 let plan = parse_to_plan("DELETE FROM sbtest1 WHERE id = ?");
588 match plan {
589 CachedPlan::PkDelete(p) => {
590 assert_eq!(p.table_name, "sbtest1");
591 assert_eq!(p.pk_columns, vec!["id"]);
592 assert_eq!(p.param_to_pk_col, vec![(0, 0)]);
593 }
594 other => panic!("Expected PkDelete, got {:?}", other),
595 }
596 }
597
598 #[test]
599 fn test_delete_without_where_not_fast_path() {
600 let plan = parse_to_plan("DELETE FROM users");
601 assert!(matches!(plan, CachedPlan::Standard));
602 }
603}