1use vibesql_ast::{Expression, Statement};
7use vibesql_types::SqlValue;
8
9#[derive(Clone, Debug, PartialEq)]
11pub enum LiteralValue {
12 Integer(i64),
13 Smallint(i16),
14 Bigint(i64),
15 Unsigned(u64),
16 Numeric(f64),
17 Float(f32),
18 Real(f64), Double(f64),
20 Character(String),
21 Varchar(String),
22 Boolean(bool),
23 Date(String),
24 Time(String),
25 Timestamp(String),
26 Blob(Vec<u8>),
27 Null,
28}
29
30impl LiteralValue {
31 pub fn from_sql_value(value: &SqlValue) -> Self {
33 match value {
34 SqlValue::Integer(n) => LiteralValue::Integer(*n),
35 SqlValue::Smallint(n) => LiteralValue::Smallint(*n),
36 SqlValue::Bigint(n) => LiteralValue::Bigint(*n),
37 SqlValue::Unsigned(n) => LiteralValue::Unsigned(*n),
38 SqlValue::Numeric(n) => LiteralValue::Numeric(*n),
39 SqlValue::Float(n) => LiteralValue::Float(*n),
40 SqlValue::Real(n) => LiteralValue::Real(*n),
41 SqlValue::Double(n) => LiteralValue::Double(*n),
42 SqlValue::Character(s) => LiteralValue::Character(s.to_string()),
43 SqlValue::Varchar(s) => LiteralValue::Varchar(s.to_string()),
44 SqlValue::Boolean(b) => LiteralValue::Boolean(*b),
45 SqlValue::Date(s) => LiteralValue::Date(s.to_string()),
46 SqlValue::Time(s) => LiteralValue::Time(s.to_string()),
47 SqlValue::Timestamp(s) => LiteralValue::Timestamp(s.to_string()),
48 SqlValue::Interval(s) => LiteralValue::Varchar(s.to_string()), SqlValue::Vector(v) => {
51 let formatted: Vec<String> = v.iter().map(|f| f.to_string()).collect();
53 LiteralValue::Varchar(format!("[{}]", formatted.join(", ")))
54 }
55 SqlValue::Blob(b) => LiteralValue::Blob(b.clone()),
56 SqlValue::Null => LiteralValue::Null,
57 }
58 }
59
60 pub fn to_sql(&self) -> String {
62 match self {
63 LiteralValue::Integer(n) => n.to_string(),
64 LiteralValue::Smallint(n) => n.to_string(),
65 LiteralValue::Bigint(n) => n.to_string(),
66 LiteralValue::Unsigned(n) => n.to_string(),
67 LiteralValue::Numeric(n) => n.to_string(),
68 LiteralValue::Float(n) => n.to_string(),
69 LiteralValue::Real(n) => n.to_string(),
70 LiteralValue::Double(n) => n.to_string(),
71 LiteralValue::Character(s) | LiteralValue::Varchar(s) => {
72 format!("'{}'", s.replace("'", "''"))
73 }
74 LiteralValue::Boolean(b) => if *b { "true" } else { "false" }.to_string(),
75 LiteralValue::Date(s) => format!("DATE '{}'", s),
76 LiteralValue::Time(s) => format!("TIME '{}'", s),
77 LiteralValue::Timestamp(s) => format!("TIMESTAMP '{}'", s),
78 LiteralValue::Blob(b) => {
79 let hex: String = b.iter().map(|byte| format!("{:02X}", byte)).collect();
80 format!("x'{}'", hex)
81 }
82 LiteralValue::Null => "NULL".to_string(),
83 }
84 }
85}
86
87#[derive(Clone, Debug)]
89pub struct ParameterPosition {
90 pub position: usize,
91 pub context: String,
92}
93
94#[derive(Clone, Debug)]
96pub struct ParameterizedPlan {
97 pub normalized_query: String,
98 pub param_positions: Vec<ParameterPosition>,
99 pub literal_values: Vec<LiteralValue>,
100}
101
102impl ParameterizedPlan {
103 pub fn new(
105 normalized_query: String,
106 param_positions: Vec<ParameterPosition>,
107 literal_values: Vec<LiteralValue>,
108 ) -> Self {
109 Self { normalized_query, param_positions, literal_values }
110 }
111
112 pub fn bind(&self, values: &[LiteralValue]) -> Result<String, String> {
114 if values.len() != self.param_positions.len() {
115 return Err(format!(
116 "Expected {} parameters, got {}",
117 self.param_positions.len(),
118 values.len()
119 ));
120 }
121
122 let mut result = self.normalized_query.clone();
123 let mut offset = 0;
124
125 for (i, value) in values.iter().enumerate() {
126 if let Some(_pos) = self.param_positions.get(i) {
127 let sql_value = value.to_sql();
128 let placeholder = "?";
129
130 if let Some(idx) = result[offset..].find(placeholder) {
131 let insert_pos = offset + idx;
132 result.replace_range(insert_pos..insert_pos + 1, &sql_value);
133 offset = insert_pos + sql_value.len();
134 }
135 }
136 }
137
138 Ok(result)
139 }
140
141 pub fn comparison_key(&self) -> String {
143 self.normalized_query.clone()
144 }
145}
146
147pub struct LiteralExtractor;
149
150impl LiteralExtractor {
151 pub fn extract(stmt: &Statement) -> Vec<LiteralValue> {
153 let mut literals = Vec::new();
154 Self::extract_from_statement(stmt, &mut literals);
155 literals
156 }
157
158 fn extract_from_statement(stmt: &Statement, literals: &mut Vec<LiteralValue>) {
159 match stmt {
160 Statement::Select(select) => Self::extract_from_select(select, literals),
161 Statement::Insert(insert) => {
162 match &insert.source {
164 vibesql_ast::InsertSource::Values(rows) => {
165 for row in rows {
166 for expr in row {
167 Self::extract_from_expression(expr, literals);
168 }
169 }
170 }
171 vibesql_ast::InsertSource::Select(select) => {
172 Self::extract_from_select(select, literals);
173 }
174 vibesql_ast::InsertSource::DefaultValues => {
175 }
177 }
178 }
179 Statement::Update(update) => {
180 for assignment in &update.assignments {
182 Self::extract_from_expression(&assignment.value, literals);
183 }
184 if let Some(ref where_clause) = update.where_clause {
186 match where_clause {
187 vibesql_ast::WhereClause::Condition(expr) => {
188 Self::extract_from_expression(expr, literals);
189 }
190 vibesql_ast::WhereClause::CurrentOf(_) => {
191 }
193 }
194 }
195 }
196 Statement::Delete(delete) => {
197 if let Some(ref where_clause) = delete.where_clause {
199 match where_clause {
200 vibesql_ast::WhereClause::Condition(expr) => {
201 Self::extract_from_expression(expr, literals);
202 }
203 vibesql_ast::WhereClause::CurrentOf(_) => {
204 }
206 }
207 }
208 }
209 _ => {}
211 }
212 }
213
214 fn extract_from_select(select: &vibesql_ast::SelectStmt, literals: &mut Vec<LiteralValue>) {
215 for item in &select.select_list {
217 if let vibesql_ast::SelectItem::Expression { expr, .. } = item {
218 Self::extract_from_expression(expr, literals);
219 }
220 }
221
222 if let Some(ref from) = select.from {
224 Self::extract_from_from_clause(from, literals);
225 }
226
227 if let Some(ref where_clause) = select.where_clause {
229 Self::extract_from_expression(where_clause, literals);
230 }
231
232 if let Some(ref group_by) = select.group_by {
234 Self::extract_from_group_by(group_by, literals);
235 }
236
237 if let Some(ref having) = select.having {
239 Self::extract_from_expression(having, literals);
240 }
241
242 if let Some(ref order_by) = select.order_by {
244 for item in order_by {
245 Self::extract_from_expression(&item.expr, literals);
246 }
247 }
248 }
249
250 fn extract_from_from_clause(from: &vibesql_ast::FromClause, literals: &mut Vec<LiteralValue>) {
251 match from {
252 vibesql_ast::FromClause::Join { left, right, condition, .. } => {
253 Self::extract_from_from_clause(left, literals);
254 Self::extract_from_from_clause(right, literals);
255 if let Some(expr) = condition {
256 Self::extract_from_expression(expr, literals);
257 }
258 }
259 vibesql_ast::FromClause::Subquery { query, .. } => {
260 Self::extract_from_select(query, literals);
261 }
262 vibesql_ast::FromClause::Table { .. } => {
263 }
265 vibesql_ast::FromClause::Values { rows, .. } => {
266 for row in rows {
268 for expr in row {
269 Self::extract_from_expression(expr, literals);
270 }
271 }
272 }
273 }
274 }
275
276 fn extract_from_group_by(
277 group_by: &vibesql_ast::GroupByClause,
278 literals: &mut Vec<LiteralValue>,
279 ) {
280 match group_by {
281 vibesql_ast::GroupByClause::Simple(exprs) => {
282 for expr in exprs {
283 Self::extract_from_expression(expr, literals);
284 }
285 }
286 vibesql_ast::GroupByClause::Rollup(elements)
287 | vibesql_ast::GroupByClause::Cube(elements) => {
288 Self::extract_from_grouping_elements(elements, literals);
289 }
290 vibesql_ast::GroupByClause::GroupingSets(sets) => {
291 Self::extract_from_grouping_sets(sets, literals);
292 }
293 vibesql_ast::GroupByClause::Mixed(items) => {
294 for item in items {
295 match item {
296 vibesql_ast::MixedGroupingItem::Simple(expr) => {
297 Self::extract_from_expression(expr, literals);
298 }
299 vibesql_ast::MixedGroupingItem::Rollup(elements)
300 | vibesql_ast::MixedGroupingItem::Cube(elements) => {
301 Self::extract_from_grouping_elements(elements, literals);
302 }
303 vibesql_ast::MixedGroupingItem::GroupingSets(sets) => {
304 Self::extract_from_grouping_sets(sets, literals);
305 }
306 }
307 }
308 }
309 }
310 }
311
312 fn extract_from_grouping_elements(
313 elements: &[vibesql_ast::GroupingElement],
314 literals: &mut Vec<LiteralValue>,
315 ) {
316 for element in elements {
317 match element {
318 vibesql_ast::GroupingElement::Single(expr) => {
319 Self::extract_from_expression(expr, literals);
320 }
321 vibesql_ast::GroupingElement::Composite(exprs) => {
322 for expr in exprs {
323 Self::extract_from_expression(expr, literals);
324 }
325 }
326 }
327 }
328 }
329
330 fn extract_from_grouping_sets(
331 sets: &[vibesql_ast::GroupingSet],
332 literals: &mut Vec<LiteralValue>,
333 ) {
334 for set in sets {
335 for expr in &set.columns {
336 Self::extract_from_expression(expr, literals);
337 }
338 }
339 }
340
341 fn extract_from_expression(expr: &Expression, literals: &mut Vec<LiteralValue>) {
342 match expr {
343 Expression::Literal(value) => {
344 literals.push(LiteralValue::from_sql_value(value));
345 }
346
347 Expression::BinaryOp { left, right, .. } => {
348 Self::extract_from_expression(left, literals);
349 Self::extract_from_expression(right, literals);
350 }
351
352 Expression::Conjunction(children)
353 | Expression::Disjunction(children)
354 | Expression::RowValueConstructor(children) => {
355 for child in children {
356 Self::extract_from_expression(child, literals);
357 }
358 }
359
360 Expression::Collate { expr, .. } => {
361 Self::extract_from_expression(expr, literals);
362 }
363
364 Expression::UnaryOp { expr, .. } => {
365 Self::extract_from_expression(expr, literals);
366 }
367
368 Expression::Function { args, .. } => {
369 for arg in args {
370 Self::extract_from_expression(arg, literals);
371 }
372 }
373
374 Expression::AggregateFunction { args, .. } => {
375 for arg in args {
376 Self::extract_from_expression(arg, literals);
377 }
378 }
379
380 Expression::IsNull { expr, .. } => {
381 Self::extract_from_expression(expr, literals);
382 }
383
384 Expression::IsDistinctFrom { left, right, .. } => {
385 Self::extract_from_expression(left, literals);
386 Self::extract_from_expression(right, literals);
387 }
388
389 Expression::IsTruthValue { expr, .. } => {
390 Self::extract_from_expression(expr, literals);
391 }
392
393 Expression::Case { operand, when_clauses, else_result } => {
394 if let Some(ref op) = operand {
395 Self::extract_from_expression(op, literals);
396 }
397 for when in when_clauses {
398 for cond in &when.conditions {
399 Self::extract_from_expression(cond, literals);
400 }
401 Self::extract_from_expression(&when.result, literals);
402 }
403 if let Some(ref else_expr) = else_result {
404 Self::extract_from_expression(else_expr, literals);
405 }
406 }
407
408 Expression::ScalarSubquery(subquery) => {
409 Self::extract_from_select(subquery, literals);
410 }
411
412 Expression::In { expr, subquery, .. } => {
413 Self::extract_from_expression(expr, literals);
414 Self::extract_from_select(subquery, literals);
415 }
416
417 Expression::InList { expr, values, .. } => {
418 Self::extract_from_expression(expr, literals);
419 for val in values {
420 Self::extract_from_expression(val, literals);
421 }
422 }
423
424 Expression::Between { expr, low, high, .. } => {
425 Self::extract_from_expression(expr, literals);
426 Self::extract_from_expression(low, literals);
427 Self::extract_from_expression(high, literals);
428 }
429
430 Expression::Cast { expr, .. } => {
431 Self::extract_from_expression(expr, literals);
432 }
433
434 Expression::Position { substring, string, .. } => {
435 Self::extract_from_expression(substring, literals);
436 Self::extract_from_expression(string, literals);
437 }
438
439 Expression::Trim { removal_char, string, .. } => {
440 if let Some(ref ch) = removal_char {
441 Self::extract_from_expression(ch, literals);
442 }
443 Self::extract_from_expression(string, literals);
444 }
445
446 Expression::Extract { expr, .. } => {
447 Self::extract_from_expression(expr, literals);
448 }
449
450 Expression::Like { expr, pattern, .. } => {
451 Self::extract_from_expression(expr, literals);
452 Self::extract_from_expression(pattern, literals);
453 }
454
455 Expression::Glob { expr, pattern, .. } => {
456 Self::extract_from_expression(expr, literals);
457 Self::extract_from_expression(pattern, literals);
458 }
459
460 Expression::Exists { subquery, .. } => {
461 Self::extract_from_select(subquery, literals);
462 }
463
464 Expression::QuantifiedComparison { expr, subquery, .. } => {
465 Self::extract_from_expression(expr, literals);
466 Self::extract_from_select(subquery, literals);
467 }
468
469 Expression::WindowFunction { function, over } => {
470 match function {
472 vibesql_ast::WindowFunctionSpec::Aggregate { args, .. }
473 | vibesql_ast::WindowFunctionSpec::Ranking { args, .. }
474 | vibesql_ast::WindowFunctionSpec::Value { args, .. } => {
475 for arg in args {
476 Self::extract_from_expression(arg, literals);
477 }
478 }
479 }
480
481 if let Some(ref partition_by) = over.partition_by {
483 for expr in partition_by {
484 Self::extract_from_expression(expr, literals);
485 }
486 }
487
488 if let Some(ref order_by) = over.order_by {
490 for item in order_by {
491 Self::extract_from_expression(&item.expr, literals);
492 }
493 }
494
495 if let Some(ref frame) = over.frame {
497 if let vibesql_ast::FrameBound::Preceding(expr)
498 | vibesql_ast::FrameBound::Following(expr) = &frame.start
499 {
500 Self::extract_from_expression(expr, literals);
501 }
502 if let Some(
503 vibesql_ast::FrameBound::Preceding(expr)
504 | vibesql_ast::FrameBound::Following(expr),
505 ) = &frame.end
506 {
507 Self::extract_from_expression(expr, literals);
508 }
509 }
510 }
511
512 Expression::ColumnRef(_)
514 | Expression::Placeholder(_)
515 | Expression::NumberedPlaceholder(_)
516 | Expression::NamedPlaceholder(_)
517 | Expression::PseudoVariable { .. }
518 | Expression::Wildcard
519 | Expression::CurrentDate
520 | Expression::CurrentTime { .. }
521 | Expression::CurrentTimestamp { .. }
522 | Expression::Interval { .. }
523 | Expression::Default
524 | Expression::DuplicateKeyValue { .. }
525 | Expression::NextValue { .. }
526 | Expression::MatchAgainst { .. }
527 | Expression::SessionVariable { .. } => {}
528 }
529 }
530}
531
532#[cfg(test)]
533mod tests {
534 use super::*;
535
536 #[test]
537 fn test_literal_value_to_string() {
538 assert_eq!(LiteralValue::Integer(42).to_sql(), "42");
539 assert_eq!(LiteralValue::Varchar("hello".to_string()).to_sql(), "'hello'");
540 assert_eq!(LiteralValue::Boolean(true).to_sql(), "true");
541 assert_eq!(LiteralValue::Null.to_sql(), "NULL");
542 }
543
544 #[test]
545 fn test_literal_value_string_escape() {
546 assert_eq!(LiteralValue::Varchar("it's".to_string()).to_sql(), "'it''s'");
547 }
548
549 #[test]
550 fn test_literal_extraction_simple() {
551 use vibesql_ast::{
552 BinaryOperator, Expression, FromClause, SelectItem, SelectStmt, Statement,
553 };
554
555 let stmt = Statement::Select(Box::new(SelectStmt {
557 with_clause: None,
558 distinct: false,
559 select_list: vec![SelectItem::Expression {
560 expr: Expression::ColumnRef(vibesql_ast::ColumnIdentifier::simple("col0", false)),
561 alias: None,
562 source_text: None,
563 }],
564 into_table: None,
565 into_variables: None,
566 from: Some(FromClause::Table {
567 name: "tab".to_string(),
568 alias: None,
569 column_aliases: None,
570 quoted: false,
571 }),
572 where_clause: Some(Expression::BinaryOp {
573 op: BinaryOperator::And,
574 left: Box::new(Expression::BinaryOp {
575 op: BinaryOperator::GreaterThan,
576 left: Box::new(Expression::ColumnRef(vibesql_ast::ColumnIdentifier::simple(
577 "col1", false,
578 ))),
579 right: Box::new(Expression::Literal(SqlValue::Integer(25))),
580 }),
581 right: Box::new(Expression::BinaryOp {
582 op: BinaryOperator::Equal,
583 left: Box::new(Expression::ColumnRef(vibesql_ast::ColumnIdentifier::simple(
584 "col2", false,
585 ))),
586 right: Box::new(Expression::Literal(SqlValue::Varchar(arcstr::ArcStr::from(
587 "John",
588 )))),
589 }),
590 }),
591 group_by: None,
592 having: None,
593 order_by: None,
594 limit: None,
595 offset: None,
596 set_operation: None,
597 values: None,
598 }));
599
600 let literals = LiteralExtractor::extract(&stmt);
601
602 assert_eq!(literals.len(), 2);
603 assert_eq!(literals[0], LiteralValue::Integer(25));
604 assert_eq!(literals[1], LiteralValue::Varchar("John".to_string()));
605 }
606
607 #[test]
608 fn test_literal_extraction_in_list() {
609 use vibesql_ast::{Expression, FromClause, SelectItem, SelectStmt, Statement};
610
611 let stmt = Statement::Select(Box::new(SelectStmt {
613 with_clause: None,
614 distinct: false,
615 select_list: vec![SelectItem::Wildcard { alias: None }],
616 into_table: None,
617 into_variables: None,
618 from: Some(FromClause::Table {
619 name: "tab".to_string(),
620 alias: None,
621 column_aliases: None,
622 quoted: false,
623 }),
624 where_clause: Some(Expression::InList {
625 expr: Box::new(Expression::ColumnRef(vibesql_ast::ColumnIdentifier::simple(
626 "id", false,
627 ))),
628 values: vec![
629 Expression::Literal(SqlValue::Integer(1)),
630 Expression::Literal(SqlValue::Integer(2)),
631 Expression::Literal(SqlValue::Integer(3)),
632 ],
633 negated: false,
634 }),
635 group_by: None,
636 having: None,
637 order_by: None,
638 limit: None,
639 offset: None,
640 set_operation: None,
641 values: None,
642 }));
643
644 let literals = LiteralExtractor::extract(&stmt);
645
646 assert_eq!(literals.len(), 3);
647 assert_eq!(literals[0], LiteralValue::Integer(1));
648 assert_eq!(literals[1], LiteralValue::Integer(2));
649 assert_eq!(literals[2], LiteralValue::Integer(3));
650 }
651
652 #[test]
653 fn test_parameterized_plan_bind() {
654 let plan = ParameterizedPlan::new(
655 "SELECT * FROM users WHERE age > ?".to_string(),
656 vec![ParameterPosition { position: 40, context: "age".to_string() }],
657 vec![LiteralValue::Integer(25)],
658 );
659
660 let result = plan.bind(&[LiteralValue::Integer(30)]).unwrap();
661 assert_eq!(result, "SELECT * FROM users WHERE age > 30");
662 }
663
664 #[test]
665 fn test_parameterized_plan_bind_string() {
666 let plan = ParameterizedPlan::new(
667 "SELECT * FROM users WHERE name = ?".to_string(),
668 vec![ParameterPosition { position: 40, context: "name".to_string() }],
669 vec![LiteralValue::Varchar("John".to_string())],
670 );
671
672 let result = plan.bind(&[LiteralValue::Varchar("Jane".to_string())]).unwrap();
673 assert_eq!(result, "SELECT * FROM users WHERE name = 'Jane'");
674 }
675
676 #[test]
677 fn test_parameterized_plan_bind_error() {
678 let plan = ParameterizedPlan::new(
679 "SELECT * FROM users WHERE age > ?".to_string(),
680 vec![ParameterPosition { position: 40, context: "age".to_string() }],
681 vec![LiteralValue::Integer(25)],
682 );
683
684 let result = plan.bind(&[LiteralValue::Integer(30), LiteralValue::Integer(40)]);
685 assert!(result.is_err());
686 }
687
688 #[test]
689 fn test_comparison_key() {
690 let plan = ParameterizedPlan::new("SELECT * FROM users".to_string(), vec![], vec![]);
691
692 assert_eq!(plan.comparison_key(), "SELECT * FROM users");
693 }
694}