1use std::collections::HashSet;
20use std::pin::Pin;
21
22use bumpalo::Bump;
23use vibesql_ast::arena::{ArenaInterner, Expression, ExtendedExpr, SelectStmt};
24use vibesql_parser::arena_parser::ArenaParser;
25use vibesql_types::SqlValue;
26
27pub struct ArenaPreparedStatement {
39 sql: String,
41 arena: Pin<Box<Bump>>,
43 statement_ptr: *const SelectStmt<'static>,
49 interner_ptr: *const ArenaInterner<'static>,
53 param_count: usize,
55 tables: HashSet<String>,
57}
58
59unsafe impl Send for ArenaPreparedStatement {}
62unsafe impl Sync for ArenaPreparedStatement {}
63
64impl ArenaPreparedStatement {
65 pub fn new(sql: String) -> Result<Self, ArenaParseError> {
69 let arena = Pin::new(Box::new(Bump::new()));
71
72 let (stmt, interner): (&SelectStmt<'_>, ArenaInterner<'_>) =
74 ArenaParser::parse_select_with_interner(&sql, &arena)
75 .map_err(|e| ArenaParseError::ParseError(e.to_string()))?;
76
77 let param_count = count_arena_placeholders(stmt);
79 let tables = extract_arena_tables(stmt, &interner);
80
81 let statement_ptr = stmt as *const SelectStmt<'_>;
85 let statement_ptr = statement_ptr.cast::<SelectStmt<'static>>();
87
88 let interner_in_arena = arena.alloc(interner);
91 let interner_ptr = interner_in_arena as *const ArenaInterner<'_>;
92 let interner_ptr = interner_ptr.cast::<ArenaInterner<'static>>();
94
95 Ok(Self { sql, arena, statement_ptr, interner_ptr, param_count, tables })
96 }
97
98 pub fn sql(&self) -> &str {
100 &self.sql
101 }
102
103 pub fn statement(&self) -> &SelectStmt<'_> {
107 unsafe { &*self.statement_ptr }
110 }
111
112 pub fn interner(&self) -> &ArenaInterner<'_> {
116 unsafe { &*self.interner_ptr }
118 }
119
120 pub fn param_count(&self) -> usize {
122 self.param_count
123 }
124
125 pub fn tables(&self) -> &HashSet<String> {
127 &self.tables
128 }
129
130 pub fn arena(&self) -> &Bump {
134 &self.arena
135 }
136
137 pub fn validate_params(&self, params: &[SqlValue]) -> Result<(), ArenaBindError> {
139 if params.len() != self.param_count {
140 return Err(ArenaBindError::ParameterCountMismatch {
141 expected: self.param_count,
142 actual: params.len(),
143 });
144 }
145 Ok(())
146 }
147}
148
149impl std::fmt::Debug for ArenaPreparedStatement {
150 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
151 f.debug_struct("ArenaPreparedStatement")
152 .field("sql", &self.sql)
153 .field("param_count", &self.param_count)
154 .field("tables", &self.tables)
155 .finish_non_exhaustive()
156 }
157}
158
159#[derive(Debug, Clone)]
161pub enum ArenaParseError {
162 ParseError(String),
164}
165
166impl std::fmt::Display for ArenaParseError {
167 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
168 match self {
169 ArenaParseError::ParseError(msg) => write!(f, "Parse error: {}", msg),
170 }
171 }
172}
173
174impl std::error::Error for ArenaParseError {}
175
176#[derive(Debug, Clone)]
178pub enum ArenaBindError {
179 ParameterCountMismatch { expected: usize, actual: usize },
181}
182
183impl std::fmt::Display for ArenaBindError {
184 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
185 match self {
186 ArenaBindError::ParameterCountMismatch { expected, actual } => {
187 write!(f, "Parameter count mismatch: expected {}, got {}", expected, actual)
188 }
189 }
190 }
191}
192
193impl std::error::Error for ArenaBindError {}
194
195fn count_arena_placeholders(stmt: &SelectStmt<'_>) -> usize {
197 let mut count = 0;
198 visit_arena_statement(stmt, &mut |expr| {
199 if matches!(expr, Expression::Placeholder(_)) {
200 count += 1;
201 }
202 });
203 count
204}
205
206fn extract_arena_tables(stmt: &SelectStmt<'_>, interner: &ArenaInterner<'_>) -> HashSet<String> {
208 let mut tables = HashSet::new();
209 visit_arena_from_clause(stmt.from.as_ref(), &mut tables, interner);
210 tables
211}
212
213fn visit_arena_statement<F>(stmt: &SelectStmt<'_>, visitor: &mut F)
215where
216 F: FnMut(&Expression<'_>),
217{
218 if let Some(ctes) = &stmt.with_clause {
220 for cte in ctes.iter() {
221 visit_arena_statement(cte.query, visitor);
222 }
223 }
224
225 for item in stmt.select_list.iter() {
227 if let vibesql_ast::arena::SelectItem::Expression { expr, .. } = item {
228 visit_arena_expression(expr, visitor);
229 }
230 }
231
232 if let Some(from) = &stmt.from {
234 visit_arena_from_clause_exprs(from, visitor);
235 }
236
237 if let Some(where_clause) = &stmt.where_clause {
239 visit_arena_expression(where_clause, visitor);
240 }
241
242 if let Some(group_by) = &stmt.group_by {
244 visit_arena_group_by(group_by, visitor);
245 }
246
247 if let Some(having) = &stmt.having {
249 visit_arena_expression(having, visitor);
250 }
251
252 if let Some(order_by) = &stmt.order_by {
254 for item in order_by.iter() {
255 visit_arena_expression(&item.expr, visitor);
256 }
257 }
258
259 if let Some(set_op) = &stmt.set_operation {
261 visit_arena_statement(set_op.right, visitor);
262 }
263}
264
265fn visit_arena_from_clause_exprs<F>(from: &vibesql_ast::arena::FromClause<'_>, visitor: &mut F)
267where
268 F: FnMut(&Expression<'_>),
269{
270 match from {
271 vibesql_ast::arena::FromClause::Table { .. } => {}
272 vibesql_ast::arena::FromClause::Subquery { query, .. } => {
273 visit_arena_statement(query, visitor);
274 }
275 vibesql_ast::arena::FromClause::Join { left, right, condition, .. } => {
276 visit_arena_from_clause_exprs(left, visitor);
277 visit_arena_from_clause_exprs(right, visitor);
278 if let Some(cond) = condition {
279 visit_arena_expression(cond, visitor);
280 }
281 }
282 }
283}
284
285fn visit_arena_from_clause(
287 from: Option<&vibesql_ast::arena::FromClause<'_>>,
288 tables: &mut HashSet<String>,
289 interner: &ArenaInterner<'_>,
290) {
291 let Some(from) = from else { return };
292
293 match from {
294 vibesql_ast::arena::FromClause::Table { name, .. } => {
295 tables.insert(interner.resolve(*name).to_string());
296 }
297 vibesql_ast::arena::FromClause::Subquery { query, .. } => {
298 visit_arena_from_clause(query.from.as_ref(), tables, interner);
299 }
300 vibesql_ast::arena::FromClause::Join { left, right, .. } => {
301 visit_arena_from_clause(Some(left), tables, interner);
302 visit_arena_from_clause(Some(right), tables, interner);
303 }
304 }
305}
306
307fn visit_arena_group_by<F>(group_by: &vibesql_ast::arena::GroupByClause<'_>, visitor: &mut F)
309where
310 F: FnMut(&Expression<'_>),
311{
312 use vibesql_ast::arena::GroupByClause;
313 match group_by {
314 GroupByClause::Simple(exprs) => {
315 for expr in exprs.iter() {
316 visit_arena_expression(expr, visitor);
317 }
318 }
319 GroupByClause::Rollup(elements) | GroupByClause::Cube(elements) => {
320 for element in elements.iter() {
321 match element {
322 vibesql_ast::arena::GroupingElement::Single(expr) => {
323 visit_arena_expression(expr, visitor);
324 }
325 vibesql_ast::arena::GroupingElement::Composite(exprs) => {
326 for expr in exprs.iter() {
327 visit_arena_expression(expr, visitor);
328 }
329 }
330 }
331 }
332 }
333 GroupByClause::GroupingSets(sets) => {
334 for set in sets.iter() {
335 for expr in set.columns.iter() {
336 visit_arena_expression(expr, visitor);
337 }
338 }
339 }
340 GroupByClause::Mixed(items) => {
341 for item in items.iter() {
342 match item {
343 vibesql_ast::arena::MixedGroupingItem::Simple(expr) => {
344 visit_arena_expression(expr, visitor);
345 }
346 vibesql_ast::arena::MixedGroupingItem::Rollup(elements)
347 | vibesql_ast::arena::MixedGroupingItem::Cube(elements) => {
348 for element in elements.iter() {
349 match element {
350 vibesql_ast::arena::GroupingElement::Single(expr) => {
351 visit_arena_expression(expr, visitor);
352 }
353 vibesql_ast::arena::GroupingElement::Composite(exprs) => {
354 for expr in exprs.iter() {
355 visit_arena_expression(expr, visitor);
356 }
357 }
358 }
359 }
360 }
361 vibesql_ast::arena::MixedGroupingItem::GroupingSets(sets) => {
362 for set in sets.iter() {
363 for expr in set.columns.iter() {
364 visit_arena_expression(expr, visitor);
365 }
366 }
367 }
368 }
369 }
370 }
371 }
372}
373
374fn visit_arena_expression<F>(expr: &Expression<'_>, visitor: &mut F)
376where
377 F: FnMut(&Expression<'_>),
378{
379 visitor(expr);
380
381 match expr {
382 Expression::BinaryOp { left, right, .. } => {
384 visit_arena_expression(left, visitor);
385 visit_arena_expression(right, visitor);
386 }
387 Expression::Conjunction(children) | Expression::Disjunction(children) => {
388 for child in children.iter() {
389 visit_arena_expression(child, visitor);
390 }
391 }
392 Expression::UnaryOp { expr: inner, .. } => {
393 visit_arena_expression(inner, visitor);
394 }
395 Expression::IsNull { expr: inner, .. } => {
396 visit_arena_expression(inner, visitor);
397 }
398 Expression::Literal(_)
400 | Expression::Placeholder(_)
401 | Expression::NumberedPlaceholder(_)
402 | Expression::NamedPlaceholder(_)
403 | Expression::ColumnRef { .. }
404 | Expression::Wildcard
405 | Expression::CurrentDate
406 | Expression::CurrentTime { .. }
407 | Expression::CurrentTimestamp { .. }
408 | Expression::Default => {}
409 Expression::Extended(ext) => match ext {
411 ExtendedExpr::Function { args, .. } | ExtendedExpr::AggregateFunction { args, .. } => {
412 for arg in args.iter() {
413 visit_arena_expression(arg, visitor);
414 }
415 }
416 ExtendedExpr::Case { operand, when_clauses, else_result } => {
417 if let Some(op) = operand {
418 visit_arena_expression(op, visitor);
419 }
420 for w in when_clauses.iter() {
421 for c in w.conditions.iter() {
422 visit_arena_expression(c, visitor);
423 }
424 visit_arena_expression(&w.result, visitor);
425 }
426 if let Some(e) = else_result {
427 visit_arena_expression(e, visitor);
428 }
429 }
430 ExtendedExpr::ScalarSubquery(select) => visit_arena_statement(select, visitor),
431 ExtendedExpr::In { expr: inner, subquery, .. } => {
432 visit_arena_expression(inner, visitor);
433 visit_arena_statement(subquery, visitor);
434 }
435 ExtendedExpr::InList { expr: inner, values, .. } => {
436 visit_arena_expression(inner, visitor);
437 for v in values.iter() {
438 visit_arena_expression(v, visitor);
439 }
440 }
441 ExtendedExpr::Between { expr: inner, low, high, .. } => {
442 visit_arena_expression(inner, visitor);
443 visit_arena_expression(low, visitor);
444 visit_arena_expression(high, visitor);
445 }
446 ExtendedExpr::Cast { expr: inner, .. } => {
447 visit_arena_expression(inner, visitor);
448 }
449 ExtendedExpr::Position { substring, string, .. } => {
450 visit_arena_expression(substring, visitor);
451 visit_arena_expression(string, visitor);
452 }
453 ExtendedExpr::Trim { removal_char, string, .. } => {
454 if let Some(c) = removal_char {
455 visit_arena_expression(c, visitor);
456 }
457 visit_arena_expression(string, visitor);
458 }
459 ExtendedExpr::Extract { expr: inner, .. } => {
460 visit_arena_expression(inner, visitor);
461 }
462 ExtendedExpr::Like { expr: inner, pattern, .. } => {
463 visit_arena_expression(inner, visitor);
464 visit_arena_expression(pattern, visitor);
465 }
466 ExtendedExpr::Exists { subquery, .. } => {
467 visit_arena_statement(subquery, visitor);
468 }
469 ExtendedExpr::QuantifiedComparison { expr: inner, subquery, .. } => {
470 visit_arena_expression(inner, visitor);
471 visit_arena_statement(subquery, visitor);
472 }
473 ExtendedExpr::Interval { value, .. } => {
474 visit_arena_expression(value, visitor);
475 }
476 ExtendedExpr::WindowFunction { function, over } => {
477 match function {
478 vibesql_ast::arena::WindowFunctionSpec::Aggregate { args, .. }
479 | vibesql_ast::arena::WindowFunctionSpec::Ranking { args, .. }
480 | vibesql_ast::arena::WindowFunctionSpec::Value { args, .. } => {
481 for arg in args.iter() {
482 visit_arena_expression(arg, visitor);
483 }
484 }
485 }
486 if let Some(partition_by) = &over.partition_by {
487 for expr in partition_by.iter() {
488 visit_arena_expression(expr, visitor);
489 }
490 }
491 if let Some(order_by) = &over.order_by {
492 for item in order_by.iter() {
493 visit_arena_expression(&item.expr, visitor);
494 }
495 }
496 }
497 ExtendedExpr::MatchAgainst { search_modifier, .. } => {
498 visit_arena_expression(search_modifier, visitor);
499 }
500 ExtendedExpr::DuplicateKeyValue { .. }
502 | ExtendedExpr::NextValue { .. }
503 | ExtendedExpr::PseudoVariable { .. }
504 | ExtendedExpr::SessionVariable { .. } => {}
505 },
506 }
507}
508
509#[cfg(test)]
510mod tests {
511 use super::*;
512
513 #[test]
514 fn test_arena_prepared_basic() {
515 let sql = "SELECT id, name FROM users WHERE id = 1";
516 let prepared = ArenaPreparedStatement::new(sql.to_string()).unwrap();
517
518 assert_eq!(prepared.sql(), sql);
519 assert_eq!(prepared.param_count(), 0);
520 assert!(
522 prepared.tables().contains("USERS"),
523 "Expected 'USERS' in tables {:?}",
524 prepared.tables()
525 );
526 }
527
528 #[test]
529 fn test_arena_prepared_with_placeholder() {
530 let sql = "SELECT * FROM users WHERE id = ?";
531 let prepared = ArenaPreparedStatement::new(sql.to_string()).unwrap();
532
533 assert_eq!(prepared.param_count(), 1);
534 assert!(prepared.tables().contains("USERS"));
536 }
537
538 #[test]
539 fn test_arena_prepared_multiple_placeholders() {
540 let sql = "SELECT * FROM users WHERE id = ? AND name = ? AND age > ?";
541 let prepared = ArenaPreparedStatement::new(sql.to_string()).unwrap();
542
543 assert_eq!(prepared.param_count(), 3);
544 }
545
546 #[test]
547 fn test_arena_prepared_param_validation() {
548 let sql = "SELECT * FROM users WHERE id = ?";
549 let prepared = ArenaPreparedStatement::new(sql.to_string()).unwrap();
550
551 assert!(prepared.validate_params(&[SqlValue::Integer(1)]).is_ok());
553
554 assert!(prepared.validate_params(&[]).is_err());
556 assert!(prepared.validate_params(&[SqlValue::Integer(1), SqlValue::Integer(2)]).is_err());
557 }
558
559 #[test]
560 fn test_arena_prepared_join_tables() {
561 let sql = "SELECT u.id FROM users u JOIN orders o ON u.id = o.user_id";
562 let prepared = ArenaPreparedStatement::new(sql.to_string()).unwrap();
563
564 let tables = prepared.tables();
565 assert!(tables.contains("USERS"), "Expected 'USERS' in {:?}", tables);
567 assert!(tables.contains("ORDERS"), "Expected 'ORDERS' in {:?}", tables);
568 }
569}