vibesql_executor/cache/prepared_statement/
arena_prepared.rs

1//! Arena-based prepared statement for zero-allocation execution
2//!
3//! This module implements Option A from the arena parser optimization:
4//! storing the arena-allocated AST directly in the prepared statement,
5//! avoiding conversion overhead entirely for simple queries.
6//!
7//! # Key Benefits
8//!
9//! - **Zero conversion overhead**: Arena AST is used directly during execution
10//! - **Near-zero allocation**: Simple prepared queries avoid heap allocation
11//! - **Arena reuse**: Arena can be reset between executions for parameter binding
12//!
13//! # Safety
14//!
15//! This module uses careful unsafe code to manage self-referential data.
16//! The arena is pinned to prevent moves, and the statement pointer is valid
17//! as long as the arena exists.
18
19use std::{collections::HashSet, pin::Pin};
20
21use bumpalo::Bump;
22use vibesql_ast::arena::{ArenaInterner, Expression, ExtendedExpr, SelectStmt};
23use vibesql_parser::arena_parser::ArenaParser;
24use vibesql_types::SqlValue;
25
26/// Arena-based prepared statement for zero-allocation execution.
27///
28/// This stores the parsed AST in an arena, keeping it alive for the lifetime
29/// of the prepared statement. This avoids conversion overhead for simple queries.
30///
31/// # Safety Invariants
32///
33/// 1. The arena is pinned and will not move after construction
34/// 2. The statement pointer is valid as long as the arena exists
35/// 3. The arena is only dropped when the ArenaPreparedStatement is dropped
36/// 4. The interner pointer is valid as long as the arena exists
37pub struct ArenaPreparedStatement {
38    /// Original SQL with placeholders
39    sql: String,
40    /// Arena containing the parsed statement (pinned to prevent moves)
41    arena: Pin<Box<Bump>>,
42    /// Pointer to the statement in the arena.
43    ///
44    /// SAFETY: This pointer is valid as long as `arena` is not dropped or reset.
45    /// Since ArenaPreparedStatement owns the arena and only drops it in Drop,
46    /// this pointer is always valid during the lifetime of the struct.
47    statement_ptr: *const SelectStmt<'static>,
48    /// Pointer to the interner in the arena.
49    ///
50    /// SAFETY: Same invariants as statement_ptr.
51    interner_ptr: *const ArenaInterner<'static>,
52    /// Number of parameters expected (? placeholders)
53    param_count: usize,
54    /// Tables referenced by this statement (for cache invalidation)
55    tables: HashSet<String>,
56}
57
58// SAFETY: The arena and statement are not accessed from multiple threads.
59// The statement pointer is only dereferenced through &self methods.
60unsafe impl Send for ArenaPreparedStatement {}
61unsafe impl Sync for ArenaPreparedStatement {}
62
63impl ArenaPreparedStatement {
64    /// Create a new arena-based prepared statement from SQL.
65    ///
66    /// Parses the SQL using the arena parser and stores the result in an owned arena.
67    pub fn new(sql: String) -> Result<Self, ArenaParseError> {
68        // Create pinned arena (won't move)
69        let arena = Pin::new(Box::new(Bump::new()));
70
71        // Parse into arena, getting both statement and interner
72        let (stmt, interner): (&SelectStmt<'_>, ArenaInterner<'_>) =
73            ArenaParser::parse_select_with_interner(&sql, &arena)
74                .map_err(|e| ArenaParseError::ParseError(e.to_string()))?;
75
76        // Count placeholders and extract tables (using interner for symbol resolution)
77        let param_count = count_arena_placeholders(stmt);
78        let tables = extract_arena_tables(stmt, &interner);
79
80        // Store as raw pointer, erasing the lifetime.
81        // SAFETY: The arena is owned by this struct and won't be dropped
82        // while the statement exists. The pointer remains valid.
83        let statement_ptr = stmt as *const SelectStmt<'_>;
84        // Cast to 'static lifetime - this is safe because the arena owns the data
85        let statement_ptr = statement_ptr.cast::<SelectStmt<'static>>();
86
87        // Allocate interner in arena and store pointer
88        // SAFETY: Same invariants as statement_ptr
89        let interner_in_arena = arena.alloc(interner);
90        let interner_ptr = interner_in_arena as *const ArenaInterner<'_>;
91        // Cast to 'static lifetime - this is safe because the arena owns the data
92        let interner_ptr = interner_ptr.cast::<ArenaInterner<'static>>();
93
94        Ok(Self { sql, arena, statement_ptr, interner_ptr, param_count, tables })
95    }
96
97    /// Get the original SQL.
98    pub fn sql(&self) -> &str {
99        &self.sql
100    }
101
102    /// Get a reference to the arena-allocated statement.
103    ///
104    /// The returned reference is valid for the lifetime of this struct.
105    pub fn statement(&self) -> &SelectStmt<'_> {
106        // SAFETY: The pointer is valid as documented in the struct invariants.
107        // We're returning a reference tied to self's lifetime, which is correct.
108        unsafe { &*self.statement_ptr }
109    }
110
111    /// Get a reference to the interner for symbol resolution.
112    ///
113    /// The returned reference is valid for the lifetime of this struct.
114    pub fn interner(&self) -> &ArenaInterner<'_> {
115        // SAFETY: The pointer is valid as documented in the struct invariants.
116        unsafe { &*self.interner_ptr }
117    }
118
119    /// Get the number of parameters expected.
120    pub fn param_count(&self) -> usize {
121        self.param_count
122    }
123
124    /// Get the tables referenced by this statement.
125    pub fn tables(&self) -> &HashSet<String> {
126        &self.tables
127    }
128
129    /// Get a reference to the arena for additional allocations.
130    ///
131    /// This can be used for allocating bound values during execution.
132    pub fn arena(&self) -> &Bump {
133        &self.arena
134    }
135
136    /// Validate that the correct number of parameters is provided.
137    pub fn validate_params(&self, params: &[SqlValue]) -> Result<(), ArenaBindError> {
138        if params.len() != self.param_count {
139            return Err(ArenaBindError::ParameterCountMismatch {
140                expected: self.param_count,
141                actual: params.len(),
142            });
143        }
144        Ok(())
145    }
146}
147
148impl std::fmt::Debug for ArenaPreparedStatement {
149    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
150        f.debug_struct("ArenaPreparedStatement")
151            .field("sql", &self.sql)
152            .field("param_count", &self.param_count)
153            .field("tables", &self.tables)
154            .finish_non_exhaustive()
155    }
156}
157
158/// Error during arena-based parsing.
159#[derive(Debug, Clone)]
160pub enum ArenaParseError {
161    /// SQL parsing failed
162    ParseError(String),
163}
164
165impl std::fmt::Display for ArenaParseError {
166    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
167        match self {
168            ArenaParseError::ParseError(msg) => write!(f, "Parse error: {}", msg),
169        }
170    }
171}
172
173impl std::error::Error for ArenaParseError {}
174
175/// Error during arena-based parameter binding.
176#[derive(Debug, Clone)]
177pub enum ArenaBindError {
178    /// Wrong number of parameters provided
179    ParameterCountMismatch { expected: usize, actual: usize },
180}
181
182impl std::fmt::Display for ArenaBindError {
183    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
184        match self {
185            ArenaBindError::ParameterCountMismatch { expected, actual } => {
186                write!(f, "Parameter count mismatch: expected {}, got {}", expected, actual)
187            }
188        }
189    }
190}
191
192impl std::error::Error for ArenaBindError {}
193
194/// Count the number of placeholder parameters in an arena statement.
195fn count_arena_placeholders(stmt: &SelectStmt<'_>) -> usize {
196    let mut count = 0;
197    visit_arena_statement(stmt, &mut |expr| {
198        if matches!(expr, Expression::Placeholder(_)) {
199            count += 1;
200        }
201    });
202    count
203}
204
205/// Extract table names from an arena statement.
206fn extract_arena_tables(stmt: &SelectStmt<'_>, interner: &ArenaInterner<'_>) -> HashSet<String> {
207    let mut tables = HashSet::new();
208    visit_arena_from_clause(stmt.from.as_ref(), &mut tables, interner);
209    tables
210}
211
212/// Visit all expressions in an arena statement.
213fn visit_arena_statement<F>(stmt: &SelectStmt<'_>, visitor: &mut F)
214where
215    F: FnMut(&Expression<'_>),
216{
217    // Visit CTEs
218    if let Some(ctes) = &stmt.with_clause {
219        for cte in ctes.iter() {
220            visit_arena_statement(cte.query, visitor);
221        }
222    }
223
224    // Visit select items
225    for item in stmt.select_list.iter() {
226        if let vibesql_ast::arena::SelectItem::Expression { expr, .. } = item {
227            visit_arena_expression(expr, visitor);
228        }
229    }
230
231    // Visit FROM clause
232    if let Some(from) = &stmt.from {
233        visit_arena_from_clause_exprs(from, visitor);
234    }
235
236    // Visit WHERE
237    if let Some(where_clause) = &stmt.where_clause {
238        visit_arena_expression(where_clause, visitor);
239    }
240
241    // Visit GROUP BY
242    if let Some(group_by) = &stmt.group_by {
243        visit_arena_group_by(group_by, visitor);
244    }
245
246    // Visit HAVING
247    if let Some(having) = &stmt.having {
248        visit_arena_expression(having, visitor);
249    }
250
251    // Visit ORDER BY
252    if let Some(order_by) = &stmt.order_by {
253        for item in order_by.iter() {
254            visit_arena_expression(&item.expr, visitor);
255        }
256    }
257
258    // Visit set operation
259    if let Some(set_op) = &stmt.set_operation {
260        visit_arena_statement(set_op.right, visitor);
261    }
262}
263
264/// Visit expressions in a FROM clause.
265fn visit_arena_from_clause_exprs<F>(from: &vibesql_ast::arena::FromClause<'_>, visitor: &mut F)
266where
267    F: FnMut(&Expression<'_>),
268{
269    match from {
270        vibesql_ast::arena::FromClause::Table { .. } => {}
271        vibesql_ast::arena::FromClause::Subquery { query, .. } => {
272            visit_arena_statement(query, visitor);
273        }
274        vibesql_ast::arena::FromClause::Join { left, right, condition, .. } => {
275            visit_arena_from_clause_exprs(left, visitor);
276            visit_arena_from_clause_exprs(right, visitor);
277            if let Some(cond) = condition {
278                visit_arena_expression(cond, visitor);
279            }
280        }
281    }
282}
283
284/// Extract table names from a FROM clause.
285fn visit_arena_from_clause(
286    from: Option<&vibesql_ast::arena::FromClause<'_>>,
287    tables: &mut HashSet<String>,
288    interner: &ArenaInterner<'_>,
289) {
290    let Some(from) = from else { return };
291
292    match from {
293        vibesql_ast::arena::FromClause::Table { name, .. } => {
294            tables.insert(interner.resolve(*name).to_string());
295        }
296        vibesql_ast::arena::FromClause::Subquery { query, .. } => {
297            visit_arena_from_clause(query.from.as_ref(), tables, interner);
298        }
299        vibesql_ast::arena::FromClause::Join { left, right, .. } => {
300            visit_arena_from_clause(Some(left), tables, interner);
301            visit_arena_from_clause(Some(right), tables, interner);
302        }
303    }
304}
305
306/// Visit GROUP BY clause expressions.
307fn visit_arena_group_by<F>(group_by: &vibesql_ast::arena::GroupByClause<'_>, visitor: &mut F)
308where
309    F: FnMut(&Expression<'_>),
310{
311    use vibesql_ast::arena::GroupByClause;
312    match group_by {
313        GroupByClause::Simple(exprs) => {
314            for expr in exprs.iter() {
315                visit_arena_expression(expr, visitor);
316            }
317        }
318        GroupByClause::Rollup(elements) | GroupByClause::Cube(elements) => {
319            for element in elements.iter() {
320                match element {
321                    vibesql_ast::arena::GroupingElement::Single(expr) => {
322                        visit_arena_expression(expr, visitor);
323                    }
324                    vibesql_ast::arena::GroupingElement::Composite(exprs) => {
325                        for expr in exprs.iter() {
326                            visit_arena_expression(expr, visitor);
327                        }
328                    }
329                }
330            }
331        }
332        GroupByClause::GroupingSets(sets) => {
333            for set in sets.iter() {
334                for expr in set.columns.iter() {
335                    visit_arena_expression(expr, visitor);
336                }
337            }
338        }
339        GroupByClause::Mixed(items) => {
340            for item in items.iter() {
341                match item {
342                    vibesql_ast::arena::MixedGroupingItem::Simple(expr) => {
343                        visit_arena_expression(expr, visitor);
344                    }
345                    vibesql_ast::arena::MixedGroupingItem::Rollup(elements)
346                    | vibesql_ast::arena::MixedGroupingItem::Cube(elements) => {
347                        for element in elements.iter() {
348                            match element {
349                                vibesql_ast::arena::GroupingElement::Single(expr) => {
350                                    visit_arena_expression(expr, visitor);
351                                }
352                                vibesql_ast::arena::GroupingElement::Composite(exprs) => {
353                                    for expr in exprs.iter() {
354                                        visit_arena_expression(expr, visitor);
355                                    }
356                                }
357                            }
358                        }
359                    }
360                    vibesql_ast::arena::MixedGroupingItem::GroupingSets(sets) => {
361                        for set in sets.iter() {
362                            for expr in set.columns.iter() {
363                                visit_arena_expression(expr, visitor);
364                            }
365                        }
366                    }
367                }
368            }
369        }
370    }
371}
372
373/// Visit all expressions in an arena expression tree.
374fn visit_arena_expression<F>(expr: &Expression<'_>, visitor: &mut F)
375where
376    F: FnMut(&Expression<'_>),
377{
378    visitor(expr);
379
380    match expr {
381        // Hot-path inline variants
382        Expression::BinaryOp { left, right, .. } => {
383            visit_arena_expression(left, visitor);
384            visit_arena_expression(right, visitor);
385        }
386        Expression::Conjunction(children) | Expression::Disjunction(children) => {
387            for child in children.iter() {
388                visit_arena_expression(child, visitor);
389            }
390        }
391        Expression::UnaryOp { expr: inner, .. } => {
392            visit_arena_expression(inner, visitor);
393        }
394        Expression::IsNull { expr: inner, .. } => {
395            visit_arena_expression(inner, visitor);
396        }
397        Expression::IsDistinctFrom { left, right, .. } => {
398            visit_arena_expression(left, visitor);
399            visit_arena_expression(right, visitor);
400        }
401        Expression::IsTruthValue { expr: inner, .. } => {
402            visit_arena_expression(inner, visitor);
403        }
404        // Leaf nodes - no recursion needed
405        Expression::Literal(_)
406        | Expression::Placeholder(_)
407        | Expression::NumberedPlaceholder(_)
408        | Expression::NamedPlaceholder(_)
409        | Expression::ColumnRef { .. }
410        | Expression::Wildcard
411        | Expression::CurrentDate
412        | Expression::CurrentTime { .. }
413        | Expression::CurrentTimestamp { .. }
414        | Expression::Default => {}
415        // Cold-path extended variants
416        Expression::Extended(ext) => match ext {
417            ExtendedExpr::Function { args, .. } | ExtendedExpr::AggregateFunction { args, .. } => {
418                for arg in args.iter() {
419                    visit_arena_expression(arg, visitor);
420                }
421            }
422            ExtendedExpr::Case { operand, when_clauses, else_result } => {
423                if let Some(op) = operand {
424                    visit_arena_expression(op, visitor);
425                }
426                for w in when_clauses.iter() {
427                    for c in w.conditions.iter() {
428                        visit_arena_expression(c, visitor);
429                    }
430                    visit_arena_expression(&w.result, visitor);
431                }
432                if let Some(e) = else_result {
433                    visit_arena_expression(e, visitor);
434                }
435            }
436            ExtendedExpr::ScalarSubquery(select) => visit_arena_statement(select, visitor),
437            ExtendedExpr::In { expr: inner, subquery, .. } => {
438                visit_arena_expression(inner, visitor);
439                visit_arena_statement(subquery, visitor);
440            }
441            ExtendedExpr::InList { expr: inner, values, .. } => {
442                visit_arena_expression(inner, visitor);
443                for v in values.iter() {
444                    visit_arena_expression(v, visitor);
445                }
446            }
447            ExtendedExpr::Between { expr: inner, low, high, .. } => {
448                visit_arena_expression(inner, visitor);
449                visit_arena_expression(low, visitor);
450                visit_arena_expression(high, visitor);
451            }
452            ExtendedExpr::Cast { expr: inner, .. } => {
453                visit_arena_expression(inner, visitor);
454            }
455            ExtendedExpr::Position { substring, string, .. } => {
456                visit_arena_expression(substring, visitor);
457                visit_arena_expression(string, visitor);
458            }
459            ExtendedExpr::Trim { removal_char, string, .. } => {
460                if let Some(c) = removal_char {
461                    visit_arena_expression(c, visitor);
462                }
463                visit_arena_expression(string, visitor);
464            }
465            ExtendedExpr::Extract { expr: inner, .. } => {
466                visit_arena_expression(inner, visitor);
467            }
468            ExtendedExpr::Like { expr: inner, pattern, .. }
469            | ExtendedExpr::Glob { expr: inner, pattern, .. } => {
470                visit_arena_expression(inner, visitor);
471                visit_arena_expression(pattern, visitor);
472            }
473            ExtendedExpr::Exists { subquery, .. } => {
474                visit_arena_statement(subquery, visitor);
475            }
476            ExtendedExpr::QuantifiedComparison { expr: inner, subquery, .. } => {
477                visit_arena_expression(inner, visitor);
478                visit_arena_statement(subquery, visitor);
479            }
480            ExtendedExpr::Interval { value, .. } => {
481                visit_arena_expression(value, visitor);
482            }
483            ExtendedExpr::WindowFunction { function, over } => {
484                match function {
485                    vibesql_ast::arena::WindowFunctionSpec::Aggregate { args, .. }
486                    | vibesql_ast::arena::WindowFunctionSpec::Ranking { args, .. }
487                    | vibesql_ast::arena::WindowFunctionSpec::Value { args, .. } => {
488                        for arg in args.iter() {
489                            visit_arena_expression(arg, visitor);
490                        }
491                    }
492                }
493                if let Some(partition_by) = &over.partition_by {
494                    for expr in partition_by.iter() {
495                        visit_arena_expression(expr, visitor);
496                    }
497                }
498                if let Some(order_by) = &over.order_by {
499                    for item in order_by.iter() {
500                        visit_arena_expression(&item.expr, visitor);
501                    }
502                }
503            }
504            ExtendedExpr::MatchAgainst { search_modifier, .. } => {
505                visit_arena_expression(search_modifier, visitor);
506            }
507            ExtendedExpr::RowValueConstructor(children) => {
508                for child in children.iter() {
509                    visit_arena_expression(child, visitor);
510                }
511            }
512            // Extended leaf nodes - no recursion needed
513            ExtendedExpr::DuplicateKeyValue { .. }
514            | ExtendedExpr::NextValue { .. }
515            | ExtendedExpr::PseudoVariable { .. }
516            | ExtendedExpr::SessionVariable { .. } => {}
517        },
518    }
519}
520
521#[cfg(test)]
522mod tests {
523    use super::*;
524
525    #[test]
526    fn test_arena_prepared_basic() {
527        let sql = "SELECT id, name FROM users WHERE id = 1";
528        let prepared = ArenaPreparedStatement::new(sql.to_string()).unwrap();
529
530        assert_eq!(prepared.sql(), sql);
531        assert_eq!(prepared.param_count(), 0);
532        // Arena parser stores table names in uppercase
533        assert!(
534            prepared.tables().contains("users"),
535            "Expected 'USERS' in tables {:?}",
536            prepared.tables()
537        );
538    }
539
540    #[test]
541    fn test_arena_prepared_with_placeholder() {
542        let sql = "SELECT * FROM users WHERE id = ?";
543        let prepared = ArenaPreparedStatement::new(sql.to_string()).unwrap();
544
545        assert_eq!(prepared.param_count(), 1);
546        // Arena parser stores table names in uppercase
547        assert!(prepared.tables().contains("users"));
548    }
549
550    #[test]
551    fn test_arena_prepared_multiple_placeholders() {
552        let sql = "SELECT * FROM users WHERE id = ? AND name = ? AND age > ?";
553        let prepared = ArenaPreparedStatement::new(sql.to_string()).unwrap();
554
555        assert_eq!(prepared.param_count(), 3);
556    }
557
558    #[test]
559    fn test_arena_prepared_param_validation() {
560        let sql = "SELECT * FROM users WHERE id = ?";
561        let prepared = ArenaPreparedStatement::new(sql.to_string()).unwrap();
562
563        // Correct param count should pass
564        assert!(prepared.validate_params(&[SqlValue::Integer(1)]).is_ok());
565
566        // Wrong param count should fail
567        assert!(prepared.validate_params(&[]).is_err());
568        assert!(prepared.validate_params(&[SqlValue::Integer(1), SqlValue::Integer(2)]).is_err());
569    }
570
571    #[test]
572    fn test_arena_prepared_join_tables() {
573        let sql = "SELECT u.id FROM users u JOIN orders o ON u.id = o.user_id";
574        let prepared = ArenaPreparedStatement::new(sql.to_string()).unwrap();
575
576        let tables = prepared.tables();
577        // Arena parser stores table names in uppercase
578        assert!(tables.contains("users"), "Expected 'USERS' in {:?}", tables);
579        assert!(tables.contains("orders"), "Expected 'ORDERS' in {:?}", tables);
580    }
581}