vibesql_executor/cache/prepared_statement/
plan.rs1use std::sync::{Arc, OnceLock};
37use vibesql_ast::{Expression, FromClause, SelectItem, SelectStmt, Statement};
38
39#[derive(Debug, Clone)]
41pub enum CachedPlan {
42 PkPointLookup(PkPointLookupPlan),
45
46 Standard,
49}
50
51impl CachedPlan {
52 pub fn is_fast_path(&self) -> bool {
54 !matches!(self, CachedPlan::Standard)
55 }
56}
57
58#[derive(Debug, Clone)]
60pub struct PkPointLookupPlan {
61 pub table_name: String,
63
64 pub pk_columns: Vec<String>,
66
67 pub param_to_pk_col: Vec<(usize, usize)>,
70
71 pub projection: ProjectionPlan,
73
74 resolved: Arc<OnceLock<ResolvedProjection>>,
77}
78
79#[derive(Debug, Clone)]
83pub struct ResolvedProjection {
84 pub column_indices: Vec<usize>,
86 pub column_names: Arc<[String]>,
88}
89
90impl PkPointLookupPlan {
91 pub fn get_or_resolve<F>(&self, resolver: F) -> Option<&ResolvedProjection>
96 where
97 F: FnOnce(&ProjectionPlan) -> Option<ResolvedProjection>,
98 {
99 self.resolved.get_or_init(|| {
100 resolver(&self.projection).unwrap_or(ResolvedProjection {
104 column_indices: vec![],
105 column_names: Arc::from([]),
106 })
107 });
108
109 self.resolved.get().filter(|r| !r.column_names.is_empty() || matches!(self.projection, ProjectionPlan::Wildcard))
111 }
112
113 pub fn is_resolved(&self) -> bool {
115 self.resolved.get().is_some()
116 }
117}
118
119#[derive(Debug, Clone)]
121pub enum ProjectionPlan {
122 Wildcard,
124
125 Columns(Vec<ColumnProjection>),
127}
128
129#[derive(Debug, Clone)]
131pub struct ColumnProjection {
132 pub column_name: String,
134
135 pub alias: Option<String>,
137}
138
139pub fn analyze_for_plan(stmt: &Statement) -> CachedPlan {
141 match stmt {
142 Statement::Select(select) => analyze_select(select),
143 _ => CachedPlan::Standard,
144 }
145}
146
147fn analyze_select(stmt: &SelectStmt) -> CachedPlan {
149 let table_name = match &stmt.from {
151 Some(FromClause::Table { name, alias: None, .. }) => name.clone(),
152 Some(FromClause::Table { name: _, alias: Some(_), .. }) => {
153 return CachedPlan::Standard;
155 }
156 _ => return CachedPlan::Standard,
157 };
158
159 if stmt.with_clause.is_some()
161 || stmt.set_operation.is_some()
162 || stmt.group_by.is_some()
163 || stmt.having.is_some()
164 || stmt.distinct
165 || stmt.order_by.is_some()
166 || stmt.limit.is_some()
167 || stmt.offset.is_some()
168 || stmt.into_table.is_some()
169 || stmt.into_variables.is_some()
170 {
171 return CachedPlan::Standard;
172 }
173
174 let projection = match analyze_select_list(&stmt.select_list) {
176 Some(p) => p,
177 None => return CachedPlan::Standard,
178 };
179
180 let where_clause = match &stmt.where_clause {
182 Some(w) => w,
183 None => return CachedPlan::Standard,
184 };
185
186 let param_mappings = match extract_pk_param_mappings(where_clause) {
188 Some(m) if !m.is_empty() => m,
189 _ => return CachedPlan::Standard,
190 };
191
192 let pk_columns: Vec<String> = param_mappings.iter().map(|(_, col)| col.clone()).collect();
196 let param_to_pk_col: Vec<(usize, usize)> = param_mappings
197 .iter()
198 .enumerate()
199 .map(|(pk_idx, (param_idx, _))| (*param_idx, pk_idx))
200 .collect();
201
202 CachedPlan::PkPointLookup(PkPointLookupPlan {
203 table_name: table_name.to_uppercase(),
204 pk_columns,
205 param_to_pk_col,
206 projection,
207 resolved: Arc::new(OnceLock::new()),
208 })
209}
210
211fn analyze_select_list(select_list: &[SelectItem]) -> Option<ProjectionPlan> {
213 if select_list.len() == 1 {
214 if let SelectItem::Wildcard { .. } = &select_list[0] {
215 return Some(ProjectionPlan::Wildcard);
216 }
217 }
218
219 let mut columns = Vec::with_capacity(select_list.len());
221 for item in select_list {
222 match item {
223 SelectItem::Wildcard { .. } => {
224 return None;
226 }
227 SelectItem::QualifiedWildcard { .. } => {
228 return None;
230 }
231 SelectItem::Expression { expr, alias } => {
232 let column_name = match expr {
233 Expression::ColumnRef { column, table: None } => column.clone(),
234 Expression::ColumnRef { column, table: Some(_) } => {
235 column.clone()
237 }
238 _ => {
239 return None;
241 }
242 };
243 columns.push(ColumnProjection { column_name, alias: alias.clone() });
244 }
245 }
246 }
247
248 Some(ProjectionPlan::Columns(columns))
249}
250
251fn extract_pk_param_mappings(expr: &Expression) -> Option<Vec<(usize, String)>> {
256 let mut mappings = Vec::new();
257 collect_pk_param_mappings(expr, &mut mappings)?;
258
259 mappings.sort_by_key(|(idx, _)| *idx);
261
262 Some(mappings)
263}
264
265fn collect_pk_param_mappings(expr: &Expression, mappings: &mut Vec<(usize, String)>) -> Option<()> {
267 match expr {
268 Expression::BinaryOp { left, op, right } => {
269 match op {
270 vibesql_ast::BinaryOperator::And => {
271 collect_pk_param_mappings(left, mappings)?;
273 collect_pk_param_mappings(right, mappings)?;
274 }
275 vibesql_ast::BinaryOperator::Equal => {
276 if let Some(mapping) = extract_column_placeholder_pair(left, right) {
278 mappings.push(mapping);
279 } else {
280 return None;
282 }
283 }
284 _ => {
285 return None;
287 }
288 }
289 }
290 _ => {
291 return None;
293 }
294 }
295 Some(())
296}
297
298fn extract_column_placeholder_pair(
300 left: &Expression,
301 right: &Expression,
302) -> Option<(usize, String)> {
303 if let Expression::ColumnRef { column, .. } = left {
305 if let Expression::Placeholder(idx) = right {
306 return Some((*idx, column.clone()));
307 }
308 }
309
310 if let Expression::Placeholder(idx) = left {
312 if let Expression::ColumnRef { column, .. } = right {
313 return Some((*idx, column.clone()));
314 }
315 }
316
317 None
318}
319
320#[cfg(test)]
321mod tests {
322 use super::*;
323 use vibesql_parser::Parser;
324
325 fn parse_to_plan(sql: &str) -> CachedPlan {
326 let stmt = Parser::parse_sql(sql).unwrap();
327 analyze_for_plan(&stmt)
328 }
329
330 #[test]
331 fn test_simple_pk_lookup() {
332 let plan = parse_to_plan("SELECT * FROM users WHERE id = ?");
333 match plan {
334 CachedPlan::PkPointLookup(p) => {
335 assert_eq!(p.table_name, "USERS");
336 assert_eq!(p.pk_columns, vec!["ID"]);
338 assert_eq!(p.param_to_pk_col, vec![(0, 0)]);
339 assert!(matches!(p.projection, ProjectionPlan::Wildcard));
340 }
341 _ => panic!("Expected PkPointLookup"),
342 }
343 }
344
345 #[test]
346 fn test_composite_pk_lookup() {
347 let plan = parse_to_plan("SELECT * FROM orders WHERE customer_id = ? AND order_id = ?");
348 match plan {
349 CachedPlan::PkPointLookup(p) => {
350 assert_eq!(p.table_name, "ORDERS");
351 assert_eq!(p.pk_columns, vec!["CUSTOMER_ID", "ORDER_ID"]);
353 assert_eq!(p.param_to_pk_col.len(), 2);
354 }
355 _ => panic!("Expected PkPointLookup"),
356 }
357 }
358
359 #[test]
360 fn test_projected_columns() {
361 let plan = parse_to_plan("SELECT name, email FROM users WHERE id = ?");
362 match plan {
363 CachedPlan::PkPointLookup(p) => {
364 match p.projection {
365 ProjectionPlan::Columns(cols) => {
366 assert_eq!(cols.len(), 2);
367 assert_eq!(cols[0].column_name, "NAME");
369 assert_eq!(cols[1].column_name, "EMAIL");
370 }
371 _ => panic!("Expected Columns projection"),
372 }
373 }
374 _ => panic!("Expected PkPointLookup"),
375 }
376 }
377
378 #[test]
379 fn test_not_cacheable_join() {
380 let plan =
381 parse_to_plan("SELECT * FROM users u JOIN orders o ON u.id = o.user_id WHERE u.id = ?");
382 assert!(matches!(plan, CachedPlan::Standard));
383 }
384
385 #[test]
386 fn test_not_cacheable_aggregate() {
387 let plan = parse_to_plan("SELECT COUNT(*) FROM users WHERE id = ?");
388 assert!(matches!(plan, CachedPlan::Standard));
389 }
390
391 #[test]
392 fn test_not_cacheable_order_by() {
393 let plan = parse_to_plan("SELECT * FROM users WHERE id = ? ORDER BY name");
394 assert!(matches!(plan, CachedPlan::Standard));
395 }
396
397 #[test]
398 fn test_not_cacheable_or() {
399 let plan = parse_to_plan("SELECT * FROM users WHERE id = ? OR name = ?");
400 assert!(matches!(plan, CachedPlan::Standard));
401 }
402
403 #[test]
404 fn test_not_cacheable_literal() {
405 let plan = parse_to_plan("SELECT * FROM users WHERE id = 1");
406 assert!(matches!(plan, CachedPlan::Standard));
407 }
408}