vibesql_executor/cache/prepared_statement/
arena_prepared.rs

1//! Arena-based prepared statement for zero-allocation execution
2//!
3//! This module implements Option A from the arena parser optimization:
4//! storing the arena-allocated AST directly in the prepared statement,
5//! avoiding conversion overhead entirely for simple queries.
6//!
7//! # Key Benefits
8//!
9//! - **Zero conversion overhead**: Arena AST is used directly during execution
10//! - **Near-zero allocation**: Simple prepared queries avoid heap allocation
11//! - **Arena reuse**: Arena can be reset between executions for parameter binding
12//!
13//! # Safety
14//!
15//! This module uses careful unsafe code to manage self-referential data.
16//! The arena is pinned to prevent moves, and the statement pointer is valid
17//! as long as the arena exists.
18
19use std::collections::HashSet;
20use std::pin::Pin;
21
22use bumpalo::Bump;
23use vibesql_ast::arena::{ArenaInterner, Expression, ExtendedExpr, SelectStmt};
24use vibesql_parser::arena_parser::ArenaParser;
25use vibesql_types::SqlValue;
26
27/// Arena-based prepared statement for zero-allocation execution.
28///
29/// This stores the parsed AST in an arena, keeping it alive for the lifetime
30/// of the prepared statement. This avoids conversion overhead for simple queries.
31///
32/// # Safety Invariants
33///
34/// 1. The arena is pinned and will not move after construction
35/// 2. The statement pointer is valid as long as the arena exists
36/// 3. The arena is only dropped when the ArenaPreparedStatement is dropped
37/// 4. The interner pointer is valid as long as the arena exists
38pub struct ArenaPreparedStatement {
39    /// Original SQL with placeholders
40    sql: String,
41    /// Arena containing the parsed statement (pinned to prevent moves)
42    arena: Pin<Box<Bump>>,
43    /// Pointer to the statement in the arena.
44    ///
45    /// SAFETY: This pointer is valid as long as `arena` is not dropped or reset.
46    /// Since ArenaPreparedStatement owns the arena and only drops it in Drop,
47    /// this pointer is always valid during the lifetime of the struct.
48    statement_ptr: *const SelectStmt<'static>,
49    /// Pointer to the interner in the arena.
50    ///
51    /// SAFETY: Same invariants as statement_ptr.
52    interner_ptr: *const ArenaInterner<'static>,
53    /// Number of parameters expected (? placeholders)
54    param_count: usize,
55    /// Tables referenced by this statement (for cache invalidation)
56    tables: HashSet<String>,
57}
58
59// SAFETY: The arena and statement are not accessed from multiple threads.
60// The statement pointer is only dereferenced through &self methods.
61unsafe impl Send for ArenaPreparedStatement {}
62unsafe impl Sync for ArenaPreparedStatement {}
63
64impl ArenaPreparedStatement {
65    /// Create a new arena-based prepared statement from SQL.
66    ///
67    /// Parses the SQL using the arena parser and stores the result in an owned arena.
68    pub fn new(sql: String) -> Result<Self, ArenaParseError> {
69        // Create pinned arena (won't move)
70        let arena = Pin::new(Box::new(Bump::new()));
71
72        // Parse into arena, getting both statement and interner
73        let (stmt, interner): (&SelectStmt<'_>, ArenaInterner<'_>) =
74            ArenaParser::parse_select_with_interner(&sql, &arena)
75                .map_err(|e| ArenaParseError::ParseError(e.to_string()))?;
76
77        // Count placeholders and extract tables (using interner for symbol resolution)
78        let param_count = count_arena_placeholders(stmt);
79        let tables = extract_arena_tables(stmt, &interner);
80
81        // Store as raw pointer, erasing the lifetime.
82        // SAFETY: The arena is owned by this struct and won't be dropped
83        // while the statement exists. The pointer remains valid.
84        let statement_ptr = stmt as *const SelectStmt<'_>;
85        // Cast to 'static lifetime - this is safe because the arena owns the data
86        let statement_ptr = statement_ptr.cast::<SelectStmt<'static>>();
87
88        // Allocate interner in arena and store pointer
89        // SAFETY: Same invariants as statement_ptr
90        let interner_in_arena = arena.alloc(interner);
91        let interner_ptr = interner_in_arena as *const ArenaInterner<'_>;
92        // Cast to 'static lifetime - this is safe because the arena owns the data
93        let interner_ptr = interner_ptr.cast::<ArenaInterner<'static>>();
94
95        Ok(Self { sql, arena, statement_ptr, interner_ptr, param_count, tables })
96    }
97
98    /// Get the original SQL.
99    pub fn sql(&self) -> &str {
100        &self.sql
101    }
102
103    /// Get a reference to the arena-allocated statement.
104    ///
105    /// The returned reference is valid for the lifetime of this struct.
106    pub fn statement(&self) -> &SelectStmt<'_> {
107        // SAFETY: The pointer is valid as documented in the struct invariants.
108        // We're returning a reference tied to self's lifetime, which is correct.
109        unsafe { &*self.statement_ptr }
110    }
111
112    /// Get a reference to the interner for symbol resolution.
113    ///
114    /// The returned reference is valid for the lifetime of this struct.
115    pub fn interner(&self) -> &ArenaInterner<'_> {
116        // SAFETY: The pointer is valid as documented in the struct invariants.
117        unsafe { &*self.interner_ptr }
118    }
119
120    /// Get the number of parameters expected.
121    pub fn param_count(&self) -> usize {
122        self.param_count
123    }
124
125    /// Get the tables referenced by this statement.
126    pub fn tables(&self) -> &HashSet<String> {
127        &self.tables
128    }
129
130    /// Get a reference to the arena for additional allocations.
131    ///
132    /// This can be used for allocating bound values during execution.
133    pub fn arena(&self) -> &Bump {
134        &self.arena
135    }
136
137    /// Validate that the correct number of parameters is provided.
138    pub fn validate_params(&self, params: &[SqlValue]) -> Result<(), ArenaBindError> {
139        if params.len() != self.param_count {
140            return Err(ArenaBindError::ParameterCountMismatch {
141                expected: self.param_count,
142                actual: params.len(),
143            });
144        }
145        Ok(())
146    }
147}
148
149impl std::fmt::Debug for ArenaPreparedStatement {
150    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
151        f.debug_struct("ArenaPreparedStatement")
152            .field("sql", &self.sql)
153            .field("param_count", &self.param_count)
154            .field("tables", &self.tables)
155            .finish_non_exhaustive()
156    }
157}
158
159/// Error during arena-based parsing.
160#[derive(Debug, Clone)]
161pub enum ArenaParseError {
162    /// SQL parsing failed
163    ParseError(String),
164}
165
166impl std::fmt::Display for ArenaParseError {
167    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
168        match self {
169            ArenaParseError::ParseError(msg) => write!(f, "Parse error: {}", msg),
170        }
171    }
172}
173
174impl std::error::Error for ArenaParseError {}
175
176/// Error during arena-based parameter binding.
177#[derive(Debug, Clone)]
178pub enum ArenaBindError {
179    /// Wrong number of parameters provided
180    ParameterCountMismatch { expected: usize, actual: usize },
181}
182
183impl std::fmt::Display for ArenaBindError {
184    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
185        match self {
186            ArenaBindError::ParameterCountMismatch { expected, actual } => {
187                write!(f, "Parameter count mismatch: expected {}, got {}", expected, actual)
188            }
189        }
190    }
191}
192
193impl std::error::Error for ArenaBindError {}
194
195/// Count the number of placeholder parameters in an arena statement.
196fn count_arena_placeholders(stmt: &SelectStmt<'_>) -> usize {
197    let mut count = 0;
198    visit_arena_statement(stmt, &mut |expr| {
199        if matches!(expr, Expression::Placeholder(_)) {
200            count += 1;
201        }
202    });
203    count
204}
205
206/// Extract table names from an arena statement.
207fn extract_arena_tables(stmt: &SelectStmt<'_>, interner: &ArenaInterner<'_>) -> HashSet<String> {
208    let mut tables = HashSet::new();
209    visit_arena_from_clause(stmt.from.as_ref(), &mut tables, interner);
210    tables
211}
212
213/// Visit all expressions in an arena statement.
214fn visit_arena_statement<F>(stmt: &SelectStmt<'_>, visitor: &mut F)
215where
216    F: FnMut(&Expression<'_>),
217{
218    // Visit CTEs
219    if let Some(ctes) = &stmt.with_clause {
220        for cte in ctes.iter() {
221            visit_arena_statement(cte.query, visitor);
222        }
223    }
224
225    // Visit select items
226    for item in stmt.select_list.iter() {
227        if let vibesql_ast::arena::SelectItem::Expression { expr, .. } = item {
228            visit_arena_expression(expr, visitor);
229        }
230    }
231
232    // Visit FROM clause
233    if let Some(from) = &stmt.from {
234        visit_arena_from_clause_exprs(from, visitor);
235    }
236
237    // Visit WHERE
238    if let Some(where_clause) = &stmt.where_clause {
239        visit_arena_expression(where_clause, visitor);
240    }
241
242    // Visit GROUP BY
243    if let Some(group_by) = &stmt.group_by {
244        visit_arena_group_by(group_by, visitor);
245    }
246
247    // Visit HAVING
248    if let Some(having) = &stmt.having {
249        visit_arena_expression(having, visitor);
250    }
251
252    // Visit ORDER BY
253    if let Some(order_by) = &stmt.order_by {
254        for item in order_by.iter() {
255            visit_arena_expression(&item.expr, visitor);
256        }
257    }
258
259    // Visit set operation
260    if let Some(set_op) = &stmt.set_operation {
261        visit_arena_statement(set_op.right, visitor);
262    }
263}
264
265/// Visit expressions in a FROM clause.
266fn visit_arena_from_clause_exprs<F>(from: &vibesql_ast::arena::FromClause<'_>, visitor: &mut F)
267where
268    F: FnMut(&Expression<'_>),
269{
270    match from {
271        vibesql_ast::arena::FromClause::Table { .. } => {}
272        vibesql_ast::arena::FromClause::Subquery { query, .. } => {
273            visit_arena_statement(query, visitor);
274        }
275        vibesql_ast::arena::FromClause::Join { left, right, condition, .. } => {
276            visit_arena_from_clause_exprs(left, visitor);
277            visit_arena_from_clause_exprs(right, visitor);
278            if let Some(cond) = condition {
279                visit_arena_expression(cond, visitor);
280            }
281        }
282    }
283}
284
285/// Extract table names from a FROM clause.
286fn visit_arena_from_clause(
287    from: Option<&vibesql_ast::arena::FromClause<'_>>,
288    tables: &mut HashSet<String>,
289    interner: &ArenaInterner<'_>,
290) {
291    let Some(from) = from else { return };
292
293    match from {
294        vibesql_ast::arena::FromClause::Table { name, .. } => {
295            tables.insert(interner.resolve(*name).to_string());
296        }
297        vibesql_ast::arena::FromClause::Subquery { query, .. } => {
298            visit_arena_from_clause(query.from.as_ref(), tables, interner);
299        }
300        vibesql_ast::arena::FromClause::Join { left, right, .. } => {
301            visit_arena_from_clause(Some(left), tables, interner);
302            visit_arena_from_clause(Some(right), tables, interner);
303        }
304    }
305}
306
307/// Visit GROUP BY clause expressions.
308fn visit_arena_group_by<F>(group_by: &vibesql_ast::arena::GroupByClause<'_>, visitor: &mut F)
309where
310    F: FnMut(&Expression<'_>),
311{
312    use vibesql_ast::arena::GroupByClause;
313    match group_by {
314        GroupByClause::Simple(exprs) => {
315            for expr in exprs.iter() {
316                visit_arena_expression(expr, visitor);
317            }
318        }
319        GroupByClause::Rollup(elements) | GroupByClause::Cube(elements) => {
320            for element in elements.iter() {
321                match element {
322                    vibesql_ast::arena::GroupingElement::Single(expr) => {
323                        visit_arena_expression(expr, visitor);
324                    }
325                    vibesql_ast::arena::GroupingElement::Composite(exprs) => {
326                        for expr in exprs.iter() {
327                            visit_arena_expression(expr, visitor);
328                        }
329                    }
330                }
331            }
332        }
333        GroupByClause::GroupingSets(sets) => {
334            for set in sets.iter() {
335                for expr in set.columns.iter() {
336                    visit_arena_expression(expr, visitor);
337                }
338            }
339        }
340        GroupByClause::Mixed(items) => {
341            for item in items.iter() {
342                match item {
343                    vibesql_ast::arena::MixedGroupingItem::Simple(expr) => {
344                        visit_arena_expression(expr, visitor);
345                    }
346                    vibesql_ast::arena::MixedGroupingItem::Rollup(elements)
347                    | vibesql_ast::arena::MixedGroupingItem::Cube(elements) => {
348                        for element in elements.iter() {
349                            match element {
350                                vibesql_ast::arena::GroupingElement::Single(expr) => {
351                                    visit_arena_expression(expr, visitor);
352                                }
353                                vibesql_ast::arena::GroupingElement::Composite(exprs) => {
354                                    for expr in exprs.iter() {
355                                        visit_arena_expression(expr, visitor);
356                                    }
357                                }
358                            }
359                        }
360                    }
361                    vibesql_ast::arena::MixedGroupingItem::GroupingSets(sets) => {
362                        for set in sets.iter() {
363                            for expr in set.columns.iter() {
364                                visit_arena_expression(expr, visitor);
365                            }
366                        }
367                    }
368                }
369            }
370        }
371    }
372}
373
374/// Visit all expressions in an arena expression tree.
375fn visit_arena_expression<F>(expr: &Expression<'_>, visitor: &mut F)
376where
377    F: FnMut(&Expression<'_>),
378{
379    visitor(expr);
380
381    match expr {
382        // Hot-path inline variants
383        Expression::BinaryOp { left, right, .. } => {
384            visit_arena_expression(left, visitor);
385            visit_arena_expression(right, visitor);
386        }
387        Expression::Conjunction(children) | Expression::Disjunction(children) => {
388            for child in children.iter() {
389                visit_arena_expression(child, visitor);
390            }
391        }
392        Expression::UnaryOp { expr: inner, .. } => {
393            visit_arena_expression(inner, visitor);
394        }
395        Expression::IsNull { expr: inner, .. } => {
396            visit_arena_expression(inner, visitor);
397        }
398        // Leaf nodes - no recursion needed
399        Expression::Literal(_)
400        | Expression::Placeholder(_)
401        | Expression::NumberedPlaceholder(_)
402        | Expression::NamedPlaceholder(_)
403        | Expression::ColumnRef { .. }
404        | Expression::Wildcard
405        | Expression::CurrentDate
406        | Expression::CurrentTime { .. }
407        | Expression::CurrentTimestamp { .. }
408        | Expression::Default => {}
409        // Cold-path extended variants
410        Expression::Extended(ext) => match ext {
411            ExtendedExpr::Function { args, .. } | ExtendedExpr::AggregateFunction { args, .. } => {
412                for arg in args.iter() {
413                    visit_arena_expression(arg, visitor);
414                }
415            }
416            ExtendedExpr::Case { operand, when_clauses, else_result } => {
417                if let Some(op) = operand {
418                    visit_arena_expression(op, visitor);
419                }
420                for w in when_clauses.iter() {
421                    for c in w.conditions.iter() {
422                        visit_arena_expression(c, visitor);
423                    }
424                    visit_arena_expression(&w.result, visitor);
425                }
426                if let Some(e) = else_result {
427                    visit_arena_expression(e, visitor);
428                }
429            }
430            ExtendedExpr::ScalarSubquery(select) => visit_arena_statement(select, visitor),
431            ExtendedExpr::In { expr: inner, subquery, .. } => {
432                visit_arena_expression(inner, visitor);
433                visit_arena_statement(subquery, visitor);
434            }
435            ExtendedExpr::InList { expr: inner, values, .. } => {
436                visit_arena_expression(inner, visitor);
437                for v in values.iter() {
438                    visit_arena_expression(v, visitor);
439                }
440            }
441            ExtendedExpr::Between { expr: inner, low, high, .. } => {
442                visit_arena_expression(inner, visitor);
443                visit_arena_expression(low, visitor);
444                visit_arena_expression(high, visitor);
445            }
446            ExtendedExpr::Cast { expr: inner, .. } => {
447                visit_arena_expression(inner, visitor);
448            }
449            ExtendedExpr::Position { substring, string, .. } => {
450                visit_arena_expression(substring, visitor);
451                visit_arena_expression(string, visitor);
452            }
453            ExtendedExpr::Trim { removal_char, string, .. } => {
454                if let Some(c) = removal_char {
455                    visit_arena_expression(c, visitor);
456                }
457                visit_arena_expression(string, visitor);
458            }
459            ExtendedExpr::Extract { expr: inner, .. } => {
460                visit_arena_expression(inner, visitor);
461            }
462            ExtendedExpr::Like { expr: inner, pattern, .. } => {
463                visit_arena_expression(inner, visitor);
464                visit_arena_expression(pattern, visitor);
465            }
466            ExtendedExpr::Exists { subquery, .. } => {
467                visit_arena_statement(subquery, visitor);
468            }
469            ExtendedExpr::QuantifiedComparison { expr: inner, subquery, .. } => {
470                visit_arena_expression(inner, visitor);
471                visit_arena_statement(subquery, visitor);
472            }
473            ExtendedExpr::Interval { value, .. } => {
474                visit_arena_expression(value, visitor);
475            }
476            ExtendedExpr::WindowFunction { function, over } => {
477                match function {
478                    vibesql_ast::arena::WindowFunctionSpec::Aggregate { args, .. }
479                    | vibesql_ast::arena::WindowFunctionSpec::Ranking { args, .. }
480                    | vibesql_ast::arena::WindowFunctionSpec::Value { args, .. } => {
481                        for arg in args.iter() {
482                            visit_arena_expression(arg, visitor);
483                        }
484                    }
485                }
486                if let Some(partition_by) = &over.partition_by {
487                    for expr in partition_by.iter() {
488                        visit_arena_expression(expr, visitor);
489                    }
490                }
491                if let Some(order_by) = &over.order_by {
492                    for item in order_by.iter() {
493                        visit_arena_expression(&item.expr, visitor);
494                    }
495                }
496            }
497            ExtendedExpr::MatchAgainst { search_modifier, .. } => {
498                visit_arena_expression(search_modifier, visitor);
499            }
500            // Extended leaf nodes - no recursion needed
501            ExtendedExpr::DuplicateKeyValue { .. }
502            | ExtendedExpr::NextValue { .. }
503            | ExtendedExpr::PseudoVariable { .. }
504            | ExtendedExpr::SessionVariable { .. } => {}
505        },
506    }
507}
508
509#[cfg(test)]
510mod tests {
511    use super::*;
512
513    #[test]
514    fn test_arena_prepared_basic() {
515        let sql = "SELECT id, name FROM users WHERE id = 1";
516        let prepared = ArenaPreparedStatement::new(sql.to_string()).unwrap();
517
518        assert_eq!(prepared.sql(), sql);
519        assert_eq!(prepared.param_count(), 0);
520        // Arena parser stores table names in uppercase
521        assert!(
522            prepared.tables().contains("USERS"),
523            "Expected 'USERS' in tables {:?}",
524            prepared.tables()
525        );
526    }
527
528    #[test]
529    fn test_arena_prepared_with_placeholder() {
530        let sql = "SELECT * FROM users WHERE id = ?";
531        let prepared = ArenaPreparedStatement::new(sql.to_string()).unwrap();
532
533        assert_eq!(prepared.param_count(), 1);
534        // Arena parser stores table names in uppercase
535        assert!(prepared.tables().contains("USERS"));
536    }
537
538    #[test]
539    fn test_arena_prepared_multiple_placeholders() {
540        let sql = "SELECT * FROM users WHERE id = ? AND name = ? AND age > ?";
541        let prepared = ArenaPreparedStatement::new(sql.to_string()).unwrap();
542
543        assert_eq!(prepared.param_count(), 3);
544    }
545
546    #[test]
547    fn test_arena_prepared_param_validation() {
548        let sql = "SELECT * FROM users WHERE id = ?";
549        let prepared = ArenaPreparedStatement::new(sql.to_string()).unwrap();
550
551        // Correct param count should pass
552        assert!(prepared.validate_params(&[SqlValue::Integer(1)]).is_ok());
553
554        // Wrong param count should fail
555        assert!(prepared.validate_params(&[]).is_err());
556        assert!(prepared.validate_params(&[SqlValue::Integer(1), SqlValue::Integer(2)]).is_err());
557    }
558
559    #[test]
560    fn test_arena_prepared_join_tables() {
561        let sql = "SELECT u.id FROM users u JOIN orders o ON u.id = o.user_id";
562        let prepared = ArenaPreparedStatement::new(sql.to_string()).unwrap();
563
564        let tables = prepared.tables();
565        // Arena parser stores table names in uppercase
566        assert!(tables.contains("USERS"), "Expected 'USERS' in {:?}", tables);
567        assert!(tables.contains("ORDERS"), "Expected 'ORDERS' in {:?}", tables);
568    }
569}