vibesql_executor/cache/
parameterized.rs

1//! Parameterized plan support for literal extraction and binding
2//!
3//! Allows queries with different literal values to share the same execution plan
4//! by extracting literals into parameters and binding them at execution time.
5
6use vibesql_ast::{Expression, Statement};
7use vibesql_types::SqlValue;
8
9/// Type-safe literal value representation
10#[derive(Clone, Debug, PartialEq)]
11pub enum LiteralValue {
12    Integer(i64),
13    Smallint(i16),
14    Bigint(i64),
15    Unsigned(u64),
16    Numeric(f64),
17    Float(f32),
18    Real(f64),  // SQLite REAL is 8-byte IEEE float
19    Double(f64),
20    Character(String),
21    Varchar(String),
22    Boolean(bool),
23    Date(String),
24    Time(String),
25    Timestamp(String),
26    Blob(Vec<u8>),
27    Null,
28}
29
30impl LiteralValue {
31    /// Convert from SqlValue
32    pub fn from_sql_value(value: &SqlValue) -> Self {
33        match value {
34            SqlValue::Integer(n) => LiteralValue::Integer(*n),
35            SqlValue::Smallint(n) => LiteralValue::Smallint(*n),
36            SqlValue::Bigint(n) => LiteralValue::Bigint(*n),
37            SqlValue::Unsigned(n) => LiteralValue::Unsigned(*n),
38            SqlValue::Numeric(n) => LiteralValue::Numeric(*n),
39            SqlValue::Float(n) => LiteralValue::Float(*n),
40            SqlValue::Real(n) => LiteralValue::Real(*n),
41            SqlValue::Double(n) => LiteralValue::Double(*n),
42            SqlValue::Character(s) => LiteralValue::Character(s.to_string()),
43            SqlValue::Varchar(s) => LiteralValue::Varchar(s.to_string()),
44            SqlValue::Boolean(b) => LiteralValue::Boolean(*b),
45            SqlValue::Date(s) => LiteralValue::Date(s.to_string()),
46            SqlValue::Time(s) => LiteralValue::Time(s.to_string()),
47            SqlValue::Timestamp(s) => LiteralValue::Timestamp(s.to_string()),
48            SqlValue::Interval(s) => LiteralValue::Varchar(s.to_string()), /* Treat interval as */
49            // string for now
50            SqlValue::Vector(v) => {
51                // Treat vector as string representation for now
52                let formatted: Vec<String> = v.iter().map(|f| f.to_string()).collect();
53                LiteralValue::Varchar(format!("[{}]", formatted.join(", ")))
54            }
55            SqlValue::Blob(b) => LiteralValue::Blob(b.clone()),
56            SqlValue::Null => LiteralValue::Null,
57        }
58    }
59
60    /// Convert to SQL string representation
61    pub fn to_sql(&self) -> String {
62        match self {
63            LiteralValue::Integer(n) => n.to_string(),
64            LiteralValue::Smallint(n) => n.to_string(),
65            LiteralValue::Bigint(n) => n.to_string(),
66            LiteralValue::Unsigned(n) => n.to_string(),
67            LiteralValue::Numeric(n) => n.to_string(),
68            LiteralValue::Float(n) => n.to_string(),
69            LiteralValue::Real(n) => n.to_string(),
70            LiteralValue::Double(n) => n.to_string(),
71            LiteralValue::Character(s) | LiteralValue::Varchar(s) => {
72                format!("'{}'", s.replace("'", "''"))
73            }
74            LiteralValue::Boolean(b) => if *b { "true" } else { "false" }.to_string(),
75            LiteralValue::Date(s) => format!("DATE '{}'", s),
76            LiteralValue::Time(s) => format!("TIME '{}'", s),
77            LiteralValue::Timestamp(s) => format!("TIMESTAMP '{}'", s),
78            LiteralValue::Blob(b) => {
79                let hex: String = b.iter().map(|byte| format!("{:02X}", byte)).collect();
80                format!("x'{}'", hex)
81            }
82            LiteralValue::Null => "NULL".to_string(),
83        }
84    }
85}
86
87/// Position of a parameter in a query
88#[derive(Clone, Debug)]
89pub struct ParameterPosition {
90    pub position: usize,
91    pub context: String,
92}
93
94/// Query plan with parameter placeholders
95#[derive(Clone, Debug)]
96pub struct ParameterizedPlan {
97    pub normalized_query: String,
98    pub param_positions: Vec<ParameterPosition>,
99    pub literal_values: Vec<LiteralValue>,
100}
101
102impl ParameterizedPlan {
103    /// Create a new parameterized plan
104    pub fn new(
105        normalized_query: String,
106        param_positions: Vec<ParameterPosition>,
107        literal_values: Vec<LiteralValue>,
108    ) -> Self {
109        Self { normalized_query, param_positions, literal_values }
110    }
111
112    /// Bind literal values to create executable query
113    pub fn bind(&self, values: &[LiteralValue]) -> Result<String, String> {
114        if values.len() != self.param_positions.len() {
115            return Err(format!(
116                "Expected {} parameters, got {}",
117                self.param_positions.len(),
118                values.len()
119            ));
120        }
121
122        let mut result = self.normalized_query.clone();
123        let mut offset = 0;
124
125        for (i, value) in values.iter().enumerate() {
126            if let Some(_pos) = self.param_positions.get(i) {
127                let sql_value = value.to_sql();
128                let placeholder = "?";
129
130                if let Some(idx) = result[offset..].find(placeholder) {
131                    let insert_pos = offset + idx;
132                    result.replace_range(insert_pos..insert_pos + 1, &sql_value);
133                    offset = insert_pos + sql_value.len();
134                }
135            }
136        }
137
138        Ok(result)
139    }
140
141    /// Get a comparison key for the plan
142    pub fn comparison_key(&self) -> String {
143        self.normalized_query.clone()
144    }
145}
146
147/// Utility for extracting literals from AST
148pub struct LiteralExtractor;
149
150impl LiteralExtractor {
151    /// Extract literals from a statement AST in order of appearance
152    pub fn extract(stmt: &Statement) -> Vec<LiteralValue> {
153        let mut literals = Vec::new();
154        Self::extract_from_statement(stmt, &mut literals);
155        literals
156    }
157
158    fn extract_from_statement(stmt: &Statement, literals: &mut Vec<LiteralValue>) {
159        match stmt {
160            Statement::Select(select) => Self::extract_from_select(select, literals),
161            Statement::Insert(insert) => {
162                // Extract from VALUES or SELECT source
163                match &insert.source {
164                    vibesql_ast::InsertSource::Values(rows) => {
165                        for row in rows {
166                            for expr in row {
167                                Self::extract_from_expression(expr, literals);
168                            }
169                        }
170                    }
171                    vibesql_ast::InsertSource::Select(select) => {
172                        Self::extract_from_select(select, literals);
173                    }
174                    vibesql_ast::InsertSource::DefaultValues => {
175                        // No literals to extract for DEFAULT VALUES
176                    }
177                }
178            }
179            Statement::Update(update) => {
180                // Extract from assignments
181                for assignment in &update.assignments {
182                    Self::extract_from_expression(&assignment.value, literals);
183                }
184                // Extract from WHERE clause
185                if let Some(ref where_clause) = update.where_clause {
186                    match where_clause {
187                        vibesql_ast::WhereClause::Condition(expr) => {
188                            Self::extract_from_expression(expr, literals);
189                        }
190                        vibesql_ast::WhereClause::CurrentOf(_) => {
191                            // No literals in positioned update/delete
192                        }
193                    }
194                }
195            }
196            Statement::Delete(delete) => {
197                // Extract from WHERE clause
198                if let Some(ref where_clause) = delete.where_clause {
199                    match where_clause {
200                        vibesql_ast::WhereClause::Condition(expr) => {
201                            Self::extract_from_expression(expr, literals);
202                        }
203                        vibesql_ast::WhereClause::CurrentOf(_) => {
204                            // No literals in positioned update/delete
205                        }
206                    }
207                }
208            }
209            // Other statement types don't have literals we need to extract
210            _ => {}
211        }
212    }
213
214    fn extract_from_select(select: &vibesql_ast::SelectStmt, literals: &mut Vec<LiteralValue>) {
215        // Extract from SELECT items
216        for item in &select.select_list {
217            if let vibesql_ast::SelectItem::Expression { expr, .. } = item {
218                Self::extract_from_expression(expr, literals);
219            }
220        }
221
222        // Extract from FROM clause
223        if let Some(ref from) = select.from {
224            Self::extract_from_from_clause(from, literals);
225        }
226
227        // Extract from WHERE
228        if let Some(ref where_clause) = select.where_clause {
229            Self::extract_from_expression(where_clause, literals);
230        }
231
232        // Extract from GROUP BY
233        if let Some(ref group_by) = select.group_by {
234            Self::extract_from_group_by(group_by, literals);
235        }
236
237        // Extract from HAVING
238        if let Some(ref having) = select.having {
239            Self::extract_from_expression(having, literals);
240        }
241
242        // Extract from ORDER BY
243        if let Some(ref order_by) = select.order_by {
244            for item in order_by {
245                Self::extract_from_expression(&item.expr, literals);
246            }
247        }
248    }
249
250    fn extract_from_from_clause(from: &vibesql_ast::FromClause, literals: &mut Vec<LiteralValue>) {
251        match from {
252            vibesql_ast::FromClause::Join { left, right, condition, .. } => {
253                Self::extract_from_from_clause(left, literals);
254                Self::extract_from_from_clause(right, literals);
255                if let Some(expr) = condition {
256                    Self::extract_from_expression(expr, literals);
257                }
258            }
259            vibesql_ast::FromClause::Subquery { query, .. } => {
260                Self::extract_from_select(query, literals);
261            }
262            vibesql_ast::FromClause::Table { .. } => {
263                // No literals in table references
264            }
265            vibesql_ast::FromClause::Values { rows, .. } => {
266                // Extract literals from VALUES expressions
267                for row in rows {
268                    for expr in row {
269                        Self::extract_from_expression(expr, literals);
270                    }
271                }
272            }
273        }
274    }
275
276    fn extract_from_group_by(
277        group_by: &vibesql_ast::GroupByClause,
278        literals: &mut Vec<LiteralValue>,
279    ) {
280        match group_by {
281            vibesql_ast::GroupByClause::Simple(exprs) => {
282                for expr in exprs {
283                    Self::extract_from_expression(expr, literals);
284                }
285            }
286            vibesql_ast::GroupByClause::Rollup(elements)
287            | vibesql_ast::GroupByClause::Cube(elements) => {
288                Self::extract_from_grouping_elements(elements, literals);
289            }
290            vibesql_ast::GroupByClause::GroupingSets(sets) => {
291                Self::extract_from_grouping_sets(sets, literals);
292            }
293            vibesql_ast::GroupByClause::Mixed(items) => {
294                for item in items {
295                    match item {
296                        vibesql_ast::MixedGroupingItem::Simple(expr) => {
297                            Self::extract_from_expression(expr, literals);
298                        }
299                        vibesql_ast::MixedGroupingItem::Rollup(elements)
300                        | vibesql_ast::MixedGroupingItem::Cube(elements) => {
301                            Self::extract_from_grouping_elements(elements, literals);
302                        }
303                        vibesql_ast::MixedGroupingItem::GroupingSets(sets) => {
304                            Self::extract_from_grouping_sets(sets, literals);
305                        }
306                    }
307                }
308            }
309        }
310    }
311
312    fn extract_from_grouping_elements(
313        elements: &[vibesql_ast::GroupingElement],
314        literals: &mut Vec<LiteralValue>,
315    ) {
316        for element in elements {
317            match element {
318                vibesql_ast::GroupingElement::Single(expr) => {
319                    Self::extract_from_expression(expr, literals);
320                }
321                vibesql_ast::GroupingElement::Composite(exprs) => {
322                    for expr in exprs {
323                        Self::extract_from_expression(expr, literals);
324                    }
325                }
326            }
327        }
328    }
329
330    fn extract_from_grouping_sets(
331        sets: &[vibesql_ast::GroupingSet],
332        literals: &mut Vec<LiteralValue>,
333    ) {
334        for set in sets {
335            for expr in &set.columns {
336                Self::extract_from_expression(expr, literals);
337            }
338        }
339    }
340
341    fn extract_from_expression(expr: &Expression, literals: &mut Vec<LiteralValue>) {
342        match expr {
343            Expression::Literal(value) => {
344                literals.push(LiteralValue::from_sql_value(value));
345            }
346
347            Expression::BinaryOp { left, right, .. } => {
348                Self::extract_from_expression(left, literals);
349                Self::extract_from_expression(right, literals);
350            }
351
352            Expression::Conjunction(children)
353            | Expression::Disjunction(children)
354            | Expression::RowValueConstructor(children) => {
355                for child in children {
356                    Self::extract_from_expression(child, literals);
357                }
358            }
359
360            Expression::Collate { expr, .. } => {
361                Self::extract_from_expression(expr, literals);
362            }
363
364            Expression::UnaryOp { expr, .. } => {
365                Self::extract_from_expression(expr, literals);
366            }
367
368            Expression::Function { args, .. } => {
369                for arg in args {
370                    Self::extract_from_expression(arg, literals);
371                }
372            }
373
374            Expression::AggregateFunction { args, .. } => {
375                for arg in args {
376                    Self::extract_from_expression(arg, literals);
377                }
378            }
379
380            Expression::IsNull { expr, .. } => {
381                Self::extract_from_expression(expr, literals);
382            }
383
384            Expression::IsDistinctFrom { left, right, .. } => {
385                Self::extract_from_expression(left, literals);
386                Self::extract_from_expression(right, literals);
387            }
388
389            Expression::IsTruthValue { expr, .. } => {
390                Self::extract_from_expression(expr, literals);
391            }
392
393            Expression::Case { operand, when_clauses, else_result } => {
394                if let Some(ref op) = operand {
395                    Self::extract_from_expression(op, literals);
396                }
397                for when in when_clauses {
398                    for cond in &when.conditions {
399                        Self::extract_from_expression(cond, literals);
400                    }
401                    Self::extract_from_expression(&when.result, literals);
402                }
403                if let Some(ref else_expr) = else_result {
404                    Self::extract_from_expression(else_expr, literals);
405                }
406            }
407
408            Expression::ScalarSubquery(subquery) => {
409                Self::extract_from_select(subquery, literals);
410            }
411
412            Expression::In { expr, subquery, .. } => {
413                Self::extract_from_expression(expr, literals);
414                Self::extract_from_select(subquery, literals);
415            }
416
417            Expression::InList { expr, values, .. } => {
418                Self::extract_from_expression(expr, literals);
419                for val in values {
420                    Self::extract_from_expression(val, literals);
421                }
422            }
423
424            Expression::Between { expr, low, high, .. } => {
425                Self::extract_from_expression(expr, literals);
426                Self::extract_from_expression(low, literals);
427                Self::extract_from_expression(high, literals);
428            }
429
430            Expression::Cast { expr, .. } => {
431                Self::extract_from_expression(expr, literals);
432            }
433
434            Expression::Position { substring, string, .. } => {
435                Self::extract_from_expression(substring, literals);
436                Self::extract_from_expression(string, literals);
437            }
438
439            Expression::Trim { removal_char, string, .. } => {
440                if let Some(ref ch) = removal_char {
441                    Self::extract_from_expression(ch, literals);
442                }
443                Self::extract_from_expression(string, literals);
444            }
445
446            Expression::Extract { expr, .. } => {
447                Self::extract_from_expression(expr, literals);
448            }
449
450            Expression::Like { expr, pattern, .. } => {
451                Self::extract_from_expression(expr, literals);
452                Self::extract_from_expression(pattern, literals);
453            }
454
455            Expression::Glob { expr, pattern, .. } => {
456                Self::extract_from_expression(expr, literals);
457                Self::extract_from_expression(pattern, literals);
458            }
459
460            Expression::Exists { subquery, .. } => {
461                Self::extract_from_select(subquery, literals);
462            }
463
464            Expression::QuantifiedComparison { expr, subquery, .. } => {
465                Self::extract_from_expression(expr, literals);
466                Self::extract_from_select(subquery, literals);
467            }
468
469            Expression::WindowFunction { function, over } => {
470                // Extract from function arguments
471                match function {
472                    vibesql_ast::WindowFunctionSpec::Aggregate { args, .. }
473                    | vibesql_ast::WindowFunctionSpec::Ranking { args, .. }
474                    | vibesql_ast::WindowFunctionSpec::Value { args, .. } => {
475                        for arg in args {
476                            Self::extract_from_expression(arg, literals);
477                        }
478                    }
479                }
480
481                // Extract from PARTITION BY
482                if let Some(ref partition_by) = over.partition_by {
483                    for expr in partition_by {
484                        Self::extract_from_expression(expr, literals);
485                    }
486                }
487
488                // Extract from ORDER BY
489                if let Some(ref order_by) = over.order_by {
490                    for item in order_by {
491                        Self::extract_from_expression(&item.expr, literals);
492                    }
493                }
494
495                // Extract from frame bounds
496                if let Some(ref frame) = over.frame {
497                    if let vibesql_ast::FrameBound::Preceding(expr)
498                    | vibesql_ast::FrameBound::Following(expr) = &frame.start
499                    {
500                        Self::extract_from_expression(expr, literals);
501                    }
502                    if let Some(
503                        vibesql_ast::FrameBound::Preceding(expr)
504                        | vibesql_ast::FrameBound::Following(expr),
505                    ) = &frame.end
506                    {
507                        Self::extract_from_expression(expr, literals);
508                    }
509                }
510            }
511
512            // These expressions don't contain literals
513            Expression::ColumnRef(_)
514            | Expression::Placeholder(_)
515            | Expression::NumberedPlaceholder(_)
516            | Expression::NamedPlaceholder(_)
517            | Expression::PseudoVariable { .. }
518            | Expression::Wildcard
519            | Expression::CurrentDate
520            | Expression::CurrentTime { .. }
521            | Expression::CurrentTimestamp { .. }
522            | Expression::Interval { .. }
523            | Expression::Default
524            | Expression::DuplicateKeyValue { .. }
525            | Expression::NextValue { .. }
526            | Expression::MatchAgainst { .. }
527            | Expression::SessionVariable { .. } => {}
528        }
529    }
530}
531
532#[cfg(test)]
533mod tests {
534    use super::*;
535
536    #[test]
537    fn test_literal_value_to_string() {
538        assert_eq!(LiteralValue::Integer(42).to_sql(), "42");
539        assert_eq!(LiteralValue::Varchar("hello".to_string()).to_sql(), "'hello'");
540        assert_eq!(LiteralValue::Boolean(true).to_sql(), "true");
541        assert_eq!(LiteralValue::Null.to_sql(), "NULL");
542    }
543
544    #[test]
545    fn test_literal_value_string_escape() {
546        assert_eq!(LiteralValue::Varchar("it's".to_string()).to_sql(), "'it''s'");
547    }
548
549    #[test]
550    fn test_literal_extraction_simple() {
551        use vibesql_ast::{
552            BinaryOperator, Expression, FromClause, SelectItem, SelectStmt, Statement,
553        };
554
555        // SELECT col0 FROM tab WHERE col1 > 25 AND col2 = 'John'
556        let stmt = Statement::Select(Box::new(SelectStmt {
557            with_clause: None,
558            distinct: false,
559            select_list: vec![SelectItem::Expression {
560                expr: Expression::ColumnRef(vibesql_ast::ColumnIdentifier::simple("col0", false)),
561                alias: None,
562                source_text: None,
563            }],
564            into_table: None,
565            into_variables: None,
566            from: Some(FromClause::Table {
567                name: "tab".to_string(),
568                alias: None,
569                column_aliases: None,
570                quoted: false,
571            }),
572            where_clause: Some(Expression::BinaryOp {
573                op: BinaryOperator::And,
574                left: Box::new(Expression::BinaryOp {
575                    op: BinaryOperator::GreaterThan,
576                    left: Box::new(Expression::ColumnRef(vibesql_ast::ColumnIdentifier::simple(
577                        "col1", false,
578                    ))),
579                    right: Box::new(Expression::Literal(SqlValue::Integer(25))),
580                }),
581                right: Box::new(Expression::BinaryOp {
582                    op: BinaryOperator::Equal,
583                    left: Box::new(Expression::ColumnRef(vibesql_ast::ColumnIdentifier::simple(
584                        "col2", false,
585                    ))),
586                    right: Box::new(Expression::Literal(SqlValue::Varchar(arcstr::ArcStr::from(
587                        "John",
588                    )))),
589                }),
590            }),
591            group_by: None,
592            having: None,
593            order_by: None,
594            limit: None,
595            offset: None,
596            set_operation: None,
597            values: None,
598        }));
599
600        let literals = LiteralExtractor::extract(&stmt);
601
602        assert_eq!(literals.len(), 2);
603        assert_eq!(literals[0], LiteralValue::Integer(25));
604        assert_eq!(literals[1], LiteralValue::Varchar("John".to_string()));
605    }
606
607    #[test]
608    fn test_literal_extraction_in_list() {
609        use vibesql_ast::{Expression, FromClause, SelectItem, SelectStmt, Statement};
610
611        // SELECT * FROM tab WHERE id IN (1, 2, 3)
612        let stmt = Statement::Select(Box::new(SelectStmt {
613            with_clause: None,
614            distinct: false,
615            select_list: vec![SelectItem::Wildcard { alias: None }],
616            into_table: None,
617            into_variables: None,
618            from: Some(FromClause::Table {
619                name: "tab".to_string(),
620                alias: None,
621                column_aliases: None,
622                quoted: false,
623            }),
624            where_clause: Some(Expression::InList {
625                expr: Box::new(Expression::ColumnRef(vibesql_ast::ColumnIdentifier::simple(
626                    "id", false,
627                ))),
628                values: vec![
629                    Expression::Literal(SqlValue::Integer(1)),
630                    Expression::Literal(SqlValue::Integer(2)),
631                    Expression::Literal(SqlValue::Integer(3)),
632                ],
633                negated: false,
634            }),
635            group_by: None,
636            having: None,
637            order_by: None,
638            limit: None,
639            offset: None,
640            set_operation: None,
641            values: None,
642        }));
643
644        let literals = LiteralExtractor::extract(&stmt);
645
646        assert_eq!(literals.len(), 3);
647        assert_eq!(literals[0], LiteralValue::Integer(1));
648        assert_eq!(literals[1], LiteralValue::Integer(2));
649        assert_eq!(literals[2], LiteralValue::Integer(3));
650    }
651
652    #[test]
653    fn test_parameterized_plan_bind() {
654        let plan = ParameterizedPlan::new(
655            "SELECT * FROM users WHERE age > ?".to_string(),
656            vec![ParameterPosition { position: 40, context: "age".to_string() }],
657            vec![LiteralValue::Integer(25)],
658        );
659
660        let result = plan.bind(&[LiteralValue::Integer(30)]).unwrap();
661        assert_eq!(result, "SELECT * FROM users WHERE age > 30");
662    }
663
664    #[test]
665    fn test_parameterized_plan_bind_string() {
666        let plan = ParameterizedPlan::new(
667            "SELECT * FROM users WHERE name = ?".to_string(),
668            vec![ParameterPosition { position: 40, context: "name".to_string() }],
669            vec![LiteralValue::Varchar("John".to_string())],
670        );
671
672        let result = plan.bind(&[LiteralValue::Varchar("Jane".to_string())]).unwrap();
673        assert_eq!(result, "SELECT * FROM users WHERE name = 'Jane'");
674    }
675
676    #[test]
677    fn test_parameterized_plan_bind_error() {
678        let plan = ParameterizedPlan::new(
679            "SELECT * FROM users WHERE age > ?".to_string(),
680            vec![ParameterPosition { position: 40, context: "age".to_string() }],
681            vec![LiteralValue::Integer(25)],
682        );
683
684        let result = plan.bind(&[LiteralValue::Integer(30), LiteralValue::Integer(40)]);
685        assert!(result.is_err());
686    }
687
688    #[test]
689    fn test_comparison_key() {
690        let plan = ParameterizedPlan::new("SELECT * FROM users".to_string(), vec![], vec![]);
691
692        assert_eq!(plan.comparison_key(), "SELECT * FROM users");
693    }
694}