1use std::sync::{
25 atomic::{AtomicUsize, Ordering},
26 Arc, Mutex,
27};
28
29use lru::LruCache;
30use std::num::NonZeroUsize;
31use vibesql_ast::Statement;
32use vibesql_types::SqlValue;
33
34use super::{extract_tables_from_statement, QuerySignature};
35
36pub mod arena_prepared;
37mod bind;
38pub mod plan;
39
40pub use plan::{CachedPlan, ColumnProjection, PkPointLookupPlan, ProjectionPlan, ResolvedProjection};
41
42#[derive(Debug, Clone)]
44pub struct PreparedStatement {
45 sql: String,
47 statement: Statement,
49 signature: QuerySignature,
51 param_count: usize,
53 tables: std::collections::HashSet<String>,
55 cached_plan: CachedPlan,
57}
58
59impl PreparedStatement {
60 pub fn new(sql: String, statement: Statement) -> Self {
62 let signature = QuerySignature::from_ast(&statement);
63 let param_count = bind::count_placeholders(&statement);
65 let tables = extract_tables_from_statement(&statement);
66 let cached_plan = plan::analyze_for_plan(&statement);
68
69 Self { sql, statement, signature, param_count, tables, cached_plan }
70 }
71
72 pub fn sql(&self) -> &str {
74 &self.sql
75 }
76
77 pub fn statement(&self) -> &Statement {
79 &self.statement
80 }
81
82 pub fn signature(&self) -> &QuerySignature {
84 &self.signature
85 }
86
87 pub fn param_count(&self) -> usize {
89 self.param_count
90 }
91
92 pub fn tables(&self) -> &std::collections::HashSet<String> {
94 &self.tables
95 }
96
97 pub fn cached_plan(&self) -> &CachedPlan {
99 &self.cached_plan
100 }
101
102 pub fn bind(&self, params: &[SqlValue]) -> Result<Statement, PreparedStatementError> {
111 if params.len() != self.param_count {
112 return Err(PreparedStatementError::ParameterCountMismatch {
113 expected: self.param_count,
114 actual: params.len(),
115 });
116 }
117
118 if self.param_count == 0 {
119 return Ok(self.statement.clone());
121 }
122
123 Ok(bind::bind_parameters(&self.statement, params))
125 }
126}
127
128#[derive(Debug, Clone)]
130pub enum PreparedStatementError {
131 ParameterCountMismatch { expected: usize, actual: usize },
133 ParseError(String),
135}
136
137impl std::fmt::Display for PreparedStatementError {
138 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
139 match self {
140 PreparedStatementError::ParameterCountMismatch { expected, actual } => {
141 write!(f, "Parameter count mismatch: expected {}, got {}", expected, actual)
142 }
143 PreparedStatementError::ParseError(msg) => write!(f, "Parse error: {}", msg),
144 }
145 }
146}
147
148impl std::error::Error for PreparedStatementError {}
149
150#[derive(Debug, Clone)]
152pub struct PreparedStatementCacheStats {
153 pub hits: usize,
154 pub misses: usize,
155 pub evictions: usize,
156 pub size: usize,
157 pub hit_rate: f64,
158}
159
160pub struct PreparedStatementCache {
162 cache: Mutex<LruCache<String, Arc<PreparedStatement>>>,
164 arena_cache: Mutex<LruCache<String, Arc<arena_prepared::ArenaPreparedStatement>>>,
166 max_size: usize,
168 hits: AtomicUsize,
170 misses: AtomicUsize,
172 evictions: AtomicUsize,
174 arena_hits: AtomicUsize,
176 arena_misses: AtomicUsize,
178}
179
180impl PreparedStatementCache {
181 pub fn new(max_size: usize) -> Self {
183 let cap = NonZeroUsize::new(max_size).unwrap_or(NonZeroUsize::new(1).unwrap());
184 Self {
185 cache: Mutex::new(LruCache::new(cap)),
186 arena_cache: Mutex::new(LruCache::new(cap)),
187 max_size,
188 hits: AtomicUsize::new(0),
189 misses: AtomicUsize::new(0),
190 evictions: AtomicUsize::new(0),
191 arena_hits: AtomicUsize::new(0),
192 arena_misses: AtomicUsize::new(0),
193 }
194 }
195
196 pub fn default_cache() -> Self {
198 Self::new(1000)
199 }
200
201 pub fn get(&self, sql: &str) -> Option<Arc<PreparedStatement>> {
203 let mut cache = self.cache.lock().unwrap();
204 if let Some(stmt) = cache.get(sql) {
205 self.hits.fetch_add(1, Ordering::Relaxed);
206 Some(Arc::clone(stmt))
207 } else {
208 self.misses.fetch_add(1, Ordering::Relaxed);
209 None
210 }
211 }
212
213 pub fn get_or_prepare(
219 &self,
220 sql: &str,
221 ) -> Result<Arc<PreparedStatement>, PreparedStatementError> {
222 let mut cache = self.cache.lock().unwrap();
224
225 if let Some(stmt) = cache.get(sql) {
227 self.hits.fetch_add(1, Ordering::Relaxed);
228 return Ok(Arc::clone(stmt));
229 }
230
231 self.misses.fetch_add(1, Ordering::Relaxed);
233 let statement = parse_with_arena_fallback(sql)
234 .map_err(|e| PreparedStatementError::ParseError(e.to_string()))?;
235
236 let prepared = Arc::new(PreparedStatement::new(sql.to_string(), statement));
237
238 if cache.len() >= self.max_size {
240 self.evictions.fetch_add(1, Ordering::Relaxed);
241 }
242
243 cache.put(sql.to_string(), Arc::clone(&prepared));
245
246 Ok(prepared)
247 }
248
249 pub fn get_or_prepare_arena(
262 &self,
263 sql: &str,
264 ) -> Result<Arc<arena_prepared::ArenaPreparedStatement>, arena_prepared::ArenaParseError> {
265 let mut cache = self.arena_cache.lock().unwrap();
267
268 if let Some(stmt) = cache.get(sql) {
270 self.arena_hits.fetch_add(1, Ordering::Relaxed);
271 return Ok(Arc::clone(stmt));
272 }
273
274 self.arena_misses.fetch_add(1, Ordering::Relaxed);
276 let prepared = Arc::new(arena_prepared::ArenaPreparedStatement::new(sql.to_string())?);
277
278 if cache.len() >= self.max_size {
280 self.evictions.fetch_add(1, Ordering::Relaxed);
281 }
282
283 cache.put(sql.to_string(), Arc::clone(&prepared));
285
286 Ok(prepared)
287 }
288
289 pub fn get_arena(&self, sql: &str) -> Option<Arc<arena_prepared::ArenaPreparedStatement>> {
291 let mut cache = self.arena_cache.lock().unwrap();
292 if let Some(stmt) = cache.get(sql) {
293 self.arena_hits.fetch_add(1, Ordering::Relaxed);
294 Some(Arc::clone(stmt))
295 } else {
296 None
297 }
298 }
299
300 pub fn clear(&self) {
302 self.cache.lock().unwrap().clear();
303 self.arena_cache.lock().unwrap().clear();
304 }
305
306 pub fn invalidate_table(&self, table: &str) {
308 {
310 let mut cache = self.cache.lock().unwrap();
311 let keys_to_remove: Vec<String> = cache
312 .iter()
313 .filter(|(_, stmt)| stmt.tables.iter().any(|t| t.eq_ignore_ascii_case(table)))
314 .map(|(k, _)| k.clone())
315 .collect();
316
317 for key in keys_to_remove {
318 cache.pop(&key);
319 }
320 }
321
322 {
324 let mut arena_cache = self.arena_cache.lock().unwrap();
325 let keys_to_remove: Vec<String> = arena_cache
326 .iter()
327 .filter(|(_, stmt)| stmt.tables().iter().any(|t| t.eq_ignore_ascii_case(table)))
328 .map(|(k, _)| k.clone())
329 .collect();
330
331 for key in keys_to_remove {
332 arena_cache.pop(&key);
333 }
334 }
335 }
336
337 pub fn stats(&self) -> PreparedStatementCacheStats {
339 let cache = self.cache.lock().unwrap();
340 let hits = self.hits.load(Ordering::Relaxed);
341 let misses = self.misses.load(Ordering::Relaxed);
342 let total = hits + misses;
343 let hit_rate = if total > 0 { hits as f64 / total as f64 } else { 0.0 };
344
345 PreparedStatementCacheStats {
346 hits,
347 misses,
348 evictions: self.evictions.load(Ordering::Relaxed),
349 size: cache.len(),
350 hit_rate,
351 }
352 }
353
354 pub fn max_size(&self) -> usize {
356 self.max_size
357 }
358}
359
360fn parse_with_arena_fallback(sql: &str) -> Result<Statement, vibesql_parser::ParseError> {
372 let trimmed = sql.trim_start();
374 let first_word = trimmed.split_whitespace().next().unwrap_or("");
375
376 if first_word.eq_ignore_ascii_case("SELECT") || first_word.eq_ignore_ascii_case("WITH") {
378 if let Ok(select_stmt) = vibesql_parser::arena_parser::parse_select_to_owned(sql) {
379 return Ok(Statement::Select(Box::new(select_stmt)));
380 }
381 }
383
384 if first_word.eq_ignore_ascii_case("INSERT") {
386 if let Ok(insert_stmt) = vibesql_parser::arena_parser::parse_insert_to_owned(sql) {
387 return Ok(Statement::Insert(insert_stmt));
388 }
389 }
391
392 if first_word.eq_ignore_ascii_case("REPLACE") {
394 if let Ok(insert_stmt) = vibesql_parser::arena_parser::parse_insert_to_owned(sql) {
395 return Ok(Statement::Insert(insert_stmt));
396 }
397 }
399
400 if first_word.eq_ignore_ascii_case("UPDATE") {
402 if let Ok(update_stmt) = vibesql_parser::arena_parser::parse_update_to_owned(sql) {
403 return Ok(Statement::Update(update_stmt));
404 }
405 }
407
408 if first_word.eq_ignore_ascii_case("DELETE") {
410 if let Ok(delete_stmt) = vibesql_parser::arena_parser::parse_delete_to_owned(sql) {
411 return Ok(Statement::Delete(delete_stmt));
412 }
413 }
415
416 vibesql_parser::Parser::parse_sql(sql)
418}
419
420#[cfg(test)]
421mod tests {
422 use super::*;
423 use vibesql_ast::Expression;
424
425 #[test]
426 fn test_prepared_statement_no_params() {
427 let sql = "SELECT * FROM users";
428 let statement = vibesql_parser::Parser::parse_sql(sql).unwrap();
429 let prepared = PreparedStatement::new(sql.to_string(), statement);
430
431 assert_eq!(prepared.param_count(), 0);
432 assert!(prepared.bind(&[]).is_ok());
433 }
434
435 #[test]
436 fn test_prepared_statement_with_placeholder() {
437 let sql = "SELECT * FROM users WHERE id = ?";
439 let statement = vibesql_parser::Parser::parse_sql(sql).unwrap();
440 let prepared = PreparedStatement::new(sql.to_string(), statement);
441
442 assert_eq!(prepared.param_count(), 1);
444
445 let bound = prepared.bind(&[SqlValue::Integer(42)]).unwrap();
447 assert!(matches!(bound, Statement::Select(_)));
448
449 if let Statement::Select(select) = bound {
451 if let Some(Expression::BinaryOp { right, .. }) = &select.where_clause {
452 assert_eq!(**right, Expression::Literal(SqlValue::Integer(42)));
453 } else {
454 panic!("Expected BinaryOp in WHERE clause");
455 }
456 }
457 }
458
459 #[test]
460 fn test_prepared_statement_multiple_placeholders() {
461 let sql = "SELECT * FROM users WHERE id = ? AND name = ?";
462 let statement = vibesql_parser::Parser::parse_sql(sql).unwrap();
463 let prepared = PreparedStatement::new(sql.to_string(), statement);
464
465 assert_eq!(prepared.param_count(), 2);
466
467 let params = vec![SqlValue::Integer(42), SqlValue::Varchar("John".to_string())];
468 let bound = prepared.bind(¶ms).unwrap();
469 assert!(matches!(bound, Statement::Select(_)));
470 }
471
472 #[test]
473 fn test_prepared_statement_bind_param_mismatch() {
474 let sql = "SELECT * FROM users WHERE id = ?";
475 let statement = vibesql_parser::Parser::parse_sql(sql).unwrap();
476 let prepared = PreparedStatement::new(sql.to_string(), statement);
477
478 let result = prepared.bind(&[]);
480 assert!(matches!(
481 result,
482 Err(PreparedStatementError::ParameterCountMismatch { expected: 1, actual: 0 })
483 ));
484
485 let result = prepared.bind(&[SqlValue::Integer(1), SqlValue::Integer(2)]);
487 assert!(matches!(
488 result,
489 Err(PreparedStatementError::ParameterCountMismatch { expected: 1, actual: 2 })
490 ));
491 }
492
493 #[test]
494 fn test_prepared_statement_reuse() {
495 let sql = "SELECT * FROM users WHERE id = ?";
497 let statement = vibesql_parser::Parser::parse_sql(sql).unwrap();
498 let prepared = PreparedStatement::new(sql.to_string(), statement);
499
500 let bound1 = prepared.bind(&[SqlValue::Integer(1)]).unwrap();
502 let bound2 = prepared.bind(&[SqlValue::Integer(2)]).unwrap();
503 let bound3 = prepared.bind(&[SqlValue::Integer(3)]).unwrap();
504
505 for (bound, expected_id) in [(bound1, 1), (bound2, 2), (bound3, 3)] {
507 if let Statement::Select(select) = bound {
508 if let Some(Expression::BinaryOp { right, .. }) = &select.where_clause {
509 assert_eq!(**right, Expression::Literal(SqlValue::Integer(expected_id)));
510 }
511 }
512 }
513 }
514
515 #[test]
516 fn test_cache_get_or_prepare() {
517 let cache = PreparedStatementCache::new(10);
518 let sql = "SELECT * FROM users WHERE id = 1";
519
520 let stmt1 = cache.get_or_prepare(sql).unwrap();
522 let stats = cache.stats();
523 assert_eq!(stats.misses, 1);
524 assert_eq!(stats.hits, 0);
525
526 let stmt2 = cache.get_or_prepare(sql).unwrap();
528 let stats = cache.stats();
529 assert_eq!(stats.misses, 1);
530 assert_eq!(stats.hits, 1);
531
532 assert!(Arc::ptr_eq(&stmt1, &stmt2));
534 }
535
536 #[test]
537 fn test_cache_placeholder_reuse() {
538 let cache = PreparedStatementCache::new(10);
540 let sql = "SELECT * FROM users WHERE id = ?";
541
542 let stmt1 = cache.get_or_prepare(sql).unwrap();
544 assert_eq!(cache.stats().misses, 1);
545 assert_eq!(cache.stats().hits, 0);
546
547 let stmt2 = cache.get_or_prepare(sql).unwrap();
549 assert_eq!(cache.stats().misses, 1);
550 assert_eq!(cache.stats().hits, 1);
551
552 assert!(Arc::ptr_eq(&stmt1, &stmt2));
554
555 let bound1 = stmt1.bind(&[SqlValue::Integer(1)]).unwrap();
557 let bound2 = stmt2.bind(&[SqlValue::Integer(999)]).unwrap();
558
559 if let (Statement::Select(s1), Statement::Select(s2)) = (&bound1, &bound2) {
561 if let (
562 Some(Expression::BinaryOp { right: r1, .. }),
563 Some(Expression::BinaryOp { right: r2, .. }),
564 ) = (&s1.where_clause, &s2.where_clause)
565 {
566 assert_eq!(**r1, Expression::Literal(SqlValue::Integer(1)));
567 assert_eq!(**r2, Expression::Literal(SqlValue::Integer(999)));
568 }
569 }
570 }
571
572 #[test]
573 fn test_cache_lru_eviction() {
574 let cache = PreparedStatementCache::new(2);
575
576 cache.get_or_prepare("SELECT * FROM users").unwrap();
577 cache.get_or_prepare("SELECT * FROM orders").unwrap();
578 assert_eq!(cache.stats().size, 2);
579 assert_eq!(cache.stats().evictions, 0);
580
581 cache.get_or_prepare("SELECT * FROM products").unwrap();
583 assert_eq!(cache.stats().size, 2);
584 assert_eq!(cache.stats().evictions, 1);
585
586 assert!(cache.get("SELECT * FROM users").is_none());
588 assert!(cache.get("SELECT * FROM orders").is_some());
589 assert!(cache.get("SELECT * FROM products").is_some());
590 }
591
592 #[test]
593 fn test_cache_table_invalidation() {
594 let cache = PreparedStatementCache::new(10);
595
596 cache.get_or_prepare("SELECT * FROM users WHERE id = ?").unwrap();
597 cache.get_or_prepare("SELECT * FROM orders WHERE id = ?").unwrap();
598 assert_eq!(cache.stats().size, 2);
599
600 cache.invalidate_table("users");
602 assert_eq!(cache.stats().size, 1);
603
604 assert!(cache.get("SELECT * FROM orders WHERE id = ?").is_some());
606 }
607
608 #[test]
609 fn test_arena_parse_insert() {
610 let sql = "INSERT INTO users (name, email) VALUES ('Alice', 'alice@example.com')";
611 let result = parse_with_arena_fallback(sql);
612 assert!(result.is_ok());
613 assert!(matches!(result.unwrap(), Statement::Insert(_)));
614 }
615
616 #[test]
617 fn test_arena_parse_insert_with_placeholders() {
618 let cache = PreparedStatementCache::new(10);
619 let sql = "INSERT INTO users (name, email) VALUES (?, ?)";
620
621 let stmt = cache.get_or_prepare(sql).unwrap();
622 assert_eq!(stmt.param_count(), 2);
623
624 let bound = stmt
625 .bind(&[
626 SqlValue::Varchar("Bob".to_string()),
627 SqlValue::Varchar("bob@example.com".to_string()),
628 ])
629 .unwrap();
630 assert!(matches!(bound, Statement::Insert(_)));
631 }
632
633 #[test]
634 fn test_arena_parse_update() {
635 let sql = "UPDATE users SET name = 'Bob' WHERE id = 1";
636 let result = parse_with_arena_fallback(sql);
637 assert!(result.is_ok());
638 assert!(matches!(result.unwrap(), Statement::Update(_)));
639 }
640
641 #[test]
642 fn test_arena_parse_update_with_placeholders() {
643 let cache = PreparedStatementCache::new(10);
644 let sql = "UPDATE users SET name = ? WHERE id = ?";
645
646 let stmt = cache.get_or_prepare(sql).unwrap();
647 assert_eq!(stmt.param_count(), 2);
648
649 let bound =
650 stmt.bind(&[SqlValue::Varchar("Charlie".to_string()), SqlValue::Integer(42)]).unwrap();
651 assert!(matches!(bound, Statement::Update(_)));
652 }
653
654 #[test]
655 fn test_arena_parse_delete() {
656 let sql = "DELETE FROM users WHERE id = 1";
657 let result = parse_with_arena_fallback(sql);
658 assert!(result.is_ok());
659 assert!(matches!(result.unwrap(), Statement::Delete(_)));
660 }
661
662 #[test]
663 fn test_arena_parse_delete_with_placeholders() {
664 let cache = PreparedStatementCache::new(10);
665 let sql = "DELETE FROM users WHERE id = ?";
666
667 let stmt = cache.get_or_prepare(sql).unwrap();
668 assert_eq!(stmt.param_count(), 1);
669
670 let bound = stmt.bind(&[SqlValue::Integer(99)]).unwrap();
671 assert!(matches!(bound, Statement::Delete(_)));
672 }
673
674 #[test]
675 fn test_arena_parse_sysbench_insert() {
676 let cache = PreparedStatementCache::new(10);
678
679 let sql3 = "INSERT INTO test (a, b, c) VALUES (?, ?, ?)";
681 let stmt3 = cache.get_or_prepare(sql3);
682 assert!(stmt3.is_ok(), "3-column INSERT failed: {:?}", stmt3.err());
683
684 let sql4_gen = "INSERT INTO test (a, b, c, d) VALUES (?, ?, ?, ?)";
686 let stmt4_gen = cache.get_or_prepare(sql4_gen);
687 assert!(stmt4_gen.is_ok(), "4-column INSERT (generic) failed: {:?}", stmt4_gen.err());
688
689 let sql4 = "INSERT INTO sbtest1 (id, k, c, padding) VALUES (?, ?, ?, ?)";
691 let stmt4 = cache.get_or_prepare(sql4);
692 assert!(stmt4.is_ok(), "4-column INSERT (sysbench) failed: {:?}", stmt4.err());
693 assert_eq!(stmt4.unwrap().param_count(), 4);
694 }
695}