1use vibesql_ast::{Expression, Statement};
7use vibesql_types::SqlValue;
8
9#[derive(Clone, Debug, PartialEq)]
11pub enum LiteralValue {
12 Integer(i64),
13 Smallint(i16),
14 Bigint(i64),
15 Unsigned(u64),
16 Numeric(f64),
17 Float(f32),
18 Real(f32),
19 Double(f64),
20 Character(String),
21 Varchar(String),
22 Boolean(bool),
23 Date(String),
24 Time(String),
25 Timestamp(String),
26 Null,
27}
28
29impl LiteralValue {
30 pub fn from_sql_value(value: &SqlValue) -> Self {
32 match value {
33 SqlValue::Integer(n) => LiteralValue::Integer(*n),
34 SqlValue::Smallint(n) => LiteralValue::Smallint(*n),
35 SqlValue::Bigint(n) => LiteralValue::Bigint(*n),
36 SqlValue::Unsigned(n) => LiteralValue::Unsigned(*n),
37 SqlValue::Numeric(n) => LiteralValue::Numeric(*n),
38 SqlValue::Float(n) => LiteralValue::Float(*n),
39 SqlValue::Real(n) => LiteralValue::Real(*n),
40 SqlValue::Double(n) => LiteralValue::Double(*n),
41 SqlValue::Character(s) => LiteralValue::Character(s.clone()),
42 SqlValue::Varchar(s) => LiteralValue::Varchar(s.clone()),
43 SqlValue::Boolean(b) => LiteralValue::Boolean(*b),
44 SqlValue::Date(s) => LiteralValue::Date(s.to_string()),
45 SqlValue::Time(s) => LiteralValue::Time(s.to_string()),
46 SqlValue::Timestamp(s) => LiteralValue::Timestamp(s.to_string()),
47 SqlValue::Interval(s) => LiteralValue::Varchar(s.to_string()), SqlValue::Vector(v) => {
50 let formatted: Vec<String> = v.iter().map(|f| f.to_string()).collect();
52 LiteralValue::Varchar(format!("[{}]", formatted.join(", ")))
53 }
54 SqlValue::Null => LiteralValue::Null,
55 }
56 }
57
58 pub fn to_sql(&self) -> String {
60 match self {
61 LiteralValue::Integer(n) => n.to_string(),
62 LiteralValue::Smallint(n) => n.to_string(),
63 LiteralValue::Bigint(n) => n.to_string(),
64 LiteralValue::Unsigned(n) => n.to_string(),
65 LiteralValue::Numeric(n) => n.to_string(),
66 LiteralValue::Float(n) => n.to_string(),
67 LiteralValue::Real(n) => n.to_string(),
68 LiteralValue::Double(n) => n.to_string(),
69 LiteralValue::Character(s) | LiteralValue::Varchar(s) => {
70 format!("'{}'", s.replace("'", "''"))
71 }
72 LiteralValue::Boolean(b) => if *b { "true" } else { "false" }.to_string(),
73 LiteralValue::Date(s) => format!("DATE '{}'", s),
74 LiteralValue::Time(s) => format!("TIME '{}'", s),
75 LiteralValue::Timestamp(s) => format!("TIMESTAMP '{}'", s),
76 LiteralValue::Null => "NULL".to_string(),
77 }
78 }
79}
80
81#[derive(Clone, Debug)]
83pub struct ParameterPosition {
84 pub position: usize,
85 pub context: String,
86}
87
88#[derive(Clone, Debug)]
90pub struct ParameterizedPlan {
91 pub normalized_query: String,
92 pub param_positions: Vec<ParameterPosition>,
93 pub literal_values: Vec<LiteralValue>,
94}
95
96impl ParameterizedPlan {
97 pub fn new(
99 normalized_query: String,
100 param_positions: Vec<ParameterPosition>,
101 literal_values: Vec<LiteralValue>,
102 ) -> Self {
103 Self { normalized_query, param_positions, literal_values }
104 }
105
106 pub fn bind(&self, values: &[LiteralValue]) -> Result<String, String> {
108 if values.len() != self.param_positions.len() {
109 return Err(format!(
110 "Expected {} parameters, got {}",
111 self.param_positions.len(),
112 values.len()
113 ));
114 }
115
116 let mut result = self.normalized_query.clone();
117 let mut offset = 0;
118
119 for (i, value) in values.iter().enumerate() {
120 if let Some(_pos) = self.param_positions.get(i) {
121 let sql_value = value.to_sql();
122 let placeholder = "?";
123
124 if let Some(idx) = result[offset..].find(placeholder) {
125 let insert_pos = offset + idx;
126 result.replace_range(insert_pos..insert_pos + 1, &sql_value);
127 offset = insert_pos + sql_value.len();
128 }
129 }
130 }
131
132 Ok(result)
133 }
134
135 pub fn comparison_key(&self) -> String {
137 self.normalized_query.clone()
138 }
139}
140
141pub struct LiteralExtractor;
143
144impl LiteralExtractor {
145 pub fn extract(stmt: &Statement) -> Vec<LiteralValue> {
147 let mut literals = Vec::new();
148 Self::extract_from_statement(stmt, &mut literals);
149 literals
150 }
151
152 fn extract_from_statement(stmt: &Statement, literals: &mut Vec<LiteralValue>) {
153 match stmt {
154 Statement::Select(select) => Self::extract_from_select(select, literals),
155 Statement::Insert(insert) => {
156 match &insert.source {
158 vibesql_ast::InsertSource::Values(rows) => {
159 for row in rows {
160 for expr in row {
161 Self::extract_from_expression(expr, literals);
162 }
163 }
164 }
165 vibesql_ast::InsertSource::Select(select) => {
166 Self::extract_from_select(select, literals);
167 }
168 }
169 }
170 Statement::Update(update) => {
171 for assignment in &update.assignments {
173 Self::extract_from_expression(&assignment.value, literals);
174 }
175 if let Some(ref where_clause) = update.where_clause {
177 match where_clause {
178 vibesql_ast::WhereClause::Condition(expr) => {
179 Self::extract_from_expression(expr, literals);
180 }
181 vibesql_ast::WhereClause::CurrentOf(_) => {
182 }
184 }
185 }
186 }
187 Statement::Delete(delete) => {
188 if let Some(ref where_clause) = delete.where_clause {
190 match where_clause {
191 vibesql_ast::WhereClause::Condition(expr) => {
192 Self::extract_from_expression(expr, literals);
193 }
194 vibesql_ast::WhereClause::CurrentOf(_) => {
195 }
197 }
198 }
199 }
200 _ => {}
202 }
203 }
204
205 fn extract_from_select(select: &vibesql_ast::SelectStmt, literals: &mut Vec<LiteralValue>) {
206 for item in &select.select_list {
208 if let vibesql_ast::SelectItem::Expression { expr, .. } = item {
209 Self::extract_from_expression(expr, literals);
210 }
211 }
212
213 if let Some(ref from) = select.from {
215 Self::extract_from_from_clause(from, literals);
216 }
217
218 if let Some(ref where_clause) = select.where_clause {
220 Self::extract_from_expression(where_clause, literals);
221 }
222
223 if let Some(ref group_by) = select.group_by {
225 Self::extract_from_group_by(group_by, literals);
226 }
227
228 if let Some(ref having) = select.having {
230 Self::extract_from_expression(having, literals);
231 }
232
233 if let Some(ref order_by) = select.order_by {
235 for item in order_by {
236 Self::extract_from_expression(&item.expr, literals);
237 }
238 }
239 }
240
241 fn extract_from_from_clause(from: &vibesql_ast::FromClause, literals: &mut Vec<LiteralValue>) {
242 match from {
243 vibesql_ast::FromClause::Join { left, right, condition, .. } => {
244 Self::extract_from_from_clause(left, literals);
245 Self::extract_from_from_clause(right, literals);
246 if let Some(expr) = condition {
247 Self::extract_from_expression(expr, literals);
248 }
249 }
250 vibesql_ast::FromClause::Subquery { query, .. } => {
251 Self::extract_from_select(query, literals);
252 }
253 vibesql_ast::FromClause::Table { .. } => {
254 }
256 }
257 }
258
259 fn extract_from_group_by(
260 group_by: &vibesql_ast::GroupByClause,
261 literals: &mut Vec<LiteralValue>,
262 ) {
263 match group_by {
264 vibesql_ast::GroupByClause::Simple(exprs) => {
265 for expr in exprs {
266 Self::extract_from_expression(expr, literals);
267 }
268 }
269 vibesql_ast::GroupByClause::Rollup(elements)
270 | vibesql_ast::GroupByClause::Cube(elements) => {
271 Self::extract_from_grouping_elements(elements, literals);
272 }
273 vibesql_ast::GroupByClause::GroupingSets(sets) => {
274 Self::extract_from_grouping_sets(sets, literals);
275 }
276 vibesql_ast::GroupByClause::Mixed(items) => {
277 for item in items {
278 match item {
279 vibesql_ast::MixedGroupingItem::Simple(expr) => {
280 Self::extract_from_expression(expr, literals);
281 }
282 vibesql_ast::MixedGroupingItem::Rollup(elements)
283 | vibesql_ast::MixedGroupingItem::Cube(elements) => {
284 Self::extract_from_grouping_elements(elements, literals);
285 }
286 vibesql_ast::MixedGroupingItem::GroupingSets(sets) => {
287 Self::extract_from_grouping_sets(sets, literals);
288 }
289 }
290 }
291 }
292 }
293 }
294
295 fn extract_from_grouping_elements(
296 elements: &[vibesql_ast::GroupingElement],
297 literals: &mut Vec<LiteralValue>,
298 ) {
299 for element in elements {
300 match element {
301 vibesql_ast::GroupingElement::Single(expr) => {
302 Self::extract_from_expression(expr, literals);
303 }
304 vibesql_ast::GroupingElement::Composite(exprs) => {
305 for expr in exprs {
306 Self::extract_from_expression(expr, literals);
307 }
308 }
309 }
310 }
311 }
312
313 fn extract_from_grouping_sets(
314 sets: &[vibesql_ast::GroupingSet],
315 literals: &mut Vec<LiteralValue>,
316 ) {
317 for set in sets {
318 for expr in &set.columns {
319 Self::extract_from_expression(expr, literals);
320 }
321 }
322 }
323
324 fn extract_from_expression(expr: &Expression, literals: &mut Vec<LiteralValue>) {
325 match expr {
326 Expression::Literal(value) => {
327 literals.push(LiteralValue::from_sql_value(value));
328 }
329
330 Expression::BinaryOp { left, right, .. } => {
331 Self::extract_from_expression(left, literals);
332 Self::extract_from_expression(right, literals);
333 }
334
335 Expression::Conjunction(children) | Expression::Disjunction(children) => {
336 for child in children {
337 Self::extract_from_expression(child, literals);
338 }
339 }
340
341 Expression::UnaryOp { expr, .. } => {
342 Self::extract_from_expression(expr, literals);
343 }
344
345 Expression::Function { args, .. } => {
346 for arg in args {
347 Self::extract_from_expression(arg, literals);
348 }
349 }
350
351 Expression::AggregateFunction { args, .. } => {
352 for arg in args {
353 Self::extract_from_expression(arg, literals);
354 }
355 }
356
357 Expression::IsNull { expr, .. } => {
358 Self::extract_from_expression(expr, literals);
359 }
360
361 Expression::Case { operand, when_clauses, else_result } => {
362 if let Some(ref op) = operand {
363 Self::extract_from_expression(op, literals);
364 }
365 for when in when_clauses {
366 for cond in &when.conditions {
367 Self::extract_from_expression(cond, literals);
368 }
369 Self::extract_from_expression(&when.result, literals);
370 }
371 if let Some(ref else_expr) = else_result {
372 Self::extract_from_expression(else_expr, literals);
373 }
374 }
375
376 Expression::ScalarSubquery(subquery) => {
377 Self::extract_from_select(subquery, literals);
378 }
379
380 Expression::In { expr, subquery, .. } => {
381 Self::extract_from_expression(expr, literals);
382 Self::extract_from_select(subquery, literals);
383 }
384
385 Expression::InList { expr, values, .. } => {
386 Self::extract_from_expression(expr, literals);
387 for val in values {
388 Self::extract_from_expression(val, literals);
389 }
390 }
391
392 Expression::Between { expr, low, high, .. } => {
393 Self::extract_from_expression(expr, literals);
394 Self::extract_from_expression(low, literals);
395 Self::extract_from_expression(high, literals);
396 }
397
398 Expression::Cast { expr, .. } => {
399 Self::extract_from_expression(expr, literals);
400 }
401
402 Expression::Position { substring, string, .. } => {
403 Self::extract_from_expression(substring, literals);
404 Self::extract_from_expression(string, literals);
405 }
406
407 Expression::Trim { removal_char, string, .. } => {
408 if let Some(ref ch) = removal_char {
409 Self::extract_from_expression(ch, literals);
410 }
411 Self::extract_from_expression(string, literals);
412 }
413
414 Expression::Extract { expr, .. } => {
415 Self::extract_from_expression(expr, literals);
416 }
417
418 Expression::Like { expr, pattern, .. } => {
419 Self::extract_from_expression(expr, literals);
420 Self::extract_from_expression(pattern, literals);
421 }
422
423 Expression::Exists { subquery, .. } => {
424 Self::extract_from_select(subquery, literals);
425 }
426
427 Expression::QuantifiedComparison { expr, subquery, .. } => {
428 Self::extract_from_expression(expr, literals);
429 Self::extract_from_select(subquery, literals);
430 }
431
432 Expression::WindowFunction { function, over } => {
433 match function {
435 vibesql_ast::WindowFunctionSpec::Aggregate { args, .. }
436 | vibesql_ast::WindowFunctionSpec::Ranking { args, .. }
437 | vibesql_ast::WindowFunctionSpec::Value { args, .. } => {
438 for arg in args {
439 Self::extract_from_expression(arg, literals);
440 }
441 }
442 }
443
444 if let Some(ref partition_by) = over.partition_by {
446 for expr in partition_by {
447 Self::extract_from_expression(expr, literals);
448 }
449 }
450
451 if let Some(ref order_by) = over.order_by {
453 for item in order_by {
454 Self::extract_from_expression(&item.expr, literals);
455 }
456 }
457
458 if let Some(ref frame) = over.frame {
460 if let vibesql_ast::FrameBound::Preceding(expr)
461 | vibesql_ast::FrameBound::Following(expr) = &frame.start
462 {
463 Self::extract_from_expression(expr, literals);
464 }
465 if let Some(
466 vibesql_ast::FrameBound::Preceding(expr)
467 | vibesql_ast::FrameBound::Following(expr),
468 ) = &frame.end
469 {
470 Self::extract_from_expression(expr, literals);
471 }
472 }
473 }
474
475 Expression::ColumnRef { .. }
477 | Expression::Placeholder(_)
478 | Expression::NumberedPlaceholder(_)
479 | Expression::NamedPlaceholder(_)
480 | Expression::PseudoVariable { .. }
481 | Expression::Wildcard
482 | Expression::CurrentDate
483 | Expression::CurrentTime { .. }
484 | Expression::CurrentTimestamp { .. }
485 | Expression::Interval { .. }
486 | Expression::Default
487 | Expression::DuplicateKeyValue { .. }
488 | Expression::NextValue { .. }
489 | Expression::MatchAgainst { .. }
490 | Expression::SessionVariable { .. } => {}
491 }
492 }
493}
494
495#[cfg(test)]
496mod tests {
497 use super::*;
498
499 #[test]
500 fn test_literal_value_to_string() {
501 assert_eq!(LiteralValue::Integer(42).to_sql(), "42");
502 assert_eq!(LiteralValue::Varchar("hello".to_string()).to_sql(), "'hello'");
503 assert_eq!(LiteralValue::Boolean(true).to_sql(), "true");
504 assert_eq!(LiteralValue::Null.to_sql(), "NULL");
505 }
506
507 #[test]
508 fn test_literal_value_string_escape() {
509 assert_eq!(LiteralValue::Varchar("it's".to_string()).to_sql(), "'it''s'");
510 }
511
512 #[test]
513 fn test_literal_extraction_simple() {
514 use vibesql_ast::{
515 BinaryOperator, Expression, FromClause, SelectItem, SelectStmt, Statement,
516 };
517
518 let stmt = Statement::Select(Box::new(SelectStmt {
520 with_clause: None,
521 distinct: false,
522 select_list: vec![SelectItem::Expression {
523 expr: Expression::ColumnRef { table: None, column: "col0".to_string() },
524 alias: None,
525 }],
526 into_table: None,
527 into_variables: None,
528 from: Some(FromClause::Table {
529 name: "tab".to_string(),
530 alias: None,
531 column_aliases: None,
532 }),
533 where_clause: Some(Expression::BinaryOp {
534 op: BinaryOperator::And,
535 left: Box::new(Expression::BinaryOp {
536 op: BinaryOperator::GreaterThan,
537 left: Box::new(Expression::ColumnRef {
538 table: None,
539 column: "col1".to_string(),
540 }),
541 right: Box::new(Expression::Literal(SqlValue::Integer(25))),
542 }),
543 right: Box::new(Expression::BinaryOp {
544 op: BinaryOperator::Equal,
545 left: Box::new(Expression::ColumnRef {
546 table: None,
547 column: "col2".to_string(),
548 }),
549 right: Box::new(Expression::Literal(SqlValue::Varchar("John".to_string()))),
550 }),
551 }),
552 group_by: None,
553 having: None,
554 order_by: None,
555 limit: None,
556 offset: None,
557 set_operation: None,
558 }));
559
560 let literals = LiteralExtractor::extract(&stmt);
561
562 assert_eq!(literals.len(), 2);
563 assert_eq!(literals[0], LiteralValue::Integer(25));
564 assert_eq!(literals[1], LiteralValue::Varchar("John".to_string()));
565 }
566
567 #[test]
568 fn test_literal_extraction_in_list() {
569 use vibesql_ast::{Expression, FromClause, SelectItem, SelectStmt, Statement};
570
571 let stmt = Statement::Select(Box::new(SelectStmt {
573 with_clause: None,
574 distinct: false,
575 select_list: vec![SelectItem::Wildcard { alias: None }],
576 into_table: None,
577 into_variables: None,
578 from: Some(FromClause::Table {
579 name: "tab".to_string(),
580 alias: None,
581 column_aliases: None,
582 }),
583 where_clause: Some(Expression::InList {
584 expr: Box::new(Expression::ColumnRef { table: None, column: "id".to_string() }),
585 values: vec![
586 Expression::Literal(SqlValue::Integer(1)),
587 Expression::Literal(SqlValue::Integer(2)),
588 Expression::Literal(SqlValue::Integer(3)),
589 ],
590 negated: false,
591 }),
592 group_by: None,
593 having: None,
594 order_by: None,
595 limit: None,
596 offset: None,
597 set_operation: None,
598 }));
599
600 let literals = LiteralExtractor::extract(&stmt);
601
602 assert_eq!(literals.len(), 3);
603 assert_eq!(literals[0], LiteralValue::Integer(1));
604 assert_eq!(literals[1], LiteralValue::Integer(2));
605 assert_eq!(literals[2], LiteralValue::Integer(3));
606 }
607
608 #[test]
609 fn test_parameterized_plan_bind() {
610 let plan = ParameterizedPlan::new(
611 "SELECT * FROM users WHERE age > ?".to_string(),
612 vec![ParameterPosition { position: 40, context: "age".to_string() }],
613 vec![LiteralValue::Integer(25)],
614 );
615
616 let result = plan.bind(&[LiteralValue::Integer(30)]).unwrap();
617 assert_eq!(result, "SELECT * FROM users WHERE age > 30");
618 }
619
620 #[test]
621 fn test_parameterized_plan_bind_string() {
622 let plan = ParameterizedPlan::new(
623 "SELECT * FROM users WHERE name = ?".to_string(),
624 vec![ParameterPosition { position: 40, context: "name".to_string() }],
625 vec![LiteralValue::Varchar("John".to_string())],
626 );
627
628 let result = plan.bind(&[LiteralValue::Varchar("Jane".to_string())]).unwrap();
629 assert_eq!(result, "SELECT * FROM users WHERE name = 'Jane'");
630 }
631
632 #[test]
633 fn test_parameterized_plan_bind_error() {
634 let plan = ParameterizedPlan::new(
635 "SELECT * FROM users WHERE age > ?".to_string(),
636 vec![ParameterPosition { position: 40, context: "age".to_string() }],
637 vec![LiteralValue::Integer(25)],
638 );
639
640 let result = plan.bind(&[LiteralValue::Integer(30), LiteralValue::Integer(40)]);
641 assert!(result.is_err());
642 }
643
644 #[test]
645 fn test_comparison_key() {
646 let plan = ParameterizedPlan::new("SELECT * FROM users".to_string(), vec![], vec![]);
647
648 assert_eq!(plan.comparison_key(), "SELECT * FROM users");
649 }
650}