1use vibesql_ast::{Expression, Statement};
7use vibesql_types::SqlValue;
8
9#[derive(Clone, Debug, PartialEq)]
11pub enum LiteralValue {
12 Integer(i64),
13 Smallint(i16),
14 Bigint(i64),
15 Unsigned(u64),
16 Numeric(f64),
17 Float(f32),
18 Real(f32),
19 Double(f64),
20 Character(String),
21 Varchar(String),
22 Boolean(bool),
23 Date(String),
24 Time(String),
25 Timestamp(String),
26 Null,
27}
28
29impl LiteralValue {
30 pub fn from_sql_value(value: &SqlValue) -> Self {
32 match value {
33 SqlValue::Integer(n) => LiteralValue::Integer(*n),
34 SqlValue::Smallint(n) => LiteralValue::Smallint(*n),
35 SqlValue::Bigint(n) => LiteralValue::Bigint(*n),
36 SqlValue::Unsigned(n) => LiteralValue::Unsigned(*n),
37 SqlValue::Numeric(n) => LiteralValue::Numeric(*n),
38 SqlValue::Float(n) => LiteralValue::Float(*n),
39 SqlValue::Real(n) => LiteralValue::Real(*n),
40 SqlValue::Double(n) => LiteralValue::Double(*n),
41 SqlValue::Character(s) => LiteralValue::Character(s.clone()),
42 SqlValue::Varchar(s) => LiteralValue::Varchar(s.clone()),
43 SqlValue::Boolean(b) => LiteralValue::Boolean(*b),
44 SqlValue::Date(s) => LiteralValue::Date(s.to_string()),
45 SqlValue::Time(s) => LiteralValue::Time(s.to_string()),
46 SqlValue::Timestamp(s) => LiteralValue::Timestamp(s.to_string()),
47 SqlValue::Interval(s) => LiteralValue::Varchar(s.to_string()), SqlValue::Vector(v) => {
50 let formatted: Vec<String> = v.iter().map(|f| f.to_string()).collect();
52 LiteralValue::Varchar(format!("[{}]", formatted.join(", ")))
53 }
54 SqlValue::Null => LiteralValue::Null,
55 }
56 }
57
58 pub fn to_sql(&self) -> String {
60 match self {
61 LiteralValue::Integer(n) => n.to_string(),
62 LiteralValue::Smallint(n) => n.to_string(),
63 LiteralValue::Bigint(n) => n.to_string(),
64 LiteralValue::Unsigned(n) => n.to_string(),
65 LiteralValue::Numeric(n) => n.to_string(),
66 LiteralValue::Float(n) => n.to_string(),
67 LiteralValue::Real(n) => n.to_string(),
68 LiteralValue::Double(n) => n.to_string(),
69 LiteralValue::Character(s) | LiteralValue::Varchar(s) => {
70 format!("'{}'", s.replace("'", "''"))
71 }
72 LiteralValue::Boolean(b) => if *b { "true" } else { "false" }.to_string(),
73 LiteralValue::Date(s) => format!("DATE '{}'", s),
74 LiteralValue::Time(s) => format!("TIME '{}'", s),
75 LiteralValue::Timestamp(s) => format!("TIMESTAMP '{}'", s),
76 LiteralValue::Null => "NULL".to_string(),
77 }
78 }
79}
80
81#[derive(Clone, Debug)]
83pub struct ParameterPosition {
84 pub position: usize,
85 pub context: String,
86}
87
88#[derive(Clone, Debug)]
90pub struct ParameterizedPlan {
91 pub normalized_query: String,
92 pub param_positions: Vec<ParameterPosition>,
93 pub literal_values: Vec<LiteralValue>,
94}
95
96impl ParameterizedPlan {
97 pub fn new(
99 normalized_query: String,
100 param_positions: Vec<ParameterPosition>,
101 literal_values: Vec<LiteralValue>,
102 ) -> Self {
103 Self { normalized_query, param_positions, literal_values }
104 }
105
106 pub fn bind(&self, values: &[LiteralValue]) -> Result<String, String> {
108 if values.len() != self.param_positions.len() {
109 return Err(format!(
110 "Expected {} parameters, got {}",
111 self.param_positions.len(),
112 values.len()
113 ));
114 }
115
116 let mut result = self.normalized_query.clone();
117 let mut offset = 0;
118
119 for (i, value) in values.iter().enumerate() {
120 if let Some(_pos) = self.param_positions.get(i) {
121 let sql_value = value.to_sql();
122 let placeholder = "?";
123
124 if let Some(idx) = result[offset..].find(placeholder) {
125 let insert_pos = offset + idx;
126 result.replace_range(insert_pos..insert_pos + 1, &sql_value);
127 offset = insert_pos + sql_value.len();
128 }
129 }
130 }
131
132 Ok(result)
133 }
134
135 pub fn comparison_key(&self) -> String {
137 self.normalized_query.clone()
138 }
139}
140
141pub struct LiteralExtractor;
143
144impl LiteralExtractor {
145 pub fn extract(stmt: &Statement) -> Vec<LiteralValue> {
147 let mut literals = Vec::new();
148 Self::extract_from_statement(stmt, &mut literals);
149 literals
150 }
151
152 fn extract_from_statement(stmt: &Statement, literals: &mut Vec<LiteralValue>) {
153 match stmt {
154 Statement::Select(select) => Self::extract_from_select(select, literals),
155 Statement::Insert(insert) => {
156 match &insert.source {
158 vibesql_ast::InsertSource::Values(rows) => {
159 for row in rows {
160 for expr in row {
161 Self::extract_from_expression(expr, literals);
162 }
163 }
164 }
165 vibesql_ast::InsertSource::Select(select) => {
166 Self::extract_from_select(select, literals);
167 }
168 }
169 }
170 Statement::Update(update) => {
171 for assignment in &update.assignments {
173 Self::extract_from_expression(&assignment.value, literals);
174 }
175 if let Some(ref where_clause) = update.where_clause {
177 match where_clause {
178 vibesql_ast::WhereClause::Condition(expr) => {
179 Self::extract_from_expression(expr, literals);
180 }
181 vibesql_ast::WhereClause::CurrentOf(_) => {
182 }
184 }
185 }
186 }
187 Statement::Delete(delete) => {
188 if let Some(ref where_clause) = delete.where_clause {
190 match where_clause {
191 vibesql_ast::WhereClause::Condition(expr) => {
192 Self::extract_from_expression(expr, literals);
193 }
194 vibesql_ast::WhereClause::CurrentOf(_) => {
195 }
197 }
198 }
199 }
200 _ => {}
202 }
203 }
204
205 fn extract_from_select(select: &vibesql_ast::SelectStmt, literals: &mut Vec<LiteralValue>) {
206 for item in &select.select_list {
208 if let vibesql_ast::SelectItem::Expression { expr, .. } = item {
209 Self::extract_from_expression(expr, literals);
210 }
211 }
212
213 if let Some(ref from) = select.from {
215 Self::extract_from_from_clause(from, literals);
216 }
217
218 if let Some(ref where_clause) = select.where_clause {
220 Self::extract_from_expression(where_clause, literals);
221 }
222
223 if let Some(ref group_by) = select.group_by {
225 Self::extract_from_group_by(group_by, literals);
226 }
227
228 if let Some(ref having) = select.having {
230 Self::extract_from_expression(having, literals);
231 }
232
233 if let Some(ref order_by) = select.order_by {
235 for item in order_by {
236 Self::extract_from_expression(&item.expr, literals);
237 }
238 }
239 }
240
241 fn extract_from_from_clause(from: &vibesql_ast::FromClause, literals: &mut Vec<LiteralValue>) {
242 match from {
243 vibesql_ast::FromClause::Join { left, right, condition, .. } => {
244 Self::extract_from_from_clause(left, literals);
245 Self::extract_from_from_clause(right, literals);
246 if let Some(expr) = condition {
247 Self::extract_from_expression(expr, literals);
248 }
249 }
250 vibesql_ast::FromClause::Subquery { query, .. } => {
251 Self::extract_from_select(query, literals);
252 }
253 vibesql_ast::FromClause::Table { .. } => {
254 }
256 }
257 }
258
259 fn extract_from_group_by(
260 group_by: &vibesql_ast::GroupByClause,
261 literals: &mut Vec<LiteralValue>,
262 ) {
263 match group_by {
264 vibesql_ast::GroupByClause::Simple(exprs) => {
265 for expr in exprs {
266 Self::extract_from_expression(expr, literals);
267 }
268 }
269 vibesql_ast::GroupByClause::Rollup(elements)
270 | vibesql_ast::GroupByClause::Cube(elements) => {
271 Self::extract_from_grouping_elements(elements, literals);
272 }
273 vibesql_ast::GroupByClause::GroupingSets(sets) => {
274 Self::extract_from_grouping_sets(sets, literals);
275 }
276 vibesql_ast::GroupByClause::Mixed(items) => {
277 for item in items {
278 match item {
279 vibesql_ast::MixedGroupingItem::Simple(expr) => {
280 Self::extract_from_expression(expr, literals);
281 }
282 vibesql_ast::MixedGroupingItem::Rollup(elements)
283 | vibesql_ast::MixedGroupingItem::Cube(elements) => {
284 Self::extract_from_grouping_elements(elements, literals);
285 }
286 vibesql_ast::MixedGroupingItem::GroupingSets(sets) => {
287 Self::extract_from_grouping_sets(sets, literals);
288 }
289 }
290 }
291 }
292 }
293 }
294
295 fn extract_from_grouping_elements(
296 elements: &[vibesql_ast::GroupingElement],
297 literals: &mut Vec<LiteralValue>,
298 ) {
299 for element in elements {
300 match element {
301 vibesql_ast::GroupingElement::Single(expr) => {
302 Self::extract_from_expression(expr, literals);
303 }
304 vibesql_ast::GroupingElement::Composite(exprs) => {
305 for expr in exprs {
306 Self::extract_from_expression(expr, literals);
307 }
308 }
309 }
310 }
311 }
312
313 fn extract_from_grouping_sets(
314 sets: &[vibesql_ast::GroupingSet],
315 literals: &mut Vec<LiteralValue>,
316 ) {
317 for set in sets {
318 for expr in &set.columns {
319 Self::extract_from_expression(expr, literals);
320 }
321 }
322 }
323
324 fn extract_from_expression(expr: &Expression, literals: &mut Vec<LiteralValue>) {
325 match expr {
326 Expression::Literal(value) => {
327 literals.push(LiteralValue::from_sql_value(value));
328 }
329
330 Expression::BinaryOp { left, right, .. } => {
331 Self::extract_from_expression(left, literals);
332 Self::extract_from_expression(right, literals);
333 }
334
335 Expression::Conjunction(children) | Expression::Disjunction(children) => {
336 for child in children {
337 Self::extract_from_expression(child, literals);
338 }
339 }
340
341 Expression::UnaryOp { expr, .. } => {
342 Self::extract_from_expression(expr, literals);
343 }
344
345 Expression::Function { args, .. } => {
346 for arg in args {
347 Self::extract_from_expression(arg, literals);
348 }
349 }
350
351 Expression::AggregateFunction { args, .. } => {
352 for arg in args {
353 Self::extract_from_expression(arg, literals);
354 }
355 }
356
357 Expression::IsNull { expr, .. } => {
358 Self::extract_from_expression(expr, literals);
359 }
360
361 Expression::Case { operand, when_clauses, else_result } => {
362 if let Some(ref op) = operand {
363 Self::extract_from_expression(op, literals);
364 }
365 for when in when_clauses {
366 for cond in &when.conditions {
367 Self::extract_from_expression(cond, literals);
368 }
369 Self::extract_from_expression(&when.result, literals);
370 }
371 if let Some(ref else_expr) = else_result {
372 Self::extract_from_expression(else_expr, literals);
373 }
374 }
375
376 Expression::ScalarSubquery(subquery) => {
377 Self::extract_from_select(subquery, literals);
378 }
379
380 Expression::In { expr, subquery, .. } => {
381 Self::extract_from_expression(expr, literals);
382 Self::extract_from_select(subquery, literals);
383 }
384
385 Expression::InList { expr, values, .. } => {
386 Self::extract_from_expression(expr, literals);
387 for val in values {
388 Self::extract_from_expression(val, literals);
389 }
390 }
391
392 Expression::Between { expr, low, high, .. } => {
393 Self::extract_from_expression(expr, literals);
394 Self::extract_from_expression(low, literals);
395 Self::extract_from_expression(high, literals);
396 }
397
398 Expression::Cast { expr, .. } => {
399 Self::extract_from_expression(expr, literals);
400 }
401
402 Expression::Position { substring, string, .. } => {
403 Self::extract_from_expression(substring, literals);
404 Self::extract_from_expression(string, literals);
405 }
406
407 Expression::Trim { removal_char, string, .. } => {
408 if let Some(ref ch) = removal_char {
409 Self::extract_from_expression(ch, literals);
410 }
411 Self::extract_from_expression(string, literals);
412 }
413
414 Expression::Extract { expr, .. } => {
415 Self::extract_from_expression(expr, literals);
416 }
417
418 Expression::Like { expr, pattern, .. } => {
419 Self::extract_from_expression(expr, literals);
420 Self::extract_from_expression(pattern, literals);
421 }
422
423 Expression::Exists { subquery, .. } => {
424 Self::extract_from_select(subquery, literals);
425 }
426
427 Expression::QuantifiedComparison { expr, subquery, .. } => {
428 Self::extract_from_expression(expr, literals);
429 Self::extract_from_select(subquery, literals);
430 }
431
432 Expression::WindowFunction { function, over } => {
433 match function {
435 vibesql_ast::WindowFunctionSpec::Aggregate { args, .. }
436 | vibesql_ast::WindowFunctionSpec::Ranking { args, .. }
437 | vibesql_ast::WindowFunctionSpec::Value { args, .. } => {
438 for arg in args {
439 Self::extract_from_expression(arg, literals);
440 }
441 }
442 }
443
444 if let Some(ref partition_by) = over.partition_by {
446 for expr in partition_by {
447 Self::extract_from_expression(expr, literals);
448 }
449 }
450
451 if let Some(ref order_by) = over.order_by {
453 for item in order_by {
454 Self::extract_from_expression(&item.expr, literals);
455 }
456 }
457
458 if let Some(ref frame) = over.frame {
460 if let vibesql_ast::FrameBound::Preceding(expr) | vibesql_ast::FrameBound::Following(expr) =
461 &frame.start
462 {
463 Self::extract_from_expression(expr, literals);
464 }
465 if let Some(vibesql_ast::FrameBound::Preceding(expr) | vibesql_ast::FrameBound::Following(expr)) = &frame.end {
466 Self::extract_from_expression(expr, literals);
467 }
468 }
469 }
470
471 Expression::ColumnRef { .. }
473 | Expression::Placeholder(_)
474 | Expression::NumberedPlaceholder(_)
475 | Expression::NamedPlaceholder(_)
476 | Expression::PseudoVariable { .. }
477 | Expression::Wildcard
478 | Expression::CurrentDate
479 | Expression::CurrentTime { .. }
480 | Expression::CurrentTimestamp { .. }
481 | Expression::Interval { .. }
482 | Expression::Default
483 | Expression::DuplicateKeyValue { .. }
484 | Expression::NextValue { .. }
485 | Expression::MatchAgainst { .. }
486 | Expression::SessionVariable { .. } => {}
487 }
488 }
489}
490
491#[cfg(test)]
492mod tests {
493 use super::*;
494
495 #[test]
496 fn test_literal_value_to_string() {
497 assert_eq!(LiteralValue::Integer(42).to_sql(), "42");
498 assert_eq!(LiteralValue::Varchar("hello".to_string()).to_sql(), "'hello'");
499 assert_eq!(LiteralValue::Boolean(true).to_sql(), "true");
500 assert_eq!(LiteralValue::Null.to_sql(), "NULL");
501 }
502
503 #[test]
504 fn test_literal_value_string_escape() {
505 assert_eq!(LiteralValue::Varchar("it's".to_string()).to_sql(), "'it''s'");
506 }
507
508 #[test]
509 fn test_literal_extraction_simple() {
510 use vibesql_ast::{BinaryOperator, Expression, FromClause, SelectItem, SelectStmt, Statement};
511
512 let stmt = Statement::Select(Box::new(SelectStmt {
514 with_clause: None,
515 distinct: false,
516 select_list: vec![SelectItem::Expression {
517 expr: Expression::ColumnRef { table: None, column: "col0".to_string() },
518 alias: None,
519 }],
520 into_table: None,
521 into_variables: None, from: Some(FromClause::Table { name: "tab".to_string(), alias: None, column_aliases: None }),
522 where_clause: Some(Expression::BinaryOp {
523 op: BinaryOperator::And,
524 left: Box::new(Expression::BinaryOp {
525 op: BinaryOperator::GreaterThan,
526 left: Box::new(Expression::ColumnRef {
527 table: None,
528 column: "col1".to_string(),
529 }),
530 right: Box::new(Expression::Literal(SqlValue::Integer(25))),
531 }),
532 right: Box::new(Expression::BinaryOp {
533 op: BinaryOperator::Equal,
534 left: Box::new(Expression::ColumnRef {
535 table: None,
536 column: "col2".to_string(),
537 }),
538 right: Box::new(Expression::Literal(SqlValue::Varchar("John".to_string()))),
539 }),
540 }),
541 group_by: None,
542 having: None,
543 order_by: None,
544 limit: None,
545 offset: None,
546 set_operation: None,
547 }));
548
549 let literals = LiteralExtractor::extract(&stmt);
550
551 assert_eq!(literals.len(), 2);
552 assert_eq!(literals[0], LiteralValue::Integer(25));
553 assert_eq!(literals[1], LiteralValue::Varchar("John".to_string()));
554 }
555
556 #[test]
557 fn test_literal_extraction_in_list() {
558 use vibesql_ast::{Expression, FromClause, SelectItem, SelectStmt, Statement};
559
560 let stmt = Statement::Select(Box::new(SelectStmt {
562 with_clause: None,
563 distinct: false,
564 select_list: vec![SelectItem::Wildcard { alias: None }],
565 into_table: None,
566 into_variables: None, from: Some(FromClause::Table { name: "tab".to_string(), alias: None, column_aliases: None }),
567 where_clause: Some(Expression::InList {
568 expr: Box::new(Expression::ColumnRef { table: None, column: "id".to_string() }),
569 values: vec![
570 Expression::Literal(SqlValue::Integer(1)),
571 Expression::Literal(SqlValue::Integer(2)),
572 Expression::Literal(SqlValue::Integer(3)),
573 ],
574 negated: false,
575 }),
576 group_by: None,
577 having: None,
578 order_by: None,
579 limit: None,
580 offset: None,
581 set_operation: None,
582 }));
583
584 let literals = LiteralExtractor::extract(&stmt);
585
586 assert_eq!(literals.len(), 3);
587 assert_eq!(literals[0], LiteralValue::Integer(1));
588 assert_eq!(literals[1], LiteralValue::Integer(2));
589 assert_eq!(literals[2], LiteralValue::Integer(3));
590 }
591
592 #[test]
593 fn test_parameterized_plan_bind() {
594 let plan = ParameterizedPlan::new(
595 "SELECT * FROM users WHERE age > ?".to_string(),
596 vec![ParameterPosition { position: 40, context: "age".to_string() }],
597 vec![LiteralValue::Integer(25)],
598 );
599
600 let result = plan.bind(&[LiteralValue::Integer(30)]).unwrap();
601 assert_eq!(result, "SELECT * FROM users WHERE age > 30");
602 }
603
604 #[test]
605 fn test_parameterized_plan_bind_string() {
606 let plan = ParameterizedPlan::new(
607 "SELECT * FROM users WHERE name = ?".to_string(),
608 vec![ParameterPosition { position: 40, context: "name".to_string() }],
609 vec![LiteralValue::Varchar("John".to_string())],
610 );
611
612 let result = plan.bind(&[LiteralValue::Varchar("Jane".to_string())]).unwrap();
613 assert_eq!(result, "SELECT * FROM users WHERE name = 'Jane'");
614 }
615
616 #[test]
617 fn test_parameterized_plan_bind_error() {
618 let plan = ParameterizedPlan::new(
619 "SELECT * FROM users WHERE age > ?".to_string(),
620 vec![ParameterPosition { position: 40, context: "age".to_string() }],
621 vec![LiteralValue::Integer(25)],
622 );
623
624 let result = plan.bind(&[LiteralValue::Integer(30), LiteralValue::Integer(40)]);
625 assert!(result.is_err());
626 }
627
628 #[test]
629 fn test_comparison_key() {
630 let plan = ParameterizedPlan::new("SELECT * FROM users".to_string(), vec![], vec![]);
631
632 assert_eq!(plan.comparison_key(), "SELECT * FROM users");
633 }
634}