1use std::collections::HashMap;
28use vibesql_storage::Row;
29
30#[derive(Debug, Clone, PartialEq)]
32#[allow(clippy::large_enum_variant)]
33pub enum TransactionChange {
34 Insert {
36 table_name: String,
37 row: Row,
38 },
39 Update {
41 table_name: String,
42 row_index: usize,
43 old_row: Row,
44 new_row: Row,
45 },
46 Delete {
48 table_name: String,
49 row_index: usize,
50 row: Row,
51 },
52 CreateTable {
54 table_name: String,
55 },
56 DropTable {
58 table_name: String,
59 },
60 CreateIndex {
62 index_name: String,
63 table_name: String,
64 },
65 DropIndex {
67 index_name: String,
68 },
69}
70
71#[derive(Debug)]
76pub struct TransactionState {
77 pub id: u64,
79 pub active: bool,
81 changes: Vec<TransactionChange>,
83 inserted_rows: HashMap<String, Vec<Row>>,
85 deleted_indices: HashMap<String, Vec<usize>>,
87 updated_rows: HashMap<String, HashMap<usize, Row>>,
89}
90
91impl TransactionState {
92 pub fn new(id: u64) -> Self {
94 Self {
95 id,
96 active: true,
97 changes: Vec::new(),
98 inserted_rows: HashMap::new(),
99 deleted_indices: HashMap::new(),
100 updated_rows: HashMap::new(),
101 }
102 }
103
104 pub fn record_insert(&mut self, table_name: String, row: Row) {
106 self.changes.push(TransactionChange::Insert {
107 table_name: table_name.clone(),
108 row: row.clone(),
109 });
110 self.inserted_rows.entry(table_name).or_default().push(row);
111 }
112
113 pub fn record_update(
115 &mut self,
116 table_name: String,
117 row_index: usize,
118 old_row: Row,
119 new_row: Row,
120 ) {
121 self.changes.push(TransactionChange::Update {
122 table_name: table_name.clone(),
123 row_index,
124 old_row,
125 new_row: new_row.clone(),
126 });
127 self.updated_rows
128 .entry(table_name)
129 .or_default()
130 .insert(row_index, new_row);
131 }
132
133 pub fn record_delete(&mut self, table_name: String, row_index: usize, row: Row) {
135 self.changes.push(TransactionChange::Delete {
136 table_name: table_name.clone(),
137 row_index,
138 row,
139 });
140 self.deleted_indices.entry(table_name).or_default().push(row_index);
141 }
142
143 pub fn record_create_table(&mut self, table_name: String) {
145 self.changes.push(TransactionChange::CreateTable { table_name });
146 }
147
148 pub fn record_drop_table(&mut self, table_name: String) {
150 self.changes.push(TransactionChange::DropTable { table_name });
151 }
152
153 pub fn record_create_index(&mut self, index_name: String, table_name: String) {
155 self.changes.push(TransactionChange::CreateIndex { index_name, table_name });
156 }
157
158 pub fn record_drop_index(&mut self, index_name: String) {
160 self.changes.push(TransactionChange::DropIndex { index_name });
161 }
162
163 pub fn get_inserted_rows(&self, table_name: &str) -> Option<&Vec<Row>> {
165 self.inserted_rows.get(table_name)
166 }
167
168 pub fn get_deleted_indices(&self, table_name: &str) -> Option<&Vec<usize>> {
170 self.deleted_indices.get(table_name)
171 }
172
173 pub fn get_updated_rows(&self, table_name: &str) -> Option<&HashMap<usize, Row>> {
175 self.updated_rows.get(table_name)
176 }
177
178 pub fn is_deleted(&self, table_name: &str, row_index: usize) -> bool {
180 self.deleted_indices
181 .get(table_name)
182 .is_some_and(|indices| indices.contains(&row_index))
183 }
184
185 pub fn get_updated_row(&self, table_name: &str, row_index: usize) -> Option<&Row> {
187 self.updated_rows
188 .get(table_name)
189 .and_then(|updates| updates.get(&row_index))
190 }
191
192 pub fn take_changes(self) -> Vec<TransactionChange> {
194 self.changes
195 }
196
197 pub fn changes(&self) -> &[TransactionChange] {
199 &self.changes
200 }
201
202 pub fn has_changes(&self) -> bool {
204 !self.changes.is_empty()
205 }
206
207 pub fn clear(&mut self) {
209 self.changes.clear();
210 self.inserted_rows.clear();
211 self.deleted_indices.clear();
212 self.updated_rows.clear();
213 }
214}
215
216#[derive(Debug, Default)]
221pub struct SessionTransactionManager {
222 current: Option<TransactionState>,
224 next_id: u64,
226}
227
228impl SessionTransactionManager {
229 pub fn new() -> Self {
231 Self { current: None, next_id: 1 }
232 }
233
234 pub fn begin(&mut self) -> Result<u64, TransactionError> {
238 if self.current.is_some() {
239 return Err(TransactionError::AlreadyInTransaction);
240 }
241
242 let id = self.next_id;
243 self.next_id += 1;
244 self.current = Some(TransactionState::new(id));
245 Ok(id)
246 }
247
248 pub fn commit(&mut self) -> Result<Vec<TransactionChange>, TransactionError> {
252 let state = self.current.take().ok_or(TransactionError::NoActiveTransaction)?;
253 Ok(state.take_changes())
254 }
255
256 pub fn rollback(&mut self) -> Result<(), TransactionError> {
260 self.current.take().ok_or(TransactionError::NoActiveTransaction)?;
261 Ok(())
262 }
263
264 pub fn in_transaction(&self) -> bool {
266 self.current.as_ref().is_some_and(|s| s.active)
267 }
268
269 pub fn transaction_id(&self) -> Option<u64> {
271 self.current.as_ref().map(|s| s.id)
272 }
273
274 pub fn current_mut(&mut self) -> Option<&mut TransactionState> {
276 self.current.as_mut()
277 }
278
279 pub fn current(&self) -> Option<&TransactionState> {
281 self.current.as_ref()
282 }
283
284 pub fn record_insert(&mut self, table_name: String, row: Row) {
288 if let Some(state) = &mut self.current {
289 state.record_insert(table_name, row);
290 }
291 }
292
293 pub fn record_update(
297 &mut self,
298 table_name: String,
299 row_index: usize,
300 old_row: Row,
301 new_row: Row,
302 ) {
303 if let Some(state) = &mut self.current {
304 state.record_update(table_name, row_index, old_row, new_row);
305 }
306 }
307
308 pub fn record_delete(&mut self, table_name: String, row_index: usize, row: Row) {
312 if let Some(state) = &mut self.current {
313 state.record_delete(table_name, row_index, row);
314 }
315 }
316}
317
318#[derive(Debug, Clone, PartialEq, Eq)]
320pub enum TransactionError {
321 AlreadyInTransaction,
323 NoActiveTransaction,
325 CommitConflict(String),
327}
328
329impl std::fmt::Display for TransactionError {
330 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
331 match self {
332 TransactionError::AlreadyInTransaction => {
333 write!(f, "Transaction already in progress")
334 }
335 TransactionError::NoActiveTransaction => {
336 write!(f, "No transaction in progress")
337 }
338 TransactionError::CommitConflict(msg) => {
339 write!(f, "Commit conflict: {}", msg)
340 }
341 }
342 }
343}
344
345impl std::error::Error for TransactionError {}
346
347#[cfg(test)]
348mod tests {
349 use super::*;
350 use vibesql_types::SqlValue;
351
352 fn make_row(values: Vec<SqlValue>) -> Row {
353 Row::new(values)
354 }
355
356 #[test]
357 fn test_begin_transaction() {
358 let mut mgr = SessionTransactionManager::new();
359
360 assert!(!mgr.in_transaction());
361 assert_eq!(mgr.transaction_id(), None);
362
363 let id = mgr.begin().unwrap();
364 assert_eq!(id, 1);
365 assert!(mgr.in_transaction());
366 assert_eq!(mgr.transaction_id(), Some(1));
367 }
368
369 #[test]
370 fn test_double_begin_fails() {
371 let mut mgr = SessionTransactionManager::new();
372
373 mgr.begin().unwrap();
374 let result = mgr.begin();
375 assert_eq!(result, Err(TransactionError::AlreadyInTransaction));
376 }
377
378 #[test]
379 fn test_commit_without_transaction_fails() {
380 let mut mgr = SessionTransactionManager::new();
381
382 let result = mgr.commit();
383 assert_eq!(result, Err(TransactionError::NoActiveTransaction));
384 }
385
386 #[test]
387 fn test_rollback_without_transaction_fails() {
388 let mut mgr = SessionTransactionManager::new();
389
390 let result = mgr.rollback();
391 assert_eq!(result, Err(TransactionError::NoActiveTransaction));
392 }
393
394 #[test]
395 fn test_record_insert() {
396 let mut mgr = SessionTransactionManager::new();
397 mgr.begin().unwrap();
398
399 let row = make_row(vec![SqlValue::Integer(1), SqlValue::Varchar(arcstr::ArcStr::from("test"))]);
400 mgr.record_insert("users".to_string(), row.clone());
401
402 let state = mgr.current().unwrap();
403 assert!(state.has_changes());
404
405 let inserted = state.get_inserted_rows("users").unwrap();
406 assert_eq!(inserted.len(), 1);
407 assert_eq!(inserted[0].values, row.values);
408 }
409
410 #[test]
411 fn test_record_delete() {
412 let mut mgr = SessionTransactionManager::new();
413 mgr.begin().unwrap();
414
415 let row = make_row(vec![SqlValue::Integer(1)]);
416 mgr.record_delete("users".to_string(), 5, row);
417
418 let state = mgr.current().unwrap();
419 assert!(state.is_deleted("users", 5));
420 assert!(!state.is_deleted("users", 6));
421 assert!(!state.is_deleted("other_table", 5));
422 }
423
424 #[test]
425 fn test_record_update() {
426 let mut mgr = SessionTransactionManager::new();
427 mgr.begin().unwrap();
428
429 let old_row = make_row(vec![SqlValue::Integer(1), SqlValue::Varchar(arcstr::ArcStr::from("old"))]);
430 let new_row = make_row(vec![SqlValue::Integer(1), SqlValue::Varchar(arcstr::ArcStr::from("new"))]);
431 mgr.record_update("users".to_string(), 3, old_row, new_row.clone());
432
433 let state = mgr.current().unwrap();
434 let updated = state.get_updated_row("users", 3).unwrap();
435 assert_eq!(updated.values, new_row.values);
436 assert!(state.get_updated_row("users", 4).is_none());
437 }
438
439 #[test]
440 fn test_commit_returns_changes() {
441 let mut mgr = SessionTransactionManager::new();
442 mgr.begin().unwrap();
443
444 let row1 = make_row(vec![SqlValue::Integer(1)]);
445 let row2 = make_row(vec![SqlValue::Integer(2)]);
446 mgr.record_insert("users".to_string(), row1);
447 mgr.record_insert("users".to_string(), row2);
448
449 let changes = mgr.commit().unwrap();
450 assert_eq!(changes.len(), 2);
451 assert!(!mgr.in_transaction());
452 }
453
454 #[test]
455 fn test_rollback_discards_changes() {
456 let mut mgr = SessionTransactionManager::new();
457 mgr.begin().unwrap();
458
459 let row = make_row(vec![SqlValue::Integer(1)]);
460 mgr.record_insert("users".to_string(), row);
461
462 mgr.rollback().unwrap();
463 assert!(!mgr.in_transaction());
464
465 mgr.begin().unwrap();
467 assert!(mgr.in_transaction());
468 assert_eq!(mgr.transaction_id(), Some(2)); }
470
471 #[test]
472 fn test_transaction_id_increments() {
473 let mut mgr = SessionTransactionManager::new();
474
475 let id1 = mgr.begin().unwrap();
476 mgr.commit().unwrap();
477
478 let id2 = mgr.begin().unwrap();
479 mgr.rollback().unwrap();
480
481 let id3 = mgr.begin().unwrap();
482
483 assert_eq!(id1, 1);
484 assert_eq!(id2, 2);
485 assert_eq!(id3, 3);
486 }
487
488 #[test]
489 fn test_no_op_when_not_in_transaction() {
490 let mut mgr = SessionTransactionManager::new();
491
492 let row = make_row(vec![SqlValue::Integer(1)]);
494 mgr.record_insert("users".to_string(), row.clone());
495 mgr.record_delete("users".to_string(), 0, row.clone());
496 mgr.record_update("users".to_string(), 0, row.clone(), row);
497
498 assert!(!mgr.in_transaction());
500 }
501}