vibesql_executor/cache/
parameterized.rs

1//! Parameterized plan support for literal extraction and binding
2//!
3//! Allows queries with different literal values to share the same execution plan
4//! by extracting literals into parameters and binding them at execution time.
5
6use vibesql_ast::{Expression, Statement};
7use vibesql_types::SqlValue;
8
9/// Type-safe literal value representation
10#[derive(Clone, Debug, PartialEq)]
11pub enum LiteralValue {
12    Integer(i64),
13    Smallint(i16),
14    Bigint(i64),
15    Unsigned(u64),
16    Numeric(f64),
17    Float(f32),
18    Real(f32),
19    Double(f64),
20    Character(String),
21    Varchar(String),
22    Boolean(bool),
23    Date(String),
24    Time(String),
25    Timestamp(String),
26    Null,
27}
28
29impl LiteralValue {
30    /// Convert from SqlValue
31    pub fn from_sql_value(value: &SqlValue) -> Self {
32        match value {
33            SqlValue::Integer(n) => LiteralValue::Integer(*n),
34            SqlValue::Smallint(n) => LiteralValue::Smallint(*n),
35            SqlValue::Bigint(n) => LiteralValue::Bigint(*n),
36            SqlValue::Unsigned(n) => LiteralValue::Unsigned(*n),
37            SqlValue::Numeric(n) => LiteralValue::Numeric(*n),
38            SqlValue::Float(n) => LiteralValue::Float(*n),
39            SqlValue::Real(n) => LiteralValue::Real(*n),
40            SqlValue::Double(n) => LiteralValue::Double(*n),
41            SqlValue::Character(s) => LiteralValue::Character(s.clone()),
42            SqlValue::Varchar(s) => LiteralValue::Varchar(s.clone()),
43            SqlValue::Boolean(b) => LiteralValue::Boolean(*b),
44            SqlValue::Date(s) => LiteralValue::Date(s.to_string()),
45            SqlValue::Time(s) => LiteralValue::Time(s.to_string()),
46            SqlValue::Timestamp(s) => LiteralValue::Timestamp(s.to_string()),
47            SqlValue::Interval(s) => LiteralValue::Varchar(s.to_string()), /* Treat interval as */
48            // string for now
49            SqlValue::Vector(v) => {
50                // Treat vector as string representation for now
51                let formatted: Vec<String> = v.iter().map(|f| f.to_string()).collect();
52                LiteralValue::Varchar(format!("[{}]", formatted.join(", ")))
53            }
54            SqlValue::Null => LiteralValue::Null,
55        }
56    }
57
58    /// Convert to SQL string representation
59    pub fn to_sql(&self) -> String {
60        match self {
61            LiteralValue::Integer(n) => n.to_string(),
62            LiteralValue::Smallint(n) => n.to_string(),
63            LiteralValue::Bigint(n) => n.to_string(),
64            LiteralValue::Unsigned(n) => n.to_string(),
65            LiteralValue::Numeric(n) => n.to_string(),
66            LiteralValue::Float(n) => n.to_string(),
67            LiteralValue::Real(n) => n.to_string(),
68            LiteralValue::Double(n) => n.to_string(),
69            LiteralValue::Character(s) | LiteralValue::Varchar(s) => {
70                format!("'{}'", s.replace("'", "''"))
71            }
72            LiteralValue::Boolean(b) => if *b { "true" } else { "false" }.to_string(),
73            LiteralValue::Date(s) => format!("DATE '{}'", s),
74            LiteralValue::Time(s) => format!("TIME '{}'", s),
75            LiteralValue::Timestamp(s) => format!("TIMESTAMP '{}'", s),
76            LiteralValue::Null => "NULL".to_string(),
77        }
78    }
79}
80
81/// Position of a parameter in a query
82#[derive(Clone, Debug)]
83pub struct ParameterPosition {
84    pub position: usize,
85    pub context: String,
86}
87
88/// Query plan with parameter placeholders
89#[derive(Clone, Debug)]
90pub struct ParameterizedPlan {
91    pub normalized_query: String,
92    pub param_positions: Vec<ParameterPosition>,
93    pub literal_values: Vec<LiteralValue>,
94}
95
96impl ParameterizedPlan {
97    /// Create a new parameterized plan
98    pub fn new(
99        normalized_query: String,
100        param_positions: Vec<ParameterPosition>,
101        literal_values: Vec<LiteralValue>,
102    ) -> Self {
103        Self { normalized_query, param_positions, literal_values }
104    }
105
106    /// Bind literal values to create executable query
107    pub fn bind(&self, values: &[LiteralValue]) -> Result<String, String> {
108        if values.len() != self.param_positions.len() {
109            return Err(format!(
110                "Expected {} parameters, got {}",
111                self.param_positions.len(),
112                values.len()
113            ));
114        }
115
116        let mut result = self.normalized_query.clone();
117        let mut offset = 0;
118
119        for (i, value) in values.iter().enumerate() {
120            if let Some(_pos) = self.param_positions.get(i) {
121                let sql_value = value.to_sql();
122                let placeholder = "?";
123
124                if let Some(idx) = result[offset..].find(placeholder) {
125                    let insert_pos = offset + idx;
126                    result.replace_range(insert_pos..insert_pos + 1, &sql_value);
127                    offset = insert_pos + sql_value.len();
128                }
129            }
130        }
131
132        Ok(result)
133    }
134
135    /// Get a comparison key for the plan
136    pub fn comparison_key(&self) -> String {
137        self.normalized_query.clone()
138    }
139}
140
141/// Utility for extracting literals from AST
142pub struct LiteralExtractor;
143
144impl LiteralExtractor {
145    /// Extract literals from a statement AST in order of appearance
146    pub fn extract(stmt: &Statement) -> Vec<LiteralValue> {
147        let mut literals = Vec::new();
148        Self::extract_from_statement(stmt, &mut literals);
149        literals
150    }
151
152    fn extract_from_statement(stmt: &Statement, literals: &mut Vec<LiteralValue>) {
153        match stmt {
154            Statement::Select(select) => Self::extract_from_select(select, literals),
155            Statement::Insert(insert) => {
156                // Extract from VALUES or SELECT source
157                match &insert.source {
158                    vibesql_ast::InsertSource::Values(rows) => {
159                        for row in rows {
160                            for expr in row {
161                                Self::extract_from_expression(expr, literals);
162                            }
163                        }
164                    }
165                    vibesql_ast::InsertSource::Select(select) => {
166                        Self::extract_from_select(select, literals);
167                    }
168                }
169            }
170            Statement::Update(update) => {
171                // Extract from assignments
172                for assignment in &update.assignments {
173                    Self::extract_from_expression(&assignment.value, literals);
174                }
175                // Extract from WHERE clause
176                if let Some(ref where_clause) = update.where_clause {
177                    match where_clause {
178                        vibesql_ast::WhereClause::Condition(expr) => {
179                            Self::extract_from_expression(expr, literals);
180                        }
181                        vibesql_ast::WhereClause::CurrentOf(_) => {
182                            // No literals in positioned update/delete
183                        }
184                    }
185                }
186            }
187            Statement::Delete(delete) => {
188                // Extract from WHERE clause
189                if let Some(ref where_clause) = delete.where_clause {
190                    match where_clause {
191                        vibesql_ast::WhereClause::Condition(expr) => {
192                            Self::extract_from_expression(expr, literals);
193                        }
194                        vibesql_ast::WhereClause::CurrentOf(_) => {
195                            // No literals in positioned update/delete
196                        }
197                    }
198                }
199            }
200            // Other statement types don't have literals we need to extract
201            _ => {}
202        }
203    }
204
205    fn extract_from_select(select: &vibesql_ast::SelectStmt, literals: &mut Vec<LiteralValue>) {
206        // Extract from SELECT items
207        for item in &select.select_list {
208            if let vibesql_ast::SelectItem::Expression { expr, .. } = item {
209                Self::extract_from_expression(expr, literals);
210            }
211        }
212
213        // Extract from FROM clause
214        if let Some(ref from) = select.from {
215            Self::extract_from_from_clause(from, literals);
216        }
217
218        // Extract from WHERE
219        if let Some(ref where_clause) = select.where_clause {
220            Self::extract_from_expression(where_clause, literals);
221        }
222
223        // Extract from GROUP BY
224        if let Some(ref group_by) = select.group_by {
225            Self::extract_from_group_by(group_by, literals);
226        }
227
228        // Extract from HAVING
229        if let Some(ref having) = select.having {
230            Self::extract_from_expression(having, literals);
231        }
232
233        // Extract from ORDER BY
234        if let Some(ref order_by) = select.order_by {
235            for item in order_by {
236                Self::extract_from_expression(&item.expr, literals);
237            }
238        }
239    }
240
241    fn extract_from_from_clause(from: &vibesql_ast::FromClause, literals: &mut Vec<LiteralValue>) {
242        match from {
243            vibesql_ast::FromClause::Join { left, right, condition, .. } => {
244                Self::extract_from_from_clause(left, literals);
245                Self::extract_from_from_clause(right, literals);
246                if let Some(expr) = condition {
247                    Self::extract_from_expression(expr, literals);
248                }
249            }
250            vibesql_ast::FromClause::Subquery { query, .. } => {
251                Self::extract_from_select(query, literals);
252            }
253            vibesql_ast::FromClause::Table { .. } => {
254                // No literals in table references
255            }
256        }
257    }
258
259    fn extract_from_group_by(
260        group_by: &vibesql_ast::GroupByClause,
261        literals: &mut Vec<LiteralValue>,
262    ) {
263        match group_by {
264            vibesql_ast::GroupByClause::Simple(exprs) => {
265                for expr in exprs {
266                    Self::extract_from_expression(expr, literals);
267                }
268            }
269            vibesql_ast::GroupByClause::Rollup(elements)
270            | vibesql_ast::GroupByClause::Cube(elements) => {
271                Self::extract_from_grouping_elements(elements, literals);
272            }
273            vibesql_ast::GroupByClause::GroupingSets(sets) => {
274                Self::extract_from_grouping_sets(sets, literals);
275            }
276            vibesql_ast::GroupByClause::Mixed(items) => {
277                for item in items {
278                    match item {
279                        vibesql_ast::MixedGroupingItem::Simple(expr) => {
280                            Self::extract_from_expression(expr, literals);
281                        }
282                        vibesql_ast::MixedGroupingItem::Rollup(elements)
283                        | vibesql_ast::MixedGroupingItem::Cube(elements) => {
284                            Self::extract_from_grouping_elements(elements, literals);
285                        }
286                        vibesql_ast::MixedGroupingItem::GroupingSets(sets) => {
287                            Self::extract_from_grouping_sets(sets, literals);
288                        }
289                    }
290                }
291            }
292        }
293    }
294
295    fn extract_from_grouping_elements(
296        elements: &[vibesql_ast::GroupingElement],
297        literals: &mut Vec<LiteralValue>,
298    ) {
299        for element in elements {
300            match element {
301                vibesql_ast::GroupingElement::Single(expr) => {
302                    Self::extract_from_expression(expr, literals);
303                }
304                vibesql_ast::GroupingElement::Composite(exprs) => {
305                    for expr in exprs {
306                        Self::extract_from_expression(expr, literals);
307                    }
308                }
309            }
310        }
311    }
312
313    fn extract_from_grouping_sets(
314        sets: &[vibesql_ast::GroupingSet],
315        literals: &mut Vec<LiteralValue>,
316    ) {
317        for set in sets {
318            for expr in &set.columns {
319                Self::extract_from_expression(expr, literals);
320            }
321        }
322    }
323
324    fn extract_from_expression(expr: &Expression, literals: &mut Vec<LiteralValue>) {
325        match expr {
326            Expression::Literal(value) => {
327                literals.push(LiteralValue::from_sql_value(value));
328            }
329
330            Expression::BinaryOp { left, right, .. } => {
331                Self::extract_from_expression(left, literals);
332                Self::extract_from_expression(right, literals);
333            }
334
335            Expression::Conjunction(children) | Expression::Disjunction(children) => {
336                for child in children {
337                    Self::extract_from_expression(child, literals);
338                }
339            }
340
341            Expression::UnaryOp { expr, .. } => {
342                Self::extract_from_expression(expr, literals);
343            }
344
345            Expression::Function { args, .. } => {
346                for arg in args {
347                    Self::extract_from_expression(arg, literals);
348                }
349            }
350
351            Expression::AggregateFunction { args, .. } => {
352                for arg in args {
353                    Self::extract_from_expression(arg, literals);
354                }
355            }
356
357            Expression::IsNull { expr, .. } => {
358                Self::extract_from_expression(expr, literals);
359            }
360
361            Expression::Case { operand, when_clauses, else_result } => {
362                if let Some(ref op) = operand {
363                    Self::extract_from_expression(op, literals);
364                }
365                for when in when_clauses {
366                    for cond in &when.conditions {
367                        Self::extract_from_expression(cond, literals);
368                    }
369                    Self::extract_from_expression(&when.result, literals);
370                }
371                if let Some(ref else_expr) = else_result {
372                    Self::extract_from_expression(else_expr, literals);
373                }
374            }
375
376            Expression::ScalarSubquery(subquery) => {
377                Self::extract_from_select(subquery, literals);
378            }
379
380            Expression::In { expr, subquery, .. } => {
381                Self::extract_from_expression(expr, literals);
382                Self::extract_from_select(subquery, literals);
383            }
384
385            Expression::InList { expr, values, .. } => {
386                Self::extract_from_expression(expr, literals);
387                for val in values {
388                    Self::extract_from_expression(val, literals);
389                }
390            }
391
392            Expression::Between { expr, low, high, .. } => {
393                Self::extract_from_expression(expr, literals);
394                Self::extract_from_expression(low, literals);
395                Self::extract_from_expression(high, literals);
396            }
397
398            Expression::Cast { expr, .. } => {
399                Self::extract_from_expression(expr, literals);
400            }
401
402            Expression::Position { substring, string, .. } => {
403                Self::extract_from_expression(substring, literals);
404                Self::extract_from_expression(string, literals);
405            }
406
407            Expression::Trim { removal_char, string, .. } => {
408                if let Some(ref ch) = removal_char {
409                    Self::extract_from_expression(ch, literals);
410                }
411                Self::extract_from_expression(string, literals);
412            }
413
414            Expression::Extract { expr, .. } => {
415                Self::extract_from_expression(expr, literals);
416            }
417
418            Expression::Like { expr, pattern, .. } => {
419                Self::extract_from_expression(expr, literals);
420                Self::extract_from_expression(pattern, literals);
421            }
422
423            Expression::Exists { subquery, .. } => {
424                Self::extract_from_select(subquery, literals);
425            }
426
427            Expression::QuantifiedComparison { expr, subquery, .. } => {
428                Self::extract_from_expression(expr, literals);
429                Self::extract_from_select(subquery, literals);
430            }
431
432            Expression::WindowFunction { function, over } => {
433                // Extract from function arguments
434                match function {
435                    vibesql_ast::WindowFunctionSpec::Aggregate { args, .. }
436                    | vibesql_ast::WindowFunctionSpec::Ranking { args, .. }
437                    | vibesql_ast::WindowFunctionSpec::Value { args, .. } => {
438                        for arg in args {
439                            Self::extract_from_expression(arg, literals);
440                        }
441                    }
442                }
443
444                // Extract from PARTITION BY
445                if let Some(ref partition_by) = over.partition_by {
446                    for expr in partition_by {
447                        Self::extract_from_expression(expr, literals);
448                    }
449                }
450
451                // Extract from ORDER BY
452                if let Some(ref order_by) = over.order_by {
453                    for item in order_by {
454                        Self::extract_from_expression(&item.expr, literals);
455                    }
456                }
457
458                // Extract from frame bounds
459                if let Some(ref frame) = over.frame {
460                    if let vibesql_ast::FrameBound::Preceding(expr) | vibesql_ast::FrameBound::Following(expr) =
461                        &frame.start
462                    {
463                        Self::extract_from_expression(expr, literals);
464                    }
465                    if let Some(vibesql_ast::FrameBound::Preceding(expr) | vibesql_ast::FrameBound::Following(expr)) = &frame.end {
466                        Self::extract_from_expression(expr, literals);
467                    }
468                }
469            }
470
471            // These expressions don't contain literals
472            Expression::ColumnRef { .. }
473            | Expression::Placeholder(_)
474            | Expression::NumberedPlaceholder(_)
475            | Expression::NamedPlaceholder(_)
476            | Expression::PseudoVariable { .. }
477            | Expression::Wildcard
478            | Expression::CurrentDate
479            | Expression::CurrentTime { .. }
480            | Expression::CurrentTimestamp { .. }
481            | Expression::Interval { .. }
482            | Expression::Default
483            | Expression::DuplicateKeyValue { .. }
484            | Expression::NextValue { .. }
485            | Expression::MatchAgainst { .. }
486            | Expression::SessionVariable { .. } => {}
487        }
488    }
489}
490
491#[cfg(test)]
492mod tests {
493    use super::*;
494
495    #[test]
496    fn test_literal_value_to_string() {
497        assert_eq!(LiteralValue::Integer(42).to_sql(), "42");
498        assert_eq!(LiteralValue::Varchar("hello".to_string()).to_sql(), "'hello'");
499        assert_eq!(LiteralValue::Boolean(true).to_sql(), "true");
500        assert_eq!(LiteralValue::Null.to_sql(), "NULL");
501    }
502
503    #[test]
504    fn test_literal_value_string_escape() {
505        assert_eq!(LiteralValue::Varchar("it's".to_string()).to_sql(), "'it''s'");
506    }
507
508    #[test]
509    fn test_literal_extraction_simple() {
510        use vibesql_ast::{BinaryOperator, Expression, FromClause, SelectItem, SelectStmt, Statement};
511
512        // SELECT col0 FROM tab WHERE col1 > 25 AND col2 = 'John'
513        let stmt = Statement::Select(Box::new(SelectStmt {
514            with_clause: None,
515            distinct: false,
516            select_list: vec![SelectItem::Expression {
517                expr: Expression::ColumnRef { table: None, column: "col0".to_string() },
518                alias: None,
519            }],
520            into_table: None,
521            into_variables: None,            from: Some(FromClause::Table { name: "tab".to_string(), alias: None, column_aliases: None }),
522            where_clause: Some(Expression::BinaryOp {
523                op: BinaryOperator::And,
524                left: Box::new(Expression::BinaryOp {
525                    op: BinaryOperator::GreaterThan,
526                    left: Box::new(Expression::ColumnRef {
527                        table: None,
528                        column: "col1".to_string(),
529                    }),
530                    right: Box::new(Expression::Literal(SqlValue::Integer(25))),
531                }),
532                right: Box::new(Expression::BinaryOp {
533                    op: BinaryOperator::Equal,
534                    left: Box::new(Expression::ColumnRef {
535                        table: None,
536                        column: "col2".to_string(),
537                    }),
538                    right: Box::new(Expression::Literal(SqlValue::Varchar("John".to_string()))),
539                }),
540            }),
541            group_by: None,
542            having: None,
543            order_by: None,
544            limit: None,
545            offset: None,
546            set_operation: None,
547        }));
548
549        let literals = LiteralExtractor::extract(&stmt);
550
551        assert_eq!(literals.len(), 2);
552        assert_eq!(literals[0], LiteralValue::Integer(25));
553        assert_eq!(literals[1], LiteralValue::Varchar("John".to_string()));
554    }
555
556    #[test]
557    fn test_literal_extraction_in_list() {
558        use vibesql_ast::{Expression, FromClause, SelectItem, SelectStmt, Statement};
559
560        // SELECT * FROM tab WHERE id IN (1, 2, 3)
561        let stmt = Statement::Select(Box::new(SelectStmt {
562            with_clause: None,
563            distinct: false,
564            select_list: vec![SelectItem::Wildcard { alias: None }],
565            into_table: None,
566            into_variables: None,            from: Some(FromClause::Table { name: "tab".to_string(), alias: None, column_aliases: None }),
567            where_clause: Some(Expression::InList {
568                expr: Box::new(Expression::ColumnRef { table: None, column: "id".to_string() }),
569                values: vec![
570                    Expression::Literal(SqlValue::Integer(1)),
571                    Expression::Literal(SqlValue::Integer(2)),
572                    Expression::Literal(SqlValue::Integer(3)),
573                ],
574                negated: false,
575            }),
576            group_by: None,
577            having: None,
578            order_by: None,
579            limit: None,
580            offset: None,
581            set_operation: None,
582        }));
583
584        let literals = LiteralExtractor::extract(&stmt);
585
586        assert_eq!(literals.len(), 3);
587        assert_eq!(literals[0], LiteralValue::Integer(1));
588        assert_eq!(literals[1], LiteralValue::Integer(2));
589        assert_eq!(literals[2], LiteralValue::Integer(3));
590    }
591
592    #[test]
593    fn test_parameterized_plan_bind() {
594        let plan = ParameterizedPlan::new(
595            "SELECT * FROM users WHERE age > ?".to_string(),
596            vec![ParameterPosition { position: 40, context: "age".to_string() }],
597            vec![LiteralValue::Integer(25)],
598        );
599
600        let result = plan.bind(&[LiteralValue::Integer(30)]).unwrap();
601        assert_eq!(result, "SELECT * FROM users WHERE age > 30");
602    }
603
604    #[test]
605    fn test_parameterized_plan_bind_string() {
606        let plan = ParameterizedPlan::new(
607            "SELECT * FROM users WHERE name = ?".to_string(),
608            vec![ParameterPosition { position: 40, context: "name".to_string() }],
609            vec![LiteralValue::Varchar("John".to_string())],
610        );
611
612        let result = plan.bind(&[LiteralValue::Varchar("Jane".to_string())]).unwrap();
613        assert_eq!(result, "SELECT * FROM users WHERE name = 'Jane'");
614    }
615
616    #[test]
617    fn test_parameterized_plan_bind_error() {
618        let plan = ParameterizedPlan::new(
619            "SELECT * FROM users WHERE age > ?".to_string(),
620            vec![ParameterPosition { position: 40, context: "age".to_string() }],
621            vec![LiteralValue::Integer(25)],
622        );
623
624        let result = plan.bind(&[LiteralValue::Integer(30), LiteralValue::Integer(40)]);
625        assert!(result.is_err());
626    }
627
628    #[test]
629    fn test_comparison_key() {
630        let plan = ParameterizedPlan::new("SELECT * FROM users".to_string(), vec![], vec![]);
631
632        assert_eq!(plan.comparison_key(), "SELECT * FROM users");
633    }
634}