1use std::{
25 num::NonZeroUsize,
26 sync::{
27 atomic::{AtomicUsize, Ordering},
28 Arc, Mutex,
29 },
30};
31
32use lru::LruCache;
33use vibesql_ast::Statement;
34use vibesql_types::SqlValue;
35
36use super::{extract_tables_from_statement, QuerySignature};
37
38pub mod arena_prepared;
39mod bind;
40pub mod plan;
41
42pub use plan::{
43 CachedPlan, ColumnProjection, PkDeletePlan, PkPointLookupPlan, ProjectionPlan,
44 ResolvedProjection, SimpleFastPathPlan,
45};
46
47#[derive(Debug, Clone)]
49pub struct PreparedStatement {
50 sql: String,
52 statement: Statement,
54 signature: QuerySignature,
56 param_count: usize,
58 tables: std::collections::HashSet<String>,
60 cached_plan: CachedPlan,
62}
63
64impl PreparedStatement {
65 pub fn new(sql: String, statement: Statement) -> Self {
67 let signature = QuerySignature::from_ast(&statement);
68 let param_count = bind::count_placeholders(&statement);
70 let tables = extract_tables_from_statement(&statement);
71 let cached_plan = plan::analyze_for_plan(&statement);
73
74 Self { sql, statement, signature, param_count, tables, cached_plan }
75 }
76
77 pub fn sql(&self) -> &str {
79 &self.sql
80 }
81
82 pub fn statement(&self) -> &Statement {
84 &self.statement
85 }
86
87 pub fn signature(&self) -> &QuerySignature {
89 &self.signature
90 }
91
92 pub fn param_count(&self) -> usize {
94 self.param_count
95 }
96
97 pub fn tables(&self) -> &std::collections::HashSet<String> {
99 &self.tables
100 }
101
102 pub fn cached_plan(&self) -> &CachedPlan {
104 &self.cached_plan
105 }
106
107 pub fn bind(&self, params: &[SqlValue]) -> Result<Statement, PreparedStatementError> {
116 if params.len() != self.param_count {
117 return Err(PreparedStatementError::ParameterCountMismatch {
118 expected: self.param_count,
119 actual: params.len(),
120 });
121 }
122
123 if self.param_count == 0 {
124 return Ok(self.statement.clone());
126 }
127
128 Ok(bind::bind_parameters(&self.statement, params))
130 }
131}
132
133#[derive(Debug, Clone)]
135pub enum PreparedStatementError {
136 ParameterCountMismatch { expected: usize, actual: usize },
138 ParseError(String),
140}
141
142impl std::fmt::Display for PreparedStatementError {
143 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
144 match self {
145 PreparedStatementError::ParameterCountMismatch { expected, actual } => {
146 write!(f, "Parameter count mismatch: expected {}, got {}", expected, actual)
147 }
148 PreparedStatementError::ParseError(msg) => write!(f, "Parse error: {}", msg),
149 }
150 }
151}
152
153impl std::error::Error for PreparedStatementError {}
154
155#[derive(Debug, Clone)]
157pub struct PreparedStatementCacheStats {
158 pub hits: usize,
159 pub misses: usize,
160 pub evictions: usize,
161 pub size: usize,
162 pub hit_rate: f64,
163}
164
165pub struct PreparedStatementCache {
167 cache: Mutex<LruCache<String, Arc<PreparedStatement>>>,
169 arena_cache: Mutex<LruCache<String, Arc<arena_prepared::ArenaPreparedStatement>>>,
171 max_size: usize,
173 hits: AtomicUsize,
175 misses: AtomicUsize,
177 evictions: AtomicUsize,
179 arena_hits: AtomicUsize,
181 arena_misses: AtomicUsize,
183}
184
185impl PreparedStatementCache {
186 pub fn new(max_size: usize) -> Self {
188 let cap = NonZeroUsize::new(max_size).unwrap_or(NonZeroUsize::new(1).unwrap());
189 Self {
190 cache: Mutex::new(LruCache::new(cap)),
191 arena_cache: Mutex::new(LruCache::new(cap)),
192 max_size,
193 hits: AtomicUsize::new(0),
194 misses: AtomicUsize::new(0),
195 evictions: AtomicUsize::new(0),
196 arena_hits: AtomicUsize::new(0),
197 arena_misses: AtomicUsize::new(0),
198 }
199 }
200
201 pub fn default_cache() -> Self {
203 Self::new(1000)
204 }
205
206 pub fn get(&self, sql: &str) -> Option<Arc<PreparedStatement>> {
208 let mut cache = self.cache.lock().unwrap();
209 if let Some(stmt) = cache.get(sql) {
210 self.hits.fetch_add(1, Ordering::Relaxed);
211 Some(Arc::clone(stmt))
212 } else {
213 self.misses.fetch_add(1, Ordering::Relaxed);
214 None
215 }
216 }
217
218 pub fn get_or_prepare(
224 &self,
225 sql: &str,
226 ) -> Result<Arc<PreparedStatement>, PreparedStatementError> {
227 let mut cache = self.cache.lock().unwrap();
229
230 if let Some(stmt) = cache.get(sql) {
232 self.hits.fetch_add(1, Ordering::Relaxed);
233 return Ok(Arc::clone(stmt));
234 }
235
236 self.misses.fetch_add(1, Ordering::Relaxed);
238 let statement = parse_with_arena_fallback(sql)
239 .map_err(|e| PreparedStatementError::ParseError(e.to_string()))?;
240
241 let prepared = Arc::new(PreparedStatement::new(sql.to_string(), statement));
242
243 if cache.len() >= self.max_size {
245 self.evictions.fetch_add(1, Ordering::Relaxed);
246 }
247
248 cache.put(sql.to_string(), Arc::clone(&prepared));
250
251 Ok(prepared)
252 }
253
254 pub fn get_or_prepare_arena(
267 &self,
268 sql: &str,
269 ) -> Result<Arc<arena_prepared::ArenaPreparedStatement>, arena_prepared::ArenaParseError> {
270 let mut cache = self.arena_cache.lock().unwrap();
272
273 if let Some(stmt) = cache.get(sql) {
275 self.arena_hits.fetch_add(1, Ordering::Relaxed);
276 return Ok(Arc::clone(stmt));
277 }
278
279 self.arena_misses.fetch_add(1, Ordering::Relaxed);
281 let prepared = Arc::new(arena_prepared::ArenaPreparedStatement::new(sql.to_string())?);
282
283 if cache.len() >= self.max_size {
285 self.evictions.fetch_add(1, Ordering::Relaxed);
286 }
287
288 cache.put(sql.to_string(), Arc::clone(&prepared));
290
291 Ok(prepared)
292 }
293
294 pub fn get_arena(&self, sql: &str) -> Option<Arc<arena_prepared::ArenaPreparedStatement>> {
296 let mut cache = self.arena_cache.lock().unwrap();
297 if let Some(stmt) = cache.get(sql) {
298 self.arena_hits.fetch_add(1, Ordering::Relaxed);
299 Some(Arc::clone(stmt))
300 } else {
301 None
302 }
303 }
304
305 pub fn clear(&self) {
307 self.cache.lock().unwrap().clear();
308 self.arena_cache.lock().unwrap().clear();
309 }
310
311 pub fn invalidate_table(&self, table: &str) {
313 {
315 let mut cache = self.cache.lock().unwrap();
316 let keys_to_remove: Vec<String> = cache
317 .iter()
318 .filter(|(_, stmt)| stmt.tables.iter().any(|t| t.eq_ignore_ascii_case(table)))
319 .map(|(k, _)| k.clone())
320 .collect();
321
322 for key in keys_to_remove {
323 cache.pop(&key);
324 }
325 }
326
327 {
329 let mut arena_cache = self.arena_cache.lock().unwrap();
330 let keys_to_remove: Vec<String> = arena_cache
331 .iter()
332 .filter(|(_, stmt)| stmt.tables().iter().any(|t| t.eq_ignore_ascii_case(table)))
333 .map(|(k, _)| k.clone())
334 .collect();
335
336 for key in keys_to_remove {
337 arena_cache.pop(&key);
338 }
339 }
340 }
341
342 pub fn stats(&self) -> PreparedStatementCacheStats {
344 let cache = self.cache.lock().unwrap();
345 let hits = self.hits.load(Ordering::Relaxed);
346 let misses = self.misses.load(Ordering::Relaxed);
347 let total = hits + misses;
348 let hit_rate = if total > 0 { hits as f64 / total as f64 } else { 0.0 };
349
350 PreparedStatementCacheStats {
351 hits,
352 misses,
353 evictions: self.evictions.load(Ordering::Relaxed),
354 size: cache.len(),
355 hit_rate,
356 }
357 }
358
359 pub fn max_size(&self) -> usize {
361 self.max_size
362 }
363}
364
365fn parse_with_arena_fallback(sql: &str) -> Result<Statement, vibesql_parser::ParseError> {
377 let trimmed = sql.trim_start();
379 let first_word = trimmed.split_whitespace().next().unwrap_or("");
380
381 if first_word.eq_ignore_ascii_case("SELECT") || first_word.eq_ignore_ascii_case("WITH") {
383 if let Ok(select_stmt) = vibesql_parser::arena_parser::parse_select_to_owned(sql) {
384 return Ok(Statement::Select(Box::new(select_stmt)));
385 }
386 }
388
389 if first_word.eq_ignore_ascii_case("INSERT") {
391 if let Ok(insert_stmt) = vibesql_parser::arena_parser::parse_insert_to_owned(sql) {
392 return Ok(Statement::Insert(insert_stmt));
393 }
394 }
396
397 if first_word.eq_ignore_ascii_case("REPLACE") {
399 if let Ok(insert_stmt) = vibesql_parser::arena_parser::parse_insert_to_owned(sql) {
400 return Ok(Statement::Insert(insert_stmt));
401 }
402 }
404
405 if first_word.eq_ignore_ascii_case("UPDATE") {
407 if let Ok(update_stmt) = vibesql_parser::arena_parser::parse_update_to_owned(sql) {
408 return Ok(Statement::Update(update_stmt));
409 }
410 }
412
413 if first_word.eq_ignore_ascii_case("DELETE") {
415 if let Ok(delete_stmt) = vibesql_parser::arena_parser::parse_delete_to_owned(sql) {
416 return Ok(Statement::Delete(delete_stmt));
417 }
418 }
420
421 vibesql_parser::Parser::parse_sql(sql)
423}
424
425#[cfg(test)]
426mod tests {
427 use vibesql_ast::Expression;
428
429 use super::*;
430
431 #[test]
432 fn test_prepared_statement_no_params() {
433 let sql = "SELECT * FROM users";
434 let statement = vibesql_parser::Parser::parse_sql(sql).unwrap();
435 let prepared = PreparedStatement::new(sql.to_string(), statement);
436
437 assert_eq!(prepared.param_count(), 0);
438 assert!(prepared.bind(&[]).is_ok());
439 }
440
441 #[test]
442 fn test_prepared_statement_with_placeholder() {
443 let sql = "SELECT * FROM users WHERE id = ?";
445 let statement = vibesql_parser::Parser::parse_sql(sql).unwrap();
446 let prepared = PreparedStatement::new(sql.to_string(), statement);
447
448 assert_eq!(prepared.param_count(), 1);
450
451 let bound = prepared.bind(&[SqlValue::Integer(42)]).unwrap();
453 assert!(matches!(bound, Statement::Select(_)));
454
455 if let Statement::Select(select) = bound {
457 if let Some(Expression::BinaryOp { right, .. }) = &select.where_clause {
458 assert_eq!(**right, Expression::Literal(SqlValue::Integer(42)));
459 } else {
460 panic!("Expected BinaryOp in WHERE clause");
461 }
462 }
463 }
464
465 #[test]
466 fn test_prepared_statement_multiple_placeholders() {
467 let sql = "SELECT * FROM users WHERE id = ? AND name = ?";
468 let statement = vibesql_parser::Parser::parse_sql(sql).unwrap();
469 let prepared = PreparedStatement::new(sql.to_string(), statement);
470
471 assert_eq!(prepared.param_count(), 2);
472
473 let params = vec![SqlValue::Integer(42), SqlValue::Varchar(arcstr::ArcStr::from("John"))];
474 let bound = prepared.bind(¶ms).unwrap();
475 assert!(matches!(bound, Statement::Select(_)));
476 }
477
478 #[test]
479 fn test_prepared_statement_bind_param_mismatch() {
480 let sql = "SELECT * FROM users WHERE id = ?";
481 let statement = vibesql_parser::Parser::parse_sql(sql).unwrap();
482 let prepared = PreparedStatement::new(sql.to_string(), statement);
483
484 let result = prepared.bind(&[]);
486 assert!(matches!(
487 result,
488 Err(PreparedStatementError::ParameterCountMismatch { expected: 1, actual: 0 })
489 ));
490
491 let result = prepared.bind(&[SqlValue::Integer(1), SqlValue::Integer(2)]);
493 assert!(matches!(
494 result,
495 Err(PreparedStatementError::ParameterCountMismatch { expected: 1, actual: 2 })
496 ));
497 }
498
499 #[test]
500 fn test_prepared_statement_reuse() {
501 let sql = "SELECT * FROM users WHERE id = ?";
503 let statement = vibesql_parser::Parser::parse_sql(sql).unwrap();
504 let prepared = PreparedStatement::new(sql.to_string(), statement);
505
506 let bound1 = prepared.bind(&[SqlValue::Integer(1)]).unwrap();
508 let bound2 = prepared.bind(&[SqlValue::Integer(2)]).unwrap();
509 let bound3 = prepared.bind(&[SqlValue::Integer(3)]).unwrap();
510
511 for (bound, expected_id) in [(bound1, 1), (bound2, 2), (bound3, 3)] {
513 if let Statement::Select(select) = bound {
514 if let Some(Expression::BinaryOp { right, .. }) = &select.where_clause {
515 assert_eq!(**right, Expression::Literal(SqlValue::Integer(expected_id)));
516 }
517 }
518 }
519 }
520
521 #[test]
522 fn test_cache_get_or_prepare() {
523 let cache = PreparedStatementCache::new(10);
524 let sql = "SELECT * FROM users WHERE id = 1";
525
526 let stmt1 = cache.get_or_prepare(sql).unwrap();
528 let stats = cache.stats();
529 assert_eq!(stats.misses, 1);
530 assert_eq!(stats.hits, 0);
531
532 let stmt2 = cache.get_or_prepare(sql).unwrap();
534 let stats = cache.stats();
535 assert_eq!(stats.misses, 1);
536 assert_eq!(stats.hits, 1);
537
538 assert!(Arc::ptr_eq(&stmt1, &stmt2));
540 }
541
542 #[test]
543 fn test_cache_placeholder_reuse() {
544 let cache = PreparedStatementCache::new(10);
546 let sql = "SELECT * FROM users WHERE id = ?";
547
548 let stmt1 = cache.get_or_prepare(sql).unwrap();
550 assert_eq!(cache.stats().misses, 1);
551 assert_eq!(cache.stats().hits, 0);
552
553 let stmt2 = cache.get_or_prepare(sql).unwrap();
555 assert_eq!(cache.stats().misses, 1);
556 assert_eq!(cache.stats().hits, 1);
557
558 assert!(Arc::ptr_eq(&stmt1, &stmt2));
560
561 let bound1 = stmt1.bind(&[SqlValue::Integer(1)]).unwrap();
563 let bound2 = stmt2.bind(&[SqlValue::Integer(999)]).unwrap();
564
565 if let (Statement::Select(s1), Statement::Select(s2)) = (&bound1, &bound2) {
567 if let (
568 Some(Expression::BinaryOp { right: r1, .. }),
569 Some(Expression::BinaryOp { right: r2, .. }),
570 ) = (&s1.where_clause, &s2.where_clause)
571 {
572 assert_eq!(**r1, Expression::Literal(SqlValue::Integer(1)));
573 assert_eq!(**r2, Expression::Literal(SqlValue::Integer(999)));
574 }
575 }
576 }
577
578 #[test]
579 fn test_cache_lru_eviction() {
580 let cache = PreparedStatementCache::new(2);
581
582 cache.get_or_prepare("SELECT * FROM users").unwrap();
583 cache.get_or_prepare("SELECT * FROM orders").unwrap();
584 assert_eq!(cache.stats().size, 2);
585 assert_eq!(cache.stats().evictions, 0);
586
587 cache.get_or_prepare("SELECT * FROM products").unwrap();
589 assert_eq!(cache.stats().size, 2);
590 assert_eq!(cache.stats().evictions, 1);
591
592 assert!(cache.get("SELECT * FROM users").is_none());
594 assert!(cache.get("SELECT * FROM orders").is_some());
595 assert!(cache.get("SELECT * FROM products").is_some());
596 }
597
598 #[test]
599 fn test_cache_table_invalidation() {
600 let cache = PreparedStatementCache::new(10);
601
602 cache.get_or_prepare("SELECT * FROM users WHERE id = ?").unwrap();
603 cache.get_or_prepare("SELECT * FROM orders WHERE id = ?").unwrap();
604 assert_eq!(cache.stats().size, 2);
605
606 cache.invalidate_table("users");
608 assert_eq!(cache.stats().size, 1);
609
610 assert!(cache.get("SELECT * FROM orders WHERE id = ?").is_some());
612 }
613
614 #[test]
615 fn test_arena_parse_insert() {
616 let sql = "INSERT INTO users (name, email) VALUES ('Alice', 'alice@example.com')";
617 let result = parse_with_arena_fallback(sql);
618 assert!(result.is_ok());
619 assert!(matches!(result.unwrap(), Statement::Insert(_)));
620 }
621
622 #[test]
623 fn test_arena_parse_insert_with_placeholders() {
624 let cache = PreparedStatementCache::new(10);
625 let sql = "INSERT INTO users (name, email) VALUES (?, ?)";
626
627 let stmt = cache.get_or_prepare(sql).unwrap();
628 assert_eq!(stmt.param_count(), 2);
629
630 let bound = stmt
631 .bind(&[
632 SqlValue::Varchar(arcstr::ArcStr::from("Bob")),
633 SqlValue::Varchar(arcstr::ArcStr::from("bob@example.com")),
634 ])
635 .unwrap();
636 assert!(matches!(bound, Statement::Insert(_)));
637 }
638
639 #[test]
640 fn test_arena_parse_update() {
641 let sql = "UPDATE users SET name = 'Bob' WHERE id = 1";
642 let result = parse_with_arena_fallback(sql);
643 assert!(result.is_ok());
644 assert!(matches!(result.unwrap(), Statement::Update(_)));
645 }
646
647 #[test]
648 fn test_arena_parse_update_with_placeholders() {
649 let cache = PreparedStatementCache::new(10);
650 let sql = "UPDATE users SET name = ? WHERE id = ?";
651
652 let stmt = cache.get_or_prepare(sql).unwrap();
653 assert_eq!(stmt.param_count(), 2);
654
655 let bound = stmt
656 .bind(&[SqlValue::Varchar(arcstr::ArcStr::from("Charlie")), SqlValue::Integer(42)])
657 .unwrap();
658 assert!(matches!(bound, Statement::Update(_)));
659 }
660
661 #[test]
662 fn test_arena_parse_delete() {
663 let sql = "DELETE FROM users WHERE id = 1";
664 let result = parse_with_arena_fallback(sql);
665 assert!(result.is_ok());
666 assert!(matches!(result.unwrap(), Statement::Delete(_)));
667 }
668
669 #[test]
670 fn test_arena_parse_delete_with_placeholders() {
671 let cache = PreparedStatementCache::new(10);
672 let sql = "DELETE FROM users WHERE id = ?";
673
674 let stmt = cache.get_or_prepare(sql).unwrap();
675 assert_eq!(stmt.param_count(), 1);
676
677 let bound = stmt.bind(&[SqlValue::Integer(99)]).unwrap();
678 assert!(matches!(bound, Statement::Delete(_)));
679 }
680
681 #[test]
682 fn test_arena_parse_sysbench_insert() {
683 let cache = PreparedStatementCache::new(10);
685
686 let sql3 = "INSERT INTO test (a, b, c) VALUES (?, ?, ?)";
688 let stmt3 = cache.get_or_prepare(sql3);
689 assert!(stmt3.is_ok(), "3-column INSERT failed: {:?}", stmt3.err());
690
691 let sql4_gen = "INSERT INTO test (a, b, c, d) VALUES (?, ?, ?, ?)";
693 let stmt4_gen = cache.get_or_prepare(sql4_gen);
694 assert!(stmt4_gen.is_ok(), "4-column INSERT (generic) failed: {:?}", stmt4_gen.err());
695
696 let sql4 = "INSERT INTO sbtest1 (id, k, c, padding) VALUES (?, ?, ?, ?)";
698 let stmt4 = cache.get_or_prepare(sql4);
699 assert!(stmt4.is_ok(), "4-column INSERT (sysbench) failed: {:?}", stmt4.err());
700 assert_eq!(stmt4.unwrap().param_count(), 4);
701 }
702}