1use std::sync::Arc;
46
47use vibesql_ast::Statement;
48use vibesql_storage::{Database, Row};
49use vibesql_types::SqlValue;
50
51use crate::{
52 cache::{
53 ArenaParseError, ArenaPreparedStatement, CachedPlan, PkPointLookupPlan, PreparedStatement,
54 PreparedStatementCache, PreparedStatementError, ProjectionPlan, ResolvedProjection,
55 },
56 errors::ExecutorError,
57 DeleteExecutor, InsertExecutor, SelectExecutor, SelectResult, UpdateExecutor,
58};
59
60#[derive(Debug)]
62pub enum PreparedExecutionResult {
63 Select(SelectResult),
65 RowsAffected(usize),
67 Ok,
69}
70
71impl PreparedExecutionResult {
72 pub fn rows(&self) -> Option<&[Row]> {
74 match self {
75 PreparedExecutionResult::Select(result) => Some(&result.rows),
76 _ => None,
77 }
78 }
79
80 pub fn rows_affected(&self) -> Option<usize> {
82 match self {
83 PreparedExecutionResult::RowsAffected(n) => Some(*n),
84 _ => None,
85 }
86 }
87
88 pub fn into_select_result(self) -> Option<SelectResult> {
90 match self {
91 PreparedExecutionResult::Select(result) => Some(result),
92 _ => None,
93 }
94 }
95}
96
97#[derive(Debug)]
99pub enum SessionError {
100 PreparedStatement(PreparedStatementError),
102 Execution(ExecutorError),
104 UnsupportedStatement(String),
106}
107
108impl std::fmt::Display for SessionError {
109 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
110 match self {
111 SessionError::PreparedStatement(e) => write!(f, "Prepared statement error: {}", e),
112 SessionError::Execution(e) => write!(f, "Execution error: {:?}", e),
113 SessionError::UnsupportedStatement(msg) => write!(f, "Unsupported statement: {}", msg),
114 }
115 }
116}
117
118impl std::error::Error for SessionError {}
119
120impl From<PreparedStatementError> for SessionError {
121 fn from(e: PreparedStatementError) -> Self {
122 SessionError::PreparedStatement(e)
123 }
124}
125
126impl From<ExecutorError> for SessionError {
127 fn from(e: ExecutorError) -> Self {
128 SessionError::Execution(e)
129 }
130}
131
132pub struct Session<'a> {
137 db: &'a Database,
138 cache: Arc<PreparedStatementCache>,
139}
140
141impl<'a> Session<'a> {
142 pub fn new(db: &'a Database) -> Self {
146 Self { db, cache: Arc::new(PreparedStatementCache::default_cache()) }
147 }
148
149 pub fn with_cache_size(db: &'a Database, cache_size: usize) -> Self {
151 Self { db, cache: Arc::new(PreparedStatementCache::new(cache_size)) }
152 }
153
154 pub fn with_shared_cache(db: &'a Database, cache: Arc<PreparedStatementCache>) -> Self {
159 Self { db, cache }
160 }
161
162 pub fn database(&self) -> &Database {
164 self.db
165 }
166
167 pub fn cache(&self) -> &PreparedStatementCache {
169 &self.cache
170 }
171
172 pub fn shared_cache(&self) -> Arc<PreparedStatementCache> {
174 Arc::clone(&self.cache)
175 }
176
177 pub fn prepare(&self, sql: &str) -> Result<Arc<PreparedStatement>, SessionError> {
191 self.cache.get_or_prepare(sql).map_err(SessionError::from)
192 }
193
194 pub fn prepare_arena(&self, sql: &str) -> Result<Arc<ArenaPreparedStatement>, ArenaParseError> {
216 self.cache.get_or_prepare_arena(sql)
217 }
218
219 pub fn execute_prepared(
234 &self,
235 stmt: &PreparedStatement,
236 params: &[SqlValue],
237 ) -> Result<PreparedExecutionResult, SessionError> {
238 if params.len() != stmt.param_count() {
240 return Err(SessionError::PreparedStatement(
241 PreparedStatementError::ParameterCountMismatch {
242 expected: stmt.param_count(),
243 actual: params.len(),
244 },
245 ));
246 }
247
248 match stmt.cached_plan() {
250 CachedPlan::PkPointLookup(plan) => {
251 if let Some(result) = self.try_execute_pk_lookup(plan, params)? {
252 return Ok(result);
253 }
254 }
256 CachedPlan::SimpleFastPath(plan) => {
257 let bound_stmt = stmt.bind(params)?;
260 if let Statement::Select(select_stmt) = &bound_stmt {
261 let executor = SelectExecutor::new(self.db);
262
263 let columns = plan.get_or_resolve_columns(|| {
265 executor.derive_fast_path_column_names(select_stmt).ok()
266 });
267
268 match columns {
269 Some(cached_columns) => {
270 let rows = executor.execute_fast_path(select_stmt)?;
272 return Ok(PreparedExecutionResult::Select(SelectResult {
273 columns: cached_columns.iter().cloned().collect(),
274 rows,
275 }));
276 }
277 None => {
278 let result = executor.execute_fast_path_with_columns(select_stmt)?;
280 return Ok(PreparedExecutionResult::Select(result));
281 }
282 }
283 }
284 }
286 CachedPlan::PkDelete(_) => {
287 }
290 CachedPlan::Standard => {
291 }
293 }
294
295 let bound_stmt = stmt.bind(params)?;
297
298 self.execute_statement(&bound_stmt)
300 }
301
302 fn try_execute_pk_lookup(
308 &self,
309 plan: &PkPointLookupPlan,
310 params: &[SqlValue],
311 ) -> Result<Option<PreparedExecutionResult>, SessionError> {
312 let table = match self.db.get_table(&plan.table_name) {
314 Some(t) => t,
315 None => return Ok(None), };
317
318 let actual_pk_columns = match &table.schema.primary_key {
320 Some(cols) if cols.len() == plan.pk_columns.len() => cols,
321 _ => return Ok(None), };
323
324 for (param_idx, pk_col_idx) in &plan.param_to_pk_col {
327 if *param_idx >= params.len() || *pk_col_idx >= plan.pk_columns.len() {
328 return Ok(None); }
330
331 let expected_col = &plan.pk_columns[*pk_col_idx];
333 let actual_col = &actual_pk_columns[*pk_col_idx];
334 if !expected_col.eq_ignore_ascii_case(actual_col) {
335 return Ok(None); }
337 }
338
339 let resolved = match plan
341 .get_or_resolve(|proj| self.resolve_projection(proj, &table.schema.columns))
342 {
343 Some(r) => r,
344 None => return Ok(None), };
346
347 let row = if plan.param_to_pk_col.len() == 1 {
350 let (param_idx, _) = plan.param_to_pk_col[0];
351 self.db
352 .get_row_by_pk(&plan.table_name, ¶ms[param_idx])
353 .map_err(|e| SessionError::Execution(ExecutorError::StorageError(e.to_string())))?
354 } else {
355 let pk_values: Vec<SqlValue> = plan
357 .param_to_pk_col
358 .iter()
359 .map(|(param_idx, _)| params[*param_idx].clone())
360 .collect();
361 self.db
362 .get_row_by_composite_pk(&plan.table_name, &pk_values)
363 .map_err(|e| SessionError::Execution(ExecutorError::StorageError(e.to_string())))?
364 };
365
366 let columns: Vec<String> = resolved.column_names.iter().cloned().collect();
368
369 let rows = match row {
370 Some(r) => {
371 if resolved.column_indices.is_empty() {
373 vec![r.clone()]
375 } else {
376 let projected_values: Vec<SqlValue> =
378 resolved.column_indices.iter().map(|&i| r.values[i].clone()).collect();
379 vec![Row::new(projected_values)]
380 }
381 }
382 None => vec![],
383 };
384
385 Ok(Some(PreparedExecutionResult::Select(SelectResult { columns, rows })))
386 }
387
388 fn resolve_projection(
392 &self,
393 proj: &ProjectionPlan,
394 schema_columns: &[vibesql_catalog::ColumnSchema],
395 ) -> Option<ResolvedProjection> {
396 match proj {
397 ProjectionPlan::Wildcard => {
398 let column_names: Arc<[String]> =
401 schema_columns.iter().map(|c| c.name.clone()).collect();
402 Some(ResolvedProjection { column_indices: vec![], column_names })
403 }
404 ProjectionPlan::Columns(projections) => {
405 let mut col_indices = Vec::with_capacity(projections.len());
406 let mut column_names = Vec::with_capacity(projections.len());
407
408 for proj in projections {
409 let idx = schema_columns
410 .iter()
411 .position(|c| c.name.eq_ignore_ascii_case(&proj.column_name))?;
412
413 col_indices.push(idx);
414 column_names
415 .push(proj.alias.clone().unwrap_or_else(|| proj.column_name.clone()));
416 }
417
418 Some(ResolvedProjection {
419 column_indices: col_indices,
420 column_names: column_names.into(),
421 })
422 }
423 }
424 }
425
426 fn execute_statement(&self, stmt: &Statement) -> Result<PreparedExecutionResult, SessionError> {
428 match stmt {
429 Statement::Select(select_stmt) => {
430 let executor = SelectExecutor::new(self.db);
431 let result = executor.execute_with_columns(select_stmt)?;
432 Ok(PreparedExecutionResult::Select(result))
433 }
434 _ => Err(SessionError::UnsupportedStatement(
435 "Only SELECT is supported for read-only sessions. Use SessionMut for DML.".into(),
436 )),
437 }
438 }
439}
440
441pub struct SessionMut<'a> {
446 db: &'a mut Database,
447 cache: Arc<PreparedStatementCache>,
448}
449
450impl<'a> SessionMut<'a> {
451 pub fn new(db: &'a mut Database) -> Self {
453 Self { db, cache: Arc::new(PreparedStatementCache::default_cache()) }
454 }
455
456 pub fn with_cache_size(db: &'a mut Database, cache_size: usize) -> Self {
458 Self { db, cache: Arc::new(PreparedStatementCache::new(cache_size)) }
459 }
460
461 pub fn with_shared_cache(db: &'a mut Database, cache: Arc<PreparedStatementCache>) -> Self {
463 Self { db, cache }
464 }
465
466 pub fn database(&self) -> &Database {
468 self.db
469 }
470
471 pub fn database_mut(&mut self) -> &mut Database {
473 self.db
474 }
475
476 pub fn cache(&self) -> &PreparedStatementCache {
478 &self.cache
479 }
480
481 pub fn shared_cache(&self) -> Arc<PreparedStatementCache> {
483 Arc::clone(&self.cache)
484 }
485
486 pub fn prepare(&self, sql: &str) -> Result<Arc<PreparedStatement>, SessionError> {
488 self.cache.get_or_prepare(sql).map_err(SessionError::from)
489 }
490
491 pub fn prepare_arena(&self, sql: &str) -> Result<Arc<ArenaPreparedStatement>, ArenaParseError> {
495 self.cache.get_or_prepare_arena(sql)
496 }
497
498 pub fn execute_prepared(
502 &self,
503 stmt: &PreparedStatement,
504 params: &[SqlValue],
505 ) -> Result<PreparedExecutionResult, SessionError> {
506 let bound_stmt = stmt.bind(params)?;
507 self.execute_statement_readonly(&bound_stmt)
508 }
509
510 pub fn execute_prepared_mut(
514 &mut self,
515 stmt: &PreparedStatement,
516 params: &[SqlValue],
517 ) -> Result<PreparedExecutionResult, SessionError> {
518 if let CachedPlan::PkDelete(plan) = stmt.cached_plan() {
520 if let Some(result) = self.try_execute_pk_delete(plan, params)? {
521 return Ok(result);
522 }
523 }
525
526 let bound_stmt = stmt.bind(params)?;
527 self.execute_statement_mut(&bound_stmt)
528 }
529
530 fn try_execute_pk_delete(
535 &mut self,
536 plan: &crate::cache::PkDeletePlan,
537 params: &[SqlValue],
538 ) -> Result<Option<PreparedExecutionResult>, SessionError> {
539 if let Some(valid) = plan.is_fast_path_valid() {
541 if !valid {
542 return Ok(None); }
544 } else {
546 let valid = self.validate_delete_fast_path(plan);
548 plan.set_fast_path_valid(valid);
549 if !valid {
550 return Ok(None);
551 }
552 }
553
554 let pk_values = plan.build_pk_values(params);
556
557 match self.db.delete_by_pk_fast(&plan.table_name, &pk_values) {
559 Ok(deleted) => {
560 let rows_affected = if deleted { 1 } else { 0 };
561 self.db.set_last_changes_count(rows_affected);
563 self.db.increment_total_changes_count(rows_affected);
564 Ok(Some(PreparedExecutionResult::RowsAffected(rows_affected)))
565 }
566 Err(_) => Ok(None), }
568 }
569
570 fn validate_delete_fast_path(&self, plan: &crate::cache::PkDeletePlan) -> bool {
573 let has_triggers = self
575 .db
576 .catalog
577 .get_triggers_for_table(&plan.table_name, Some(vibesql_ast::TriggerEvent::Delete))
578 .next()
579 .is_some();
580
581 if has_triggers {
582 return false;
583 }
584
585 let schema = match self.db.catalog.get_table(&plan.table_name) {
587 Some(s) => s,
588 None => return false, };
590
591 let has_pk = schema.get_primary_key_indices().is_some();
592 if has_pk {
593 let has_referencing_fks = self.db.catalog.list_tables().iter().any(|t| {
594 self.db
595 .catalog
596 .get_table(t)
597 .map(|s| {
598 s.foreign_keys
599 .iter()
600 .any(|fk| fk.parent_table.eq_ignore_ascii_case(&plan.table_name))
601 })
602 .unwrap_or(false)
603 });
604
605 if has_referencing_fks {
606 return false;
607 }
608 }
609
610 true }
612
613 fn execute_statement_readonly(
615 &self,
616 stmt: &Statement,
617 ) -> Result<PreparedExecutionResult, SessionError> {
618 match stmt {
619 Statement::Select(select_stmt) => {
620 let executor = SelectExecutor::new(self.db);
621 let result = executor.execute_with_columns(select_stmt)?;
622 Ok(PreparedExecutionResult::Select(result))
623 }
624 _ => Err(SessionError::UnsupportedStatement(
625 "Use execute_prepared_mut for DML statements".into(),
626 )),
627 }
628 }
629
630 fn execute_statement_mut(
632 &mut self,
633 stmt: &Statement,
634 ) -> Result<PreparedExecutionResult, SessionError> {
635 match stmt {
636 Statement::Select(select_stmt) => {
637 let executor = SelectExecutor::new(self.db);
638 let result = executor.execute_with_columns(select_stmt)?;
639 Ok(PreparedExecutionResult::Select(result))
640 }
641 Statement::Insert(insert_stmt) => {
642 let rows_affected = InsertExecutor::execute(self.db, insert_stmt)?;
643 self.db.set_last_changes_count(rows_affected);
645 self.db.increment_total_changes_count(rows_affected);
646 Ok(PreparedExecutionResult::RowsAffected(rows_affected))
651 }
652 Statement::Update(update_stmt) => {
653 let rows_affected = UpdateExecutor::execute(update_stmt, self.db)?;
654 self.db.set_last_changes_count(rows_affected);
656 self.db.increment_total_changes_count(rows_affected);
657 Ok(PreparedExecutionResult::RowsAffected(rows_affected))
662 }
663 Statement::Delete(delete_stmt) => {
664 let rows_affected = DeleteExecutor::execute(delete_stmt, self.db)?;
665 self.db.set_last_changes_count(rows_affected);
667 self.db.increment_total_changes_count(rows_affected);
668 Ok(PreparedExecutionResult::RowsAffected(rows_affected))
673 }
674 _ => Err(SessionError::UnsupportedStatement(format!(
675 "Statement type {:?} not supported for prepared execution",
676 std::mem::discriminant(stmt)
677 ))),
678 }
679 }
680}
681
682#[cfg(test)]
683mod tests {
684 use vibesql_catalog::{ColumnSchema, TableSchema};
685 use vibesql_types::DataType;
686
687 use super::*;
688
689 fn create_test_db() -> Database {
690 let mut db = Database::new();
691 db.catalog.set_case_sensitive_identifiers(false);
693
694 let columns = vec![
696 ColumnSchema::new("id".to_string(), DataType::Integer, false),
697 ColumnSchema::new(
698 "name".to_string(),
699 DataType::Varchar { max_length: Some(100) },
700 true,
701 ),
702 ];
703 let schema =
704 TableSchema::with_primary_key("users".to_string(), columns, vec!["id".to_string()]);
705 db.create_table(schema).unwrap();
706
707 let row1 =
709 Row::new(vec![SqlValue::Integer(1), SqlValue::Varchar(arcstr::ArcStr::from("Alice"))]);
710 let row2 =
711 Row::new(vec![SqlValue::Integer(2), SqlValue::Varchar(arcstr::ArcStr::from("Bob"))]);
712 let row3 = Row::new(vec![
713 SqlValue::Integer(3),
714 SqlValue::Varchar(arcstr::ArcStr::from("Charlie")),
715 ]);
716
717 db.insert_row("users", row1).unwrap();
718 db.insert_row("users", row2).unwrap();
719 db.insert_row("users", row3).unwrap();
720
721 db
722 }
723
724 #[test]
725 fn test_session_prepare() {
726 let db = create_test_db();
727 let session = Session::new(&db);
728
729 let stmt = session.prepare("SELECT * FROM users WHERE id = ?").unwrap();
730 assert_eq!(stmt.param_count(), 1);
731 }
732
733 #[test]
734 fn test_session_execute_prepared() {
735 let db = create_test_db();
736 let session = Session::new(&db);
737
738 let stmt = session.prepare("SELECT * FROM users WHERE id = ?").unwrap();
739
740 let result = session.execute_prepared(&stmt, &[SqlValue::Integer(1)]).unwrap();
742
743 if let PreparedExecutionResult::Select(select_result) = result {
744 assert_eq!(select_result.rows.len(), 1);
745 assert_eq!(select_result.rows[0].values[0], SqlValue::Integer(1));
746 assert_eq!(
747 select_result.rows[0].values[1],
748 SqlValue::Varchar(arcstr::ArcStr::from("Alice"))
749 );
750 } else {
751 panic!("Expected Select result");
752 }
753 }
754
755 #[test]
756 fn test_session_reuse_prepared() {
757 let db = create_test_db();
758 let session = Session::new(&db);
759
760 let stmt = session.prepare("SELECT * FROM users WHERE id = ?").unwrap();
761
762 let result1 = session.execute_prepared(&stmt, &[SqlValue::Integer(1)]).unwrap();
764 let result2 = session.execute_prepared(&stmt, &[SqlValue::Integer(2)]).unwrap();
765 let result3 = session.execute_prepared(&stmt, &[SqlValue::Integer(3)]).unwrap();
766
767 assert_eq!(
769 result1.rows().unwrap()[0].values[1],
770 SqlValue::Varchar(arcstr::ArcStr::from("Alice"))
771 );
772 assert_eq!(
773 result2.rows().unwrap()[0].values[1],
774 SqlValue::Varchar(arcstr::ArcStr::from("Bob"))
775 );
776 assert_eq!(
777 result3.rows().unwrap()[0].values[1],
778 SqlValue::Varchar(arcstr::ArcStr::from("Charlie"))
779 );
780
781 let stats = session.cache().stats();
783 assert_eq!(stats.misses, 1);
784 let _hits = stats.hits;
786 }
787
788 #[test]
789 fn test_session_param_count_mismatch() {
790 let db = create_test_db();
791 let session = Session::new(&db);
792
793 let stmt = session.prepare("SELECT * FROM users WHERE id = ?").unwrap();
794
795 let result = session.execute_prepared(&stmt, &[]);
797 assert!(result.is_err());
798
799 let result = session.execute_prepared(&stmt, &[SqlValue::Integer(1), SqlValue::Integer(2)]);
800 assert!(result.is_err());
801 }
802
803 #[test]
804 fn test_session_mut_insert() {
805 let mut db = create_test_db();
806 let mut session = SessionMut::new(&mut db);
807
808 let stmt = session.prepare("INSERT INTO users (id, name) VALUES (?, ?)").unwrap();
809
810 let result = session
811 .execute_prepared_mut(
812 &stmt,
813 &[SqlValue::Integer(4), SqlValue::Varchar(arcstr::ArcStr::from("David"))],
814 )
815 .unwrap();
816
817 assert_eq!(result.rows_affected(), Some(1));
818
819 let select_stmt = session.prepare("SELECT * FROM users WHERE id = ?").unwrap();
821 let select_result =
822 session.execute_prepared(&select_stmt, &[SqlValue::Integer(4)]).unwrap();
823
824 assert_eq!(select_result.rows().unwrap().len(), 1);
825 assert_eq!(
826 select_result.rows().unwrap()[0].values[1],
827 SqlValue::Varchar(arcstr::ArcStr::from("David"))
828 );
829 }
830
831 #[test]
832 fn test_session_mut_update() {
833 let mut db = create_test_db();
834 let mut session = SessionMut::new(&mut db);
835
836 let stmt = session.prepare("UPDATE users SET name = ? WHERE id = ?").unwrap();
837
838 let result = session
839 .execute_prepared_mut(
840 &stmt,
841 &[SqlValue::Varchar(arcstr::ArcStr::from("Alicia")), SqlValue::Integer(1)],
842 )
843 .unwrap();
844
845 assert_eq!(result.rows_affected(), Some(1));
846
847 let select_stmt = session.prepare("SELECT * FROM users WHERE id = ?").unwrap();
849 let select_result =
850 session.execute_prepared(&select_stmt, &[SqlValue::Integer(1)]).unwrap();
851
852 assert_eq!(
853 select_result.rows().unwrap()[0].values[1],
854 SqlValue::Varchar(arcstr::ArcStr::from("Alicia"))
855 );
856 }
857
858 #[test]
859 fn test_session_mut_delete() {
860 let mut db = create_test_db();
861 let mut session = SessionMut::new(&mut db);
862
863 let stmt = session.prepare("DELETE FROM users WHERE id = ?").unwrap();
864
865 let result = session.execute_prepared_mut(&stmt, &[SqlValue::Integer(1)]).unwrap();
866
867 assert_eq!(result.rows_affected(), Some(1));
868
869 let select_stmt = session.prepare("SELECT * FROM users WHERE id = ?").unwrap();
871 let select_result =
872 session.execute_prepared(&select_stmt, &[SqlValue::Integer(1)]).unwrap();
873
874 assert_eq!(select_result.rows().unwrap().len(), 0);
875 }
876
877 #[test]
878 fn test_shared_cache() {
879 let db = create_test_db();
880
881 let session1 = Session::new(&db);
883 let stmt = session1.prepare("SELECT * FROM users WHERE id = ?").unwrap();
884
885 let shared_cache = session1.shared_cache();
887 let initial_misses = session1.cache().stats().misses;
888
889 let session2 = Session::with_shared_cache(&db, shared_cache);
891
892 let _stmt2 = session2.prepare("SELECT * FROM users WHERE id = ?").unwrap();
894
895 assert_eq!(session2.cache().stats().misses, initial_misses);
897
898 let result1 = session1.execute_prepared(&stmt, &[SqlValue::Integer(1)]).unwrap();
900 let result2 = session2.execute_prepared(&stmt, &[SqlValue::Integer(2)]).unwrap();
901
902 assert_eq!(
903 result1.rows().unwrap()[0].values[1],
904 SqlValue::Varchar(arcstr::ArcStr::from("Alice"))
905 );
906 assert_eq!(
907 result2.rows().unwrap()[0].values[1],
908 SqlValue::Varchar(arcstr::ArcStr::from("Bob"))
909 );
910 }
911
912 #[test]
913 fn test_no_params_statement() {
914 let db = create_test_db();
915 let session = Session::new(&db);
916
917 let stmt = session.prepare("SELECT * FROM users").unwrap();
918 assert_eq!(stmt.param_count(), 0);
919
920 let result = session.execute_prepared(&stmt, &[]).unwrap();
921 assert_eq!(result.rows().unwrap().len(), 3);
922 }
923
924 #[test]
925 fn test_multiple_placeholders() {
926 let db = create_test_db();
927 let session = Session::new(&db);
928
929 let stmt = session.prepare("SELECT * FROM users WHERE id >= ? AND id <= ?").unwrap();
930 assert_eq!(stmt.param_count(), 2);
931
932 let result =
933 session.execute_prepared(&stmt, &[SqlValue::Integer(1), SqlValue::Integer(2)]).unwrap();
934
935 assert_eq!(result.rows().unwrap().len(), 2);
936 }
937
938 #[test]
939 fn test_session_prepare_arena() {
940 let db = create_test_db();
941 let session = Session::new(&db);
942
943 let stmt = session.prepare_arena("SELECT * FROM users WHERE id = ?").unwrap();
945 assert_eq!(stmt.param_count(), 1);
946
947 assert!(stmt.tables().contains("users"));
949
950 let stmt2 = session.prepare_arena("SELECT * FROM users WHERE id = ?").unwrap();
952 assert_eq!(stmt2.param_count(), 1);
953
954 assert!(std::sync::Arc::ptr_eq(&stmt, &stmt2));
956 }
957
958 #[test]
959 fn test_session_prepare_arena_no_params() {
960 let db = create_test_db();
961 let session = Session::new(&db);
962
963 let stmt = session.prepare_arena("SELECT * FROM users").unwrap();
964 assert_eq!(stmt.param_count(), 0);
965 }
966
967 #[test]
968 fn test_session_prepare_arena_join() {
969 let db = create_test_db();
970 let session = Session::new(&db);
971
972 use vibesql_catalog::{ColumnSchema, TableSchema};
974 use vibesql_types::DataType;
975
976 let orders_columns = vec![
977 ColumnSchema::new("id".to_string(), DataType::Integer, false),
978 ColumnSchema::new("user_id".to_string(), DataType::Integer, false),
979 ];
980 let _orders_schema = TableSchema::with_primary_key(
981 "orders".to_string(),
982 orders_columns,
983 vec!["id".to_string()],
984 );
985
986 let stmt = session
989 .prepare_arena("SELECT u.id FROM users u JOIN orders o ON u.id = o.user_id")
990 .unwrap();
991
992 let tables = stmt.tables();
994 assert!(tables.contains("users"), "Expected USERS in {:?}", tables);
995 assert!(tables.contains("orders"), "Expected orders in {:?}", tables);
996 }
997
998 #[test]
999 fn test_session_mut_prepare_arena() {
1000 let mut db = create_test_db();
1001 let session = SessionMut::new(&mut db);
1002
1003 let stmt = session.prepare_arena("SELECT * FROM users WHERE id = ?").unwrap();
1005 assert_eq!(stmt.param_count(), 1);
1006 }
1007
1008 #[test]
1009 fn test_delete_fast_path_plan() {
1010 use crate::cache::CachedPlan;
1011
1012 let mut db = create_test_db();
1013 let mut session = SessionMut::new(&mut db);
1014
1015 let stmt = session.prepare("DELETE FROM users WHERE id = ?").unwrap();
1017
1018 match stmt.cached_plan() {
1020 CachedPlan::PkDelete(plan) => {
1021 assert_eq!(plan.table_name, "users");
1023 assert_eq!(plan.pk_columns, vec!["id"]);
1025 assert_eq!(plan.param_to_pk_col, vec![(0, 0)]);
1027 assert!(plan.is_fast_path_valid().is_none());
1029 }
1030 other => panic!("Expected PkDelete plan, got {:?}", other),
1031 }
1032
1033 let result = session.execute_prepared_mut(&stmt, &[SqlValue::Integer(1)]).unwrap();
1035 assert_eq!(result.rows_affected(), Some(1));
1036
1037 match stmt.cached_plan() {
1039 CachedPlan::PkDelete(plan) => {
1040 assert_eq!(
1041 plan.is_fast_path_valid(),
1042 Some(true),
1043 "Fast path should be valid after execution"
1044 );
1045 }
1046 _ => panic!("Plan should still be PkDelete"),
1047 }
1048
1049 let select_stmt = session.prepare("SELECT * FROM users WHERE id = ?").unwrap();
1051 let select_result =
1052 session.execute_prepared(&select_stmt, &[SqlValue::Integer(1)]).unwrap();
1053 assert_eq!(select_result.rows().unwrap().len(), 0);
1054 }
1055
1056 #[test]
1064 fn test_concurrent_sessions_coexist() {
1065 let db = create_test_db();
1066
1067 let session1 = Session::new(&db);
1069 let session2 = Session::new(&db);
1070 let session3 = Session::new(&db);
1071
1072 let stmt1 = session1.prepare("SELECT * FROM users WHERE id = ?").unwrap();
1074 let stmt2 = session2.prepare("SELECT * FROM users WHERE id = ?").unwrap();
1075 let stmt3 = session3.prepare("SELECT name FROM users WHERE id = ?").unwrap();
1076
1077 let result1 = session1.execute_prepared(&stmt1, &[SqlValue::Integer(1)]).unwrap();
1079 let result2 = session2.execute_prepared(&stmt2, &[SqlValue::Integer(2)]).unwrap();
1080 let result3 = session3.execute_prepared(&stmt3, &[SqlValue::Integer(3)]).unwrap();
1081
1082 assert_eq!(
1084 result1.rows().unwrap()[0].values[1],
1085 SqlValue::Varchar(arcstr::ArcStr::from("Alice"))
1086 );
1087 assert_eq!(
1088 result2.rows().unwrap()[0].values[1],
1089 SqlValue::Varchar(arcstr::ArcStr::from("Bob"))
1090 );
1091 assert_eq!(
1092 result3.rows().unwrap()[0].values[0],
1093 SqlValue::Varchar(arcstr::ArcStr::from("Charlie"))
1094 );
1095 }
1096
1097 #[test]
1098 fn test_concurrent_sessions_shared_cache() {
1099 let db = create_test_db();
1100
1101 let shared_cache = Arc::new(PreparedStatementCache::default_cache());
1103
1104 let session1 = Session::with_shared_cache(&db, Arc::clone(&shared_cache));
1106 let session2 = Session::with_shared_cache(&db, Arc::clone(&shared_cache));
1107
1108 let stmt = session1.prepare("SELECT * FROM users WHERE id = ?").unwrap();
1110 let stats_after_first = shared_cache.stats();
1111 assert_eq!(stats_after_first.misses, 1);
1112
1113 let _stmt2 = session2.prepare("SELECT * FROM users WHERE id = ?").unwrap();
1115 let stats_after_second = shared_cache.stats();
1116 assert_eq!(stats_after_second.misses, 1); assert!(stats_after_second.hits >= 1); let r1 = session1.execute_prepared(&stmt, &[SqlValue::Integer(1)]).unwrap();
1121 let r2 = session2.execute_prepared(&stmt, &[SqlValue::Integer(2)]).unwrap();
1122
1123 assert_eq!(r1.rows().unwrap().len(), 1);
1124 assert_eq!(r2.rows().unwrap().len(), 1);
1125 }
1126
1127 #[test]
1128 fn test_concurrent_sessions_different_queries() {
1129 let db = create_test_db();
1130
1131 let session1 = Session::new(&db);
1132 let session2 = Session::new(&db);
1133
1134 let point_query = session1.prepare("SELECT * FROM users WHERE id = ?").unwrap();
1136 let range_query =
1137 session2.prepare("SELECT * FROM users WHERE id >= ? AND id <= ?").unwrap();
1138 let all_query = session1.prepare("SELECT * FROM users").unwrap();
1139 let projection_query = session2.prepare("SELECT name FROM users WHERE id = ?").unwrap();
1140
1141 let r1 = session1.execute_prepared(&point_query, &[SqlValue::Integer(1)]).unwrap();
1143 let r2 = session2
1144 .execute_prepared(&range_query, &[SqlValue::Integer(1), SqlValue::Integer(2)])
1145 .unwrap();
1146 let r3 = session1.execute_prepared(&all_query, &[]).unwrap();
1147 let r4 = session2.execute_prepared(&projection_query, &[SqlValue::Integer(3)]).unwrap();
1148
1149 assert_eq!(r1.rows().unwrap().len(), 1);
1150 assert_eq!(r2.rows().unwrap().len(), 2);
1151 assert_eq!(r3.rows().unwrap().len(), 3);
1152 assert_eq!(r4.rows().unwrap().len(), 1);
1153 }
1154
1155 #[test]
1156 fn test_concurrent_sessions_interleaved_execution() {
1157 let db = create_test_db();
1158
1159 let session1 = Session::new(&db);
1160 let session2 = Session::new(&db);
1161
1162 let stmt = session1.prepare("SELECT * FROM users WHERE id = ?").unwrap();
1163
1164 let r1 = session1.execute_prepared(&stmt, &[SqlValue::Integer(1)]).unwrap();
1167 let r2 = session2.execute_prepared(&stmt, &[SqlValue::Integer(2)]).unwrap();
1168 let r3 = session1.execute_prepared(&stmt, &[SqlValue::Integer(3)]).unwrap();
1169 let r4 = session2.execute_prepared(&stmt, &[SqlValue::Integer(1)]).unwrap();
1170
1171 assert_eq!(r1.rows().unwrap()[0].values[0], SqlValue::Integer(1));
1173 assert_eq!(r2.rows().unwrap()[0].values[0], SqlValue::Integer(2));
1174 assert_eq!(r3.rows().unwrap()[0].values[0], SqlValue::Integer(3));
1175 assert_eq!(r4.rows().unwrap()[0].values[0], SqlValue::Integer(1));
1176 }
1177
1178 #[test]
1179 fn test_session_immutable_borrow_allows_multiple() {
1180 let db = create_test_db();
1181
1182 let session1 = Session::new(&db);
1185 let session2 = Session::new(&db);
1186
1187 let db_ref1 = session1.database();
1190 let db_ref2 = session2.database();
1191
1192 assert!(std::ptr::eq(db_ref1, db_ref2));
1194
1195 let stmt = session1.prepare("SELECT COUNT(*) FROM users").unwrap();
1197 let r1 = session1.execute_prepared(&stmt, &[]).unwrap();
1198 let r2 = session2.execute_prepared(&stmt, &[]).unwrap();
1199
1200 assert_eq!(r1.rows().unwrap().len(), 1);
1201 assert_eq!(r2.rows().unwrap().len(), 1);
1202 }
1203
1204 #[test]
1205 fn test_session_execute_uses_immutable_self() {
1206 let db = create_test_db();
1207 let session = Session::new(&db);
1208
1209 let stmt = session.prepare("SELECT * FROM users WHERE id = ?").unwrap();
1211
1212 let _ = session.execute_prepared(&stmt, &[SqlValue::Integer(1)]);
1215 let _ = session.execute_prepared(&stmt, &[SqlValue::Integer(2)]);
1216 let _ = session.execute_prepared(&stmt, &[SqlValue::Integer(3)]);
1217
1218 let cache = session.cache();
1220 let _stats = cache.stats();
1221 }
1223
1224 #[test]
1225 fn test_concurrent_sessions_with_aggregates() {
1226 let db = create_test_db();
1227
1228 let session1 = Session::new(&db);
1229 let session2 = Session::new(&db);
1230
1231 let count_stmt = session1.prepare("SELECT COUNT(*) FROM users").unwrap();
1233 let sum_stmt = session2.prepare("SELECT COUNT(*) FROM users WHERE id <= ?").unwrap();
1234
1235 let r1 = session1.execute_prepared(&count_stmt, &[]).unwrap();
1236 let r2 = session2.execute_prepared(&sum_stmt, &[SqlValue::Integer(2)]).unwrap();
1237
1238 assert_eq!(r1.rows().unwrap()[0].values[0], SqlValue::Integer(3));
1240 assert_eq!(r2.rows().unwrap()[0].values[0], SqlValue::Integer(2));
1241 }
1242
1243 #[test]
1244 fn test_concurrent_sessions_with_pk_fast_path() {
1245 let db = create_test_db();
1246
1247 let session1 = Session::new(&db);
1248 let session2 = Session::new(&db);
1249
1250 let stmt = session1.prepare("SELECT * FROM users WHERE id = ?").unwrap();
1252
1253 let r1 = session1.execute_prepared(&stmt, &[SqlValue::Integer(1)]).unwrap();
1255 let r2 = session2.execute_prepared(&stmt, &[SqlValue::Integer(2)]).unwrap();
1256
1257 assert_eq!(
1258 r1.rows().unwrap()[0].values[1],
1259 SqlValue::Varchar(arcstr::ArcStr::from("Alice"))
1260 );
1261 assert_eq!(r2.rows().unwrap()[0].values[1], SqlValue::Varchar(arcstr::ArcStr::from("Bob")));
1262 }
1263}