vibesql_executor/cache/prepared_statement/
arena_prepared.rs

1//! Arena-based prepared statement for zero-allocation execution
2//!
3//! This module implements Option A from the arena parser optimization:
4//! storing the arena-allocated AST directly in the prepared statement,
5//! avoiding conversion overhead entirely for simple queries.
6//!
7//! # Key Benefits
8//!
9//! - **Zero conversion overhead**: Arena AST is used directly during execution
10//! - **Near-zero allocation**: Simple prepared queries avoid heap allocation
11//! - **Arena reuse**: Arena can be reset between executions for parameter binding
12//!
13//! # Safety
14//!
15//! This module uses careful unsafe code to manage self-referential data.
16//! The arena is pinned to prevent moves, and the statement pointer is valid
17//! as long as the arena exists.
18
19use std::collections::HashSet;
20use std::pin::Pin;
21
22use bumpalo::Bump;
23use vibesql_ast::arena::{ArenaInterner, Expression, ExtendedExpr, SelectStmt};
24use vibesql_parser::arena_parser::ArenaParser;
25use vibesql_types::SqlValue;
26
27/// Arena-based prepared statement for zero-allocation execution.
28///
29/// This stores the parsed AST in an arena, keeping it alive for the lifetime
30/// of the prepared statement. This avoids conversion overhead for simple queries.
31///
32/// # Safety Invariants
33///
34/// 1. The arena is pinned and will not move after construction
35/// 2. The statement pointer is valid as long as the arena exists
36/// 3. The arena is only dropped when the ArenaPreparedStatement is dropped
37/// 4. The interner pointer is valid as long as the arena exists
38pub struct ArenaPreparedStatement {
39    /// Original SQL with placeholders
40    sql: String,
41    /// Arena containing the parsed statement (pinned to prevent moves)
42    arena: Pin<Box<Bump>>,
43    /// Pointer to the statement in the arena.
44    ///
45    /// SAFETY: This pointer is valid as long as `arena` is not dropped or reset.
46    /// Since ArenaPreparedStatement owns the arena and only drops it in Drop,
47    /// this pointer is always valid during the lifetime of the struct.
48    statement_ptr: *const SelectStmt<'static>,
49    /// Pointer to the interner in the arena.
50    ///
51    /// SAFETY: Same invariants as statement_ptr.
52    interner_ptr: *const ArenaInterner<'static>,
53    /// Number of parameters expected (? placeholders)
54    param_count: usize,
55    /// Tables referenced by this statement (for cache invalidation)
56    tables: HashSet<String>,
57}
58
59// SAFETY: The arena and statement are not accessed from multiple threads.
60// The statement pointer is only dereferenced through &self methods.
61unsafe impl Send for ArenaPreparedStatement {}
62unsafe impl Sync for ArenaPreparedStatement {}
63
64impl ArenaPreparedStatement {
65    /// Create a new arena-based prepared statement from SQL.
66    ///
67    /// Parses the SQL using the arena parser and stores the result in an owned arena.
68    pub fn new(sql: String) -> Result<Self, ArenaParseError> {
69        // Create pinned arena (won't move)
70        let arena = Pin::new(Box::new(Bump::new()));
71
72        // Parse into arena, getting both statement and interner
73        let (stmt, interner): (&SelectStmt<'_>, ArenaInterner<'_>) =
74            ArenaParser::parse_select_with_interner(&sql, &arena)
75                .map_err(|e| ArenaParseError::ParseError(e.to_string()))?;
76
77        // Count placeholders and extract tables (using interner for symbol resolution)
78        let param_count = count_arena_placeholders(stmt);
79        let tables = extract_arena_tables(stmt, &interner);
80
81        // Store as raw pointer, erasing the lifetime.
82        // SAFETY: The arena is owned by this struct and won't be dropped
83        // while the statement exists. The pointer remains valid.
84        let statement_ptr = stmt as *const SelectStmt<'_>;
85        // Cast to 'static lifetime - this is safe because the arena owns the data
86        let statement_ptr = statement_ptr.cast::<SelectStmt<'static>>();
87
88        // Allocate interner in arena and store pointer
89        // SAFETY: Same invariants as statement_ptr
90        let interner_in_arena = arena.alloc(interner);
91        let interner_ptr = interner_in_arena as *const ArenaInterner<'_>;
92        // Cast to 'static lifetime - this is safe because the arena owns the data
93        let interner_ptr = interner_ptr.cast::<ArenaInterner<'static>>();
94
95        Ok(Self {
96            sql,
97            arena,
98            statement_ptr,
99            interner_ptr,
100            param_count,
101            tables,
102        })
103    }
104
105    /// Get the original SQL.
106    pub fn sql(&self) -> &str {
107        &self.sql
108    }
109
110    /// Get a reference to the arena-allocated statement.
111    ///
112    /// The returned reference is valid for the lifetime of this struct.
113    pub fn statement(&self) -> &SelectStmt<'_> {
114        // SAFETY: The pointer is valid as documented in the struct invariants.
115        // We're returning a reference tied to self's lifetime, which is correct.
116        unsafe { &*self.statement_ptr }
117    }
118
119    /// Get a reference to the interner for symbol resolution.
120    ///
121    /// The returned reference is valid for the lifetime of this struct.
122    pub fn interner(&self) -> &ArenaInterner<'_> {
123        // SAFETY: The pointer is valid as documented in the struct invariants.
124        unsafe { &*self.interner_ptr }
125    }
126
127    /// Get the number of parameters expected.
128    pub fn param_count(&self) -> usize {
129        self.param_count
130    }
131
132    /// Get the tables referenced by this statement.
133    pub fn tables(&self) -> &HashSet<String> {
134        &self.tables
135    }
136
137    /// Get a reference to the arena for additional allocations.
138    ///
139    /// This can be used for allocating bound values during execution.
140    pub fn arena(&self) -> &Bump {
141        &self.arena
142    }
143
144    /// Validate that the correct number of parameters is provided.
145    pub fn validate_params(&self, params: &[SqlValue]) -> Result<(), ArenaBindError> {
146        if params.len() != self.param_count {
147            return Err(ArenaBindError::ParameterCountMismatch {
148                expected: self.param_count,
149                actual: params.len(),
150            });
151        }
152        Ok(())
153    }
154}
155
156impl std::fmt::Debug for ArenaPreparedStatement {
157    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
158        f.debug_struct("ArenaPreparedStatement")
159            .field("sql", &self.sql)
160            .field("param_count", &self.param_count)
161            .field("tables", &self.tables)
162            .finish_non_exhaustive()
163    }
164}
165
166/// Error during arena-based parsing.
167#[derive(Debug, Clone)]
168pub enum ArenaParseError {
169    /// SQL parsing failed
170    ParseError(String),
171}
172
173impl std::fmt::Display for ArenaParseError {
174    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
175        match self {
176            ArenaParseError::ParseError(msg) => write!(f, "Parse error: {}", msg),
177        }
178    }
179}
180
181impl std::error::Error for ArenaParseError {}
182
183/// Error during arena-based parameter binding.
184#[derive(Debug, Clone)]
185pub enum ArenaBindError {
186    /// Wrong number of parameters provided
187    ParameterCountMismatch { expected: usize, actual: usize },
188}
189
190impl std::fmt::Display for ArenaBindError {
191    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
192        match self {
193            ArenaBindError::ParameterCountMismatch { expected, actual } => {
194                write!(
195                    f,
196                    "Parameter count mismatch: expected {}, got {}",
197                    expected, actual
198                )
199            }
200        }
201    }
202}
203
204impl std::error::Error for ArenaBindError {}
205
206/// Count the number of placeholder parameters in an arena statement.
207fn count_arena_placeholders(stmt: &SelectStmt<'_>) -> usize {
208    let mut count = 0;
209    visit_arena_statement(stmt, &mut |expr| {
210        if matches!(expr, Expression::Placeholder(_)) {
211            count += 1;
212        }
213    });
214    count
215}
216
217/// Extract table names from an arena statement.
218fn extract_arena_tables(stmt: &SelectStmt<'_>, interner: &ArenaInterner<'_>) -> HashSet<String> {
219    let mut tables = HashSet::new();
220    visit_arena_from_clause(stmt.from.as_ref(), &mut tables, interner);
221    tables
222}
223
224/// Visit all expressions in an arena statement.
225fn visit_arena_statement<F>(stmt: &SelectStmt<'_>, visitor: &mut F)
226where
227    F: FnMut(&Expression<'_>),
228{
229    // Visit CTEs
230    if let Some(ctes) = &stmt.with_clause {
231        for cte in ctes.iter() {
232            visit_arena_statement(cte.query, visitor);
233        }
234    }
235
236    // Visit select items
237    for item in stmt.select_list.iter() {
238        if let vibesql_ast::arena::SelectItem::Expression { expr, .. } = item {
239            visit_arena_expression(expr, visitor);
240        }
241    }
242
243    // Visit FROM clause
244    if let Some(from) = &stmt.from {
245        visit_arena_from_clause_exprs(from, visitor);
246    }
247
248    // Visit WHERE
249    if let Some(where_clause) = &stmt.where_clause {
250        visit_arena_expression(where_clause, visitor);
251    }
252
253    // Visit GROUP BY
254    if let Some(group_by) = &stmt.group_by {
255        visit_arena_group_by(group_by, visitor);
256    }
257
258    // Visit HAVING
259    if let Some(having) = &stmt.having {
260        visit_arena_expression(having, visitor);
261    }
262
263    // Visit ORDER BY
264    if let Some(order_by) = &stmt.order_by {
265        for item in order_by.iter() {
266            visit_arena_expression(&item.expr, visitor);
267        }
268    }
269
270    // Visit set operation
271    if let Some(set_op) = &stmt.set_operation {
272        visit_arena_statement(set_op.right, visitor);
273    }
274}
275
276/// Visit expressions in a FROM clause.
277fn visit_arena_from_clause_exprs<F>(from: &vibesql_ast::arena::FromClause<'_>, visitor: &mut F)
278where
279    F: FnMut(&Expression<'_>),
280{
281    match from {
282        vibesql_ast::arena::FromClause::Table { .. } => {}
283        vibesql_ast::arena::FromClause::Subquery { query, .. } => {
284            visit_arena_statement(query, visitor);
285        }
286        vibesql_ast::arena::FromClause::Join {
287            left,
288            right,
289            condition,
290            ..
291        } => {
292            visit_arena_from_clause_exprs(left, visitor);
293            visit_arena_from_clause_exprs(right, visitor);
294            if let Some(cond) = condition {
295                visit_arena_expression(cond, visitor);
296            }
297        }
298    }
299}
300
301/// Extract table names from a FROM clause.
302fn visit_arena_from_clause(
303    from: Option<&vibesql_ast::arena::FromClause<'_>>,
304    tables: &mut HashSet<String>,
305    interner: &ArenaInterner<'_>,
306) {
307    let Some(from) = from else { return };
308
309    match from {
310        vibesql_ast::arena::FromClause::Table { name, .. } => {
311            tables.insert(interner.resolve(*name).to_string());
312        }
313        vibesql_ast::arena::FromClause::Subquery { query, .. } => {
314            visit_arena_from_clause(query.from.as_ref(), tables, interner);
315        }
316        vibesql_ast::arena::FromClause::Join { left, right, .. } => {
317            visit_arena_from_clause(Some(left), tables, interner);
318            visit_arena_from_clause(Some(right), tables, interner);
319        }
320    }
321}
322
323/// Visit GROUP BY clause expressions.
324fn visit_arena_group_by<F>(group_by: &vibesql_ast::arena::GroupByClause<'_>, visitor: &mut F)
325where
326    F: FnMut(&Expression<'_>),
327{
328    use vibesql_ast::arena::GroupByClause;
329    match group_by {
330        GroupByClause::Simple(exprs) => {
331            for expr in exprs.iter() {
332                visit_arena_expression(expr, visitor);
333            }
334        }
335        GroupByClause::Rollup(elements) | GroupByClause::Cube(elements) => {
336            for element in elements.iter() {
337                match element {
338                    vibesql_ast::arena::GroupingElement::Single(expr) => {
339                        visit_arena_expression(expr, visitor);
340                    }
341                    vibesql_ast::arena::GroupingElement::Composite(exprs) => {
342                        for expr in exprs.iter() {
343                            visit_arena_expression(expr, visitor);
344                        }
345                    }
346                }
347            }
348        }
349        GroupByClause::GroupingSets(sets) => {
350            for set in sets.iter() {
351                for expr in set.columns.iter() {
352                    visit_arena_expression(expr, visitor);
353                }
354            }
355        }
356        GroupByClause::Mixed(items) => {
357            for item in items.iter() {
358                match item {
359                    vibesql_ast::arena::MixedGroupingItem::Simple(expr) => {
360                        visit_arena_expression(expr, visitor);
361                    }
362                    vibesql_ast::arena::MixedGroupingItem::Rollup(elements)
363                    | vibesql_ast::arena::MixedGroupingItem::Cube(elements) => {
364                        for element in elements.iter() {
365                            match element {
366                                vibesql_ast::arena::GroupingElement::Single(expr) => {
367                                    visit_arena_expression(expr, visitor);
368                                }
369                                vibesql_ast::arena::GroupingElement::Composite(exprs) => {
370                                    for expr in exprs.iter() {
371                                        visit_arena_expression(expr, visitor);
372                                    }
373                                }
374                            }
375                        }
376                    }
377                    vibesql_ast::arena::MixedGroupingItem::GroupingSets(sets) => {
378                        for set in sets.iter() {
379                            for expr in set.columns.iter() {
380                                visit_arena_expression(expr, visitor);
381                            }
382                        }
383                    }
384                }
385            }
386        }
387    }
388}
389
390/// Visit all expressions in an arena expression tree.
391fn visit_arena_expression<F>(expr: &Expression<'_>, visitor: &mut F)
392where
393    F: FnMut(&Expression<'_>),
394{
395    visitor(expr);
396
397    match expr {
398        // Hot-path inline variants
399        Expression::BinaryOp { left, right, .. } => {
400            visit_arena_expression(left, visitor);
401            visit_arena_expression(right, visitor);
402        }
403        Expression::Conjunction(children) | Expression::Disjunction(children) => {
404            for child in children.iter() {
405                visit_arena_expression(child, visitor);
406            }
407        }
408        Expression::UnaryOp { expr: inner, .. } => {
409            visit_arena_expression(inner, visitor);
410        }
411        Expression::IsNull { expr: inner, .. } => {
412            visit_arena_expression(inner, visitor);
413        }
414        // Leaf nodes - no recursion needed
415        Expression::Literal(_)
416        | Expression::Placeholder(_)
417        | Expression::NumberedPlaceholder(_)
418        | Expression::NamedPlaceholder(_)
419        | Expression::ColumnRef { .. }
420        | Expression::Wildcard
421        | Expression::CurrentDate
422        | Expression::CurrentTime { .. }
423        | Expression::CurrentTimestamp { .. }
424        | Expression::Default => {}
425        // Cold-path extended variants
426        Expression::Extended(ext) => match ext {
427            ExtendedExpr::Function { args, .. } | ExtendedExpr::AggregateFunction { args, .. } => {
428                for arg in args.iter() {
429                    visit_arena_expression(arg, visitor);
430                }
431            }
432            ExtendedExpr::Case {
433                operand,
434                when_clauses,
435                else_result,
436            } => {
437                if let Some(op) = operand {
438                    visit_arena_expression(op, visitor);
439                }
440                for w in when_clauses.iter() {
441                    for c in w.conditions.iter() {
442                        visit_arena_expression(c, visitor);
443                    }
444                    visit_arena_expression(&w.result, visitor);
445                }
446                if let Some(e) = else_result {
447                    visit_arena_expression(e, visitor);
448                }
449            }
450            ExtendedExpr::ScalarSubquery(select) => visit_arena_statement(select, visitor),
451            ExtendedExpr::In {
452                expr: inner,
453                subquery,
454                ..
455            } => {
456                visit_arena_expression(inner, visitor);
457                visit_arena_statement(subquery, visitor);
458            }
459            ExtendedExpr::InList {
460                expr: inner,
461                values,
462                ..
463            } => {
464                visit_arena_expression(inner, visitor);
465                for v in values.iter() {
466                    visit_arena_expression(v, visitor);
467                }
468            }
469            ExtendedExpr::Between {
470                expr: inner,
471                low,
472                high,
473                ..
474            } => {
475                visit_arena_expression(inner, visitor);
476                visit_arena_expression(low, visitor);
477                visit_arena_expression(high, visitor);
478            }
479            ExtendedExpr::Cast { expr: inner, .. } => {
480                visit_arena_expression(inner, visitor);
481            }
482            ExtendedExpr::Position {
483                substring, string, ..
484            } => {
485                visit_arena_expression(substring, visitor);
486                visit_arena_expression(string, visitor);
487            }
488            ExtendedExpr::Trim {
489                removal_char,
490                string,
491                ..
492            } => {
493                if let Some(c) = removal_char {
494                    visit_arena_expression(c, visitor);
495                }
496                visit_arena_expression(string, visitor);
497            }
498            ExtendedExpr::Extract { expr: inner, .. } => {
499                visit_arena_expression(inner, visitor);
500            }
501            ExtendedExpr::Like {
502                expr: inner,
503                pattern,
504                ..
505            } => {
506                visit_arena_expression(inner, visitor);
507                visit_arena_expression(pattern, visitor);
508            }
509            ExtendedExpr::Exists { subquery, .. } => {
510                visit_arena_statement(subquery, visitor);
511            }
512            ExtendedExpr::QuantifiedComparison {
513                expr: inner,
514                subquery,
515                ..
516            } => {
517                visit_arena_expression(inner, visitor);
518                visit_arena_statement(subquery, visitor);
519            }
520            ExtendedExpr::Interval { value, .. } => {
521                visit_arena_expression(value, visitor);
522            }
523            ExtendedExpr::WindowFunction { function, over } => {
524                match function {
525                    vibesql_ast::arena::WindowFunctionSpec::Aggregate { args, .. }
526                    | vibesql_ast::arena::WindowFunctionSpec::Ranking { args, .. }
527                    | vibesql_ast::arena::WindowFunctionSpec::Value { args, .. } => {
528                        for arg in args.iter() {
529                            visit_arena_expression(arg, visitor);
530                        }
531                    }
532                }
533                if let Some(partition_by) = &over.partition_by {
534                    for expr in partition_by.iter() {
535                        visit_arena_expression(expr, visitor);
536                    }
537                }
538                if let Some(order_by) = &over.order_by {
539                    for item in order_by.iter() {
540                        visit_arena_expression(&item.expr, visitor);
541                    }
542                }
543            }
544            ExtendedExpr::MatchAgainst { search_modifier, .. } => {
545                visit_arena_expression(search_modifier, visitor);
546            }
547            // Extended leaf nodes - no recursion needed
548            ExtendedExpr::DuplicateKeyValue { .. }
549            | ExtendedExpr::NextValue { .. }
550            | ExtendedExpr::PseudoVariable { .. }
551            | ExtendedExpr::SessionVariable { .. } => {}
552        },
553    }
554}
555
556#[cfg(test)]
557mod tests {
558    use super::*;
559
560    #[test]
561    fn test_arena_prepared_basic() {
562        let sql = "SELECT id, name FROM users WHERE id = 1";
563        let prepared = ArenaPreparedStatement::new(sql.to_string()).unwrap();
564
565        assert_eq!(prepared.sql(), sql);
566        assert_eq!(prepared.param_count(), 0);
567        // Arena parser stores table names in uppercase
568        assert!(prepared.tables().contains("USERS"),
569            "Expected 'USERS' in tables {:?}", prepared.tables());
570    }
571
572    #[test]
573    fn test_arena_prepared_with_placeholder() {
574        let sql = "SELECT * FROM users WHERE id = ?";
575        let prepared = ArenaPreparedStatement::new(sql.to_string()).unwrap();
576
577        assert_eq!(prepared.param_count(), 1);
578        // Arena parser stores table names in uppercase
579        assert!(prepared.tables().contains("USERS"));
580    }
581
582    #[test]
583    fn test_arena_prepared_multiple_placeholders() {
584        let sql = "SELECT * FROM users WHERE id = ? AND name = ? AND age > ?";
585        let prepared = ArenaPreparedStatement::new(sql.to_string()).unwrap();
586
587        assert_eq!(prepared.param_count(), 3);
588    }
589
590    #[test]
591    fn test_arena_prepared_param_validation() {
592        let sql = "SELECT * FROM users WHERE id = ?";
593        let prepared = ArenaPreparedStatement::new(sql.to_string()).unwrap();
594
595        // Correct param count should pass
596        assert!(prepared.validate_params(&[SqlValue::Integer(1)]).is_ok());
597
598        // Wrong param count should fail
599        assert!(prepared.validate_params(&[]).is_err());
600        assert!(prepared
601            .validate_params(&[SqlValue::Integer(1), SqlValue::Integer(2)])
602            .is_err());
603    }
604
605    #[test]
606    fn test_arena_prepared_join_tables() {
607        let sql = "SELECT u.id FROM users u JOIN orders o ON u.id = o.user_id";
608        let prepared = ArenaPreparedStatement::new(sql.to_string()).unwrap();
609
610        let tables = prepared.tables();
611        // Arena parser stores table names in uppercase
612        assert!(tables.contains("USERS"), "Expected 'USERS' in {:?}", tables);
613        assert!(tables.contains("ORDERS"), "Expected 'ORDERS' in {:?}", tables);
614    }
615}