1use std::sync::Arc;
46
47use vibesql_ast::Statement;
48use vibesql_storage::{Database, Row};
49use vibesql_types::SqlValue;
50
51use crate::cache::{
52 ArenaParseError, ArenaPreparedStatement, CachedPlan, PkPointLookupPlan, PreparedStatement,
53 PreparedStatementCache, PreparedStatementError, ProjectionPlan, ResolvedProjection,
54};
55use crate::errors::ExecutorError;
56use crate::{DeleteExecutor, InsertExecutor, SelectExecutor, SelectResult, UpdateExecutor};
57
58#[derive(Debug)]
60pub enum PreparedExecutionResult {
61 Select(SelectResult),
63 RowsAffected(usize),
65 Ok,
67}
68
69impl PreparedExecutionResult {
70 pub fn rows(&self) -> Option<&[Row]> {
72 match self {
73 PreparedExecutionResult::Select(result) => Some(&result.rows),
74 _ => None,
75 }
76 }
77
78 pub fn rows_affected(&self) -> Option<usize> {
80 match self {
81 PreparedExecutionResult::RowsAffected(n) => Some(*n),
82 _ => None,
83 }
84 }
85
86 pub fn into_select_result(self) -> Option<SelectResult> {
88 match self {
89 PreparedExecutionResult::Select(result) => Some(result),
90 _ => None,
91 }
92 }
93}
94
95#[derive(Debug)]
97pub enum SessionError {
98 PreparedStatement(PreparedStatementError),
100 Execution(ExecutorError),
102 UnsupportedStatement(String),
104}
105
106impl std::fmt::Display for SessionError {
107 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
108 match self {
109 SessionError::PreparedStatement(e) => write!(f, "Prepared statement error: {}", e),
110 SessionError::Execution(e) => write!(f, "Execution error: {:?}", e),
111 SessionError::UnsupportedStatement(msg) => write!(f, "Unsupported statement: {}", msg),
112 }
113 }
114}
115
116impl std::error::Error for SessionError {}
117
118impl From<PreparedStatementError> for SessionError {
119 fn from(e: PreparedStatementError) -> Self {
120 SessionError::PreparedStatement(e)
121 }
122}
123
124impl From<ExecutorError> for SessionError {
125 fn from(e: ExecutorError) -> Self {
126 SessionError::Execution(e)
127 }
128}
129
130pub struct Session<'a> {
135 db: &'a Database,
136 cache: Arc<PreparedStatementCache>,
137}
138
139impl<'a> Session<'a> {
140 pub fn new(db: &'a Database) -> Self {
144 Self { db, cache: Arc::new(PreparedStatementCache::default_cache()) }
145 }
146
147 pub fn with_cache_size(db: &'a Database, cache_size: usize) -> Self {
149 Self { db, cache: Arc::new(PreparedStatementCache::new(cache_size)) }
150 }
151
152 pub fn with_shared_cache(db: &'a Database, cache: Arc<PreparedStatementCache>) -> Self {
157 Self { db, cache }
158 }
159
160 pub fn database(&self) -> &Database {
162 self.db
163 }
164
165 pub fn cache(&self) -> &PreparedStatementCache {
167 &self.cache
168 }
169
170 pub fn shared_cache(&self) -> Arc<PreparedStatementCache> {
172 Arc::clone(&self.cache)
173 }
174
175 pub fn prepare(&self, sql: &str) -> Result<Arc<PreparedStatement>, SessionError> {
189 self.cache.get_or_prepare(sql).map_err(SessionError::from)
190 }
191
192 pub fn prepare_arena(&self, sql: &str) -> Result<Arc<ArenaPreparedStatement>, ArenaParseError> {
214 self.cache.get_or_prepare_arena(sql)
215 }
216
217 pub fn execute_prepared(
232 &self,
233 stmt: &PreparedStatement,
234 params: &[SqlValue],
235 ) -> Result<PreparedExecutionResult, SessionError> {
236 if params.len() != stmt.param_count() {
238 return Err(SessionError::PreparedStatement(
239 PreparedStatementError::ParameterCountMismatch {
240 expected: stmt.param_count(),
241 actual: params.len(),
242 },
243 ));
244 }
245
246 if let CachedPlan::PkPointLookup(plan) = stmt.cached_plan() {
248 if let Some(result) = self.try_execute_pk_lookup(plan, params)? {
249 return Ok(result);
250 }
251 }
253
254 let bound_stmt = stmt.bind(params)?;
256
257 self.execute_statement(&bound_stmt)
259 }
260
261 fn try_execute_pk_lookup(
267 &self,
268 plan: &PkPointLookupPlan,
269 params: &[SqlValue],
270 ) -> Result<Option<PreparedExecutionResult>, SessionError> {
271 let table = match self.db.get_table(&plan.table_name) {
273 Some(t) => t,
274 None => return Ok(None), };
276
277 let actual_pk_columns = match &table.schema.primary_key {
279 Some(cols) if cols.len() == plan.pk_columns.len() => cols,
280 _ => return Ok(None), };
282
283 for (param_idx, pk_col_idx) in &plan.param_to_pk_col {
286 if *param_idx >= params.len() || *pk_col_idx >= plan.pk_columns.len() {
287 return Ok(None); }
289
290 let expected_col = &plan.pk_columns[*pk_col_idx];
292 let actual_col = &actual_pk_columns[*pk_col_idx];
293 if !expected_col.eq_ignore_ascii_case(actual_col) {
294 return Ok(None); }
296 }
297
298 let resolved = match plan.get_or_resolve(|proj| {
300 self.resolve_projection(proj, &table.schema.columns)
301 }) {
302 Some(r) => r,
303 None => return Ok(None), };
305
306 let row = if plan.param_to_pk_col.len() == 1 {
309 let (param_idx, _) = plan.param_to_pk_col[0];
310 self.db
311 .get_row_by_pk(&plan.table_name, ¶ms[param_idx])
312 .map_err(|e| SessionError::Execution(ExecutorError::StorageError(e.to_string())))?
313 } else {
314 let pk_values: Vec<SqlValue> = plan
316 .param_to_pk_col
317 .iter()
318 .map(|(param_idx, _)| params[*param_idx].clone())
319 .collect();
320 self.db
321 .get_row_by_composite_pk(&plan.table_name, &pk_values)
322 .map_err(|e| SessionError::Execution(ExecutorError::StorageError(e.to_string())))?
323 };
324
325 let columns: Vec<String> = resolved.column_names.iter().cloned().collect();
327
328 let rows = match row {
329 Some(r) => {
330 if resolved.column_indices.is_empty() {
332 vec![r.clone()]
334 } else {
335 let projected_values: Vec<SqlValue> = resolved
337 .column_indices
338 .iter()
339 .map(|&i| r.values[i].clone())
340 .collect();
341 vec![Row::new(projected_values)]
342 }
343 }
344 None => vec![],
345 };
346
347 Ok(Some(PreparedExecutionResult::Select(SelectResult {
348 columns,
349 rows,
350 })))
351 }
352
353 fn resolve_projection(
357 &self,
358 proj: &ProjectionPlan,
359 schema_columns: &[vibesql_catalog::ColumnSchema],
360 ) -> Option<ResolvedProjection> {
361 match proj {
362 ProjectionPlan::Wildcard => {
363 let column_names: Arc<[String]> = schema_columns
366 .iter()
367 .map(|c| c.name.clone())
368 .collect();
369 Some(ResolvedProjection {
370 column_indices: vec![],
371 column_names,
372 })
373 }
374 ProjectionPlan::Columns(projections) => {
375 let mut col_indices = Vec::with_capacity(projections.len());
376 let mut column_names = Vec::with_capacity(projections.len());
377
378 for proj in projections {
379 let idx = schema_columns
380 .iter()
381 .position(|c| c.name.eq_ignore_ascii_case(&proj.column_name))?;
382
383 col_indices.push(idx);
384 column_names.push(
385 proj.alias
386 .clone()
387 .unwrap_or_else(|| proj.column_name.clone()),
388 );
389 }
390
391 Some(ResolvedProjection {
392 column_indices: col_indices,
393 column_names: column_names.into(),
394 })
395 }
396 }
397 }
398
399 fn execute_statement(&self, stmt: &Statement) -> Result<PreparedExecutionResult, SessionError> {
401 match stmt {
402 Statement::Select(select_stmt) => {
403 let executor = SelectExecutor::new(self.db);
404 let result = executor.execute_with_columns(select_stmt)?;
405 Ok(PreparedExecutionResult::Select(result))
406 }
407 _ => Err(SessionError::UnsupportedStatement(
408 "Only SELECT is supported for read-only sessions. Use SessionMut for DML.".into(),
409 )),
410 }
411 }
412}
413
414pub struct SessionMut<'a> {
419 db: &'a mut Database,
420 cache: Arc<PreparedStatementCache>,
421}
422
423impl<'a> SessionMut<'a> {
424 pub fn new(db: &'a mut Database) -> Self {
426 Self { db, cache: Arc::new(PreparedStatementCache::default_cache()) }
427 }
428
429 pub fn with_cache_size(db: &'a mut Database, cache_size: usize) -> Self {
431 Self { db, cache: Arc::new(PreparedStatementCache::new(cache_size)) }
432 }
433
434 pub fn with_shared_cache(db: &'a mut Database, cache: Arc<PreparedStatementCache>) -> Self {
436 Self { db, cache }
437 }
438
439 pub fn database(&self) -> &Database {
441 self.db
442 }
443
444 pub fn database_mut(&mut self) -> &mut Database {
446 self.db
447 }
448
449 pub fn cache(&self) -> &PreparedStatementCache {
451 &self.cache
452 }
453
454 pub fn shared_cache(&self) -> Arc<PreparedStatementCache> {
456 Arc::clone(&self.cache)
457 }
458
459 pub fn prepare(&self, sql: &str) -> Result<Arc<PreparedStatement>, SessionError> {
461 self.cache.get_or_prepare(sql).map_err(SessionError::from)
462 }
463
464 pub fn prepare_arena(&self, sql: &str) -> Result<Arc<ArenaPreparedStatement>, ArenaParseError> {
468 self.cache.get_or_prepare_arena(sql)
469 }
470
471 pub fn execute_prepared(
475 &self,
476 stmt: &PreparedStatement,
477 params: &[SqlValue],
478 ) -> Result<PreparedExecutionResult, SessionError> {
479 let bound_stmt = stmt.bind(params)?;
480 self.execute_statement_readonly(&bound_stmt)
481 }
482
483 pub fn execute_prepared_mut(
487 &mut self,
488 stmt: &PreparedStatement,
489 params: &[SqlValue],
490 ) -> Result<PreparedExecutionResult, SessionError> {
491 let bound_stmt = stmt.bind(params)?;
492 self.execute_statement_mut(&bound_stmt)
493 }
494
495 fn execute_statement_readonly(
497 &self,
498 stmt: &Statement,
499 ) -> Result<PreparedExecutionResult, SessionError> {
500 match stmt {
501 Statement::Select(select_stmt) => {
502 let executor = SelectExecutor::new(self.db);
503 let result = executor.execute_with_columns(select_stmt)?;
504 Ok(PreparedExecutionResult::Select(result))
505 }
506 _ => Err(SessionError::UnsupportedStatement(
507 "Use execute_prepared_mut for DML statements".into(),
508 )),
509 }
510 }
511
512 fn execute_statement_mut(
514 &mut self,
515 stmt: &Statement,
516 ) -> Result<PreparedExecutionResult, SessionError> {
517 match stmt {
518 Statement::Select(select_stmt) => {
519 let executor = SelectExecutor::new(self.db);
520 let result = executor.execute_with_columns(select_stmt)?;
521 Ok(PreparedExecutionResult::Select(result))
522 }
523 Statement::Insert(insert_stmt) => {
524 let rows_affected = InsertExecutor::execute(self.db, insert_stmt)?;
525 self.cache.invalidate_table(&insert_stmt.table_name);
527 Ok(PreparedExecutionResult::RowsAffected(rows_affected))
528 }
529 Statement::Update(update_stmt) => {
530 let rows_affected = UpdateExecutor::execute(update_stmt, self.db)?;
531 self.cache.invalidate_table(&update_stmt.table_name);
533 Ok(PreparedExecutionResult::RowsAffected(rows_affected))
534 }
535 Statement::Delete(delete_stmt) => {
536 let rows_affected = DeleteExecutor::execute(delete_stmt, self.db)?;
537 self.cache.invalidate_table(&delete_stmt.table_name);
539 Ok(PreparedExecutionResult::RowsAffected(rows_affected))
540 }
541 _ => Err(SessionError::UnsupportedStatement(format!(
542 "Statement type {:?} not supported for prepared execution",
543 std::mem::discriminant(stmt)
544 ))),
545 }
546 }
547}
548
549#[cfg(test)]
550mod tests {
551 use super::*;
552 use vibesql_catalog::{ColumnSchema, TableSchema};
553 use vibesql_types::DataType;
554
555 fn create_test_db() -> Database {
556 let mut db = Database::new();
557 db.catalog.set_case_sensitive_identifiers(false);
559
560 let columns = vec![
562 ColumnSchema::new("id".to_string(), DataType::Integer, false),
563 ColumnSchema::new(
564 "name".to_string(),
565 DataType::Varchar { max_length: Some(100) },
566 true,
567 ),
568 ];
569 let schema =
570 TableSchema::with_primary_key("users".to_string(), columns, vec!["id".to_string()]);
571 db.create_table(schema).unwrap();
572
573 let row1 = Row::new(vec![SqlValue::Integer(1), SqlValue::Varchar("Alice".into())]);
575 let row2 = Row::new(vec![SqlValue::Integer(2), SqlValue::Varchar("Bob".into())]);
576 let row3 = Row::new(vec![SqlValue::Integer(3), SqlValue::Varchar("Charlie".into())]);
577
578 db.insert_row("users", row1).unwrap();
579 db.insert_row("users", row2).unwrap();
580 db.insert_row("users", row3).unwrap();
581
582 db
583 }
584
585 #[test]
586 fn test_session_prepare() {
587 let db = create_test_db();
588 let session = Session::new(&db);
589
590 let stmt = session.prepare("SELECT * FROM users WHERE id = ?").unwrap();
591 assert_eq!(stmt.param_count(), 1);
592 }
593
594 #[test]
595 fn test_session_execute_prepared() {
596 let db = create_test_db();
597 let session = Session::new(&db);
598
599 let stmt = session.prepare("SELECT * FROM users WHERE id = ?").unwrap();
600
601 let result = session.execute_prepared(&stmt, &[SqlValue::Integer(1)]).unwrap();
603
604 if let PreparedExecutionResult::Select(select_result) = result {
605 assert_eq!(select_result.rows.len(), 1);
606 assert_eq!(select_result.rows[0].values[0], SqlValue::Integer(1));
607 assert_eq!(select_result.rows[0].values[1], SqlValue::Varchar("Alice".into()));
608 } else {
609 panic!("Expected Select result");
610 }
611 }
612
613 #[test]
614 fn test_session_reuse_prepared() {
615 let db = create_test_db();
616 let session = Session::new(&db);
617
618 let stmt = session.prepare("SELECT * FROM users WHERE id = ?").unwrap();
619
620 let result1 = session.execute_prepared(&stmt, &[SqlValue::Integer(1)]).unwrap();
622 let result2 = session.execute_prepared(&stmt, &[SqlValue::Integer(2)]).unwrap();
623 let result3 = session.execute_prepared(&stmt, &[SqlValue::Integer(3)]).unwrap();
624
625 assert_eq!(result1.rows().unwrap()[0].values[1], SqlValue::Varchar("Alice".into()));
627 assert_eq!(result2.rows().unwrap()[0].values[1], SqlValue::Varchar("Bob".into()));
628 assert_eq!(result3.rows().unwrap()[0].values[1], SqlValue::Varchar("Charlie".into()));
629
630 let stats = session.cache().stats();
632 assert_eq!(stats.misses, 1);
633 let _hits = stats.hits;
635 }
636
637 #[test]
638 fn test_session_param_count_mismatch() {
639 let db = create_test_db();
640 let session = Session::new(&db);
641
642 let stmt = session.prepare("SELECT * FROM users WHERE id = ?").unwrap();
643
644 let result = session.execute_prepared(&stmt, &[]);
646 assert!(result.is_err());
647
648 let result = session.execute_prepared(&stmt, &[SqlValue::Integer(1), SqlValue::Integer(2)]);
649 assert!(result.is_err());
650 }
651
652 #[test]
653 fn test_session_mut_insert() {
654 let mut db = create_test_db();
655 let mut session = SessionMut::new(&mut db);
656
657 let stmt = session.prepare("INSERT INTO users (id, name) VALUES (?, ?)").unwrap();
658
659 let result = session
660 .execute_prepared_mut(&stmt, &[SqlValue::Integer(4), SqlValue::Varchar("David".into())])
661 .unwrap();
662
663 assert_eq!(result.rows_affected(), Some(1));
664
665 let select_stmt = session.prepare("SELECT * FROM users WHERE id = ?").unwrap();
667 let select_result =
668 session.execute_prepared(&select_stmt, &[SqlValue::Integer(4)]).unwrap();
669
670 assert_eq!(select_result.rows().unwrap().len(), 1);
671 assert_eq!(select_result.rows().unwrap()[0].values[1], SqlValue::Varchar("David".into()));
672 }
673
674 #[test]
675 fn test_session_mut_update() {
676 let mut db = create_test_db();
677 let mut session = SessionMut::new(&mut db);
678
679 let stmt = session.prepare("UPDATE users SET name = ? WHERE id = ?").unwrap();
680
681 let result = session
682 .execute_prepared_mut(
683 &stmt,
684 &[SqlValue::Varchar("Alicia".into()), SqlValue::Integer(1)],
685 )
686 .unwrap();
687
688 assert_eq!(result.rows_affected(), Some(1));
689
690 let select_stmt = session.prepare("SELECT * FROM users WHERE id = ?").unwrap();
692 let select_result =
693 session.execute_prepared(&select_stmt, &[SqlValue::Integer(1)]).unwrap();
694
695 assert_eq!(select_result.rows().unwrap()[0].values[1], SqlValue::Varchar("Alicia".into()));
696 }
697
698 #[test]
699 fn test_session_mut_delete() {
700 let mut db = create_test_db();
701 let mut session = SessionMut::new(&mut db);
702
703 let stmt = session.prepare("DELETE FROM users WHERE id = ?").unwrap();
704
705 let result = session.execute_prepared_mut(&stmt, &[SqlValue::Integer(1)]).unwrap();
706
707 assert_eq!(result.rows_affected(), Some(1));
708
709 let select_stmt = session.prepare("SELECT * FROM users WHERE id = ?").unwrap();
711 let select_result =
712 session.execute_prepared(&select_stmt, &[SqlValue::Integer(1)]).unwrap();
713
714 assert_eq!(select_result.rows().unwrap().len(), 0);
715 }
716
717 #[test]
718 fn test_shared_cache() {
719 let db = create_test_db();
720
721 let session1 = Session::new(&db);
723 let stmt = session1.prepare("SELECT * FROM users WHERE id = ?").unwrap();
724
725 let shared_cache = session1.shared_cache();
727 let initial_misses = session1.cache().stats().misses;
728
729 let session2 = Session::with_shared_cache(&db, shared_cache);
731
732 let _stmt2 = session2.prepare("SELECT * FROM users WHERE id = ?").unwrap();
734
735 assert_eq!(session2.cache().stats().misses, initial_misses);
737
738 let result1 = session1.execute_prepared(&stmt, &[SqlValue::Integer(1)]).unwrap();
740 let result2 = session2.execute_prepared(&stmt, &[SqlValue::Integer(2)]).unwrap();
741
742 assert_eq!(result1.rows().unwrap()[0].values[1], SqlValue::Varchar("Alice".into()));
743 assert_eq!(result2.rows().unwrap()[0].values[1], SqlValue::Varchar("Bob".into()));
744 }
745
746 #[test]
747 fn test_no_params_statement() {
748 let db = create_test_db();
749 let session = Session::new(&db);
750
751 let stmt = session.prepare("SELECT * FROM users").unwrap();
752 assert_eq!(stmt.param_count(), 0);
753
754 let result = session.execute_prepared(&stmt, &[]).unwrap();
755 assert_eq!(result.rows().unwrap().len(), 3);
756 }
757
758 #[test]
759 fn test_multiple_placeholders() {
760 let db = create_test_db();
761 let session = Session::new(&db);
762
763 let stmt = session.prepare("SELECT * FROM users WHERE id >= ? AND id <= ?").unwrap();
764 assert_eq!(stmt.param_count(), 2);
765
766 let result =
767 session.execute_prepared(&stmt, &[SqlValue::Integer(1), SqlValue::Integer(2)]).unwrap();
768
769 assert_eq!(result.rows().unwrap().len(), 2);
770 }
771
772 #[test]
773 fn test_session_prepare_arena() {
774 let db = create_test_db();
775 let session = Session::new(&db);
776
777 let stmt = session.prepare_arena("SELECT * FROM users WHERE id = ?").unwrap();
779 assert_eq!(stmt.param_count(), 1);
780
781 assert!(stmt.tables().contains("USERS"));
783
784 let stmt2 = session.prepare_arena("SELECT * FROM users WHERE id = ?").unwrap();
786 assert_eq!(stmt2.param_count(), 1);
787
788 assert!(std::sync::Arc::ptr_eq(&stmt, &stmt2));
790 }
791
792 #[test]
793 fn test_session_prepare_arena_no_params() {
794 let db = create_test_db();
795 let session = Session::new(&db);
796
797 let stmt = session.prepare_arena("SELECT * FROM users").unwrap();
798 assert_eq!(stmt.param_count(), 0);
799 }
800
801 #[test]
802 fn test_session_prepare_arena_join() {
803 let db = create_test_db();
804 let session = Session::new(&db);
805
806 use vibesql_catalog::{ColumnSchema, TableSchema};
808 use vibesql_types::DataType;
809
810 let orders_columns = vec![
811 ColumnSchema::new("id".to_string(), DataType::Integer, false),
812 ColumnSchema::new("user_id".to_string(), DataType::Integer, false),
813 ];
814 let _orders_schema = TableSchema::with_primary_key(
815 "orders".to_string(),
816 orders_columns,
817 vec!["id".to_string()],
818 );
819
820 let stmt = session
823 .prepare_arena("SELECT u.id FROM users u JOIN orders o ON u.id = o.user_id")
824 .unwrap();
825
826 let tables = stmt.tables();
828 assert!(tables.contains("USERS"), "Expected USERS in {:?}", tables);
829 assert!(tables.contains("ORDERS"), "Expected ORDERS in {:?}", tables);
830 }
831
832 #[test]
833 fn test_session_mut_prepare_arena() {
834 let mut db = create_test_db();
835 let session = SessionMut::new(&mut db);
836
837 let stmt = session.prepare_arena("SELECT * FROM users WHERE id = ?").unwrap();
839 assert_eq!(stmt.param_count(), 1);
840 }
841}