1use std::sync::{
25 atomic::{AtomicUsize, Ordering},
26 Arc, Mutex,
27};
28
29use lru::LruCache;
30use std::num::NonZeroUsize;
31use vibesql_ast::Statement;
32use vibesql_types::SqlValue;
33
34use super::{extract_tables_from_statement, QuerySignature};
35
36pub mod arena_prepared;
37mod bind;
38pub mod plan;
39
40pub use plan::{CachedPlan, PkPointLookupPlan, ProjectionPlan, ColumnProjection};
41
42#[derive(Debug, Clone)]
44pub struct PreparedStatement {
45 sql: String,
47 statement: Statement,
49 signature: QuerySignature,
51 param_count: usize,
53 tables: std::collections::HashSet<String>,
55 cached_plan: CachedPlan,
57}
58
59impl PreparedStatement {
60 pub fn new(sql: String, statement: Statement) -> Self {
62 let signature = QuerySignature::from_ast(&statement);
63 let param_count = bind::count_placeholders(&statement);
65 let tables = extract_tables_from_statement(&statement);
66 let cached_plan = plan::analyze_for_plan(&statement);
68
69 Self {
70 sql,
71 statement,
72 signature,
73 param_count,
74 tables,
75 cached_plan,
76 }
77 }
78
79 pub fn sql(&self) -> &str {
81 &self.sql
82 }
83
84 pub fn statement(&self) -> &Statement {
86 &self.statement
87 }
88
89 pub fn signature(&self) -> &QuerySignature {
91 &self.signature
92 }
93
94 pub fn param_count(&self) -> usize {
96 self.param_count
97 }
98
99 pub fn tables(&self) -> &std::collections::HashSet<String> {
101 &self.tables
102 }
103
104 pub fn cached_plan(&self) -> &CachedPlan {
106 &self.cached_plan
107 }
108
109 pub fn bind(&self, params: &[SqlValue]) -> Result<Statement, PreparedStatementError> {
118 if params.len() != self.param_count {
119 return Err(PreparedStatementError::ParameterCountMismatch {
120 expected: self.param_count,
121 actual: params.len(),
122 });
123 }
124
125 if self.param_count == 0 {
126 return Ok(self.statement.clone());
128 }
129
130 Ok(bind::bind_parameters(&self.statement, params))
132 }
133}
134
135#[derive(Debug, Clone)]
137pub enum PreparedStatementError {
138 ParameterCountMismatch { expected: usize, actual: usize },
140 ParseError(String),
142}
143
144impl std::fmt::Display for PreparedStatementError {
145 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
146 match self {
147 PreparedStatementError::ParameterCountMismatch { expected, actual } => {
148 write!(f, "Parameter count mismatch: expected {}, got {}", expected, actual)
149 }
150 PreparedStatementError::ParseError(msg) => write!(f, "Parse error: {}", msg),
151 }
152 }
153}
154
155impl std::error::Error for PreparedStatementError {}
156
157#[derive(Debug, Clone)]
159pub struct PreparedStatementCacheStats {
160 pub hits: usize,
161 pub misses: usize,
162 pub evictions: usize,
163 pub size: usize,
164 pub hit_rate: f64,
165}
166
167pub struct PreparedStatementCache {
169 cache: Mutex<LruCache<String, Arc<PreparedStatement>>>,
171 arena_cache: Mutex<LruCache<String, Arc<arena_prepared::ArenaPreparedStatement>>>,
173 max_size: usize,
175 hits: AtomicUsize,
177 misses: AtomicUsize,
179 evictions: AtomicUsize,
181 arena_hits: AtomicUsize,
183 arena_misses: AtomicUsize,
185}
186
187impl PreparedStatementCache {
188 pub fn new(max_size: usize) -> Self {
190 let cap = NonZeroUsize::new(max_size).unwrap_or(NonZeroUsize::new(1).unwrap());
191 Self {
192 cache: Mutex::new(LruCache::new(cap)),
193 arena_cache: Mutex::new(LruCache::new(cap)),
194 max_size,
195 hits: AtomicUsize::new(0),
196 misses: AtomicUsize::new(0),
197 evictions: AtomicUsize::new(0),
198 arena_hits: AtomicUsize::new(0),
199 arena_misses: AtomicUsize::new(0),
200 }
201 }
202
203 pub fn default_cache() -> Self {
205 Self::new(1000)
206 }
207
208 pub fn get(&self, sql: &str) -> Option<Arc<PreparedStatement>> {
210 let mut cache = self.cache.lock().unwrap();
211 if let Some(stmt) = cache.get(sql) {
212 self.hits.fetch_add(1, Ordering::Relaxed);
213 Some(Arc::clone(stmt))
214 } else {
215 self.misses.fetch_add(1, Ordering::Relaxed);
216 None
217 }
218 }
219
220 pub fn get_or_prepare(
226 &self,
227 sql: &str,
228 ) -> Result<Arc<PreparedStatement>, PreparedStatementError> {
229 let mut cache = self.cache.lock().unwrap();
231
232 if let Some(stmt) = cache.get(sql) {
234 self.hits.fetch_add(1, Ordering::Relaxed);
235 return Ok(Arc::clone(stmt));
236 }
237
238 self.misses.fetch_add(1, Ordering::Relaxed);
240 let statement = parse_with_arena_fallback(sql)
241 .map_err(|e| PreparedStatementError::ParseError(e.to_string()))?;
242
243 let prepared = Arc::new(PreparedStatement::new(sql.to_string(), statement));
244
245 if cache.len() >= self.max_size {
247 self.evictions.fetch_add(1, Ordering::Relaxed);
248 }
249
250 cache.put(sql.to_string(), Arc::clone(&prepared));
252
253 Ok(prepared)
254 }
255
256 pub fn get_or_prepare_arena(
269 &self,
270 sql: &str,
271 ) -> Result<Arc<arena_prepared::ArenaPreparedStatement>, arena_prepared::ArenaParseError> {
272 let mut cache = self.arena_cache.lock().unwrap();
274
275 if let Some(stmt) = cache.get(sql) {
277 self.arena_hits.fetch_add(1, Ordering::Relaxed);
278 return Ok(Arc::clone(stmt));
279 }
280
281 self.arena_misses.fetch_add(1, Ordering::Relaxed);
283 let prepared = Arc::new(arena_prepared::ArenaPreparedStatement::new(sql.to_string())?);
284
285 if cache.len() >= self.max_size {
287 self.evictions.fetch_add(1, Ordering::Relaxed);
288 }
289
290 cache.put(sql.to_string(), Arc::clone(&prepared));
292
293 Ok(prepared)
294 }
295
296 pub fn get_arena(&self, sql: &str) -> Option<Arc<arena_prepared::ArenaPreparedStatement>> {
298 let mut cache = self.arena_cache.lock().unwrap();
299 if let Some(stmt) = cache.get(sql) {
300 self.arena_hits.fetch_add(1, Ordering::Relaxed);
301 Some(Arc::clone(stmt))
302 } else {
303 None
304 }
305 }
306
307 pub fn clear(&self) {
309 self.cache.lock().unwrap().clear();
310 self.arena_cache.lock().unwrap().clear();
311 }
312
313 pub fn invalidate_table(&self, table: &str) {
315 {
317 let mut cache = self.cache.lock().unwrap();
318 let keys_to_remove: Vec<String> = cache
319 .iter()
320 .filter(|(_, stmt)| stmt.tables.iter().any(|t| t.eq_ignore_ascii_case(table)))
321 .map(|(k, _)| k.clone())
322 .collect();
323
324 for key in keys_to_remove {
325 cache.pop(&key);
326 }
327 }
328
329 {
331 let mut arena_cache = self.arena_cache.lock().unwrap();
332 let keys_to_remove: Vec<String> = arena_cache
333 .iter()
334 .filter(|(_, stmt)| stmt.tables().iter().any(|t| t.eq_ignore_ascii_case(table)))
335 .map(|(k, _)| k.clone())
336 .collect();
337
338 for key in keys_to_remove {
339 arena_cache.pop(&key);
340 }
341 }
342 }
343
344 pub fn stats(&self) -> PreparedStatementCacheStats {
346 let cache = self.cache.lock().unwrap();
347 let hits = self.hits.load(Ordering::Relaxed);
348 let misses = self.misses.load(Ordering::Relaxed);
349 let total = hits + misses;
350 let hit_rate = if total > 0 { hits as f64 / total as f64 } else { 0.0 };
351
352 PreparedStatementCacheStats {
353 hits,
354 misses,
355 evictions: self.evictions.load(Ordering::Relaxed),
356 size: cache.len(),
357 hit_rate,
358 }
359 }
360
361 pub fn max_size(&self) -> usize {
363 self.max_size
364 }
365}
366
367fn parse_with_arena_fallback(sql: &str) -> Result<Statement, vibesql_parser::ParseError> {
379 let trimmed = sql.trim_start();
381 let first_word = trimmed.split_whitespace().next().unwrap_or("");
382
383 if first_word.eq_ignore_ascii_case("SELECT") || first_word.eq_ignore_ascii_case("WITH") {
385 if let Ok(select_stmt) = vibesql_parser::arena_parser::parse_select_to_owned(sql) {
386 return Ok(Statement::Select(Box::new(select_stmt)));
387 }
388 }
390
391 if first_word.eq_ignore_ascii_case("INSERT") {
393 if let Ok(insert_stmt) = vibesql_parser::arena_parser::parse_insert_to_owned(sql) {
394 return Ok(Statement::Insert(insert_stmt));
395 }
396 }
398
399 if first_word.eq_ignore_ascii_case("REPLACE") {
401 if let Ok(insert_stmt) = vibesql_parser::arena_parser::parse_insert_to_owned(sql) {
402 return Ok(Statement::Insert(insert_stmt));
403 }
404 }
406
407 if first_word.eq_ignore_ascii_case("UPDATE") {
409 if let Ok(update_stmt) = vibesql_parser::arena_parser::parse_update_to_owned(sql) {
410 return Ok(Statement::Update(update_stmt));
411 }
412 }
414
415 if first_word.eq_ignore_ascii_case("DELETE") {
417 if let Ok(delete_stmt) = vibesql_parser::arena_parser::parse_delete_to_owned(sql) {
418 return Ok(Statement::Delete(delete_stmt));
419 }
420 }
422
423 vibesql_parser::Parser::parse_sql(sql)
425}
426
427#[cfg(test)]
428mod tests {
429 use super::*;
430 use vibesql_ast::Expression;
431
432 #[test]
433 fn test_prepared_statement_no_params() {
434 let sql = "SELECT * FROM users";
435 let statement = vibesql_parser::Parser::parse_sql(sql).unwrap();
436 let prepared = PreparedStatement::new(sql.to_string(), statement);
437
438 assert_eq!(prepared.param_count(), 0);
439 assert!(prepared.bind(&[]).is_ok());
440 }
441
442 #[test]
443 fn test_prepared_statement_with_placeholder() {
444 let sql = "SELECT * FROM users WHERE id = ?";
446 let statement = vibesql_parser::Parser::parse_sql(sql).unwrap();
447 let prepared = PreparedStatement::new(sql.to_string(), statement);
448
449 assert_eq!(prepared.param_count(), 1);
451
452 let bound = prepared.bind(&[SqlValue::Integer(42)]).unwrap();
454 assert!(matches!(bound, Statement::Select(_)));
455
456 if let Statement::Select(select) = bound {
458 if let Some(Expression::BinaryOp { right, .. }) = &select.where_clause {
459 assert_eq!(**right, Expression::Literal(SqlValue::Integer(42)));
460 } else {
461 panic!("Expected BinaryOp in WHERE clause");
462 }
463 }
464 }
465
466 #[test]
467 fn test_prepared_statement_multiple_placeholders() {
468 let sql = "SELECT * FROM users WHERE id = ? AND name = ?";
469 let statement = vibesql_parser::Parser::parse_sql(sql).unwrap();
470 let prepared = PreparedStatement::new(sql.to_string(), statement);
471
472 assert_eq!(prepared.param_count(), 2);
473
474 let params = vec![SqlValue::Integer(42), SqlValue::Varchar("John".to_string())];
475 let bound = prepared.bind(¶ms).unwrap();
476 assert!(matches!(bound, Statement::Select(_)));
477 }
478
479 #[test]
480 fn test_prepared_statement_bind_param_mismatch() {
481 let sql = "SELECT * FROM users WHERE id = ?";
482 let statement = vibesql_parser::Parser::parse_sql(sql).unwrap();
483 let prepared = PreparedStatement::new(sql.to_string(), statement);
484
485 let result = prepared.bind(&[]);
487 assert!(matches!(
488 result,
489 Err(PreparedStatementError::ParameterCountMismatch { expected: 1, actual: 0 })
490 ));
491
492 let result = prepared.bind(&[SqlValue::Integer(1), SqlValue::Integer(2)]);
494 assert!(matches!(
495 result,
496 Err(PreparedStatementError::ParameterCountMismatch { expected: 1, actual: 2 })
497 ));
498 }
499
500 #[test]
501 fn test_prepared_statement_reuse() {
502 let sql = "SELECT * FROM users WHERE id = ?";
504 let statement = vibesql_parser::Parser::parse_sql(sql).unwrap();
505 let prepared = PreparedStatement::new(sql.to_string(), statement);
506
507 let bound1 = prepared.bind(&[SqlValue::Integer(1)]).unwrap();
509 let bound2 = prepared.bind(&[SqlValue::Integer(2)]).unwrap();
510 let bound3 = prepared.bind(&[SqlValue::Integer(3)]).unwrap();
511
512 for (bound, expected_id) in [(bound1, 1), (bound2, 2), (bound3, 3)] {
514 if let Statement::Select(select) = bound {
515 if let Some(Expression::BinaryOp { right, .. }) = &select.where_clause {
516 assert_eq!(**right, Expression::Literal(SqlValue::Integer(expected_id)));
517 }
518 }
519 }
520 }
521
522 #[test]
523 fn test_cache_get_or_prepare() {
524 let cache = PreparedStatementCache::new(10);
525 let sql = "SELECT * FROM users WHERE id = 1";
526
527 let stmt1 = cache.get_or_prepare(sql).unwrap();
529 let stats = cache.stats();
530 assert_eq!(stats.misses, 1);
531 assert_eq!(stats.hits, 0);
532
533 let stmt2 = cache.get_or_prepare(sql).unwrap();
535 let stats = cache.stats();
536 assert_eq!(stats.misses, 1);
537 assert_eq!(stats.hits, 1);
538
539 assert!(Arc::ptr_eq(&stmt1, &stmt2));
541 }
542
543 #[test]
544 fn test_cache_placeholder_reuse() {
545 let cache = PreparedStatementCache::new(10);
547 let sql = "SELECT * FROM users WHERE id = ?";
548
549 let stmt1 = cache.get_or_prepare(sql).unwrap();
551 assert_eq!(cache.stats().misses, 1);
552 assert_eq!(cache.stats().hits, 0);
553
554 let stmt2 = cache.get_or_prepare(sql).unwrap();
556 assert_eq!(cache.stats().misses, 1);
557 assert_eq!(cache.stats().hits, 1);
558
559 assert!(Arc::ptr_eq(&stmt1, &stmt2));
561
562 let bound1 = stmt1.bind(&[SqlValue::Integer(1)]).unwrap();
564 let bound2 = stmt2.bind(&[SqlValue::Integer(999)]).unwrap();
565
566 if let (Statement::Select(s1), Statement::Select(s2)) = (&bound1, &bound2) {
568 if let (
569 Some(Expression::BinaryOp { right: r1, .. }),
570 Some(Expression::BinaryOp { right: r2, .. }),
571 ) = (&s1.where_clause, &s2.where_clause)
572 {
573 assert_eq!(**r1, Expression::Literal(SqlValue::Integer(1)));
574 assert_eq!(**r2, Expression::Literal(SqlValue::Integer(999)));
575 }
576 }
577 }
578
579 #[test]
580 fn test_cache_lru_eviction() {
581 let cache = PreparedStatementCache::new(2);
582
583 cache.get_or_prepare("SELECT * FROM users").unwrap();
584 cache.get_or_prepare("SELECT * FROM orders").unwrap();
585 assert_eq!(cache.stats().size, 2);
586 assert_eq!(cache.stats().evictions, 0);
587
588 cache.get_or_prepare("SELECT * FROM products").unwrap();
590 assert_eq!(cache.stats().size, 2);
591 assert_eq!(cache.stats().evictions, 1);
592
593 assert!(cache.get("SELECT * FROM users").is_none());
595 assert!(cache.get("SELECT * FROM orders").is_some());
596 assert!(cache.get("SELECT * FROM products").is_some());
597 }
598
599 #[test]
600 fn test_cache_table_invalidation() {
601 let cache = PreparedStatementCache::new(10);
602
603 cache.get_or_prepare("SELECT * FROM users WHERE id = ?").unwrap();
604 cache.get_or_prepare("SELECT * FROM orders WHERE id = ?").unwrap();
605 assert_eq!(cache.stats().size, 2);
606
607 cache.invalidate_table("users");
609 assert_eq!(cache.stats().size, 1);
610
611 assert!(cache.get("SELECT * FROM orders WHERE id = ?").is_some());
613 }
614
615 #[test]
616 fn test_arena_parse_insert() {
617 let sql = "INSERT INTO users (name, email) VALUES ('Alice', 'alice@example.com')";
618 let result = parse_with_arena_fallback(sql);
619 assert!(result.is_ok());
620 assert!(matches!(result.unwrap(), Statement::Insert(_)));
621 }
622
623 #[test]
624 fn test_arena_parse_insert_with_placeholders() {
625 let cache = PreparedStatementCache::new(10);
626 let sql = "INSERT INTO users (name, email) VALUES (?, ?)";
627
628 let stmt = cache.get_or_prepare(sql).unwrap();
629 assert_eq!(stmt.param_count(), 2);
630
631 let bound = stmt.bind(&[
632 SqlValue::Varchar("Bob".to_string()),
633 SqlValue::Varchar("bob@example.com".to_string()),
634 ]).unwrap();
635 assert!(matches!(bound, Statement::Insert(_)));
636 }
637
638 #[test]
639 fn test_arena_parse_update() {
640 let sql = "UPDATE users SET name = 'Bob' WHERE id = 1";
641 let result = parse_with_arena_fallback(sql);
642 assert!(result.is_ok());
643 assert!(matches!(result.unwrap(), Statement::Update(_)));
644 }
645
646 #[test]
647 fn test_arena_parse_update_with_placeholders() {
648 let cache = PreparedStatementCache::new(10);
649 let sql = "UPDATE users SET name = ? WHERE id = ?";
650
651 let stmt = cache.get_or_prepare(sql).unwrap();
652 assert_eq!(stmt.param_count(), 2);
653
654 let bound = stmt.bind(&[
655 SqlValue::Varchar("Charlie".to_string()),
656 SqlValue::Integer(42),
657 ]).unwrap();
658 assert!(matches!(bound, Statement::Update(_)));
659 }
660
661 #[test]
662 fn test_arena_parse_delete() {
663 let sql = "DELETE FROM users WHERE id = 1";
664 let result = parse_with_arena_fallback(sql);
665 assert!(result.is_ok());
666 assert!(matches!(result.unwrap(), Statement::Delete(_)));
667 }
668
669 #[test]
670 fn test_arena_parse_delete_with_placeholders() {
671 let cache = PreparedStatementCache::new(10);
672 let sql = "DELETE FROM users WHERE id = ?";
673
674 let stmt = cache.get_or_prepare(sql).unwrap();
675 assert_eq!(stmt.param_count(), 1);
676
677 let bound = stmt.bind(&[SqlValue::Integer(99)]).unwrap();
678 assert!(matches!(bound, Statement::Delete(_)));
679 }
680}