vibesql_executor/cache/prepared_statement/
plan.rs1use std::sync::{Arc, OnceLock};
37use vibesql_ast::{DeleteStmt, Expression, FromClause, SelectItem, SelectStmt, Statement, WhereClause};
38
39#[derive(Debug, Clone)]
41pub enum CachedPlan {
42 PkPointLookup(PkPointLookupPlan),
45
46 SimpleFastPath(SimpleFastPathPlan),
50
51 PkDelete(PkDeletePlan),
54
55 Standard,
58}
59
60impl CachedPlan {
61 pub fn is_fast_path(&self) -> bool {
63 !matches!(self, CachedPlan::Standard)
64 }
65}
66
67#[derive(Debug, Clone)]
84pub struct SimpleFastPathPlan {
85 pub table_name: String,
87
88 resolved_columns: Arc<OnceLock<Arc<[String]>>>,
91}
92
93impl SimpleFastPathPlan {
94 pub fn new(table_name: String) -> Self {
96 Self {
97 table_name,
98 resolved_columns: Arc::new(OnceLock::new()),
99 }
100 }
101
102 pub fn get_or_resolve_columns<F>(&self, resolver: F) -> Option<&Arc<[String]>>
107 where
108 F: FnOnce() -> Option<Vec<String>>,
109 {
110 self.resolved_columns.get_or_init(|| {
111 resolver().map(|v| v.into()).unwrap_or_else(|| Arc::from([]))
113 });
114
115 self.resolved_columns.get().filter(|cols| !cols.is_empty())
117 }
118
119 pub fn is_resolved(&self) -> bool {
121 self.resolved_columns.get().is_some()
122 }
123}
124
125#[derive(Debug, Clone)]
127pub struct PkPointLookupPlan {
128 pub table_name: String,
130
131 pub pk_columns: Vec<String>,
133
134 pub param_to_pk_col: Vec<(usize, usize)>,
137
138 pub projection: ProjectionPlan,
140
141 resolved: Arc<OnceLock<ResolvedProjection>>,
144}
145
146#[derive(Debug, Clone)]
150pub struct ResolvedProjection {
151 pub column_indices: Vec<usize>,
153 pub column_names: Arc<[String]>,
155}
156
157impl PkPointLookupPlan {
158 pub fn get_or_resolve<F>(&self, resolver: F) -> Option<&ResolvedProjection>
163 where
164 F: FnOnce(&ProjectionPlan) -> Option<ResolvedProjection>,
165 {
166 self.resolved.get_or_init(|| {
167 resolver(&self.projection).unwrap_or(ResolvedProjection {
171 column_indices: vec![],
172 column_names: Arc::from([]),
173 })
174 });
175
176 self.resolved.get().filter(|r| !r.column_names.is_empty() || matches!(self.projection, ProjectionPlan::Wildcard))
178 }
179
180 pub fn is_resolved(&self) -> bool {
182 self.resolved.get().is_some()
183 }
184}
185
186#[derive(Debug, Clone)]
188pub enum ProjectionPlan {
189 Wildcard,
191
192 Columns(Vec<ColumnProjection>),
194}
195
196#[derive(Debug, Clone)]
198pub struct ColumnProjection {
199 pub column_name: String,
201
202 pub alias: Option<String>,
204}
205
206#[derive(Debug, Clone)]
216pub struct PkDeletePlan {
217 pub table_name: String,
219
220 pub pk_columns: Vec<String>,
222
223 pub param_to_pk_col: Vec<(usize, usize)>,
226
227 fast_path_valid: Arc<OnceLock<bool>>,
230}
231
232impl PkDeletePlan {
233 pub fn new(table_name: String, pk_columns: Vec<String>, param_to_pk_col: Vec<(usize, usize)>) -> Self {
235 Self {
236 table_name,
237 pk_columns,
238 param_to_pk_col,
239 fast_path_valid: Arc::new(OnceLock::new()),
240 }
241 }
242
243 pub fn build_pk_values(&self, params: &[vibesql_types::SqlValue]) -> Vec<vibesql_types::SqlValue> {
245 let mut pk_values = vec![vibesql_types::SqlValue::Null; self.pk_columns.len()];
246 for &(param_idx, pk_col_idx) in &self.param_to_pk_col {
247 if param_idx < params.len() && pk_col_idx < pk_values.len() {
248 pk_values[pk_col_idx] = params[param_idx].clone();
249 }
250 }
251 pk_values
252 }
253
254 pub fn is_fast_path_valid(&self) -> Option<bool> {
256 self.fast_path_valid.get().copied()
257 }
258
259 pub fn set_fast_path_valid(&self, valid: bool) -> bool {
262 *self.fast_path_valid.get_or_init(|| valid)
263 }
264}
265
266pub fn analyze_for_plan(stmt: &Statement) -> CachedPlan {
268 match stmt {
269 Statement::Select(select) => analyze_select(select),
270 Statement::Delete(delete) => analyze_delete(delete),
271 _ => CachedPlan::Standard,
272 }
273}
274
275fn analyze_select(stmt: &SelectStmt) -> CachedPlan {
277 if let Some(plan) = try_analyze_pk_lookup(stmt) {
279 return CachedPlan::PkPointLookup(plan);
280 }
281
282 if crate::select::is_simple_point_query(stmt) {
285 if let Some(table_name) = extract_single_table_name(stmt) {
286 return CachedPlan::SimpleFastPath(SimpleFastPathPlan::new(
287 table_name.to_uppercase(),
288 ));
289 }
290 }
291
292 CachedPlan::Standard
293}
294
295fn analyze_delete(stmt: &DeleteStmt) -> CachedPlan {
297 let where_clause = match &stmt.where_clause {
299 Some(WhereClause::Condition(expr)) => expr,
300 _ => return CachedPlan::Standard,
301 };
302
303 let param_mappings = match extract_pk_param_mappings(where_clause) {
305 Some(mappings) if !mappings.is_empty() => mappings,
306 _ => return CachedPlan::Standard,
307 };
308
309 let pk_columns: Vec<String> = param_mappings.iter().map(|(_, col)| col.clone()).collect();
311 let param_to_pk_col: Vec<(usize, usize)> = param_mappings
312 .iter()
313 .enumerate()
314 .map(|(pk_idx, (param_idx, _))| (*param_idx, pk_idx))
315 .collect();
316
317 CachedPlan::PkDelete(PkDeletePlan::new(
318 stmt.table_name.to_uppercase(),
319 pk_columns,
320 param_to_pk_col,
321 ))
322}
323
324fn try_analyze_pk_lookup(stmt: &SelectStmt) -> Option<PkPointLookupPlan> {
326 let table_name = match &stmt.from {
328 Some(FromClause::Table { name, alias: None, .. }) => name.clone(),
329 _ => return None,
330 };
331
332 if stmt.with_clause.is_some()
334 || stmt.set_operation.is_some()
335 || stmt.group_by.is_some()
336 || stmt.having.is_some()
337 || stmt.distinct
338 || stmt.order_by.is_some()
339 || stmt.limit.is_some()
340 || stmt.offset.is_some()
341 || stmt.into_table.is_some()
342 || stmt.into_variables.is_some()
343 {
344 return None;
345 }
346
347 let projection = analyze_select_list(&stmt.select_list)?;
349
350 let where_clause = stmt.where_clause.as_ref()?;
352
353 let param_mappings = extract_pk_param_mappings(where_clause)?;
355 if param_mappings.is_empty() {
356 return None;
357 }
358
359 let pk_columns: Vec<String> = param_mappings.iter().map(|(_, col)| col.clone()).collect();
363 let param_to_pk_col: Vec<(usize, usize)> = param_mappings
364 .iter()
365 .enumerate()
366 .map(|(pk_idx, (param_idx, _))| (*param_idx, pk_idx))
367 .collect();
368
369 Some(PkPointLookupPlan {
370 table_name: table_name.to_uppercase(),
371 pk_columns,
372 param_to_pk_col,
373 projection,
374 resolved: Arc::new(OnceLock::new()),
375 })
376}
377
378fn extract_single_table_name(stmt: &SelectStmt) -> Option<String> {
380 match &stmt.from {
381 Some(FromClause::Table { name, .. }) => Some(name.clone()),
382 _ => None,
383 }
384}
385
386fn analyze_select_list(select_list: &[SelectItem]) -> Option<ProjectionPlan> {
388 if select_list.len() == 1 {
389 if let SelectItem::Wildcard { .. } = &select_list[0] {
390 return Some(ProjectionPlan::Wildcard);
391 }
392 }
393
394 let mut columns = Vec::with_capacity(select_list.len());
396 for item in select_list {
397 match item {
398 SelectItem::Wildcard { .. } => {
399 return None;
401 }
402 SelectItem::QualifiedWildcard { .. } => {
403 return None;
405 }
406 SelectItem::Expression { expr, alias } => {
407 let column_name = match expr {
408 Expression::ColumnRef { column, table: None } => column.clone(),
409 Expression::ColumnRef { column, table: Some(_) } => {
410 column.clone()
412 }
413 _ => {
414 return None;
416 }
417 };
418 columns.push(ColumnProjection { column_name, alias: alias.clone() });
419 }
420 }
421 }
422
423 Some(ProjectionPlan::Columns(columns))
424}
425
426fn extract_pk_param_mappings(expr: &Expression) -> Option<Vec<(usize, String)>> {
431 let mut mappings = Vec::new();
432 collect_pk_param_mappings(expr, &mut mappings)?;
433
434 mappings.sort_by_key(|(idx, _)| *idx);
436
437 Some(mappings)
438}
439
440fn collect_pk_param_mappings(expr: &Expression, mappings: &mut Vec<(usize, String)>) -> Option<()> {
442 match expr {
443 Expression::BinaryOp { left, op, right } => {
444 match op {
445 vibesql_ast::BinaryOperator::And => {
446 collect_pk_param_mappings(left, mappings)?;
448 collect_pk_param_mappings(right, mappings)?;
449 }
450 vibesql_ast::BinaryOperator::Equal => {
451 if let Some(mapping) = extract_column_placeholder_pair(left, right) {
453 mappings.push(mapping);
454 } else {
455 return None;
457 }
458 }
459 _ => {
460 return None;
462 }
463 }
464 }
465 _ => {
466 return None;
468 }
469 }
470 Some(())
471}
472
473fn extract_column_placeholder_pair(
475 left: &Expression,
476 right: &Expression,
477) -> Option<(usize, String)> {
478 if let Expression::ColumnRef { column, .. } = left {
480 if let Expression::Placeholder(idx) = right {
481 return Some((*idx, column.clone()));
482 }
483 }
484
485 if let Expression::Placeholder(idx) = left {
487 if let Expression::ColumnRef { column, .. } = right {
488 return Some((*idx, column.clone()));
489 }
490 }
491
492 None
493}
494
495#[cfg(test)]
496mod tests {
497 use super::*;
498 use vibesql_parser::Parser;
499
500 fn parse_to_plan(sql: &str) -> CachedPlan {
501 let stmt = Parser::parse_sql(sql).unwrap();
502 analyze_for_plan(&stmt)
503 }
504
505 #[test]
506 fn test_simple_pk_lookup() {
507 let plan = parse_to_plan("SELECT * FROM users WHERE id = ?");
508 match plan {
509 CachedPlan::PkPointLookup(p) => {
510 assert_eq!(p.table_name, "USERS");
511 assert_eq!(p.pk_columns, vec!["ID"]);
513 assert_eq!(p.param_to_pk_col, vec![(0, 0)]);
514 assert!(matches!(p.projection, ProjectionPlan::Wildcard));
515 }
516 _ => panic!("Expected PkPointLookup"),
517 }
518 }
519
520 #[test]
521 fn test_composite_pk_lookup() {
522 let plan = parse_to_plan("SELECT * FROM orders WHERE customer_id = ? AND order_id = ?");
523 match plan {
524 CachedPlan::PkPointLookup(p) => {
525 assert_eq!(p.table_name, "ORDERS");
526 assert_eq!(p.pk_columns, vec!["CUSTOMER_ID", "ORDER_ID"]);
528 assert_eq!(p.param_to_pk_col.len(), 2);
529 }
530 _ => panic!("Expected PkPointLookup"),
531 }
532 }
533
534 #[test]
535 fn test_projected_columns() {
536 let plan = parse_to_plan("SELECT name, email FROM users WHERE id = ?");
537 match plan {
538 CachedPlan::PkPointLookup(p) => {
539 match p.projection {
540 ProjectionPlan::Columns(cols) => {
541 assert_eq!(cols.len(), 2);
542 assert_eq!(cols[0].column_name, "NAME");
544 assert_eq!(cols[1].column_name, "EMAIL");
545 }
546 _ => panic!("Expected Columns projection"),
547 }
548 }
549 _ => panic!("Expected PkPointLookup"),
550 }
551 }
552
553 #[test]
554 fn test_not_cacheable_join() {
555 let plan =
556 parse_to_plan("SELECT * FROM users u JOIN orders o ON u.id = o.user_id WHERE u.id = ?");
557 assert!(matches!(plan, CachedPlan::Standard));
558 }
559
560 #[test]
561 fn test_not_cacheable_aggregate() {
562 let plan = parse_to_plan("SELECT COUNT(*) FROM users WHERE id = ?");
563 assert!(matches!(plan, CachedPlan::Standard));
564 }
565
566 #[test]
567 fn test_not_cacheable_order_by() {
568 let plan = parse_to_plan("SELECT * FROM users WHERE id = ? ORDER BY name");
569 assert!(matches!(plan, CachedPlan::Standard));
570 }
571
572 #[test]
573 fn test_not_cacheable_or() {
574 let plan = parse_to_plan("SELECT * FROM users WHERE id = ? OR name = ?");
575 assert!(matches!(plan, CachedPlan::Standard));
576 }
577
578 #[test]
579 fn test_literal_gets_simple_fast_path() {
580 let plan = parse_to_plan("SELECT * FROM users WHERE id = 1");
583 assert!(matches!(plan, CachedPlan::SimpleFastPath(_)));
584 }
585
586 #[test]
587 fn test_delete_pk_lookup() {
588 let plan = parse_to_plan("DELETE FROM sbtest1 WHERE id = ?");
589 match plan {
590 CachedPlan::PkDelete(p) => {
591 assert_eq!(p.table_name, "SBTEST1");
592 assert_eq!(p.pk_columns, vec!["ID"]);
593 assert_eq!(p.param_to_pk_col, vec![(0, 0)]);
594 }
595 other => panic!("Expected PkDelete, got {:?}", other),
596 }
597 }
598
599 #[test]
600 fn test_delete_without_where_not_fast_path() {
601 let plan = parse_to_plan("DELETE FROM users");
602 assert!(matches!(plan, CachedPlan::Standard));
603 }
604}