1use std::collections::HashSet;
20use std::pin::Pin;
21
22use bumpalo::Bump;
23use vibesql_ast::arena::{ArenaInterner, Expression, ExtendedExpr, SelectStmt};
24use vibesql_parser::arena_parser::ArenaParser;
25use vibesql_types::SqlValue;
26
27pub struct ArenaPreparedStatement {
39 sql: String,
41 arena: Pin<Box<Bump>>,
43 statement_ptr: *const SelectStmt<'static>,
49 interner_ptr: *const ArenaInterner<'static>,
53 param_count: usize,
55 tables: HashSet<String>,
57}
58
59unsafe impl Send for ArenaPreparedStatement {}
62unsafe impl Sync for ArenaPreparedStatement {}
63
64impl ArenaPreparedStatement {
65 pub fn new(sql: String) -> Result<Self, ArenaParseError> {
69 let arena = Pin::new(Box::new(Bump::new()));
71
72 let (stmt, interner): (&SelectStmt<'_>, ArenaInterner<'_>) =
74 ArenaParser::parse_select_with_interner(&sql, &arena)
75 .map_err(|e| ArenaParseError::ParseError(e.to_string()))?;
76
77 let param_count = count_arena_placeholders(stmt);
79 let tables = extract_arena_tables(stmt, &interner);
80
81 let statement_ptr = stmt as *const SelectStmt<'_>;
85 let statement_ptr = statement_ptr.cast::<SelectStmt<'static>>();
87
88 let interner_in_arena = arena.alloc(interner);
91 let interner_ptr = interner_in_arena as *const ArenaInterner<'_>;
92 let interner_ptr = interner_ptr.cast::<ArenaInterner<'static>>();
94
95 Ok(Self {
96 sql,
97 arena,
98 statement_ptr,
99 interner_ptr,
100 param_count,
101 tables,
102 })
103 }
104
105 pub fn sql(&self) -> &str {
107 &self.sql
108 }
109
110 pub fn statement(&self) -> &SelectStmt<'_> {
114 unsafe { &*self.statement_ptr }
117 }
118
119 pub fn interner(&self) -> &ArenaInterner<'_> {
123 unsafe { &*self.interner_ptr }
125 }
126
127 pub fn param_count(&self) -> usize {
129 self.param_count
130 }
131
132 pub fn tables(&self) -> &HashSet<String> {
134 &self.tables
135 }
136
137 pub fn arena(&self) -> &Bump {
141 &self.arena
142 }
143
144 pub fn validate_params(&self, params: &[SqlValue]) -> Result<(), ArenaBindError> {
146 if params.len() != self.param_count {
147 return Err(ArenaBindError::ParameterCountMismatch {
148 expected: self.param_count,
149 actual: params.len(),
150 });
151 }
152 Ok(())
153 }
154}
155
156impl std::fmt::Debug for ArenaPreparedStatement {
157 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
158 f.debug_struct("ArenaPreparedStatement")
159 .field("sql", &self.sql)
160 .field("param_count", &self.param_count)
161 .field("tables", &self.tables)
162 .finish_non_exhaustive()
163 }
164}
165
166#[derive(Debug, Clone)]
168pub enum ArenaParseError {
169 ParseError(String),
171}
172
173impl std::fmt::Display for ArenaParseError {
174 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
175 match self {
176 ArenaParseError::ParseError(msg) => write!(f, "Parse error: {}", msg),
177 }
178 }
179}
180
181impl std::error::Error for ArenaParseError {}
182
183#[derive(Debug, Clone)]
185pub enum ArenaBindError {
186 ParameterCountMismatch { expected: usize, actual: usize },
188}
189
190impl std::fmt::Display for ArenaBindError {
191 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
192 match self {
193 ArenaBindError::ParameterCountMismatch { expected, actual } => {
194 write!(
195 f,
196 "Parameter count mismatch: expected {}, got {}",
197 expected, actual
198 )
199 }
200 }
201 }
202}
203
204impl std::error::Error for ArenaBindError {}
205
206fn count_arena_placeholders(stmt: &SelectStmt<'_>) -> usize {
208 let mut count = 0;
209 visit_arena_statement(stmt, &mut |expr| {
210 if matches!(expr, Expression::Placeholder(_)) {
211 count += 1;
212 }
213 });
214 count
215}
216
217fn extract_arena_tables(stmt: &SelectStmt<'_>, interner: &ArenaInterner<'_>) -> HashSet<String> {
219 let mut tables = HashSet::new();
220 visit_arena_from_clause(stmt.from.as_ref(), &mut tables, interner);
221 tables
222}
223
224fn visit_arena_statement<F>(stmt: &SelectStmt<'_>, visitor: &mut F)
226where
227 F: FnMut(&Expression<'_>),
228{
229 if let Some(ctes) = &stmt.with_clause {
231 for cte in ctes.iter() {
232 visit_arena_statement(cte.query, visitor);
233 }
234 }
235
236 for item in stmt.select_list.iter() {
238 if let vibesql_ast::arena::SelectItem::Expression { expr, .. } = item {
239 visit_arena_expression(expr, visitor);
240 }
241 }
242
243 if let Some(from) = &stmt.from {
245 visit_arena_from_clause_exprs(from, visitor);
246 }
247
248 if let Some(where_clause) = &stmt.where_clause {
250 visit_arena_expression(where_clause, visitor);
251 }
252
253 if let Some(group_by) = &stmt.group_by {
255 visit_arena_group_by(group_by, visitor);
256 }
257
258 if let Some(having) = &stmt.having {
260 visit_arena_expression(having, visitor);
261 }
262
263 if let Some(order_by) = &stmt.order_by {
265 for item in order_by.iter() {
266 visit_arena_expression(&item.expr, visitor);
267 }
268 }
269
270 if let Some(set_op) = &stmt.set_operation {
272 visit_arena_statement(set_op.right, visitor);
273 }
274}
275
276fn visit_arena_from_clause_exprs<F>(from: &vibesql_ast::arena::FromClause<'_>, visitor: &mut F)
278where
279 F: FnMut(&Expression<'_>),
280{
281 match from {
282 vibesql_ast::arena::FromClause::Table { .. } => {}
283 vibesql_ast::arena::FromClause::Subquery { query, .. } => {
284 visit_arena_statement(query, visitor);
285 }
286 vibesql_ast::arena::FromClause::Join {
287 left,
288 right,
289 condition,
290 ..
291 } => {
292 visit_arena_from_clause_exprs(left, visitor);
293 visit_arena_from_clause_exprs(right, visitor);
294 if let Some(cond) = condition {
295 visit_arena_expression(cond, visitor);
296 }
297 }
298 }
299}
300
301fn visit_arena_from_clause(
303 from: Option<&vibesql_ast::arena::FromClause<'_>>,
304 tables: &mut HashSet<String>,
305 interner: &ArenaInterner<'_>,
306) {
307 let Some(from) = from else { return };
308
309 match from {
310 vibesql_ast::arena::FromClause::Table { name, .. } => {
311 tables.insert(interner.resolve(*name).to_string());
312 }
313 vibesql_ast::arena::FromClause::Subquery { query, .. } => {
314 visit_arena_from_clause(query.from.as_ref(), tables, interner);
315 }
316 vibesql_ast::arena::FromClause::Join { left, right, .. } => {
317 visit_arena_from_clause(Some(left), tables, interner);
318 visit_arena_from_clause(Some(right), tables, interner);
319 }
320 }
321}
322
323fn visit_arena_group_by<F>(group_by: &vibesql_ast::arena::GroupByClause<'_>, visitor: &mut F)
325where
326 F: FnMut(&Expression<'_>),
327{
328 use vibesql_ast::arena::GroupByClause;
329 match group_by {
330 GroupByClause::Simple(exprs) => {
331 for expr in exprs.iter() {
332 visit_arena_expression(expr, visitor);
333 }
334 }
335 GroupByClause::Rollup(elements) | GroupByClause::Cube(elements) => {
336 for element in elements.iter() {
337 match element {
338 vibesql_ast::arena::GroupingElement::Single(expr) => {
339 visit_arena_expression(expr, visitor);
340 }
341 vibesql_ast::arena::GroupingElement::Composite(exprs) => {
342 for expr in exprs.iter() {
343 visit_arena_expression(expr, visitor);
344 }
345 }
346 }
347 }
348 }
349 GroupByClause::GroupingSets(sets) => {
350 for set in sets.iter() {
351 for expr in set.columns.iter() {
352 visit_arena_expression(expr, visitor);
353 }
354 }
355 }
356 GroupByClause::Mixed(items) => {
357 for item in items.iter() {
358 match item {
359 vibesql_ast::arena::MixedGroupingItem::Simple(expr) => {
360 visit_arena_expression(expr, visitor);
361 }
362 vibesql_ast::arena::MixedGroupingItem::Rollup(elements)
363 | vibesql_ast::arena::MixedGroupingItem::Cube(elements) => {
364 for element in elements.iter() {
365 match element {
366 vibesql_ast::arena::GroupingElement::Single(expr) => {
367 visit_arena_expression(expr, visitor);
368 }
369 vibesql_ast::arena::GroupingElement::Composite(exprs) => {
370 for expr in exprs.iter() {
371 visit_arena_expression(expr, visitor);
372 }
373 }
374 }
375 }
376 }
377 vibesql_ast::arena::MixedGroupingItem::GroupingSets(sets) => {
378 for set in sets.iter() {
379 for expr in set.columns.iter() {
380 visit_arena_expression(expr, visitor);
381 }
382 }
383 }
384 }
385 }
386 }
387 }
388}
389
390fn visit_arena_expression<F>(expr: &Expression<'_>, visitor: &mut F)
392where
393 F: FnMut(&Expression<'_>),
394{
395 visitor(expr);
396
397 match expr {
398 Expression::BinaryOp { left, right, .. } => {
400 visit_arena_expression(left, visitor);
401 visit_arena_expression(right, visitor);
402 }
403 Expression::Conjunction(children) | Expression::Disjunction(children) => {
404 for child in children.iter() {
405 visit_arena_expression(child, visitor);
406 }
407 }
408 Expression::UnaryOp { expr: inner, .. } => {
409 visit_arena_expression(inner, visitor);
410 }
411 Expression::IsNull { expr: inner, .. } => {
412 visit_arena_expression(inner, visitor);
413 }
414 Expression::Literal(_)
416 | Expression::Placeholder(_)
417 | Expression::NumberedPlaceholder(_)
418 | Expression::NamedPlaceholder(_)
419 | Expression::ColumnRef { .. }
420 | Expression::Wildcard
421 | Expression::CurrentDate
422 | Expression::CurrentTime { .. }
423 | Expression::CurrentTimestamp { .. }
424 | Expression::Default => {}
425 Expression::Extended(ext) => match ext {
427 ExtendedExpr::Function { args, .. } | ExtendedExpr::AggregateFunction { args, .. } => {
428 for arg in args.iter() {
429 visit_arena_expression(arg, visitor);
430 }
431 }
432 ExtendedExpr::Case {
433 operand,
434 when_clauses,
435 else_result,
436 } => {
437 if let Some(op) = operand {
438 visit_arena_expression(op, visitor);
439 }
440 for w in when_clauses.iter() {
441 for c in w.conditions.iter() {
442 visit_arena_expression(c, visitor);
443 }
444 visit_arena_expression(&w.result, visitor);
445 }
446 if let Some(e) = else_result {
447 visit_arena_expression(e, visitor);
448 }
449 }
450 ExtendedExpr::ScalarSubquery(select) => visit_arena_statement(select, visitor),
451 ExtendedExpr::In {
452 expr: inner,
453 subquery,
454 ..
455 } => {
456 visit_arena_expression(inner, visitor);
457 visit_arena_statement(subquery, visitor);
458 }
459 ExtendedExpr::InList {
460 expr: inner,
461 values,
462 ..
463 } => {
464 visit_arena_expression(inner, visitor);
465 for v in values.iter() {
466 visit_arena_expression(v, visitor);
467 }
468 }
469 ExtendedExpr::Between {
470 expr: inner,
471 low,
472 high,
473 ..
474 } => {
475 visit_arena_expression(inner, visitor);
476 visit_arena_expression(low, visitor);
477 visit_arena_expression(high, visitor);
478 }
479 ExtendedExpr::Cast { expr: inner, .. } => {
480 visit_arena_expression(inner, visitor);
481 }
482 ExtendedExpr::Position {
483 substring, string, ..
484 } => {
485 visit_arena_expression(substring, visitor);
486 visit_arena_expression(string, visitor);
487 }
488 ExtendedExpr::Trim {
489 removal_char,
490 string,
491 ..
492 } => {
493 if let Some(c) = removal_char {
494 visit_arena_expression(c, visitor);
495 }
496 visit_arena_expression(string, visitor);
497 }
498 ExtendedExpr::Extract { expr: inner, .. } => {
499 visit_arena_expression(inner, visitor);
500 }
501 ExtendedExpr::Like {
502 expr: inner,
503 pattern,
504 ..
505 } => {
506 visit_arena_expression(inner, visitor);
507 visit_arena_expression(pattern, visitor);
508 }
509 ExtendedExpr::Exists { subquery, .. } => {
510 visit_arena_statement(subquery, visitor);
511 }
512 ExtendedExpr::QuantifiedComparison {
513 expr: inner,
514 subquery,
515 ..
516 } => {
517 visit_arena_expression(inner, visitor);
518 visit_arena_statement(subquery, visitor);
519 }
520 ExtendedExpr::Interval { value, .. } => {
521 visit_arena_expression(value, visitor);
522 }
523 ExtendedExpr::WindowFunction { function, over } => {
524 match function {
525 vibesql_ast::arena::WindowFunctionSpec::Aggregate { args, .. }
526 | vibesql_ast::arena::WindowFunctionSpec::Ranking { args, .. }
527 | vibesql_ast::arena::WindowFunctionSpec::Value { args, .. } => {
528 for arg in args.iter() {
529 visit_arena_expression(arg, visitor);
530 }
531 }
532 }
533 if let Some(partition_by) = &over.partition_by {
534 for expr in partition_by.iter() {
535 visit_arena_expression(expr, visitor);
536 }
537 }
538 if let Some(order_by) = &over.order_by {
539 for item in order_by.iter() {
540 visit_arena_expression(&item.expr, visitor);
541 }
542 }
543 }
544 ExtendedExpr::MatchAgainst { search_modifier, .. } => {
545 visit_arena_expression(search_modifier, visitor);
546 }
547 ExtendedExpr::DuplicateKeyValue { .. }
549 | ExtendedExpr::NextValue { .. }
550 | ExtendedExpr::PseudoVariable { .. }
551 | ExtendedExpr::SessionVariable { .. } => {}
552 },
553 }
554}
555
556#[cfg(test)]
557mod tests {
558 use super::*;
559
560 #[test]
561 fn test_arena_prepared_basic() {
562 let sql = "SELECT id, name FROM users WHERE id = 1";
563 let prepared = ArenaPreparedStatement::new(sql.to_string()).unwrap();
564
565 assert_eq!(prepared.sql(), sql);
566 assert_eq!(prepared.param_count(), 0);
567 assert!(prepared.tables().contains("USERS"),
569 "Expected 'USERS' in tables {:?}", prepared.tables());
570 }
571
572 #[test]
573 fn test_arena_prepared_with_placeholder() {
574 let sql = "SELECT * FROM users WHERE id = ?";
575 let prepared = ArenaPreparedStatement::new(sql.to_string()).unwrap();
576
577 assert_eq!(prepared.param_count(), 1);
578 assert!(prepared.tables().contains("USERS"));
580 }
581
582 #[test]
583 fn test_arena_prepared_multiple_placeholders() {
584 let sql = "SELECT * FROM users WHERE id = ? AND name = ? AND age > ?";
585 let prepared = ArenaPreparedStatement::new(sql.to_string()).unwrap();
586
587 assert_eq!(prepared.param_count(), 3);
588 }
589
590 #[test]
591 fn test_arena_prepared_param_validation() {
592 let sql = "SELECT * FROM users WHERE id = ?";
593 let prepared = ArenaPreparedStatement::new(sql.to_string()).unwrap();
594
595 assert!(prepared.validate_params(&[SqlValue::Integer(1)]).is_ok());
597
598 assert!(prepared.validate_params(&[]).is_err());
600 assert!(prepared
601 .validate_params(&[SqlValue::Integer(1), SqlValue::Integer(2)])
602 .is_err());
603 }
604
605 #[test]
606 fn test_arena_prepared_join_tables() {
607 let sql = "SELECT u.id FROM users u JOIN orders o ON u.id = o.user_id";
608 let prepared = ArenaPreparedStatement::new(sql.to_string()).unwrap();
609
610 let tables = prepared.tables();
611 assert!(tables.contains("USERS"), "Expected 'USERS' in {:?}", tables);
613 assert!(tables.contains("ORDERS"), "Expected 'ORDERS' in {:?}", tables);
614 }
615}