1use std::sync::{
25 atomic::{AtomicUsize, Ordering},
26 Arc, Mutex,
27};
28
29use lru::LruCache;
30use std::num::NonZeroUsize;
31use vibesql_ast::Statement;
32use vibesql_types::SqlValue;
33
34use super::{extract_tables_from_statement, QuerySignature};
35
36pub mod arena_prepared;
37mod bind;
38pub mod plan;
39
40pub use plan::{
41 CachedPlan, ColumnProjection, PkDeletePlan, PkPointLookupPlan, ProjectionPlan, ResolvedProjection,
42 SimpleFastPathPlan,
43};
44
45#[derive(Debug, Clone)]
47pub struct PreparedStatement {
48 sql: String,
50 statement: Statement,
52 signature: QuerySignature,
54 param_count: usize,
56 tables: std::collections::HashSet<String>,
58 cached_plan: CachedPlan,
60}
61
62impl PreparedStatement {
63 pub fn new(sql: String, statement: Statement) -> Self {
65 let signature = QuerySignature::from_ast(&statement);
66 let param_count = bind::count_placeholders(&statement);
68 let tables = extract_tables_from_statement(&statement);
69 let cached_plan = plan::analyze_for_plan(&statement);
71
72 Self { sql, statement, signature, param_count, tables, cached_plan }
73 }
74
75 pub fn sql(&self) -> &str {
77 &self.sql
78 }
79
80 pub fn statement(&self) -> &Statement {
82 &self.statement
83 }
84
85 pub fn signature(&self) -> &QuerySignature {
87 &self.signature
88 }
89
90 pub fn param_count(&self) -> usize {
92 self.param_count
93 }
94
95 pub fn tables(&self) -> &std::collections::HashSet<String> {
97 &self.tables
98 }
99
100 pub fn cached_plan(&self) -> &CachedPlan {
102 &self.cached_plan
103 }
104
105 pub fn bind(&self, params: &[SqlValue]) -> Result<Statement, PreparedStatementError> {
114 if params.len() != self.param_count {
115 return Err(PreparedStatementError::ParameterCountMismatch {
116 expected: self.param_count,
117 actual: params.len(),
118 });
119 }
120
121 if self.param_count == 0 {
122 return Ok(self.statement.clone());
124 }
125
126 Ok(bind::bind_parameters(&self.statement, params))
128 }
129}
130
131#[derive(Debug, Clone)]
133pub enum PreparedStatementError {
134 ParameterCountMismatch { expected: usize, actual: usize },
136 ParseError(String),
138}
139
140impl std::fmt::Display for PreparedStatementError {
141 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
142 match self {
143 PreparedStatementError::ParameterCountMismatch { expected, actual } => {
144 write!(f, "Parameter count mismatch: expected {}, got {}", expected, actual)
145 }
146 PreparedStatementError::ParseError(msg) => write!(f, "Parse error: {}", msg),
147 }
148 }
149}
150
151impl std::error::Error for PreparedStatementError {}
152
153#[derive(Debug, Clone)]
155pub struct PreparedStatementCacheStats {
156 pub hits: usize,
157 pub misses: usize,
158 pub evictions: usize,
159 pub size: usize,
160 pub hit_rate: f64,
161}
162
163pub struct PreparedStatementCache {
165 cache: Mutex<LruCache<String, Arc<PreparedStatement>>>,
167 arena_cache: Mutex<LruCache<String, Arc<arena_prepared::ArenaPreparedStatement>>>,
169 max_size: usize,
171 hits: AtomicUsize,
173 misses: AtomicUsize,
175 evictions: AtomicUsize,
177 arena_hits: AtomicUsize,
179 arena_misses: AtomicUsize,
181}
182
183impl PreparedStatementCache {
184 pub fn new(max_size: usize) -> Self {
186 let cap = NonZeroUsize::new(max_size).unwrap_or(NonZeroUsize::new(1).unwrap());
187 Self {
188 cache: Mutex::new(LruCache::new(cap)),
189 arena_cache: Mutex::new(LruCache::new(cap)),
190 max_size,
191 hits: AtomicUsize::new(0),
192 misses: AtomicUsize::new(0),
193 evictions: AtomicUsize::new(0),
194 arena_hits: AtomicUsize::new(0),
195 arena_misses: AtomicUsize::new(0),
196 }
197 }
198
199 pub fn default_cache() -> Self {
201 Self::new(1000)
202 }
203
204 pub fn get(&self, sql: &str) -> Option<Arc<PreparedStatement>> {
206 let mut cache = self.cache.lock().unwrap();
207 if let Some(stmt) = cache.get(sql) {
208 self.hits.fetch_add(1, Ordering::Relaxed);
209 Some(Arc::clone(stmt))
210 } else {
211 self.misses.fetch_add(1, Ordering::Relaxed);
212 None
213 }
214 }
215
216 pub fn get_or_prepare(
222 &self,
223 sql: &str,
224 ) -> Result<Arc<PreparedStatement>, PreparedStatementError> {
225 let mut cache = self.cache.lock().unwrap();
227
228 if let Some(stmt) = cache.get(sql) {
230 self.hits.fetch_add(1, Ordering::Relaxed);
231 return Ok(Arc::clone(stmt));
232 }
233
234 self.misses.fetch_add(1, Ordering::Relaxed);
236 let statement = parse_with_arena_fallback(sql)
237 .map_err(|e| PreparedStatementError::ParseError(e.to_string()))?;
238
239 let prepared = Arc::new(PreparedStatement::new(sql.to_string(), statement));
240
241 if cache.len() >= self.max_size {
243 self.evictions.fetch_add(1, Ordering::Relaxed);
244 }
245
246 cache.put(sql.to_string(), Arc::clone(&prepared));
248
249 Ok(prepared)
250 }
251
252 pub fn get_or_prepare_arena(
265 &self,
266 sql: &str,
267 ) -> Result<Arc<arena_prepared::ArenaPreparedStatement>, arena_prepared::ArenaParseError> {
268 let mut cache = self.arena_cache.lock().unwrap();
270
271 if let Some(stmt) = cache.get(sql) {
273 self.arena_hits.fetch_add(1, Ordering::Relaxed);
274 return Ok(Arc::clone(stmt));
275 }
276
277 self.arena_misses.fetch_add(1, Ordering::Relaxed);
279 let prepared = Arc::new(arena_prepared::ArenaPreparedStatement::new(sql.to_string())?);
280
281 if cache.len() >= self.max_size {
283 self.evictions.fetch_add(1, Ordering::Relaxed);
284 }
285
286 cache.put(sql.to_string(), Arc::clone(&prepared));
288
289 Ok(prepared)
290 }
291
292 pub fn get_arena(&self, sql: &str) -> Option<Arc<arena_prepared::ArenaPreparedStatement>> {
294 let mut cache = self.arena_cache.lock().unwrap();
295 if let Some(stmt) = cache.get(sql) {
296 self.arena_hits.fetch_add(1, Ordering::Relaxed);
297 Some(Arc::clone(stmt))
298 } else {
299 None
300 }
301 }
302
303 pub fn clear(&self) {
305 self.cache.lock().unwrap().clear();
306 self.arena_cache.lock().unwrap().clear();
307 }
308
309 pub fn invalidate_table(&self, table: &str) {
311 {
313 let mut cache = self.cache.lock().unwrap();
314 let keys_to_remove: Vec<String> = cache
315 .iter()
316 .filter(|(_, stmt)| stmt.tables.iter().any(|t| t.eq_ignore_ascii_case(table)))
317 .map(|(k, _)| k.clone())
318 .collect();
319
320 for key in keys_to_remove {
321 cache.pop(&key);
322 }
323 }
324
325 {
327 let mut arena_cache = self.arena_cache.lock().unwrap();
328 let keys_to_remove: Vec<String> = arena_cache
329 .iter()
330 .filter(|(_, stmt)| stmt.tables().iter().any(|t| t.eq_ignore_ascii_case(table)))
331 .map(|(k, _)| k.clone())
332 .collect();
333
334 for key in keys_to_remove {
335 arena_cache.pop(&key);
336 }
337 }
338 }
339
340 pub fn stats(&self) -> PreparedStatementCacheStats {
342 let cache = self.cache.lock().unwrap();
343 let hits = self.hits.load(Ordering::Relaxed);
344 let misses = self.misses.load(Ordering::Relaxed);
345 let total = hits + misses;
346 let hit_rate = if total > 0 { hits as f64 / total as f64 } else { 0.0 };
347
348 PreparedStatementCacheStats {
349 hits,
350 misses,
351 evictions: self.evictions.load(Ordering::Relaxed),
352 size: cache.len(),
353 hit_rate,
354 }
355 }
356
357 pub fn max_size(&self) -> usize {
359 self.max_size
360 }
361}
362
363fn parse_with_arena_fallback(sql: &str) -> Result<Statement, vibesql_parser::ParseError> {
375 let trimmed = sql.trim_start();
377 let first_word = trimmed.split_whitespace().next().unwrap_or("");
378
379 if first_word.eq_ignore_ascii_case("SELECT") || first_word.eq_ignore_ascii_case("WITH") {
381 if let Ok(select_stmt) = vibesql_parser::arena_parser::parse_select_to_owned(sql) {
382 return Ok(Statement::Select(Box::new(select_stmt)));
383 }
384 }
386
387 if first_word.eq_ignore_ascii_case("INSERT") {
389 if let Ok(insert_stmt) = vibesql_parser::arena_parser::parse_insert_to_owned(sql) {
390 return Ok(Statement::Insert(insert_stmt));
391 }
392 }
394
395 if first_word.eq_ignore_ascii_case("REPLACE") {
397 if let Ok(insert_stmt) = vibesql_parser::arena_parser::parse_insert_to_owned(sql) {
398 return Ok(Statement::Insert(insert_stmt));
399 }
400 }
402
403 if first_word.eq_ignore_ascii_case("UPDATE") {
405 if let Ok(update_stmt) = vibesql_parser::arena_parser::parse_update_to_owned(sql) {
406 return Ok(Statement::Update(update_stmt));
407 }
408 }
410
411 if first_word.eq_ignore_ascii_case("DELETE") {
413 if let Ok(delete_stmt) = vibesql_parser::arena_parser::parse_delete_to_owned(sql) {
414 return Ok(Statement::Delete(delete_stmt));
415 }
416 }
418
419 vibesql_parser::Parser::parse_sql(sql)
421}
422
423#[cfg(test)]
424mod tests {
425 use super::*;
426 use vibesql_ast::Expression;
427
428 #[test]
429 fn test_prepared_statement_no_params() {
430 let sql = "SELECT * FROM users";
431 let statement = vibesql_parser::Parser::parse_sql(sql).unwrap();
432 let prepared = PreparedStatement::new(sql.to_string(), statement);
433
434 assert_eq!(prepared.param_count(), 0);
435 assert!(prepared.bind(&[]).is_ok());
436 }
437
438 #[test]
439 fn test_prepared_statement_with_placeholder() {
440 let sql = "SELECT * FROM users WHERE id = ?";
442 let statement = vibesql_parser::Parser::parse_sql(sql).unwrap();
443 let prepared = PreparedStatement::new(sql.to_string(), statement);
444
445 assert_eq!(prepared.param_count(), 1);
447
448 let bound = prepared.bind(&[SqlValue::Integer(42)]).unwrap();
450 assert!(matches!(bound, Statement::Select(_)));
451
452 if let Statement::Select(select) = bound {
454 if let Some(Expression::BinaryOp { right, .. }) = &select.where_clause {
455 assert_eq!(**right, Expression::Literal(SqlValue::Integer(42)));
456 } else {
457 panic!("Expected BinaryOp in WHERE clause");
458 }
459 }
460 }
461
462 #[test]
463 fn test_prepared_statement_multiple_placeholders() {
464 let sql = "SELECT * FROM users WHERE id = ? AND name = ?";
465 let statement = vibesql_parser::Parser::parse_sql(sql).unwrap();
466 let prepared = PreparedStatement::new(sql.to_string(), statement);
467
468 assert_eq!(prepared.param_count(), 2);
469
470 let params = vec![SqlValue::Integer(42), SqlValue::Varchar(arcstr::ArcStr::from("John"))];
471 let bound = prepared.bind(¶ms).unwrap();
472 assert!(matches!(bound, Statement::Select(_)));
473 }
474
475 #[test]
476 fn test_prepared_statement_bind_param_mismatch() {
477 let sql = "SELECT * FROM users WHERE id = ?";
478 let statement = vibesql_parser::Parser::parse_sql(sql).unwrap();
479 let prepared = PreparedStatement::new(sql.to_string(), statement);
480
481 let result = prepared.bind(&[]);
483 assert!(matches!(
484 result,
485 Err(PreparedStatementError::ParameterCountMismatch { expected: 1, actual: 0 })
486 ));
487
488 let result = prepared.bind(&[SqlValue::Integer(1), SqlValue::Integer(2)]);
490 assert!(matches!(
491 result,
492 Err(PreparedStatementError::ParameterCountMismatch { expected: 1, actual: 2 })
493 ));
494 }
495
496 #[test]
497 fn test_prepared_statement_reuse() {
498 let sql = "SELECT * FROM users WHERE id = ?";
500 let statement = vibesql_parser::Parser::parse_sql(sql).unwrap();
501 let prepared = PreparedStatement::new(sql.to_string(), statement);
502
503 let bound1 = prepared.bind(&[SqlValue::Integer(1)]).unwrap();
505 let bound2 = prepared.bind(&[SqlValue::Integer(2)]).unwrap();
506 let bound3 = prepared.bind(&[SqlValue::Integer(3)]).unwrap();
507
508 for (bound, expected_id) in [(bound1, 1), (bound2, 2), (bound3, 3)] {
510 if let Statement::Select(select) = bound {
511 if let Some(Expression::BinaryOp { right, .. }) = &select.where_clause {
512 assert_eq!(**right, Expression::Literal(SqlValue::Integer(expected_id)));
513 }
514 }
515 }
516 }
517
518 #[test]
519 fn test_cache_get_or_prepare() {
520 let cache = PreparedStatementCache::new(10);
521 let sql = "SELECT * FROM users WHERE id = 1";
522
523 let stmt1 = cache.get_or_prepare(sql).unwrap();
525 let stats = cache.stats();
526 assert_eq!(stats.misses, 1);
527 assert_eq!(stats.hits, 0);
528
529 let stmt2 = cache.get_or_prepare(sql).unwrap();
531 let stats = cache.stats();
532 assert_eq!(stats.misses, 1);
533 assert_eq!(stats.hits, 1);
534
535 assert!(Arc::ptr_eq(&stmt1, &stmt2));
537 }
538
539 #[test]
540 fn test_cache_placeholder_reuse() {
541 let cache = PreparedStatementCache::new(10);
543 let sql = "SELECT * FROM users WHERE id = ?";
544
545 let stmt1 = cache.get_or_prepare(sql).unwrap();
547 assert_eq!(cache.stats().misses, 1);
548 assert_eq!(cache.stats().hits, 0);
549
550 let stmt2 = cache.get_or_prepare(sql).unwrap();
552 assert_eq!(cache.stats().misses, 1);
553 assert_eq!(cache.stats().hits, 1);
554
555 assert!(Arc::ptr_eq(&stmt1, &stmt2));
557
558 let bound1 = stmt1.bind(&[SqlValue::Integer(1)]).unwrap();
560 let bound2 = stmt2.bind(&[SqlValue::Integer(999)]).unwrap();
561
562 if let (Statement::Select(s1), Statement::Select(s2)) = (&bound1, &bound2) {
564 if let (
565 Some(Expression::BinaryOp { right: r1, .. }),
566 Some(Expression::BinaryOp { right: r2, .. }),
567 ) = (&s1.where_clause, &s2.where_clause)
568 {
569 assert_eq!(**r1, Expression::Literal(SqlValue::Integer(1)));
570 assert_eq!(**r2, Expression::Literal(SqlValue::Integer(999)));
571 }
572 }
573 }
574
575 #[test]
576 fn test_cache_lru_eviction() {
577 let cache = PreparedStatementCache::new(2);
578
579 cache.get_or_prepare("SELECT * FROM users").unwrap();
580 cache.get_or_prepare("SELECT * FROM orders").unwrap();
581 assert_eq!(cache.stats().size, 2);
582 assert_eq!(cache.stats().evictions, 0);
583
584 cache.get_or_prepare("SELECT * FROM products").unwrap();
586 assert_eq!(cache.stats().size, 2);
587 assert_eq!(cache.stats().evictions, 1);
588
589 assert!(cache.get("SELECT * FROM users").is_none());
591 assert!(cache.get("SELECT * FROM orders").is_some());
592 assert!(cache.get("SELECT * FROM products").is_some());
593 }
594
595 #[test]
596 fn test_cache_table_invalidation() {
597 let cache = PreparedStatementCache::new(10);
598
599 cache.get_or_prepare("SELECT * FROM users WHERE id = ?").unwrap();
600 cache.get_or_prepare("SELECT * FROM orders WHERE id = ?").unwrap();
601 assert_eq!(cache.stats().size, 2);
602
603 cache.invalidate_table("users");
605 assert_eq!(cache.stats().size, 1);
606
607 assert!(cache.get("SELECT * FROM orders WHERE id = ?").is_some());
609 }
610
611 #[test]
612 fn test_arena_parse_insert() {
613 let sql = "INSERT INTO users (name, email) VALUES ('Alice', 'alice@example.com')";
614 let result = parse_with_arena_fallback(sql);
615 assert!(result.is_ok());
616 assert!(matches!(result.unwrap(), Statement::Insert(_)));
617 }
618
619 #[test]
620 fn test_arena_parse_insert_with_placeholders() {
621 let cache = PreparedStatementCache::new(10);
622 let sql = "INSERT INTO users (name, email) VALUES (?, ?)";
623
624 let stmt = cache.get_or_prepare(sql).unwrap();
625 assert_eq!(stmt.param_count(), 2);
626
627 let bound = stmt
628 .bind(&[
629 SqlValue::Varchar(arcstr::ArcStr::from("Bob")),
630 SqlValue::Varchar(arcstr::ArcStr::from("bob@example.com")),
631 ])
632 .unwrap();
633 assert!(matches!(bound, Statement::Insert(_)));
634 }
635
636 #[test]
637 fn test_arena_parse_update() {
638 let sql = "UPDATE users SET name = 'Bob' WHERE id = 1";
639 let result = parse_with_arena_fallback(sql);
640 assert!(result.is_ok());
641 assert!(matches!(result.unwrap(), Statement::Update(_)));
642 }
643
644 #[test]
645 fn test_arena_parse_update_with_placeholders() {
646 let cache = PreparedStatementCache::new(10);
647 let sql = "UPDATE users SET name = ? WHERE id = ?";
648
649 let stmt = cache.get_or_prepare(sql).unwrap();
650 assert_eq!(stmt.param_count(), 2);
651
652 let bound =
653 stmt.bind(&[SqlValue::Varchar(arcstr::ArcStr::from("Charlie")), SqlValue::Integer(42)]).unwrap();
654 assert!(matches!(bound, Statement::Update(_)));
655 }
656
657 #[test]
658 fn test_arena_parse_delete() {
659 let sql = "DELETE FROM users WHERE id = 1";
660 let result = parse_with_arena_fallback(sql);
661 assert!(result.is_ok());
662 assert!(matches!(result.unwrap(), Statement::Delete(_)));
663 }
664
665 #[test]
666 fn test_arena_parse_delete_with_placeholders() {
667 let cache = PreparedStatementCache::new(10);
668 let sql = "DELETE FROM users WHERE id = ?";
669
670 let stmt = cache.get_or_prepare(sql).unwrap();
671 assert_eq!(stmt.param_count(), 1);
672
673 let bound = stmt.bind(&[SqlValue::Integer(99)]).unwrap();
674 assert!(matches!(bound, Statement::Delete(_)));
675 }
676
677 #[test]
678 fn test_arena_parse_sysbench_insert() {
679 let cache = PreparedStatementCache::new(10);
681
682 let sql3 = "INSERT INTO test (a, b, c) VALUES (?, ?, ?)";
684 let stmt3 = cache.get_or_prepare(sql3);
685 assert!(stmt3.is_ok(), "3-column INSERT failed: {:?}", stmt3.err());
686
687 let sql4_gen = "INSERT INTO test (a, b, c, d) VALUES (?, ?, ?, ?)";
689 let stmt4_gen = cache.get_or_prepare(sql4_gen);
690 assert!(stmt4_gen.is_ok(), "4-column INSERT (generic) failed: {:?}", stmt4_gen.err());
691
692 let sql4 = "INSERT INTO sbtest1 (id, k, c, padding) VALUES (?, ?, ?, ?)";
694 let stmt4 = cache.get_or_prepare(sql4);
695 assert!(stmt4.is_ok(), "4-column INSERT (sysbench) failed: {:?}", stmt4.err());
696 assert_eq!(stmt4.unwrap().param_count(), 4);
697 }
698}