vibesql_executor/cache/prepared_statement/
plan.rs1use vibesql_ast::{Expression, FromClause, SelectItem, SelectStmt, Statement};
37
38#[derive(Debug, Clone)]
40pub enum CachedPlan {
41 PkPointLookup(PkPointLookupPlan),
44
45 Standard,
48}
49
50impl CachedPlan {
51 pub fn is_fast_path(&self) -> bool {
53 !matches!(self, CachedPlan::Standard)
54 }
55}
56
57#[derive(Debug, Clone)]
59pub struct PkPointLookupPlan {
60 pub table_name: String,
62
63 pub pk_columns: Vec<String>,
65
66 pub param_to_pk_col: Vec<(usize, usize)>,
69
70 pub projection: ProjectionPlan,
72}
73
74#[derive(Debug, Clone)]
76pub enum ProjectionPlan {
77 Wildcard,
79
80 Columns(Vec<ColumnProjection>),
82}
83
84#[derive(Debug, Clone)]
86pub struct ColumnProjection {
87 pub column_name: String,
89
90 pub alias: Option<String>,
92}
93
94pub fn analyze_for_plan(stmt: &Statement) -> CachedPlan {
96 match stmt {
97 Statement::Select(select) => analyze_select(select),
98 _ => CachedPlan::Standard,
99 }
100}
101
102fn analyze_select(stmt: &SelectStmt) -> CachedPlan {
104 let table_name = match &stmt.from {
106 Some(FromClause::Table { name, alias: None, .. }) => name.clone(),
107 Some(FromClause::Table { name: _, alias: Some(_), .. }) => {
108 return CachedPlan::Standard;
110 }
111 _ => return CachedPlan::Standard,
112 };
113
114 if stmt.with_clause.is_some()
116 || stmt.set_operation.is_some()
117 || stmt.group_by.is_some()
118 || stmt.having.is_some()
119 || stmt.distinct
120 || stmt.order_by.is_some()
121 || stmt.limit.is_some()
122 || stmt.offset.is_some()
123 || stmt.into_table.is_some()
124 || stmt.into_variables.is_some()
125 {
126 return CachedPlan::Standard;
127 }
128
129 let projection = match analyze_select_list(&stmt.select_list) {
131 Some(p) => p,
132 None => return CachedPlan::Standard,
133 };
134
135 let where_clause = match &stmt.where_clause {
137 Some(w) => w,
138 None => return CachedPlan::Standard,
139 };
140
141 let param_mappings = match extract_pk_param_mappings(where_clause) {
143 Some(m) if !m.is_empty() => m,
144 _ => return CachedPlan::Standard,
145 };
146
147 let pk_columns: Vec<String> = param_mappings.iter().map(|(_, col)| col.clone()).collect();
151 let param_to_pk_col: Vec<(usize, usize)> = param_mappings
152 .iter()
153 .enumerate()
154 .map(|(pk_idx, (param_idx, _))| (*param_idx, pk_idx))
155 .collect();
156
157 CachedPlan::PkPointLookup(PkPointLookupPlan {
158 table_name: table_name.to_uppercase(),
159 pk_columns,
160 param_to_pk_col,
161 projection,
162 })
163}
164
165fn analyze_select_list(select_list: &[SelectItem]) -> Option<ProjectionPlan> {
167 if select_list.len() == 1 {
168 if let SelectItem::Wildcard { .. } = &select_list[0] {
169 return Some(ProjectionPlan::Wildcard);
170 }
171 }
172
173 let mut columns = Vec::with_capacity(select_list.len());
175 for item in select_list {
176 match item {
177 SelectItem::Wildcard { .. } => {
178 return None;
180 }
181 SelectItem::QualifiedWildcard { .. } => {
182 return None;
184 }
185 SelectItem::Expression { expr, alias } => {
186 let column_name = match expr {
187 Expression::ColumnRef { column, table: None } => column.clone(),
188 Expression::ColumnRef { column, table: Some(_) } => {
189 column.clone()
191 }
192 _ => {
193 return None;
195 }
196 };
197 columns.push(ColumnProjection {
198 column_name,
199 alias: alias.clone(),
200 });
201 }
202 }
203 }
204
205 Some(ProjectionPlan::Columns(columns))
206}
207
208fn extract_pk_param_mappings(expr: &Expression) -> Option<Vec<(usize, String)>> {
213 let mut mappings = Vec::new();
214 collect_pk_param_mappings(expr, &mut mappings)?;
215
216 mappings.sort_by_key(|(idx, _)| *idx);
218
219 Some(mappings)
220}
221
222fn collect_pk_param_mappings(
224 expr: &Expression,
225 mappings: &mut Vec<(usize, String)>,
226) -> Option<()> {
227 match expr {
228 Expression::BinaryOp { left, op, right } => {
229 match op {
230 vibesql_ast::BinaryOperator::And => {
231 collect_pk_param_mappings(left, mappings)?;
233 collect_pk_param_mappings(right, mappings)?;
234 }
235 vibesql_ast::BinaryOperator::Equal => {
236 if let Some(mapping) = extract_column_placeholder_pair(left, right) {
238 mappings.push(mapping);
239 } else {
240 return None;
242 }
243 }
244 _ => {
245 return None;
247 }
248 }
249 }
250 _ => {
251 return None;
253 }
254 }
255 Some(())
256}
257
258fn extract_column_placeholder_pair(left: &Expression, right: &Expression) -> Option<(usize, String)> {
260 if let Expression::ColumnRef { column, .. } = left {
262 if let Expression::Placeholder(idx) = right {
263 return Some((*idx, column.clone()));
264 }
265 }
266
267 if let Expression::Placeholder(idx) = left {
269 if let Expression::ColumnRef { column, .. } = right {
270 return Some((*idx, column.clone()));
271 }
272 }
273
274 None
275}
276
277#[cfg(test)]
278mod tests {
279 use super::*;
280 use vibesql_parser::Parser;
281
282 fn parse_to_plan(sql: &str) -> CachedPlan {
283 let stmt = Parser::parse_sql(sql).unwrap();
284 analyze_for_plan(&stmt)
285 }
286
287 #[test]
288 fn test_simple_pk_lookup() {
289 let plan = parse_to_plan("SELECT * FROM users WHERE id = ?");
290 match plan {
291 CachedPlan::PkPointLookup(p) => {
292 assert_eq!(p.table_name, "USERS");
293 assert_eq!(p.pk_columns, vec!["ID"]);
295 assert_eq!(p.param_to_pk_col, vec![(0, 0)]);
296 assert!(matches!(p.projection, ProjectionPlan::Wildcard));
297 }
298 _ => panic!("Expected PkPointLookup"),
299 }
300 }
301
302 #[test]
303 fn test_composite_pk_lookup() {
304 let plan = parse_to_plan("SELECT * FROM orders WHERE customer_id = ? AND order_id = ?");
305 match plan {
306 CachedPlan::PkPointLookup(p) => {
307 assert_eq!(p.table_name, "ORDERS");
308 assert_eq!(p.pk_columns, vec!["CUSTOMER_ID", "ORDER_ID"]);
310 assert_eq!(p.param_to_pk_col.len(), 2);
311 }
312 _ => panic!("Expected PkPointLookup"),
313 }
314 }
315
316 #[test]
317 fn test_projected_columns() {
318 let plan = parse_to_plan("SELECT name, email FROM users WHERE id = ?");
319 match plan {
320 CachedPlan::PkPointLookup(p) => {
321 match p.projection {
322 ProjectionPlan::Columns(cols) => {
323 assert_eq!(cols.len(), 2);
324 assert_eq!(cols[0].column_name, "NAME");
326 assert_eq!(cols[1].column_name, "EMAIL");
327 }
328 _ => panic!("Expected Columns projection"),
329 }
330 }
331 _ => panic!("Expected PkPointLookup"),
332 }
333 }
334
335 #[test]
336 fn test_not_cacheable_join() {
337 let plan = parse_to_plan("SELECT * FROM users u JOIN orders o ON u.id = o.user_id WHERE u.id = ?");
338 assert!(matches!(plan, CachedPlan::Standard));
339 }
340
341 #[test]
342 fn test_not_cacheable_aggregate() {
343 let plan = parse_to_plan("SELECT COUNT(*) FROM users WHERE id = ?");
344 assert!(matches!(plan, CachedPlan::Standard));
345 }
346
347 #[test]
348 fn test_not_cacheable_order_by() {
349 let plan = parse_to_plan("SELECT * FROM users WHERE id = ? ORDER BY name");
350 assert!(matches!(plan, CachedPlan::Standard));
351 }
352
353 #[test]
354 fn test_not_cacheable_or() {
355 let plan = parse_to_plan("SELECT * FROM users WHERE id = ? OR name = ?");
356 assert!(matches!(plan, CachedPlan::Standard));
357 }
358
359 #[test]
360 fn test_not_cacheable_literal() {
361 let plan = parse_to_plan("SELECT * FROM users WHERE id = 1");
362 assert!(matches!(plan, CachedPlan::Standard));
363 }
364}