1use std::sync::Arc;
46
47use vibesql_ast::Statement;
48use vibesql_storage::{Database, Row};
49use vibesql_types::SqlValue;
50
51use crate::cache::{
52 ArenaParseError, ArenaPreparedStatement, CachedPlan, PkPointLookupPlan, PreparedStatement,
53 PreparedStatementCache, PreparedStatementError, ProjectionPlan,
54};
55use crate::errors::ExecutorError;
56use crate::{DeleteExecutor, InsertExecutor, SelectExecutor, SelectResult, UpdateExecutor};
57
58#[derive(Debug)]
60pub enum PreparedExecutionResult {
61 Select(SelectResult),
63 RowsAffected(usize),
65 Ok,
67}
68
69impl PreparedExecutionResult {
70 pub fn rows(&self) -> Option<&[Row]> {
72 match self {
73 PreparedExecutionResult::Select(result) => Some(&result.rows),
74 _ => None,
75 }
76 }
77
78 pub fn rows_affected(&self) -> Option<usize> {
80 match self {
81 PreparedExecutionResult::RowsAffected(n) => Some(*n),
82 _ => None,
83 }
84 }
85
86 pub fn into_select_result(self) -> Option<SelectResult> {
88 match self {
89 PreparedExecutionResult::Select(result) => Some(result),
90 _ => None,
91 }
92 }
93}
94
95#[derive(Debug)]
97pub enum SessionError {
98 PreparedStatement(PreparedStatementError),
100 Execution(ExecutorError),
102 UnsupportedStatement(String),
104}
105
106impl std::fmt::Display for SessionError {
107 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
108 match self {
109 SessionError::PreparedStatement(e) => write!(f, "Prepared statement error: {}", e),
110 SessionError::Execution(e) => write!(f, "Execution error: {:?}", e),
111 SessionError::UnsupportedStatement(msg) => write!(f, "Unsupported statement: {}", msg),
112 }
113 }
114}
115
116impl std::error::Error for SessionError {}
117
118impl From<PreparedStatementError> for SessionError {
119 fn from(e: PreparedStatementError) -> Self {
120 SessionError::PreparedStatement(e)
121 }
122}
123
124impl From<ExecutorError> for SessionError {
125 fn from(e: ExecutorError) -> Self {
126 SessionError::Execution(e)
127 }
128}
129
130pub struct Session<'a> {
135 db: &'a Database,
136 cache: Arc<PreparedStatementCache>,
137}
138
139impl<'a> Session<'a> {
140 pub fn new(db: &'a Database) -> Self {
144 Self {
145 db,
146 cache: Arc::new(PreparedStatementCache::default_cache()),
147 }
148 }
149
150 pub fn with_cache_size(db: &'a Database, cache_size: usize) -> Self {
152 Self {
153 db,
154 cache: Arc::new(PreparedStatementCache::new(cache_size)),
155 }
156 }
157
158 pub fn with_shared_cache(db: &'a Database, cache: Arc<PreparedStatementCache>) -> Self {
163 Self { db, cache }
164 }
165
166 pub fn database(&self) -> &Database {
168 self.db
169 }
170
171 pub fn cache(&self) -> &PreparedStatementCache {
173 &self.cache
174 }
175
176 pub fn shared_cache(&self) -> Arc<PreparedStatementCache> {
178 Arc::clone(&self.cache)
179 }
180
181 pub fn prepare(&self, sql: &str) -> Result<Arc<PreparedStatement>, SessionError> {
195 self.cache.get_or_prepare(sql).map_err(SessionError::from)
196 }
197
198 pub fn prepare_arena(&self, sql: &str) -> Result<Arc<ArenaPreparedStatement>, ArenaParseError> {
220 self.cache.get_or_prepare_arena(sql)
221 }
222
223 pub fn execute_prepared(
238 &self,
239 stmt: &PreparedStatement,
240 params: &[SqlValue],
241 ) -> Result<PreparedExecutionResult, SessionError> {
242 if params.len() != stmt.param_count() {
244 return Err(SessionError::PreparedStatement(
245 PreparedStatementError::ParameterCountMismatch {
246 expected: stmt.param_count(),
247 actual: params.len(),
248 },
249 ));
250 }
251
252 if let CachedPlan::PkPointLookup(plan) = stmt.cached_plan() {
254 if let Some(result) = self.try_execute_pk_lookup(plan, params)? {
255 return Ok(result);
256 }
257 }
259
260 let bound_stmt = stmt.bind(params)?;
262
263 self.execute_statement(&bound_stmt)
265 }
266
267 fn try_execute_pk_lookup(
273 &self,
274 plan: &PkPointLookupPlan,
275 params: &[SqlValue],
276 ) -> Result<Option<PreparedExecutionResult>, SessionError> {
277 let table = match self.db.get_table(&plan.table_name) {
279 Some(t) => t,
280 None => return Ok(None), };
282
283 let actual_pk_columns = match &table.schema.primary_key {
285 Some(cols) if cols.len() == plan.pk_columns.len() => cols,
286 _ => return Ok(None), };
288
289 let mut pk_values = Vec::with_capacity(plan.pk_columns.len());
291 for (param_idx, pk_col_idx) in &plan.param_to_pk_col {
292 if *param_idx >= params.len() || *pk_col_idx >= plan.pk_columns.len() {
293 return Ok(None); }
295
296 let expected_col = &plan.pk_columns[*pk_col_idx];
298 let actual_col = &actual_pk_columns[*pk_col_idx];
299 if !expected_col.eq_ignore_ascii_case(actual_col) {
300 return Ok(None); }
302
303 pk_values.push(params[*param_idx].clone());
304 }
305
306 let row = if pk_values.len() == 1 {
308 self.db
309 .get_row_by_pk(&plan.table_name, &pk_values[0])
310 .map_err(|e| SessionError::Execution(ExecutorError::StorageError(e.to_string())))?
311 } else {
312 self.db
313 .get_row_by_composite_pk(&plan.table_name, &pk_values)
314 .map_err(|e| SessionError::Execution(ExecutorError::StorageError(e.to_string())))?
315 };
316
317 let rows = match row {
318 Some(r) => vec![r.clone()],
319 None => vec![],
320 };
321
322 let (columns, result_rows) = match &plan.projection {
324 ProjectionPlan::Wildcard => {
325 let columns: Vec<String> = table.schema.columns.iter().map(|c| c.name.clone()).collect();
327 (columns, rows)
328 }
329 ProjectionPlan::Columns(projections) => {
330 let mut col_indices = Vec::with_capacity(projections.len());
332 let mut column_names = Vec::with_capacity(projections.len());
333
334 for proj in projections {
335 let idx = table
336 .schema
337 .columns
338 .iter()
339 .position(|c| c.name.eq_ignore_ascii_case(&proj.column_name));
340
341 match idx {
342 Some(i) => {
343 col_indices.push(i);
344 column_names.push(proj.alias.clone().unwrap_or_else(|| proj.column_name.clone()));
345 }
346 None => return Ok(None), }
348 }
349
350 let projected_rows: Vec<Row> = rows
352 .into_iter()
353 .map(|row| {
354 let projected_values: Vec<SqlValue> =
355 col_indices.iter().map(|&i| row.values[i].clone()).collect();
356 Row::new(projected_values)
357 })
358 .collect();
359
360 (column_names, projected_rows)
361 }
362 };
363
364 Ok(Some(PreparedExecutionResult::Select(SelectResult {
365 columns,
366 rows: result_rows,
367 })))
368 }
369
370 fn execute_statement(&self, stmt: &Statement) -> Result<PreparedExecutionResult, SessionError> {
372 match stmt {
373 Statement::Select(select_stmt) => {
374 let executor = SelectExecutor::new(self.db);
375 let result = executor.execute_with_columns(select_stmt)?;
376 Ok(PreparedExecutionResult::Select(result))
377 }
378 _ => Err(SessionError::UnsupportedStatement(
379 "Only SELECT is supported for read-only sessions. Use SessionMut for DML.".into(),
380 )),
381 }
382 }
383}
384
385pub struct SessionMut<'a> {
390 db: &'a mut Database,
391 cache: Arc<PreparedStatementCache>,
392}
393
394impl<'a> SessionMut<'a> {
395 pub fn new(db: &'a mut Database) -> Self {
397 Self {
398 db,
399 cache: Arc::new(PreparedStatementCache::default_cache()),
400 }
401 }
402
403 pub fn with_cache_size(db: &'a mut Database, cache_size: usize) -> Self {
405 Self {
406 db,
407 cache: Arc::new(PreparedStatementCache::new(cache_size)),
408 }
409 }
410
411 pub fn with_shared_cache(db: &'a mut Database, cache: Arc<PreparedStatementCache>) -> Self {
413 Self { db, cache }
414 }
415
416 pub fn database(&self) -> &Database {
418 self.db
419 }
420
421 pub fn database_mut(&mut self) -> &mut Database {
423 self.db
424 }
425
426 pub fn cache(&self) -> &PreparedStatementCache {
428 &self.cache
429 }
430
431 pub fn shared_cache(&self) -> Arc<PreparedStatementCache> {
433 Arc::clone(&self.cache)
434 }
435
436 pub fn prepare(&self, sql: &str) -> Result<Arc<PreparedStatement>, SessionError> {
438 self.cache.get_or_prepare(sql).map_err(SessionError::from)
439 }
440
441 pub fn prepare_arena(&self, sql: &str) -> Result<Arc<ArenaPreparedStatement>, ArenaParseError> {
445 self.cache.get_or_prepare_arena(sql)
446 }
447
448 pub fn execute_prepared(
452 &self,
453 stmt: &PreparedStatement,
454 params: &[SqlValue],
455 ) -> Result<PreparedExecutionResult, SessionError> {
456 let bound_stmt = stmt.bind(params)?;
457 self.execute_statement_readonly(&bound_stmt)
458 }
459
460 pub fn execute_prepared_mut(
464 &mut self,
465 stmt: &PreparedStatement,
466 params: &[SqlValue],
467 ) -> Result<PreparedExecutionResult, SessionError> {
468 let bound_stmt = stmt.bind(params)?;
469 self.execute_statement_mut(&bound_stmt)
470 }
471
472 fn execute_statement_readonly(
474 &self,
475 stmt: &Statement,
476 ) -> Result<PreparedExecutionResult, SessionError> {
477 match stmt {
478 Statement::Select(select_stmt) => {
479 let executor = SelectExecutor::new(self.db);
480 let result = executor.execute_with_columns(select_stmt)?;
481 Ok(PreparedExecutionResult::Select(result))
482 }
483 _ => Err(SessionError::UnsupportedStatement(
484 "Use execute_prepared_mut for DML statements".into(),
485 )),
486 }
487 }
488
489 fn execute_statement_mut(
491 &mut self,
492 stmt: &Statement,
493 ) -> Result<PreparedExecutionResult, SessionError> {
494 match stmt {
495 Statement::Select(select_stmt) => {
496 let executor = SelectExecutor::new(self.db);
497 let result = executor.execute_with_columns(select_stmt)?;
498 Ok(PreparedExecutionResult::Select(result))
499 }
500 Statement::Insert(insert_stmt) => {
501 let rows_affected = InsertExecutor::execute(self.db, insert_stmt)?;
502 self.cache.invalidate_table(&insert_stmt.table_name);
504 Ok(PreparedExecutionResult::RowsAffected(rows_affected))
505 }
506 Statement::Update(update_stmt) => {
507 let rows_affected = UpdateExecutor::execute(update_stmt, self.db)?;
508 self.cache.invalidate_table(&update_stmt.table_name);
510 Ok(PreparedExecutionResult::RowsAffected(rows_affected))
511 }
512 Statement::Delete(delete_stmt) => {
513 let rows_affected = DeleteExecutor::execute(delete_stmt, self.db)?;
514 self.cache.invalidate_table(&delete_stmt.table_name);
516 Ok(PreparedExecutionResult::RowsAffected(rows_affected))
517 }
518 _ => Err(SessionError::UnsupportedStatement(format!(
519 "Statement type {:?} not supported for prepared execution",
520 std::mem::discriminant(stmt)
521 ))),
522 }
523 }
524}
525
526#[cfg(test)]
527mod tests {
528 use super::*;
529 use vibesql_catalog::{ColumnSchema, TableSchema};
530 use vibesql_types::DataType;
531
532 fn create_test_db() -> Database {
533 let mut db = Database::new();
534 db.catalog.set_case_sensitive_identifiers(false);
536
537 let columns = vec![
539 ColumnSchema::new("id".to_string(), DataType::Integer, false),
540 ColumnSchema::new("name".to_string(), DataType::Varchar { max_length: Some(100) }, true),
541 ];
542 let schema = TableSchema::with_primary_key(
543 "users".to_string(),
544 columns,
545 vec!["id".to_string()],
546 );
547 db.create_table(schema).unwrap();
548
549 let row1 = Row::new(vec![SqlValue::Integer(1), SqlValue::Varchar("Alice".into())]);
551 let row2 = Row::new(vec![SqlValue::Integer(2), SqlValue::Varchar("Bob".into())]);
552 let row3 = Row::new(vec![SqlValue::Integer(3), SqlValue::Varchar("Charlie".into())]);
553
554 db.insert_row("users", row1).unwrap();
555 db.insert_row("users", row2).unwrap();
556 db.insert_row("users", row3).unwrap();
557
558 db
559 }
560
561 #[test]
562 fn test_session_prepare() {
563 let db = create_test_db();
564 let session = Session::new(&db);
565
566 let stmt = session.prepare("SELECT * FROM users WHERE id = ?").unwrap();
567 assert_eq!(stmt.param_count(), 1);
568 }
569
570 #[test]
571 fn test_session_execute_prepared() {
572 let db = create_test_db();
573 let session = Session::new(&db);
574
575 let stmt = session.prepare("SELECT * FROM users WHERE id = ?").unwrap();
576
577 let result = session
579 .execute_prepared(&stmt, &[SqlValue::Integer(1)])
580 .unwrap();
581
582 if let PreparedExecutionResult::Select(select_result) = result {
583 assert_eq!(select_result.rows.len(), 1);
584 assert_eq!(select_result.rows[0].values[0], SqlValue::Integer(1));
585 assert_eq!(
586 select_result.rows[0].values[1],
587 SqlValue::Varchar("Alice".into())
588 );
589 } else {
590 panic!("Expected Select result");
591 }
592 }
593
594 #[test]
595 fn test_session_reuse_prepared() {
596 let db = create_test_db();
597 let session = Session::new(&db);
598
599 let stmt = session.prepare("SELECT * FROM users WHERE id = ?").unwrap();
600
601 let result1 = session
603 .execute_prepared(&stmt, &[SqlValue::Integer(1)])
604 .unwrap();
605 let result2 = session
606 .execute_prepared(&stmt, &[SqlValue::Integer(2)])
607 .unwrap();
608 let result3 = session
609 .execute_prepared(&stmt, &[SqlValue::Integer(3)])
610 .unwrap();
611
612 assert_eq!(
614 result1.rows().unwrap()[0].values[1],
615 SqlValue::Varchar("Alice".into())
616 );
617 assert_eq!(
618 result2.rows().unwrap()[0].values[1],
619 SqlValue::Varchar("Bob".into())
620 );
621 assert_eq!(
622 result3.rows().unwrap()[0].values[1],
623 SqlValue::Varchar("Charlie".into())
624 );
625
626 let stats = session.cache().stats();
628 assert_eq!(stats.misses, 1);
629 let _hits = stats.hits;
631 }
632
633 #[test]
634 fn test_session_param_count_mismatch() {
635 let db = create_test_db();
636 let session = Session::new(&db);
637
638 let stmt = session.prepare("SELECT * FROM users WHERE id = ?").unwrap();
639
640 let result = session.execute_prepared(&stmt, &[]);
642 assert!(result.is_err());
643
644 let result = session.execute_prepared(&stmt, &[SqlValue::Integer(1), SqlValue::Integer(2)]);
645 assert!(result.is_err());
646 }
647
648 #[test]
649 fn test_session_mut_insert() {
650 let mut db = create_test_db();
651 let mut session = SessionMut::new(&mut db);
652
653 let stmt = session
654 .prepare("INSERT INTO users (id, name) VALUES (?, ?)")
655 .unwrap();
656
657 let result = session
658 .execute_prepared_mut(
659 &stmt,
660 &[SqlValue::Integer(4), SqlValue::Varchar("David".into())],
661 )
662 .unwrap();
663
664 assert_eq!(result.rows_affected(), Some(1));
665
666 let select_stmt = session.prepare("SELECT * FROM users WHERE id = ?").unwrap();
668 let select_result = session
669 .execute_prepared(&select_stmt, &[SqlValue::Integer(4)])
670 .unwrap();
671
672 assert_eq!(select_result.rows().unwrap().len(), 1);
673 assert_eq!(
674 select_result.rows().unwrap()[0].values[1],
675 SqlValue::Varchar("David".into())
676 );
677 }
678
679 #[test]
680 fn test_session_mut_update() {
681 let mut db = create_test_db();
682 let mut session = SessionMut::new(&mut db);
683
684 let stmt = session
685 .prepare("UPDATE users SET name = ? WHERE id = ?")
686 .unwrap();
687
688 let result = session
689 .execute_prepared_mut(
690 &stmt,
691 &[SqlValue::Varchar("Alicia".into()), SqlValue::Integer(1)],
692 )
693 .unwrap();
694
695 assert_eq!(result.rows_affected(), Some(1));
696
697 let select_stmt = session.prepare("SELECT * FROM users WHERE id = ?").unwrap();
699 let select_result = session
700 .execute_prepared(&select_stmt, &[SqlValue::Integer(1)])
701 .unwrap();
702
703 assert_eq!(
704 select_result.rows().unwrap()[0].values[1],
705 SqlValue::Varchar("Alicia".into())
706 );
707 }
708
709 #[test]
710 fn test_session_mut_delete() {
711 let mut db = create_test_db();
712 let mut session = SessionMut::new(&mut db);
713
714 let stmt = session.prepare("DELETE FROM users WHERE id = ?").unwrap();
715
716 let result = session
717 .execute_prepared_mut(&stmt, &[SqlValue::Integer(1)])
718 .unwrap();
719
720 assert_eq!(result.rows_affected(), Some(1));
721
722 let select_stmt = session.prepare("SELECT * FROM users WHERE id = ?").unwrap();
724 let select_result = session
725 .execute_prepared(&select_stmt, &[SqlValue::Integer(1)])
726 .unwrap();
727
728 assert_eq!(select_result.rows().unwrap().len(), 0);
729 }
730
731 #[test]
732 fn test_shared_cache() {
733 let db = create_test_db();
734
735 let session1 = Session::new(&db);
737 let stmt = session1
738 .prepare("SELECT * FROM users WHERE id = ?")
739 .unwrap();
740
741 let shared_cache = session1.shared_cache();
743 let initial_misses = session1.cache().stats().misses;
744
745 let session2 = Session::with_shared_cache(&db, shared_cache);
747
748 let _stmt2 = session2
750 .prepare("SELECT * FROM users WHERE id = ?")
751 .unwrap();
752
753 assert_eq!(session2.cache().stats().misses, initial_misses);
755
756 let result1 = session1
758 .execute_prepared(&stmt, &[SqlValue::Integer(1)])
759 .unwrap();
760 let result2 = session2
761 .execute_prepared(&stmt, &[SqlValue::Integer(2)])
762 .unwrap();
763
764 assert_eq!(
765 result1.rows().unwrap()[0].values[1],
766 SqlValue::Varchar("Alice".into())
767 );
768 assert_eq!(
769 result2.rows().unwrap()[0].values[1],
770 SqlValue::Varchar("Bob".into())
771 );
772 }
773
774 #[test]
775 fn test_no_params_statement() {
776 let db = create_test_db();
777 let session = Session::new(&db);
778
779 let stmt = session.prepare("SELECT * FROM users").unwrap();
780 assert_eq!(stmt.param_count(), 0);
781
782 let result = session.execute_prepared(&stmt, &[]).unwrap();
783 assert_eq!(result.rows().unwrap().len(), 3);
784 }
785
786 #[test]
787 fn test_multiple_placeholders() {
788 let db = create_test_db();
789 let session = Session::new(&db);
790
791 let stmt = session
792 .prepare("SELECT * FROM users WHERE id >= ? AND id <= ?")
793 .unwrap();
794 assert_eq!(stmt.param_count(), 2);
795
796 let result = session
797 .execute_prepared(&stmt, &[SqlValue::Integer(1), SqlValue::Integer(2)])
798 .unwrap();
799
800 assert_eq!(result.rows().unwrap().len(), 2);
801 }
802
803 #[test]
804 fn test_session_prepare_arena() {
805 let db = create_test_db();
806 let session = Session::new(&db);
807
808 let stmt = session
810 .prepare_arena("SELECT * FROM users WHERE id = ?")
811 .unwrap();
812 assert_eq!(stmt.param_count(), 1);
813
814 assert!(stmt.tables().contains("USERS"));
816
817 let stmt2 = session
819 .prepare_arena("SELECT * FROM users WHERE id = ?")
820 .unwrap();
821 assert_eq!(stmt2.param_count(), 1);
822
823 assert!(std::sync::Arc::ptr_eq(&stmt, &stmt2));
825 }
826
827 #[test]
828 fn test_session_prepare_arena_no_params() {
829 let db = create_test_db();
830 let session = Session::new(&db);
831
832 let stmt = session
833 .prepare_arena("SELECT * FROM users")
834 .unwrap();
835 assert_eq!(stmt.param_count(), 0);
836 }
837
838 #[test]
839 fn test_session_prepare_arena_join() {
840 let db = create_test_db();
841 let session = Session::new(&db);
842
843 use vibesql_catalog::{ColumnSchema, TableSchema};
845 use vibesql_types::DataType;
846
847 let orders_columns = vec![
848 ColumnSchema::new("id".to_string(), DataType::Integer, false),
849 ColumnSchema::new("user_id".to_string(), DataType::Integer, false),
850 ];
851 let _orders_schema = TableSchema::with_primary_key(
852 "orders".to_string(),
853 orders_columns,
854 vec!["id".to_string()],
855 );
856
857 let stmt = session
860 .prepare_arena("SELECT u.id FROM users u JOIN orders o ON u.id = o.user_id")
861 .unwrap();
862
863 let tables = stmt.tables();
865 assert!(tables.contains("USERS"), "Expected USERS in {:?}", tables);
866 assert!(tables.contains("ORDERS"), "Expected ORDERS in {:?}", tables);
867 }
868
869 #[test]
870 fn test_session_mut_prepare_arena() {
871 let mut db = create_test_db();
872 let session = SessionMut::new(&mut db);
873
874 let stmt = session
876 .prepare_arena("SELECT * FROM users WHERE id = ?")
877 .unwrap();
878 assert_eq!(stmt.param_count(), 1);
879 }
880}