1use std::sync::Arc;
46
47use vibesql_ast::Statement;
48use vibesql_storage::{Database, Row};
49use vibesql_types::SqlValue;
50
51use crate::cache::{
52 ArenaParseError, ArenaPreparedStatement, CachedPlan, PkPointLookupPlan, PreparedStatement,
53 PreparedStatementCache, PreparedStatementError, ProjectionPlan, ResolvedProjection,
54};
55use crate::errors::ExecutorError;
56use crate::{DeleteExecutor, InsertExecutor, SelectExecutor, SelectResult, UpdateExecutor};
57
58#[derive(Debug)]
60pub enum PreparedExecutionResult {
61 Select(SelectResult),
63 RowsAffected(usize),
65 Ok,
67}
68
69impl PreparedExecutionResult {
70 pub fn rows(&self) -> Option<&[Row]> {
72 match self {
73 PreparedExecutionResult::Select(result) => Some(&result.rows),
74 _ => None,
75 }
76 }
77
78 pub fn rows_affected(&self) -> Option<usize> {
80 match self {
81 PreparedExecutionResult::RowsAffected(n) => Some(*n),
82 _ => None,
83 }
84 }
85
86 pub fn into_select_result(self) -> Option<SelectResult> {
88 match self {
89 PreparedExecutionResult::Select(result) => Some(result),
90 _ => None,
91 }
92 }
93}
94
95#[derive(Debug)]
97pub enum SessionError {
98 PreparedStatement(PreparedStatementError),
100 Execution(ExecutorError),
102 UnsupportedStatement(String),
104}
105
106impl std::fmt::Display for SessionError {
107 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
108 match self {
109 SessionError::PreparedStatement(e) => write!(f, "Prepared statement error: {}", e),
110 SessionError::Execution(e) => write!(f, "Execution error: {:?}", e),
111 SessionError::UnsupportedStatement(msg) => write!(f, "Unsupported statement: {}", msg),
112 }
113 }
114}
115
116impl std::error::Error for SessionError {}
117
118impl From<PreparedStatementError> for SessionError {
119 fn from(e: PreparedStatementError) -> Self {
120 SessionError::PreparedStatement(e)
121 }
122}
123
124impl From<ExecutorError> for SessionError {
125 fn from(e: ExecutorError) -> Self {
126 SessionError::Execution(e)
127 }
128}
129
130pub struct Session<'a> {
135 db: &'a Database,
136 cache: Arc<PreparedStatementCache>,
137}
138
139impl<'a> Session<'a> {
140 pub fn new(db: &'a Database) -> Self {
144 Self { db, cache: Arc::new(PreparedStatementCache::default_cache()) }
145 }
146
147 pub fn with_cache_size(db: &'a Database, cache_size: usize) -> Self {
149 Self { db, cache: Arc::new(PreparedStatementCache::new(cache_size)) }
150 }
151
152 pub fn with_shared_cache(db: &'a Database, cache: Arc<PreparedStatementCache>) -> Self {
157 Self { db, cache }
158 }
159
160 pub fn database(&self) -> &Database {
162 self.db
163 }
164
165 pub fn cache(&self) -> &PreparedStatementCache {
167 &self.cache
168 }
169
170 pub fn shared_cache(&self) -> Arc<PreparedStatementCache> {
172 Arc::clone(&self.cache)
173 }
174
175 pub fn prepare(&self, sql: &str) -> Result<Arc<PreparedStatement>, SessionError> {
189 self.cache.get_or_prepare(sql).map_err(SessionError::from)
190 }
191
192 pub fn prepare_arena(&self, sql: &str) -> Result<Arc<ArenaPreparedStatement>, ArenaParseError> {
214 self.cache.get_or_prepare_arena(sql)
215 }
216
217 pub fn execute_prepared(
232 &self,
233 stmt: &PreparedStatement,
234 params: &[SqlValue],
235 ) -> Result<PreparedExecutionResult, SessionError> {
236 if params.len() != stmt.param_count() {
238 return Err(SessionError::PreparedStatement(
239 PreparedStatementError::ParameterCountMismatch {
240 expected: stmt.param_count(),
241 actual: params.len(),
242 },
243 ));
244 }
245
246 match stmt.cached_plan() {
248 CachedPlan::PkPointLookup(plan) => {
249 if let Some(result) = self.try_execute_pk_lookup(plan, params)? {
250 return Ok(result);
251 }
252 }
254 CachedPlan::SimpleFastPath(plan) => {
255 let bound_stmt = stmt.bind(params)?;
258 if let Statement::Select(select_stmt) = &bound_stmt {
259 let executor = SelectExecutor::new(self.db);
260
261 let columns = plan.get_or_resolve_columns(|| {
263 executor.derive_fast_path_column_names(select_stmt).ok()
264 });
265
266 match columns {
267 Some(cached_columns) => {
268 let rows = executor.execute_fast_path(select_stmt)?;
270 return Ok(PreparedExecutionResult::Select(SelectResult {
271 columns: cached_columns.iter().cloned().collect(),
272 rows,
273 }));
274 }
275 None => {
276 let result = executor.execute_fast_path_with_columns(select_stmt)?;
278 return Ok(PreparedExecutionResult::Select(result));
279 }
280 }
281 }
282 }
284 CachedPlan::PkDelete(_) => {
285 }
288 CachedPlan::Standard => {
289 }
291 }
292
293 let bound_stmt = stmt.bind(params)?;
295
296 self.execute_statement(&bound_stmt)
298 }
299
300 fn try_execute_pk_lookup(
306 &self,
307 plan: &PkPointLookupPlan,
308 params: &[SqlValue],
309 ) -> Result<Option<PreparedExecutionResult>, SessionError> {
310 let table = match self.db.get_table(&plan.table_name) {
312 Some(t) => t,
313 None => return Ok(None), };
315
316 let actual_pk_columns = match &table.schema.primary_key {
318 Some(cols) if cols.len() == plan.pk_columns.len() => cols,
319 _ => return Ok(None), };
321
322 for (param_idx, pk_col_idx) in &plan.param_to_pk_col {
325 if *param_idx >= params.len() || *pk_col_idx >= plan.pk_columns.len() {
326 return Ok(None); }
328
329 let expected_col = &plan.pk_columns[*pk_col_idx];
331 let actual_col = &actual_pk_columns[*pk_col_idx];
332 if !expected_col.eq_ignore_ascii_case(actual_col) {
333 return Ok(None); }
335 }
336
337 let resolved = match plan.get_or_resolve(|proj| {
339 self.resolve_projection(proj, &table.schema.columns)
340 }) {
341 Some(r) => r,
342 None => return Ok(None), };
344
345 let row = if plan.param_to_pk_col.len() == 1 {
348 let (param_idx, _) = plan.param_to_pk_col[0];
349 self.db
350 .get_row_by_pk(&plan.table_name, ¶ms[param_idx])
351 .map_err(|e| SessionError::Execution(ExecutorError::StorageError(e.to_string())))?
352 } else {
353 let pk_values: Vec<SqlValue> = plan
355 .param_to_pk_col
356 .iter()
357 .map(|(param_idx, _)| params[*param_idx].clone())
358 .collect();
359 self.db
360 .get_row_by_composite_pk(&plan.table_name, &pk_values)
361 .map_err(|e| SessionError::Execution(ExecutorError::StorageError(e.to_string())))?
362 };
363
364 let columns: Vec<String> = resolved.column_names.iter().cloned().collect();
366
367 let rows = match row {
368 Some(r) => {
369 if resolved.column_indices.is_empty() {
371 vec![r.clone()]
373 } else {
374 let projected_values: Vec<SqlValue> = resolved
376 .column_indices
377 .iter()
378 .map(|&i| r.values[i].clone())
379 .collect();
380 vec![Row::new(projected_values)]
381 }
382 }
383 None => vec![],
384 };
385
386 Ok(Some(PreparedExecutionResult::Select(SelectResult {
387 columns,
388 rows,
389 })))
390 }
391
392 fn resolve_projection(
396 &self,
397 proj: &ProjectionPlan,
398 schema_columns: &[vibesql_catalog::ColumnSchema],
399 ) -> Option<ResolvedProjection> {
400 match proj {
401 ProjectionPlan::Wildcard => {
402 let column_names: Arc<[String]> = schema_columns
405 .iter()
406 .map(|c| c.name.clone())
407 .collect();
408 Some(ResolvedProjection {
409 column_indices: vec![],
410 column_names,
411 })
412 }
413 ProjectionPlan::Columns(projections) => {
414 let mut col_indices = Vec::with_capacity(projections.len());
415 let mut column_names = Vec::with_capacity(projections.len());
416
417 for proj in projections {
418 let idx = schema_columns
419 .iter()
420 .position(|c| c.name.eq_ignore_ascii_case(&proj.column_name))?;
421
422 col_indices.push(idx);
423 column_names.push(
424 proj.alias
425 .clone()
426 .unwrap_or_else(|| proj.column_name.clone()),
427 );
428 }
429
430 Some(ResolvedProjection {
431 column_indices: col_indices,
432 column_names: column_names.into(),
433 })
434 }
435 }
436 }
437
438 fn execute_statement(&self, stmt: &Statement) -> Result<PreparedExecutionResult, SessionError> {
440 match stmt {
441 Statement::Select(select_stmt) => {
442 let executor = SelectExecutor::new(self.db);
443 let result = executor.execute_with_columns(select_stmt)?;
444 Ok(PreparedExecutionResult::Select(result))
445 }
446 _ => Err(SessionError::UnsupportedStatement(
447 "Only SELECT is supported for read-only sessions. Use SessionMut for DML.".into(),
448 )),
449 }
450 }
451}
452
453pub struct SessionMut<'a> {
458 db: &'a mut Database,
459 cache: Arc<PreparedStatementCache>,
460}
461
462impl<'a> SessionMut<'a> {
463 pub fn new(db: &'a mut Database) -> Self {
465 Self { db, cache: Arc::new(PreparedStatementCache::default_cache()) }
466 }
467
468 pub fn with_cache_size(db: &'a mut Database, cache_size: usize) -> Self {
470 Self { db, cache: Arc::new(PreparedStatementCache::new(cache_size)) }
471 }
472
473 pub fn with_shared_cache(db: &'a mut Database, cache: Arc<PreparedStatementCache>) -> Self {
475 Self { db, cache }
476 }
477
478 pub fn database(&self) -> &Database {
480 self.db
481 }
482
483 pub fn database_mut(&mut self) -> &mut Database {
485 self.db
486 }
487
488 pub fn cache(&self) -> &PreparedStatementCache {
490 &self.cache
491 }
492
493 pub fn shared_cache(&self) -> Arc<PreparedStatementCache> {
495 Arc::clone(&self.cache)
496 }
497
498 pub fn prepare(&self, sql: &str) -> Result<Arc<PreparedStatement>, SessionError> {
500 self.cache.get_or_prepare(sql).map_err(SessionError::from)
501 }
502
503 pub fn prepare_arena(&self, sql: &str) -> Result<Arc<ArenaPreparedStatement>, ArenaParseError> {
507 self.cache.get_or_prepare_arena(sql)
508 }
509
510 pub fn execute_prepared(
514 &self,
515 stmt: &PreparedStatement,
516 params: &[SqlValue],
517 ) -> Result<PreparedExecutionResult, SessionError> {
518 let bound_stmt = stmt.bind(params)?;
519 self.execute_statement_readonly(&bound_stmt)
520 }
521
522 pub fn execute_prepared_mut(
526 &mut self,
527 stmt: &PreparedStatement,
528 params: &[SqlValue],
529 ) -> Result<PreparedExecutionResult, SessionError> {
530 if let CachedPlan::PkDelete(plan) = stmt.cached_plan() {
532 if let Some(result) = self.try_execute_pk_delete(plan, params)? {
533 return Ok(result);
534 }
535 }
537
538 let bound_stmt = stmt.bind(params)?;
539 self.execute_statement_mut(&bound_stmt)
540 }
541
542 fn try_execute_pk_delete(
547 &mut self,
548 plan: &crate::cache::PkDeletePlan,
549 params: &[SqlValue],
550 ) -> Result<Option<PreparedExecutionResult>, SessionError> {
551 if let Some(valid) = plan.is_fast_path_valid() {
553 if !valid {
554 return Ok(None); }
556 } else {
558 let valid = self.validate_delete_fast_path(plan);
560 plan.set_fast_path_valid(valid);
561 if !valid {
562 return Ok(None);
563 }
564 }
565
566 let pk_values = plan.build_pk_values(params);
568
569 match self.db.delete_by_pk_fast(&plan.table_name, &pk_values) {
571 Ok(deleted) => Ok(Some(PreparedExecutionResult::RowsAffected(if deleted {
572 1
573 } else {
574 0
575 }))),
576 Err(_) => Ok(None), }
578 }
579
580 fn validate_delete_fast_path(&self, plan: &crate::cache::PkDeletePlan) -> bool {
583 let has_triggers = self
585 .db
586 .catalog
587 .get_triggers_for_table(&plan.table_name, Some(vibesql_ast::TriggerEvent::Delete))
588 .next()
589 .is_some();
590
591 if has_triggers {
592 return false;
593 }
594
595 let schema = match self.db.catalog.get_table(&plan.table_name) {
597 Some(s) => s,
598 None => return false, };
600
601 let has_pk = schema.get_primary_key_indices().is_some();
602 if has_pk {
603 let has_referencing_fks = self.db.catalog.list_tables().iter().any(|t| {
604 self.db
605 .catalog
606 .get_table(t)
607 .map(|s| {
608 s.foreign_keys
609 .iter()
610 .any(|fk| fk.parent_table.eq_ignore_ascii_case(&plan.table_name))
611 })
612 .unwrap_or(false)
613 });
614
615 if has_referencing_fks {
616 return false;
617 }
618 }
619
620 true }
622
623 fn execute_statement_readonly(
625 &self,
626 stmt: &Statement,
627 ) -> Result<PreparedExecutionResult, SessionError> {
628 match stmt {
629 Statement::Select(select_stmt) => {
630 let executor = SelectExecutor::new(self.db);
631 let result = executor.execute_with_columns(select_stmt)?;
632 Ok(PreparedExecutionResult::Select(result))
633 }
634 _ => Err(SessionError::UnsupportedStatement(
635 "Use execute_prepared_mut for DML statements".into(),
636 )),
637 }
638 }
639
640 fn execute_statement_mut(
642 &mut self,
643 stmt: &Statement,
644 ) -> Result<PreparedExecutionResult, SessionError> {
645 match stmt {
646 Statement::Select(select_stmt) => {
647 let executor = SelectExecutor::new(self.db);
648 let result = executor.execute_with_columns(select_stmt)?;
649 Ok(PreparedExecutionResult::Select(result))
650 }
651 Statement::Insert(insert_stmt) => {
652 let rows_affected = InsertExecutor::execute(self.db, insert_stmt)?;
653 Ok(PreparedExecutionResult::RowsAffected(rows_affected))
658 }
659 Statement::Update(update_stmt) => {
660 let rows_affected = UpdateExecutor::execute(update_stmt, self.db)?;
661 Ok(PreparedExecutionResult::RowsAffected(rows_affected))
666 }
667 Statement::Delete(delete_stmt) => {
668 let rows_affected = DeleteExecutor::execute(delete_stmt, self.db)?;
669 Ok(PreparedExecutionResult::RowsAffected(rows_affected))
674 }
675 _ => Err(SessionError::UnsupportedStatement(format!(
676 "Statement type {:?} not supported for prepared execution",
677 std::mem::discriminant(stmt)
678 ))),
679 }
680 }
681}
682
683#[cfg(test)]
684mod tests {
685 use super::*;
686 use vibesql_catalog::{ColumnSchema, TableSchema};
687 use vibesql_types::DataType;
688
689 fn create_test_db() -> Database {
690 let mut db = Database::new();
691 db.catalog.set_case_sensitive_identifiers(false);
693
694 let columns = vec![
696 ColumnSchema::new("id".to_string(), DataType::Integer, false),
697 ColumnSchema::new(
698 "name".to_string(),
699 DataType::Varchar { max_length: Some(100) },
700 true,
701 ),
702 ];
703 let schema =
704 TableSchema::with_primary_key("users".to_string(), columns, vec!["id".to_string()]);
705 db.create_table(schema).unwrap();
706
707 let row1 = Row::new(vec![SqlValue::Integer(1), SqlValue::Varchar(arcstr::ArcStr::from("Alice"))]);
709 let row2 = Row::new(vec![SqlValue::Integer(2), SqlValue::Varchar(arcstr::ArcStr::from("Bob"))]);
710 let row3 = Row::new(vec![SqlValue::Integer(3), SqlValue::Varchar(arcstr::ArcStr::from("Charlie"))]);
711
712 db.insert_row("users", row1).unwrap();
713 db.insert_row("users", row2).unwrap();
714 db.insert_row("users", row3).unwrap();
715
716 db
717 }
718
719 #[test]
720 fn test_session_prepare() {
721 let db = create_test_db();
722 let session = Session::new(&db);
723
724 let stmt = session.prepare("SELECT * FROM users WHERE id = ?").unwrap();
725 assert_eq!(stmt.param_count(), 1);
726 }
727
728 #[test]
729 fn test_session_execute_prepared() {
730 let db = create_test_db();
731 let session = Session::new(&db);
732
733 let stmt = session.prepare("SELECT * FROM users WHERE id = ?").unwrap();
734
735 let result = session.execute_prepared(&stmt, &[SqlValue::Integer(1)]).unwrap();
737
738 if let PreparedExecutionResult::Select(select_result) = result {
739 assert_eq!(select_result.rows.len(), 1);
740 assert_eq!(select_result.rows[0].values[0], SqlValue::Integer(1));
741 assert_eq!(select_result.rows[0].values[1], SqlValue::Varchar(arcstr::ArcStr::from("Alice")));
742 } else {
743 panic!("Expected Select result");
744 }
745 }
746
747 #[test]
748 fn test_session_reuse_prepared() {
749 let db = create_test_db();
750 let session = Session::new(&db);
751
752 let stmt = session.prepare("SELECT * FROM users WHERE id = ?").unwrap();
753
754 let result1 = session.execute_prepared(&stmt, &[SqlValue::Integer(1)]).unwrap();
756 let result2 = session.execute_prepared(&stmt, &[SqlValue::Integer(2)]).unwrap();
757 let result3 = session.execute_prepared(&stmt, &[SqlValue::Integer(3)]).unwrap();
758
759 assert_eq!(result1.rows().unwrap()[0].values[1], SqlValue::Varchar(arcstr::ArcStr::from("Alice")));
761 assert_eq!(result2.rows().unwrap()[0].values[1], SqlValue::Varchar(arcstr::ArcStr::from("Bob")));
762 assert_eq!(result3.rows().unwrap()[0].values[1], SqlValue::Varchar(arcstr::ArcStr::from("Charlie")));
763
764 let stats = session.cache().stats();
766 assert_eq!(stats.misses, 1);
767 let _hits = stats.hits;
769 }
770
771 #[test]
772 fn test_session_param_count_mismatch() {
773 let db = create_test_db();
774 let session = Session::new(&db);
775
776 let stmt = session.prepare("SELECT * FROM users WHERE id = ?").unwrap();
777
778 let result = session.execute_prepared(&stmt, &[]);
780 assert!(result.is_err());
781
782 let result = session.execute_prepared(&stmt, &[SqlValue::Integer(1), SqlValue::Integer(2)]);
783 assert!(result.is_err());
784 }
785
786 #[test]
787 fn test_session_mut_insert() {
788 let mut db = create_test_db();
789 let mut session = SessionMut::new(&mut db);
790
791 let stmt = session.prepare("INSERT INTO users (id, name) VALUES (?, ?)").unwrap();
792
793 let result = session
794 .execute_prepared_mut(&stmt, &[SqlValue::Integer(4), SqlValue::Varchar(arcstr::ArcStr::from("David"))])
795 .unwrap();
796
797 assert_eq!(result.rows_affected(), Some(1));
798
799 let select_stmt = session.prepare("SELECT * FROM users WHERE id = ?").unwrap();
801 let select_result =
802 session.execute_prepared(&select_stmt, &[SqlValue::Integer(4)]).unwrap();
803
804 assert_eq!(select_result.rows().unwrap().len(), 1);
805 assert_eq!(select_result.rows().unwrap()[0].values[1], SqlValue::Varchar(arcstr::ArcStr::from("David")));
806 }
807
808 #[test]
809 fn test_session_mut_update() {
810 let mut db = create_test_db();
811 let mut session = SessionMut::new(&mut db);
812
813 let stmt = session.prepare("UPDATE users SET name = ? WHERE id = ?").unwrap();
814
815 let result = session
816 .execute_prepared_mut(
817 &stmt,
818 &[SqlValue::Varchar(arcstr::ArcStr::from("Alicia")), SqlValue::Integer(1)],
819 )
820 .unwrap();
821
822 assert_eq!(result.rows_affected(), Some(1));
823
824 let select_stmt = session.prepare("SELECT * FROM users WHERE id = ?").unwrap();
826 let select_result =
827 session.execute_prepared(&select_stmt, &[SqlValue::Integer(1)]).unwrap();
828
829 assert_eq!(select_result.rows().unwrap()[0].values[1], SqlValue::Varchar(arcstr::ArcStr::from("Alicia")));
830 }
831
832 #[test]
833 fn test_session_mut_delete() {
834 let mut db = create_test_db();
835 let mut session = SessionMut::new(&mut db);
836
837 let stmt = session.prepare("DELETE FROM users WHERE id = ?").unwrap();
838
839 let result = session.execute_prepared_mut(&stmt, &[SqlValue::Integer(1)]).unwrap();
840
841 assert_eq!(result.rows_affected(), Some(1));
842
843 let select_stmt = session.prepare("SELECT * FROM users WHERE id = ?").unwrap();
845 let select_result =
846 session.execute_prepared(&select_stmt, &[SqlValue::Integer(1)]).unwrap();
847
848 assert_eq!(select_result.rows().unwrap().len(), 0);
849 }
850
851 #[test]
852 fn test_shared_cache() {
853 let db = create_test_db();
854
855 let session1 = Session::new(&db);
857 let stmt = session1.prepare("SELECT * FROM users WHERE id = ?").unwrap();
858
859 let shared_cache = session1.shared_cache();
861 let initial_misses = session1.cache().stats().misses;
862
863 let session2 = Session::with_shared_cache(&db, shared_cache);
865
866 let _stmt2 = session2.prepare("SELECT * FROM users WHERE id = ?").unwrap();
868
869 assert_eq!(session2.cache().stats().misses, initial_misses);
871
872 let result1 = session1.execute_prepared(&stmt, &[SqlValue::Integer(1)]).unwrap();
874 let result2 = session2.execute_prepared(&stmt, &[SqlValue::Integer(2)]).unwrap();
875
876 assert_eq!(result1.rows().unwrap()[0].values[1], SqlValue::Varchar(arcstr::ArcStr::from("Alice")));
877 assert_eq!(result2.rows().unwrap()[0].values[1], SqlValue::Varchar(arcstr::ArcStr::from("Bob")));
878 }
879
880 #[test]
881 fn test_no_params_statement() {
882 let db = create_test_db();
883 let session = Session::new(&db);
884
885 let stmt = session.prepare("SELECT * FROM users").unwrap();
886 assert_eq!(stmt.param_count(), 0);
887
888 let result = session.execute_prepared(&stmt, &[]).unwrap();
889 assert_eq!(result.rows().unwrap().len(), 3);
890 }
891
892 #[test]
893 fn test_multiple_placeholders() {
894 let db = create_test_db();
895 let session = Session::new(&db);
896
897 let stmt = session.prepare("SELECT * FROM users WHERE id >= ? AND id <= ?").unwrap();
898 assert_eq!(stmt.param_count(), 2);
899
900 let result =
901 session.execute_prepared(&stmt, &[SqlValue::Integer(1), SqlValue::Integer(2)]).unwrap();
902
903 assert_eq!(result.rows().unwrap().len(), 2);
904 }
905
906 #[test]
907 fn test_session_prepare_arena() {
908 let db = create_test_db();
909 let session = Session::new(&db);
910
911 let stmt = session.prepare_arena("SELECT * FROM users WHERE id = ?").unwrap();
913 assert_eq!(stmt.param_count(), 1);
914
915 assert!(stmt.tables().contains("USERS"));
917
918 let stmt2 = session.prepare_arena("SELECT * FROM users WHERE id = ?").unwrap();
920 assert_eq!(stmt2.param_count(), 1);
921
922 assert!(std::sync::Arc::ptr_eq(&stmt, &stmt2));
924 }
925
926 #[test]
927 fn test_session_prepare_arena_no_params() {
928 let db = create_test_db();
929 let session = Session::new(&db);
930
931 let stmt = session.prepare_arena("SELECT * FROM users").unwrap();
932 assert_eq!(stmt.param_count(), 0);
933 }
934
935 #[test]
936 fn test_session_prepare_arena_join() {
937 let db = create_test_db();
938 let session = Session::new(&db);
939
940 use vibesql_catalog::{ColumnSchema, TableSchema};
942 use vibesql_types::DataType;
943
944 let orders_columns = vec![
945 ColumnSchema::new("id".to_string(), DataType::Integer, false),
946 ColumnSchema::new("user_id".to_string(), DataType::Integer, false),
947 ];
948 let _orders_schema = TableSchema::with_primary_key(
949 "orders".to_string(),
950 orders_columns,
951 vec!["id".to_string()],
952 );
953
954 let stmt = session
957 .prepare_arena("SELECT u.id FROM users u JOIN orders o ON u.id = o.user_id")
958 .unwrap();
959
960 let tables = stmt.tables();
962 assert!(tables.contains("USERS"), "Expected USERS in {:?}", tables);
963 assert!(tables.contains("ORDERS"), "Expected ORDERS in {:?}", tables);
964 }
965
966 #[test]
967 fn test_session_mut_prepare_arena() {
968 let mut db = create_test_db();
969 let session = SessionMut::new(&mut db);
970
971 let stmt = session.prepare_arena("SELECT * FROM users WHERE id = ?").unwrap();
973 assert_eq!(stmt.param_count(), 1);
974 }
975
976 #[test]
977 fn test_delete_fast_path_plan() {
978 use crate::cache::CachedPlan;
979
980 let mut db = create_test_db();
981 let mut session = SessionMut::new(&mut db);
982
983 let stmt = session.prepare("DELETE FROM users WHERE id = ?").unwrap();
985
986 match stmt.cached_plan() {
988 CachedPlan::PkDelete(plan) => {
989 assert_eq!(plan.table_name, "USERS");
991 assert_eq!(plan.pk_columns, vec!["ID"]);
993 assert_eq!(plan.param_to_pk_col, vec![(0, 0)]);
995 assert!(plan.is_fast_path_valid().is_none());
997 }
998 other => panic!("Expected PkDelete plan, got {:?}", other),
999 }
1000
1001 let result = session.execute_prepared_mut(&stmt, &[SqlValue::Integer(1)]).unwrap();
1003 assert_eq!(result.rows_affected(), Some(1));
1004
1005 match stmt.cached_plan() {
1007 CachedPlan::PkDelete(plan) => {
1008 assert_eq!(plan.is_fast_path_valid(), Some(true), "Fast path should be valid after execution");
1009 }
1010 _ => panic!("Plan should still be PkDelete"),
1011 }
1012
1013 let select_stmt = session.prepare("SELECT * FROM users WHERE id = ?").unwrap();
1015 let select_result = session.execute_prepared(&select_stmt, &[SqlValue::Integer(1)]).unwrap();
1016 assert_eq!(select_result.rows().unwrap().len(), 0);
1017 }
1018}