1use std::{collections::HashSet, pin::Pin};
20
21use bumpalo::Bump;
22use vibesql_ast::arena::{ArenaInterner, Expression, ExtendedExpr, SelectStmt};
23use vibesql_parser::arena_parser::ArenaParser;
24use vibesql_types::SqlValue;
25
26pub struct ArenaPreparedStatement {
38 sql: String,
40 arena: Pin<Box<Bump>>,
42 statement_ptr: *const SelectStmt<'static>,
48 interner_ptr: *const ArenaInterner<'static>,
52 param_count: usize,
54 tables: HashSet<String>,
56}
57
58unsafe impl Send for ArenaPreparedStatement {}
61unsafe impl Sync for ArenaPreparedStatement {}
62
63impl ArenaPreparedStatement {
64 pub fn new(sql: String) -> Result<Self, ArenaParseError> {
68 let arena = Pin::new(Box::new(Bump::new()));
70
71 let (stmt, interner): (&SelectStmt<'_>, ArenaInterner<'_>) =
73 ArenaParser::parse_select_with_interner(&sql, &arena)
74 .map_err(|e| ArenaParseError::ParseError(e.to_string()))?;
75
76 let param_count = count_arena_placeholders(stmt);
78 let tables = extract_arena_tables(stmt, &interner);
79
80 let statement_ptr = stmt as *const SelectStmt<'_>;
84 let statement_ptr = statement_ptr.cast::<SelectStmt<'static>>();
86
87 let interner_in_arena = arena.alloc(interner);
90 let interner_ptr = interner_in_arena as *const ArenaInterner<'_>;
91 let interner_ptr = interner_ptr.cast::<ArenaInterner<'static>>();
93
94 Ok(Self { sql, arena, statement_ptr, interner_ptr, param_count, tables })
95 }
96
97 pub fn sql(&self) -> &str {
99 &self.sql
100 }
101
102 pub fn statement(&self) -> &SelectStmt<'_> {
106 unsafe { &*self.statement_ptr }
109 }
110
111 pub fn interner(&self) -> &ArenaInterner<'_> {
115 unsafe { &*self.interner_ptr }
117 }
118
119 pub fn param_count(&self) -> usize {
121 self.param_count
122 }
123
124 pub fn tables(&self) -> &HashSet<String> {
126 &self.tables
127 }
128
129 pub fn arena(&self) -> &Bump {
133 &self.arena
134 }
135
136 pub fn validate_params(&self, params: &[SqlValue]) -> Result<(), ArenaBindError> {
138 if params.len() != self.param_count {
139 return Err(ArenaBindError::ParameterCountMismatch {
140 expected: self.param_count,
141 actual: params.len(),
142 });
143 }
144 Ok(())
145 }
146}
147
148impl std::fmt::Debug for ArenaPreparedStatement {
149 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
150 f.debug_struct("ArenaPreparedStatement")
151 .field("sql", &self.sql)
152 .field("param_count", &self.param_count)
153 .field("tables", &self.tables)
154 .finish_non_exhaustive()
155 }
156}
157
158#[derive(Debug, Clone)]
160pub enum ArenaParseError {
161 ParseError(String),
163}
164
165impl std::fmt::Display for ArenaParseError {
166 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
167 match self {
168 ArenaParseError::ParseError(msg) => write!(f, "Parse error: {}", msg),
169 }
170 }
171}
172
173impl std::error::Error for ArenaParseError {}
174
175#[derive(Debug, Clone)]
177pub enum ArenaBindError {
178 ParameterCountMismatch { expected: usize, actual: usize },
180}
181
182impl std::fmt::Display for ArenaBindError {
183 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
184 match self {
185 ArenaBindError::ParameterCountMismatch { expected, actual } => {
186 write!(f, "Parameter count mismatch: expected {}, got {}", expected, actual)
187 }
188 }
189 }
190}
191
192impl std::error::Error for ArenaBindError {}
193
194fn count_arena_placeholders(stmt: &SelectStmt<'_>) -> usize {
196 let mut count = 0;
197 visit_arena_statement(stmt, &mut |expr| {
198 if matches!(expr, Expression::Placeholder(_)) {
199 count += 1;
200 }
201 });
202 count
203}
204
205fn extract_arena_tables(stmt: &SelectStmt<'_>, interner: &ArenaInterner<'_>) -> HashSet<String> {
207 let mut tables = HashSet::new();
208 visit_arena_from_clause(stmt.from.as_ref(), &mut tables, interner);
209 tables
210}
211
212fn visit_arena_statement<F>(stmt: &SelectStmt<'_>, visitor: &mut F)
214where
215 F: FnMut(&Expression<'_>),
216{
217 if let Some(ctes) = &stmt.with_clause {
219 for cte in ctes.iter() {
220 visit_arena_statement(cte.query, visitor);
221 }
222 }
223
224 for item in stmt.select_list.iter() {
226 if let vibesql_ast::arena::SelectItem::Expression { expr, .. } = item {
227 visit_arena_expression(expr, visitor);
228 }
229 }
230
231 if let Some(from) = &stmt.from {
233 visit_arena_from_clause_exprs(from, visitor);
234 }
235
236 if let Some(where_clause) = &stmt.where_clause {
238 visit_arena_expression(where_clause, visitor);
239 }
240
241 if let Some(group_by) = &stmt.group_by {
243 visit_arena_group_by(group_by, visitor);
244 }
245
246 if let Some(having) = &stmt.having {
248 visit_arena_expression(having, visitor);
249 }
250
251 if let Some(order_by) = &stmt.order_by {
253 for item in order_by.iter() {
254 visit_arena_expression(&item.expr, visitor);
255 }
256 }
257
258 if let Some(set_op) = &stmt.set_operation {
260 visit_arena_statement(set_op.right, visitor);
261 }
262}
263
264fn visit_arena_from_clause_exprs<F>(from: &vibesql_ast::arena::FromClause<'_>, visitor: &mut F)
266where
267 F: FnMut(&Expression<'_>),
268{
269 match from {
270 vibesql_ast::arena::FromClause::Table { .. } => {}
271 vibesql_ast::arena::FromClause::Subquery { query, .. } => {
272 visit_arena_statement(query, visitor);
273 }
274 vibesql_ast::arena::FromClause::Join { left, right, condition, .. } => {
275 visit_arena_from_clause_exprs(left, visitor);
276 visit_arena_from_clause_exprs(right, visitor);
277 if let Some(cond) = condition {
278 visit_arena_expression(cond, visitor);
279 }
280 }
281 }
282}
283
284fn visit_arena_from_clause(
286 from: Option<&vibesql_ast::arena::FromClause<'_>>,
287 tables: &mut HashSet<String>,
288 interner: &ArenaInterner<'_>,
289) {
290 let Some(from) = from else { return };
291
292 match from {
293 vibesql_ast::arena::FromClause::Table { name, .. } => {
294 tables.insert(interner.resolve(*name).to_string());
295 }
296 vibesql_ast::arena::FromClause::Subquery { query, .. } => {
297 visit_arena_from_clause(query.from.as_ref(), tables, interner);
298 }
299 vibesql_ast::arena::FromClause::Join { left, right, .. } => {
300 visit_arena_from_clause(Some(left), tables, interner);
301 visit_arena_from_clause(Some(right), tables, interner);
302 }
303 }
304}
305
306fn visit_arena_group_by<F>(group_by: &vibesql_ast::arena::GroupByClause<'_>, visitor: &mut F)
308where
309 F: FnMut(&Expression<'_>),
310{
311 use vibesql_ast::arena::GroupByClause;
312 match group_by {
313 GroupByClause::Simple(exprs) => {
314 for expr in exprs.iter() {
315 visit_arena_expression(expr, visitor);
316 }
317 }
318 GroupByClause::Rollup(elements) | GroupByClause::Cube(elements) => {
319 for element in elements.iter() {
320 match element {
321 vibesql_ast::arena::GroupingElement::Single(expr) => {
322 visit_arena_expression(expr, visitor);
323 }
324 vibesql_ast::arena::GroupingElement::Composite(exprs) => {
325 for expr in exprs.iter() {
326 visit_arena_expression(expr, visitor);
327 }
328 }
329 }
330 }
331 }
332 GroupByClause::GroupingSets(sets) => {
333 for set in sets.iter() {
334 for expr in set.columns.iter() {
335 visit_arena_expression(expr, visitor);
336 }
337 }
338 }
339 GroupByClause::Mixed(items) => {
340 for item in items.iter() {
341 match item {
342 vibesql_ast::arena::MixedGroupingItem::Simple(expr) => {
343 visit_arena_expression(expr, visitor);
344 }
345 vibesql_ast::arena::MixedGroupingItem::Rollup(elements)
346 | vibesql_ast::arena::MixedGroupingItem::Cube(elements) => {
347 for element in elements.iter() {
348 match element {
349 vibesql_ast::arena::GroupingElement::Single(expr) => {
350 visit_arena_expression(expr, visitor);
351 }
352 vibesql_ast::arena::GroupingElement::Composite(exprs) => {
353 for expr in exprs.iter() {
354 visit_arena_expression(expr, visitor);
355 }
356 }
357 }
358 }
359 }
360 vibesql_ast::arena::MixedGroupingItem::GroupingSets(sets) => {
361 for set in sets.iter() {
362 for expr in set.columns.iter() {
363 visit_arena_expression(expr, visitor);
364 }
365 }
366 }
367 }
368 }
369 }
370 }
371}
372
373fn visit_arena_expression<F>(expr: &Expression<'_>, visitor: &mut F)
375where
376 F: FnMut(&Expression<'_>),
377{
378 visitor(expr);
379
380 match expr {
381 Expression::BinaryOp { left, right, .. } => {
383 visit_arena_expression(left, visitor);
384 visit_arena_expression(right, visitor);
385 }
386 Expression::Conjunction(children) | Expression::Disjunction(children) => {
387 for child in children.iter() {
388 visit_arena_expression(child, visitor);
389 }
390 }
391 Expression::UnaryOp { expr: inner, .. } => {
392 visit_arena_expression(inner, visitor);
393 }
394 Expression::IsNull { expr: inner, .. } => {
395 visit_arena_expression(inner, visitor);
396 }
397 Expression::IsDistinctFrom { left, right, .. } => {
398 visit_arena_expression(left, visitor);
399 visit_arena_expression(right, visitor);
400 }
401 Expression::IsTruthValue { expr: inner, .. } => {
402 visit_arena_expression(inner, visitor);
403 }
404 Expression::Literal(_)
406 | Expression::Placeholder(_)
407 | Expression::NumberedPlaceholder(_)
408 | Expression::NamedPlaceholder(_)
409 | Expression::ColumnRef { .. }
410 | Expression::Wildcard
411 | Expression::CurrentDate
412 | Expression::CurrentTime { .. }
413 | Expression::CurrentTimestamp { .. }
414 | Expression::Default => {}
415 Expression::Extended(ext) => match ext {
417 ExtendedExpr::Function { args, .. } | ExtendedExpr::AggregateFunction { args, .. } => {
418 for arg in args.iter() {
419 visit_arena_expression(arg, visitor);
420 }
421 }
422 ExtendedExpr::Case { operand, when_clauses, else_result } => {
423 if let Some(op) = operand {
424 visit_arena_expression(op, visitor);
425 }
426 for w in when_clauses.iter() {
427 for c in w.conditions.iter() {
428 visit_arena_expression(c, visitor);
429 }
430 visit_arena_expression(&w.result, visitor);
431 }
432 if let Some(e) = else_result {
433 visit_arena_expression(e, visitor);
434 }
435 }
436 ExtendedExpr::ScalarSubquery(select) => visit_arena_statement(select, visitor),
437 ExtendedExpr::In { expr: inner, subquery, .. } => {
438 visit_arena_expression(inner, visitor);
439 visit_arena_statement(subquery, visitor);
440 }
441 ExtendedExpr::InList { expr: inner, values, .. } => {
442 visit_arena_expression(inner, visitor);
443 for v in values.iter() {
444 visit_arena_expression(v, visitor);
445 }
446 }
447 ExtendedExpr::Between { expr: inner, low, high, .. } => {
448 visit_arena_expression(inner, visitor);
449 visit_arena_expression(low, visitor);
450 visit_arena_expression(high, visitor);
451 }
452 ExtendedExpr::Cast { expr: inner, .. } => {
453 visit_arena_expression(inner, visitor);
454 }
455 ExtendedExpr::Position { substring, string, .. } => {
456 visit_arena_expression(substring, visitor);
457 visit_arena_expression(string, visitor);
458 }
459 ExtendedExpr::Trim { removal_char, string, .. } => {
460 if let Some(c) = removal_char {
461 visit_arena_expression(c, visitor);
462 }
463 visit_arena_expression(string, visitor);
464 }
465 ExtendedExpr::Extract { expr: inner, .. } => {
466 visit_arena_expression(inner, visitor);
467 }
468 ExtendedExpr::Like { expr: inner, pattern, .. }
469 | ExtendedExpr::Glob { expr: inner, pattern, .. } => {
470 visit_arena_expression(inner, visitor);
471 visit_arena_expression(pattern, visitor);
472 }
473 ExtendedExpr::Exists { subquery, .. } => {
474 visit_arena_statement(subquery, visitor);
475 }
476 ExtendedExpr::QuantifiedComparison { expr: inner, subquery, .. } => {
477 visit_arena_expression(inner, visitor);
478 visit_arena_statement(subquery, visitor);
479 }
480 ExtendedExpr::Interval { value, .. } => {
481 visit_arena_expression(value, visitor);
482 }
483 ExtendedExpr::WindowFunction { function, over } => {
484 match function {
485 vibesql_ast::arena::WindowFunctionSpec::Aggregate { args, .. }
486 | vibesql_ast::arena::WindowFunctionSpec::Ranking { args, .. }
487 | vibesql_ast::arena::WindowFunctionSpec::Value { args, .. } => {
488 for arg in args.iter() {
489 visit_arena_expression(arg, visitor);
490 }
491 }
492 }
493 if let Some(partition_by) = &over.partition_by {
494 for expr in partition_by.iter() {
495 visit_arena_expression(expr, visitor);
496 }
497 }
498 if let Some(order_by) = &over.order_by {
499 for item in order_by.iter() {
500 visit_arena_expression(&item.expr, visitor);
501 }
502 }
503 }
504 ExtendedExpr::MatchAgainst { search_modifier, .. } => {
505 visit_arena_expression(search_modifier, visitor);
506 }
507 ExtendedExpr::RowValueConstructor(children) => {
508 for child in children.iter() {
509 visit_arena_expression(child, visitor);
510 }
511 }
512 ExtendedExpr::DuplicateKeyValue { .. }
514 | ExtendedExpr::NextValue { .. }
515 | ExtendedExpr::PseudoVariable { .. }
516 | ExtendedExpr::SessionVariable { .. } => {}
517 },
518 }
519}
520
521#[cfg(test)]
522mod tests {
523 use super::*;
524
525 #[test]
526 fn test_arena_prepared_basic() {
527 let sql = "SELECT id, name FROM users WHERE id = 1";
528 let prepared = ArenaPreparedStatement::new(sql.to_string()).unwrap();
529
530 assert_eq!(prepared.sql(), sql);
531 assert_eq!(prepared.param_count(), 0);
532 assert!(
534 prepared.tables().contains("users"),
535 "Expected 'USERS' in tables {:?}",
536 prepared.tables()
537 );
538 }
539
540 #[test]
541 fn test_arena_prepared_with_placeholder() {
542 let sql = "SELECT * FROM users WHERE id = ?";
543 let prepared = ArenaPreparedStatement::new(sql.to_string()).unwrap();
544
545 assert_eq!(prepared.param_count(), 1);
546 assert!(prepared.tables().contains("users"));
548 }
549
550 #[test]
551 fn test_arena_prepared_multiple_placeholders() {
552 let sql = "SELECT * FROM users WHERE id = ? AND name = ? AND age > ?";
553 let prepared = ArenaPreparedStatement::new(sql.to_string()).unwrap();
554
555 assert_eq!(prepared.param_count(), 3);
556 }
557
558 #[test]
559 fn test_arena_prepared_param_validation() {
560 let sql = "SELECT * FROM users WHERE id = ?";
561 let prepared = ArenaPreparedStatement::new(sql.to_string()).unwrap();
562
563 assert!(prepared.validate_params(&[SqlValue::Integer(1)]).is_ok());
565
566 assert!(prepared.validate_params(&[]).is_err());
568 assert!(prepared.validate_params(&[SqlValue::Integer(1), SqlValue::Integer(2)]).is_err());
569 }
570
571 #[test]
572 fn test_arena_prepared_join_tables() {
573 let sql = "SELECT u.id FROM users u JOIN orders o ON u.id = o.user_id";
574 let prepared = ArenaPreparedStatement::new(sql.to_string()).unwrap();
575
576 let tables = prepared.tables();
577 assert!(tables.contains("users"), "Expected 'USERS' in {:?}", tables);
579 assert!(tables.contains("orders"), "Expected 'ORDERS' in {:?}", tables);
580 }
581}