1use std::{
2 cell::RefCell,
3 sync::atomic::{AtomicU32, AtomicU64, Ordering},
4};
5
6use rustrails_support::{database, runtime};
7use sea_orm::{ConnectionTrait, DatabaseConnection};
8
9use crate::base::{Record, RecordError};
10
11type TransactionCallback = Box<dyn FnOnce() + Send>;
12
13#[derive(Default)]
14struct TransactionLevel {
15 savepoint_name: Option<String>,
16 after_commit: Vec<TransactionCallback>,
17 after_rollback: Vec<TransactionCallback>,
18}
19
20impl TransactionLevel {
21 fn outermost() -> Self {
22 Self::default()
23 }
24
25 fn nested(savepoint_name: String) -> Self {
26 Self {
27 savepoint_name: Some(savepoint_name),
28 ..Self::default()
29 }
30 }
31
32 fn absorb(&mut self, nested: Self) {
33 self.after_commit.extend(nested.after_commit);
34 self.after_rollback.extend(nested.after_rollback);
35 }
36}
37
38enum FinalizeAction {
39 Commit,
40 Rollback,
41}
42
43enum TransactionBoundary {
44 Outermost,
45 Nested(String),
46}
47
48thread_local! {
49 static OPEN_TRANSACTION_COUNT: AtomicU32 = const { AtomicU32::new(0) };
50 static SAVEPOINT_SEQUENCE: AtomicU32 = const { AtomicU32::new(0) };
51 static CURRENT_TRANSACTION_ID: RefCell<Option<String>> = const { RefCell::new(None) };
52 static TRANSACTION_LEVELS: RefCell<Vec<TransactionLevel>> = const { RefCell::new(Vec::new()) };
53}
54
55static NEXT_TRANSACTION_ID: AtomicU64 = AtomicU64::new(1);
56
57#[must_use]
59pub fn open_transactions() -> u32 {
60 OPEN_TRANSACTION_COUNT.with(|count| count.load(Ordering::SeqCst))
61}
62
63#[must_use]
65pub fn transaction_open() -> bool {
66 open_transactions() > 0
67}
68
69#[must_use]
71pub fn current_transaction_id() -> Option<String> {
72 CURRENT_TRANSACTION_ID.with(|current| current.borrow().clone())
73}
74
75pub fn after_commit<F>(callback: F)
79where
80 F: FnOnce() + Send + 'static,
81{
82 let mut callback = Some(Box::new(callback) as TransactionCallback);
83 let registered = TRANSACTION_LEVELS.with(|levels| {
84 let mut levels = levels.borrow_mut();
85 if let Some(level) = levels.last_mut() {
86 level
87 .after_commit
88 .push(callback.take().expect("after_commit callback should exist"));
89 true
90 } else {
91 false
92 }
93 });
94
95 if !registered {
96 callback.expect("after_commit callback should exist outside transactions")();
97 }
98}
99
100pub fn after_rollback<F>(callback: F)
102where
103 F: FnOnce() + Send + 'static,
104{
105 let mut callback = Some(Box::new(callback) as TransactionCallback);
106 TRANSACTION_LEVELS.with(|levels| {
107 let mut levels = levels.borrow_mut();
108 if let Some(level) = levels.last_mut() {
109 level.after_rollback.push(
110 callback
111 .take()
112 .expect("after_rollback callback should exist"),
113 );
114 }
115 });
116}
117
118pub async fn transaction<F, Fut, T>(db: &DatabaseConnection, f: F) -> Result<T, RecordError>
125where
126 F: FnOnce(&DatabaseConnection) -> Fut + Send,
127 Fut: std::future::Future<Output = Result<T, RecordError>> + Send,
128 T: Send,
129{
130 begin_transaction_scope(db).await?;
131
132 match f(db).await {
133 Ok(value) => {
134 commit_transaction_scope(db).await?;
135 Ok(value)
136 }
137 Err(error) => {
138 rollback_transaction_scope(db).await?;
139 Err(error)
140 }
141 }
142}
143
144pub fn transaction_sync<F, Fut, T>(f: F) -> Result<T, RecordError>
146where
147 F: FnOnce(&DatabaseConnection) -> Fut + Send,
148 Fut: std::future::Future<Output = Result<T, RecordError>> + Send,
149 T: Send,
150{
151 database::with_db(|db| runtime::block_on(transaction(db, f)))
152}
153
154async fn begin_transaction_scope(db: &DatabaseConnection) -> Result<(), RecordError> {
155 if transaction_open() {
156 let savepoint_name = next_savepoint_name();
157 execute_transaction_control(db, &format!("SAVEPOINT {savepoint_name}")).await?;
158 OPEN_TRANSACTION_COUNT.with(|count| {
159 count.fetch_add(1, Ordering::SeqCst);
160 });
161 TRANSACTION_LEVELS.with(|levels| {
162 levels
163 .borrow_mut()
164 .push(TransactionLevel::nested(savepoint_name));
165 });
166 } else {
167 let transaction_id = next_transaction_id();
168 execute_transaction_control(db, "BEGIN").await?;
169 OPEN_TRANSACTION_COUNT.with(|count| count.store(1, Ordering::SeqCst));
170 SAVEPOINT_SEQUENCE.with(|sequence| sequence.store(0, Ordering::SeqCst));
171 CURRENT_TRANSACTION_ID.with(|current| {
172 current.replace(Some(transaction_id));
173 });
174 TRANSACTION_LEVELS.with(|levels| {
175 levels.borrow_mut().push(TransactionLevel::outermost());
176 });
177 }
178
179 Ok(())
180}
181
182async fn commit_transaction_scope(db: &DatabaseConnection) -> Result<(), RecordError> {
183 match current_transaction_boundary() {
184 Some(TransactionBoundary::Outermost) => {
185 execute_or_reset_state(db, "COMMIT").await?;
186 let callbacks = finish_outermost_transaction(FinalizeAction::Commit);
187 run_callbacks(callbacks);
188 Ok(())
189 }
190 Some(TransactionBoundary::Nested(savepoint_name)) => {
191 execute_or_reset_state(db, &format!("RELEASE SAVEPOINT {savepoint_name}")).await?;
192 merge_nested_transaction_into_parent();
193 Ok(())
194 }
195 None => Ok(()),
196 }
197}
198
199async fn rollback_transaction_scope(db: &DatabaseConnection) -> Result<(), RecordError> {
200 match current_transaction_boundary() {
201 Some(TransactionBoundary::Outermost) => {
202 execute_or_reset_state(db, "ROLLBACK").await?;
203 let callbacks = finish_outermost_transaction(FinalizeAction::Rollback);
204 run_callbacks(callbacks);
205 Ok(())
206 }
207 Some(TransactionBoundary::Nested(savepoint_name)) => {
208 execute_or_reset_state(db, &format!("ROLLBACK TO SAVEPOINT {savepoint_name}")).await?;
209 let callbacks = rollback_nested_transaction();
210 run_callbacks(callbacks);
211 Ok(())
212 }
213 None => Ok(()),
214 }
215}
216
217async fn execute_or_reset_state(db: &DatabaseConnection, sql: &str) -> Result<(), RecordError> {
218 if let Err(error) = execute_transaction_control(db, sql).await {
219 reset_transaction_state();
220 Err(error)
221 } else {
222 Ok(())
223 }
224}
225
226async fn execute_transaction_control(
227 db: &DatabaseConnection,
228 sql: &str,
229) -> Result<(), RecordError> {
230 db.execute_unprepared(sql).await?;
231 Ok(())
232}
233
234fn current_transaction_boundary() -> Option<TransactionBoundary> {
235 TRANSACTION_LEVELS.with(|levels| {
236 let levels = levels.borrow();
237 levels.last().map(|level| match &level.savepoint_name {
238 Some(savepoint_name) => TransactionBoundary::Nested(savepoint_name.clone()),
239 None => TransactionBoundary::Outermost,
240 })
241 })
242}
243
244fn merge_nested_transaction_into_parent() {
245 TRANSACTION_LEVELS.with(|levels| {
246 let mut levels = levels.borrow_mut();
247 let nested = levels
248 .pop()
249 .expect("nested transaction state should exist during commit");
250 let parent = levels
251 .last_mut()
252 .expect("parent transaction state should exist during nested commit");
253 parent.absorb(nested);
254 });
255 OPEN_TRANSACTION_COUNT.with(|count| {
256 count.fetch_sub(1, Ordering::SeqCst);
257 });
258}
259
260fn rollback_nested_transaction() -> Vec<TransactionCallback> {
261 let callbacks = TRANSACTION_LEVELS.with(|levels| {
262 let mut levels = levels.borrow_mut();
263 levels
264 .pop()
265 .expect("nested transaction state should exist during rollback")
266 .after_rollback
267 });
268 OPEN_TRANSACTION_COUNT.with(|count| {
269 count.fetch_sub(1, Ordering::SeqCst);
270 });
271 callbacks
272}
273
274fn finish_outermost_transaction(action: FinalizeAction) -> Vec<TransactionCallback> {
275 OPEN_TRANSACTION_COUNT.with(|count| count.store(0, Ordering::SeqCst));
276 SAVEPOINT_SEQUENCE.with(|sequence| sequence.store(0, Ordering::SeqCst));
277 CURRENT_TRANSACTION_ID.with(|current| {
278 current.replace(None);
279 });
280
281 TRANSACTION_LEVELS.with(|levels| {
282 let mut levels = levels.borrow_mut();
283 let outermost = levels
284 .pop()
285 .expect("outermost transaction state should exist during finalization");
286 levels.clear();
287 match action {
288 FinalizeAction::Commit => outermost.after_commit,
289 FinalizeAction::Rollback => outermost.after_rollback,
290 }
291 })
292}
293
294fn next_transaction_id() -> String {
295 format!("tx-{}", NEXT_TRANSACTION_ID.fetch_add(1, Ordering::Relaxed))
296}
297
298fn next_savepoint_name() -> String {
299 SAVEPOINT_SEQUENCE.with(|sequence| {
300 let next = sequence.fetch_add(1, Ordering::SeqCst) + 1;
301 format!("sp_{next}")
302 })
303}
304
305fn reset_transaction_state() {
306 OPEN_TRANSACTION_COUNT.with(|count| count.store(0, Ordering::SeqCst));
307 SAVEPOINT_SEQUENCE.with(|sequence| sequence.store(0, Ordering::SeqCst));
308 CURRENT_TRANSACTION_ID.with(|current| {
309 current.replace(None);
310 });
311 TRANSACTION_LEVELS.with(|levels| levels.borrow_mut().clear());
312}
313
314fn run_callbacks(callbacks: Vec<TransactionCallback>) {
315 for callback in callbacks {
316 callback();
317 }
318}
319
320pub trait Transactional: Record {
322 async fn transaction<F, Fut, T>(db: &DatabaseConnection, f: F) -> Result<T, RecordError>
324 where
325 F: FnOnce(&DatabaseConnection) -> Fut + Send,
326 Fut: std::future::Future<Output = Result<T, RecordError>> + Send,
327 T: Send,
328 {
329 crate::transactions::transaction(db, f).await
330 }
331}
332
333#[cfg(test)]
334mod tests {
335 use std::{
336 collections::HashMap,
337 sync::{
338 Arc, Mutex,
339 atomic::{AtomicUsize, Ordering as AtomicOrdering},
340 },
341 };
342
343 use sea_orm::{ConnectionTrait, Schema};
344 use serde_json::{Value, json};
345
346 use super::{
347 Transactional, after_commit, after_rollback, current_transaction_id, open_transactions,
348 transaction, transaction_open, transaction_sync,
349 };
350 use crate::{
351 Record, RecordError,
352 base::test_support::{TestUser, setup_db, test_user},
353 persistence::AsyncPersistence,
354 querying::AsyncQuerying,
355 };
356 use rustrails_support::{database, runtime};
357
358 fn run_sync_transaction_test(test: impl FnOnce() + Send + 'static) {
359 std::thread::spawn(move || {
360 let _rt = runtime::init_runtime();
361 database::establish("sqlite::memory:")
362 .expect("sqlite in-memory connection should succeed");
363 runtime::block_on(async {
364 let db = database::db();
365 let schema = Schema::new(db.get_database_backend());
366 db.execute(&schema.create_table_from_entity(test_user::Entity))
367 .await
368 .expect("test_users table should be created");
369 });
370 test();
371 })
372 .join()
373 .unwrap();
374 }
375
376 fn user_attrs(name: &str, email: &str) -> HashMap<String, Value> {
377 HashMap::from([
378 ("name".to_owned(), json!(name)),
379 ("email".to_owned(), json!(email)),
380 ])
381 }
382
383 impl Transactional for TestUser {}
384
385 #[tokio::test]
386 async fn transaction_commits_on_success() {
387 let db = setup_db().await;
388
389 transaction(&db, |txn| {
390 let txn = txn.clone();
391 async move {
392 TestUser::create(
393 HashMap::from([
394 ("name".to_owned(), json!("Alice")),
395 ("email".to_owned(), json!("alice@example.com")),
396 ]),
397 &txn,
398 )
399 .await?;
400 Ok(())
401 }
402 })
403 .await
404 .expect("transaction should commit");
405
406 assert_eq!(TestUser::count(&db).await.expect("count should succeed"), 1);
407 }
408
409 #[tokio::test]
410 async fn transaction_rolls_back_on_error() {
411 let db = setup_db().await;
412
413 let error = transaction(&db, |txn| {
414 let txn = txn.clone();
415 async move {
416 TestUser::create(
417 HashMap::from([
418 ("name".to_owned(), json!("Alice")),
419 ("email".to_owned(), json!("alice@example.com")),
420 ]),
421 &txn,
422 )
423 .await?;
424 Err::<(), RecordError>(RecordError::Invalid("force rollback".to_owned()))
425 }
426 })
427 .await
428 .expect_err("transaction should fail");
429
430 assert!(matches!(error, RecordError::Invalid(message) if message == "force rollback"));
431 assert_eq!(TestUser::count(&db).await.expect("count should succeed"), 0);
432 }
433
434 #[tokio::test]
435 async fn transaction_returns_closure_value() {
436 let db = setup_db().await;
437
438 let id = transaction(&db, |txn| {
439 let txn = txn.clone();
440 async move {
441 let user = TestUser::create(
442 HashMap::from([
443 ("name".to_owned(), json!("Alice")),
444 ("email".to_owned(), json!("alice@example.com")),
445 ]),
446 &txn,
447 )
448 .await?;
449 user.id().ok_or(RecordError::NotSaved)
450 }
451 })
452 .await
453 .expect("transaction should return a value");
454
455 assert_eq!(id, 1);
456 }
457
458 #[tokio::test]
459 async fn transactional_trait_delegates_to_helper() {
460 let db = setup_db().await;
461
462 TestUser::transaction(&db, |txn| {
463 let txn = txn.clone();
464 async move {
465 TestUser::create(
466 HashMap::from([
467 ("name".to_owned(), json!("Bob")),
468 ("email".to_owned(), json!("bob@example.com")),
469 ]),
470 &txn,
471 )
472 .await?;
473 Ok(())
474 }
475 })
476 .await
477 .expect("trait helper should commit");
478
479 let user = TestUser::find(1, &db).await.expect("user should exist");
480 assert_eq!(user.name, "Bob");
481 }
482
483 #[tokio::test]
484 async fn rollback_preserves_rows_outside_failed_transaction() {
485 let db = setup_db().await;
486
487 TestUser::create(
488 HashMap::from([
489 ("name".to_owned(), json!("Alice")),
490 ("email".to_owned(), json!("alice@example.com")),
491 ]),
492 &db,
493 )
494 .await
495 .expect("seed insert should succeed");
496
497 let _: Result<(), RecordError> = transaction(&db, |txn| {
498 let txn = txn.clone();
499 async move {
500 TestUser::create(
501 HashMap::from([
502 ("name".to_owned(), json!("Bob")),
503 ("email".to_owned(), json!("bob@example.com")),
504 ]),
505 &txn,
506 )
507 .await?;
508 Err(RecordError::Invalid("rollback".to_owned()))
509 }
510 })
511 .await;
512
513 assert_eq!(TestUser::count(&db).await.expect("count should succeed"), 1);
514 }
515
516 #[tokio::test]
517 async fn transaction_commits_multiple_writes() {
518 let db = setup_db().await;
519
520 transaction(&db, |txn| {
521 let txn = txn.clone();
522 async move {
523 for (name, email) in [("Alice", "alice@example.com"), ("Bob", "bob@example.com")] {
524 TestUser::create(
525 HashMap::from([
526 ("name".to_owned(), json!(name)),
527 ("email".to_owned(), json!(email)),
528 ]),
529 &txn,
530 )
531 .await?;
532 }
533 Ok(())
534 }
535 })
536 .await
537 .expect("multi-write transaction should commit");
538
539 assert_eq!(TestUser::count(&db).await.expect("count should succeed"), 2);
540 }
541
542 #[tokio::test]
543 async fn transaction_rolls_back_multiple_writes() {
544 let db = setup_db().await;
545
546 let error = transaction(&db, |txn| {
547 let txn = txn.clone();
548 async move {
549 for (name, email) in [("Alice", "alice@example.com"), ("Bob", "bob@example.com")] {
550 TestUser::create(
551 HashMap::from([
552 ("name".to_owned(), json!(name)),
553 ("email".to_owned(), json!(email)),
554 ]),
555 &txn,
556 )
557 .await?;
558 }
559 Err::<(), RecordError>(RecordError::Invalid("rollback all writes".to_owned()))
560 }
561 })
562 .await
563 .expect_err("multi-write transaction should fail");
564
565 assert!(matches!(error, RecordError::Invalid(message) if message == "rollback all writes"));
566 assert_eq!(TestUser::count(&db).await.expect("count should succeed"), 0);
567 }
568
569 #[tokio::test]
570 async fn transaction_exposes_uncommitted_writes_inside_closure() {
571 let db = setup_db().await;
572
573 let visible_count = transaction(&db, |txn| {
574 let txn = txn.clone();
575 async move {
576 TestUser::create(
577 HashMap::from([
578 ("name".to_owned(), json!("Alice")),
579 ("email".to_owned(), json!("alice@example.com")),
580 ]),
581 &txn,
582 )
583 .await?;
584 TestUser::count(&txn).await
585 }
586 })
587 .await
588 .expect("transaction should return the in-transaction count");
589
590 assert_eq!(visible_count, 1);
591 assert_eq!(TestUser::count(&db).await.expect("count should succeed"), 1);
592 }
593
594 #[tokio::test]
595 async fn transaction_rolls_back_writes_visible_inside_failed_closure() {
596 let db = setup_db().await;
597
598 let visible_count = transaction(&db, |txn| {
599 let txn = txn.clone();
600 async move {
601 TestUser::create(
602 HashMap::from([
603 ("name".to_owned(), json!("Alice")),
604 ("email".to_owned(), json!("alice@example.com")),
605 ]),
606 &txn,
607 )
608 .await?;
609 let count = TestUser::count(&txn).await?;
610 Err::<u64, RecordError>(RecordError::Invalid(format!(
611 "count before rollback: {count}"
612 )))
613 }
614 })
615 .await
616 .expect_err("transaction should fail");
617
618 assert!(
619 matches!(visible_count, RecordError::Invalid(message) if message == "count before rollback: 1")
620 );
621 assert_eq!(TestUser::count(&db).await.expect("count should succeed"), 0);
622 }
623
624 #[tokio::test]
625 async fn transaction_can_read_seeded_rows_and_insert_more() {
626 let db = setup_db().await;
627
628 TestUser::create(
629 HashMap::from([
630 ("name".to_owned(), json!("Alice")),
631 ("email".to_owned(), json!("alice@example.com")),
632 ]),
633 &db,
634 )
635 .await
636 .expect("seed insert should succeed");
637
638 let counts = transaction(&db, |txn| {
639 let txn = txn.clone();
640 async move {
641 let before = TestUser::count(&txn).await?;
642 TestUser::create(
643 HashMap::from([
644 ("name".to_owned(), json!("Bob")),
645 ("email".to_owned(), json!("bob@example.com")),
646 ]),
647 &txn,
648 )
649 .await?;
650 let after = TestUser::count(&txn).await?;
651 Ok((before, after))
652 }
653 })
654 .await
655 .expect("transaction should commit");
656
657 assert_eq!(counts, (1, 2));
658 assert_eq!(TestUser::count(&db).await.expect("count should succeed"), 2);
659 }
660
661 #[tokio::test]
662 #[ignore = "Nested savepoint-backed transactions are supported on the same connection"]
663 async fn nested_transaction_on_same_connection_returns_database_error() {
664 let db = setup_db().await;
665
666 let error = transaction(&db, |txn| {
667 let txn = txn.clone();
668 async move {
669 TestUser::create(
670 HashMap::from([
671 ("name".to_owned(), json!("Outer")),
672 ("email".to_owned(), json!("outer@example.com")),
673 ]),
674 &txn,
675 )
676 .await?;
677
678 let nested = transaction(&txn, |inner_txn| {
679 let inner_txn = inner_txn.clone();
680 async move {
681 TestUser::create(
682 HashMap::from([
683 ("name".to_owned(), json!("Inner")),
684 ("email".to_owned(), json!("inner@example.com")),
685 ]),
686 &inner_txn,
687 )
688 .await?;
689 Ok(())
690 }
691 })
692 .await;
693
694 assert!(nested.is_err());
695 nested
696 }
697 })
698 .await
699 .expect_err("nested transaction should fail on the same connection");
700
701 assert!(matches!(error, RecordError::Database(_)));
702 assert_eq!(TestUser::count(&db).await.expect("count should succeed"), 0);
703 }
704
705 #[tokio::test]
706 async fn transaction_rollback_preserves_seeded_rows_on_multiwrite_failure() {
707 let db = setup_db().await;
708
709 TestUser::create(
710 HashMap::from([
711 ("name".to_owned(), json!("Alice")),
712 ("email".to_owned(), json!("alice@example.com")),
713 ]),
714 &db,
715 )
716 .await
717 .expect("seed insert should succeed");
718
719 let _ = transaction(&db, |txn| {
720 let txn = txn.clone();
721 async move {
722 for (name, email) in [("Bob", "bob@example.com"), ("Carol", "carol@example.com")] {
723 TestUser::create(
724 HashMap::from([
725 ("name".to_owned(), json!(name)),
726 ("email".to_owned(), json!(email)),
727 ]),
728 &txn,
729 )
730 .await?;
731 }
732 Err::<(), RecordError>(RecordError::NotSaved)
733 }
734 })
735 .await;
736
737 assert_eq!(TestUser::count(&db).await.expect("count should succeed"), 1);
738 assert_eq!(
739 TestUser::find(1, &db)
740 .await
741 .expect("seed row should still exist")
742 .name,
743 "Alice"
744 );
745 }
746
747 #[tokio::test]
748 async fn manual_rollback_via_not_saved_error_rolls_back() {
749 let db = setup_db().await;
750
751 let error = transaction(&db, |txn| {
752 let txn = txn.clone();
753 async move {
754 TestUser::create(
755 HashMap::from([
756 ("name".to_owned(), json!("Alice")),
757 ("email".to_owned(), json!("alice@example.com")),
758 ]),
759 &txn,
760 )
761 .await?;
762 Err::<(), RecordError>(RecordError::NotSaved)
763 }
764 })
765 .await
766 .expect_err("manual rollback should bubble the original error");
767
768 assert!(matches!(error, RecordError::NotSaved));
769 assert_eq!(TestUser::count(&db).await.expect("count should succeed"), 0);
770 }
771
772 #[tokio::test]
773 async fn transactional_trait_can_return_tuple_values() {
774 let db = setup_db().await;
775
776 let result = TestUser::transaction(&db, |txn| {
777 let txn = txn.clone();
778 async move {
779 let user = TestUser::create(
780 HashMap::from([
781 ("name".to_owned(), json!("Alice")),
782 ("email".to_owned(), json!("alice@example.com")),
783 ]),
784 &txn,
785 )
786 .await?;
787 Ok((
788 user.id().expect("id should be assigned"),
789 TestUser::count(&txn).await?,
790 ))
791 }
792 })
793 .await
794 .expect("trait helper should return tuple values");
795
796 assert_eq!(result, (1, 1));
797 }
798
799 #[tokio::test]
800 async fn transaction_without_writes_can_return_existing_count() {
801 let db = setup_db().await;
802
803 TestUser::create(
804 HashMap::from([
805 ("name".to_owned(), json!("Alice")),
806 ("email".to_owned(), json!("alice@example.com")),
807 ]),
808 &db,
809 )
810 .await
811 .expect("seed insert should succeed");
812
813 let count = transaction(&db, |txn| {
814 let txn = txn.clone();
815 async move { TestUser::count(&txn).await }
816 })
817 .await
818 .expect("read-only transaction should commit");
819
820 assert_eq!(count, 1);
821 }
822
823 #[tokio::test]
824 async fn transaction_commits_updates_to_existing_rows() {
825 let db = setup_db().await;
826 let user = TestUser::create(
827 HashMap::from([
828 ("name".to_owned(), json!("Alice")),
829 ("email".to_owned(), json!("alice@example.com")),
830 ]),
831 &db,
832 )
833 .await
834 .expect("seed insert should succeed");
835 let id = user.id().expect("seed row should have an id");
836
837 transaction(&db, |txn| {
838 let txn = txn.clone();
839 async move {
840 let mut user = TestUser::find(id, &txn).await?;
841 user.update_attributes(
842 HashMap::from([("name".to_owned(), json!("Updated Alice"))]),
843 &txn,
844 )
845 .await?;
846 Ok(())
847 }
848 })
849 .await
850 .expect("update transaction should commit");
851
852 let reloaded = TestUser::find(id, &db)
853 .await
854 .expect("updated row should load after commit");
855 assert_eq!(reloaded.name, "Updated Alice");
856 assert_eq!(reloaded.email, "alice@example.com");
857 }
858
859 #[tokio::test]
860 async fn transaction_rolls_back_updates_to_existing_rows() {
861 let db = setup_db().await;
862 let user = TestUser::create(
863 HashMap::from([
864 ("name".to_owned(), json!("Alice")),
865 ("email".to_owned(), json!("alice@example.com")),
866 ]),
867 &db,
868 )
869 .await
870 .expect("seed insert should succeed");
871 let id = user.id().expect("seed row should have an id");
872
873 let error = transaction(&db, |txn| {
874 let txn = txn.clone();
875 async move {
876 let mut user = TestUser::find(id, &txn).await?;
877 user.update_attributes(
878 HashMap::from([("name".to_owned(), json!("Updated Alice"))]),
879 &txn,
880 )
881 .await?;
882 Err::<(), RecordError>(RecordError::Invalid("rollback update".to_owned()))
883 }
884 })
885 .await
886 .expect_err("update transaction should fail");
887
888 assert!(matches!(error, RecordError::Invalid(message) if message == "rollback update"));
889
890 let reloaded = TestUser::find(id, &db)
891 .await
892 .expect("seed row should still load after rollback");
893 assert_eq!(reloaded.name, "Alice");
894 assert_eq!(reloaded.email, "alice@example.com");
895 }
896
897 #[tokio::test]
898 async fn nested_transaction_commits_with_savepoint_release() {
899 let db = setup_db().await;
900
901 let (outer_id, inner_id) = transaction(&db, |txn| {
902 let txn = txn.clone();
903 async move {
904 TestUser::create(user_attrs("Outer", "outer@example.com"), &txn).await?;
905 let outer_id = current_transaction_id().expect("outer transaction id should exist");
906
907 let inner_id = transaction(&txn, |inner_txn| {
908 let inner_txn = inner_txn.clone();
909 async move {
910 assert_eq!(open_transactions(), 2);
911 TestUser::create(user_attrs("Inner", "inner@example.com"), &inner_txn)
912 .await?;
913 Ok::<String, RecordError>(
914 current_transaction_id()
915 .expect("nested transaction should reuse the outer id"),
916 )
917 }
918 })
919 .await?;
920
921 assert_eq!(TestUser::count(&txn).await?, 2);
922 Ok((outer_id, inner_id))
923 }
924 })
925 .await
926 .expect("nested transaction should commit");
927
928 assert_eq!(outer_id, inner_id);
929 assert_eq!(TestUser::count(&db).await.expect("count should succeed"), 2);
930 }
931
932 #[test]
933 fn transaction_sync_commits_on_success() {
934 run_sync_transaction_test(|| {
935 transaction_sync(|txn| {
936 let txn = txn.clone();
937 async move {
938 TestUser::create(
939 HashMap::from([
940 ("name".to_owned(), json!("Alice")),
941 ("email".to_owned(), json!("alice@example.com")),
942 ]),
943 &txn,
944 )
945 .await?;
946 Ok(())
947 }
948 })
949 .expect("transaction should commit");
950
951 let count = runtime::block_on(async {
952 let db = database::db();
953 TestUser::count(&db).await.expect("count should succeed")
954 });
955 assert_eq!(count, 1);
956 });
957 }
958
959 #[tokio::test]
960 async fn nested_transaction_rollback_to_savepoint_preserves_outer_changes() {
961 let db = setup_db().await;
962
963 transaction(&db, |txn| {
964 let txn = txn.clone();
965 async move {
966 TestUser::create(user_attrs("Outer", "outer@example.com"), &txn).await?;
967
968 let error = transaction(&txn, |inner_txn| {
969 let inner_txn = inner_txn.clone();
970 async move {
971 TestUser::create(user_attrs("Inner", "inner@example.com"), &inner_txn)
972 .await?;
973 Err::<(), RecordError>(RecordError::Invalid("rollback inner".to_owned()))
974 }
975 })
976 .await
977 .expect_err("inner transaction should roll back to its savepoint");
978
979 assert!(
980 matches!(error, RecordError::Invalid(message) if message == "rollback inner")
981 );
982 assert_eq!(open_transactions(), 1);
983 assert_eq!(TestUser::count(&txn).await?, 1);
984
985 TestUser::create(user_attrs("AfterInner", "after@example.com"), &txn).await?;
986 Ok(())
987 }
988 })
989 .await
990 .expect("outer transaction should still commit");
991
992 assert_eq!(TestUser::count(&db).await.expect("count should succeed"), 2);
993 assert_eq!(
994 TestUser::find(1, &db)
995 .await
996 .expect("outer row should persist")
997 .name,
998 "Outer"
999 );
1000 assert_eq!(
1001 TestUser::find(2, &db)
1002 .await
1003 .expect("post-rollback outer row should persist")
1004 .name,
1005 "AfterInner"
1006 );
1007 }
1008
1009 #[tokio::test]
1010 async fn after_commit_fires_after_outermost_commit_only() {
1011 let db = setup_db().await;
1012 let calls = Arc::new(AtomicUsize::new(0));
1013 let transaction_calls = Arc::clone(&calls);
1014
1015 transaction(&db, move |txn| {
1016 let txn = txn.clone();
1017 let calls = Arc::clone(&transaction_calls);
1018 let outer_calls = Arc::clone(&calls);
1019 async move {
1020 after_commit(move || {
1021 outer_calls.fetch_add(1, AtomicOrdering::SeqCst);
1022 });
1023
1024 let nested_calls = Arc::clone(&calls);
1025 transaction(&txn, move |inner_txn| {
1026 let inner_txn = inner_txn.clone();
1027 let calls = Arc::clone(&nested_calls);
1028 let inner_calls = Arc::clone(&calls);
1029 async move {
1030 TestUser::create(user_attrs("Inner", "inner@example.com"), &inner_txn)
1031 .await?;
1032 after_commit(move || {
1033 inner_calls.fetch_add(1, AtomicOrdering::SeqCst);
1034 });
1035 assert_eq!(calls.load(AtomicOrdering::SeqCst), 0);
1036 Ok(())
1037 }
1038 })
1039 .await?;
1040
1041 assert_eq!(calls.load(AtomicOrdering::SeqCst), 0);
1042 Ok(())
1043 }
1044 })
1045 .await
1046 .expect("outer transaction should commit");
1047
1048 assert_eq!(calls.load(AtomicOrdering::SeqCst), 2);
1049 }
1050
1051 #[tokio::test]
1052 async fn after_commit_callbacks_fire_in_registration_order_across_nested_transactions() {
1053 let db = setup_db().await;
1054 let events = Arc::new(Mutex::new(Vec::new()));
1055 let transaction_events = Arc::clone(&events);
1056
1057 transaction(&db, move |txn| {
1058 let txn = txn.clone();
1059 let events = Arc::clone(&transaction_events);
1060 let outer_events = Arc::clone(&events);
1061 async move {
1062 after_commit(move || outer_events.lock().unwrap().push("outer-1".to_owned()));
1063
1064 let nested_events = Arc::clone(&events);
1065 transaction(&txn, move |inner_txn| {
1066 let inner_txn = inner_txn.clone();
1067 let inner_events = Arc::clone(&nested_events);
1068 async move {
1069 after_commit(move || {
1070 inner_events.lock().unwrap().push("inner-1".to_owned());
1071 });
1072 TestUser::create(user_attrs("Inner", "inner@example.com"), &inner_txn)
1073 .await?;
1074 Ok(())
1075 }
1076 })
1077 .await?;
1078
1079 let trailing_events = Arc::clone(&events);
1080 after_commit(move || {
1081 trailing_events.lock().unwrap().push("outer-2".to_owned());
1082 });
1083 Ok(())
1084 }
1085 })
1086 .await
1087 .expect("callback ordering transaction should commit");
1088
1089 assert_eq!(
1090 *events.lock().unwrap(),
1091 vec![
1092 "outer-1".to_owned(),
1093 "inner-1".to_owned(),
1094 "outer-2".to_owned()
1095 ]
1096 );
1097 }
1098
1099 #[tokio::test]
1100 async fn after_commit_callbacks_do_not_fire_when_outer_transaction_rolls_back() {
1101 let db = setup_db().await;
1102 let calls = Arc::new(AtomicUsize::new(0));
1103 let transaction_calls = Arc::clone(&calls);
1104
1105 let error = transaction(&db, move |txn| {
1106 let txn = txn.clone();
1107 let calls = Arc::clone(&transaction_calls);
1108 let outer_calls = Arc::clone(&calls);
1109 async move {
1110 after_commit(move || {
1111 outer_calls.fetch_add(1, AtomicOrdering::SeqCst);
1112 });
1113
1114 let nested_calls = Arc::clone(&calls);
1115 transaction(&txn, move |inner_txn| {
1116 let inner_txn = inner_txn.clone();
1117 let calls = Arc::clone(&nested_calls);
1118 let inner_calls = Arc::clone(&calls);
1119 async move {
1120 after_commit(move || {
1121 inner_calls.fetch_add(1, AtomicOrdering::SeqCst);
1122 });
1123 TestUser::create(user_attrs("Inner", "inner@example.com"), &inner_txn)
1124 .await?;
1125 Ok(())
1126 }
1127 })
1128 .await?;
1129
1130 Err::<(), RecordError>(RecordError::Invalid("rollback outer".to_owned()))
1131 }
1132 })
1133 .await
1134 .expect_err("outer transaction should roll back");
1135
1136 assert!(matches!(error, RecordError::Invalid(message) if message == "rollback outer"));
1137 assert_eq!(calls.load(AtomicOrdering::SeqCst), 0);
1138 }
1139
1140 #[tokio::test]
1141 async fn after_commit_callbacks_are_cleared_after_commit() {
1142 let db = setup_db().await;
1143 let calls = Arc::new(AtomicUsize::new(0));
1144
1145 transaction(&db, |_| {
1146 let calls = Arc::clone(&calls);
1147 async move {
1148 after_commit(move || {
1149 calls.fetch_add(1, AtomicOrdering::SeqCst);
1150 });
1151 Ok(())
1152 }
1153 })
1154 .await
1155 .expect("first transaction should commit");
1156
1157 transaction(&db, |_| async move { Ok(()) })
1158 .await
1159 .expect("second transaction should commit");
1160
1161 assert_eq!(calls.load(AtomicOrdering::SeqCst), 1);
1162 }
1163
1164 #[tokio::test]
1165 async fn after_rollback_fires_on_outermost_rollback() {
1166 let db = setup_db().await;
1167 let calls = Arc::new(AtomicUsize::new(0));
1168 let callback_calls = Arc::clone(&calls);
1169
1170 let error = transaction(&db, move |_txn| {
1171 let outer_calls = Arc::clone(&callback_calls);
1172 async move {
1173 after_rollback(move || {
1174 outer_calls.fetch_add(1, AtomicOrdering::SeqCst);
1175 });
1176 Err::<(), RecordError>(RecordError::Invalid("rollback outer".to_owned()))
1177 }
1178 })
1179 .await
1180 .expect_err("transaction should roll back");
1181
1182 assert!(matches!(error, RecordError::Invalid(message) if message == "rollback outer"));
1183 assert_eq!(calls.load(AtomicOrdering::SeqCst), 1);
1184 }
1185
1186 #[tokio::test]
1187 async fn nested_rollback_fires_only_inner_after_rollback_callbacks() {
1188 let db = setup_db().await;
1189 let calls = Arc::new(Mutex::new(Vec::new()));
1190 let transaction_calls = Arc::clone(&calls);
1191
1192 transaction(&db, move |txn| {
1193 let txn = txn.clone();
1194 let calls = Arc::clone(&transaction_calls);
1195 let outer_calls = Arc::clone(&calls);
1196 async move {
1197 after_rollback(move || outer_calls.lock().unwrap().push("outer".to_owned()));
1198
1199 let nested_calls = Arc::clone(&calls);
1200 let error = transaction(&txn, move |inner_txn| {
1201 let inner_txn = inner_txn.clone();
1202 let calls = Arc::clone(&nested_calls);
1203 let inner_calls = Arc::clone(&calls);
1204 async move {
1205 after_rollback(move || {
1206 inner_calls.lock().unwrap().push("inner".to_owned())
1207 });
1208 TestUser::create(user_attrs("Inner", "inner@example.com"), &inner_txn)
1209 .await?;
1210 Err::<(), RecordError>(RecordError::Invalid("rollback inner".to_owned()))
1211 }
1212 })
1213 .await
1214 .expect_err("inner transaction should roll back");
1215
1216 assert!(
1217 matches!(error, RecordError::Invalid(message) if message == "rollback inner")
1218 );
1219 assert_eq!(*calls.lock().unwrap(), vec!["inner".to_owned()]);
1220 Ok(())
1221 }
1222 })
1223 .await
1224 .expect("outer transaction should commit");
1225
1226 assert_eq!(*calls.lock().unwrap(), vec!["inner".to_owned()]);
1227 }
1228
1229 #[tokio::test]
1230 async fn nested_successful_after_rollback_callbacks_fire_if_outer_transaction_rolls_back() {
1231 let db = setup_db().await;
1232 let calls = Arc::new(Mutex::new(Vec::new()));
1233 let transaction_calls = Arc::clone(&calls);
1234
1235 let error = transaction(&db, move |txn| {
1236 let txn = txn.clone();
1237 let calls = Arc::clone(&transaction_calls);
1238 let outer_calls = Arc::clone(&calls);
1239 async move {
1240 after_rollback(move || outer_calls.lock().unwrap().push("outer".to_owned()));
1241
1242 let nested_calls = Arc::clone(&calls);
1243 transaction(&txn, move |inner_txn| {
1244 let inner_txn = inner_txn.clone();
1245 let calls = Arc::clone(&nested_calls);
1246 let inner_calls = Arc::clone(&calls);
1247 async move {
1248 after_rollback(move || {
1249 inner_calls.lock().unwrap().push("inner".to_owned())
1250 });
1251 TestUser::create(user_attrs("Inner", "inner@example.com"), &inner_txn)
1252 .await?;
1253 Ok(())
1254 }
1255 })
1256 .await?;
1257
1258 Err::<(), RecordError>(RecordError::Invalid("rollback outer".to_owned()))
1259 }
1260 })
1261 .await
1262 .expect_err("outer transaction should roll back");
1263
1264 assert!(matches!(error, RecordError::Invalid(message) if message == "rollback outer"));
1265 assert_eq!(
1266 *calls.lock().unwrap(),
1267 vec!["outer".to_owned(), "inner".to_owned()]
1268 );
1269 }
1270
1271 #[tokio::test]
1272 async fn after_rollback_callbacks_are_cleared_after_rollback() {
1273 let db = setup_db().await;
1274 let calls = Arc::new(AtomicUsize::new(0));
1275 let callback_calls = Arc::clone(&calls);
1276
1277 let _ = transaction(&db, move |_txn| {
1278 let outer_calls = Arc::clone(&callback_calls);
1279 async move {
1280 after_rollback(move || {
1281 outer_calls.fetch_add(1, AtomicOrdering::SeqCst);
1282 });
1283 Err::<(), RecordError>(RecordError::Invalid("rollback once".to_owned()))
1284 }
1285 })
1286 .await;
1287
1288 transaction(&db, |_| async move { Ok(()) })
1289 .await
1290 .expect("later transaction should commit");
1291
1292 assert_eq!(calls.load(AtomicOrdering::SeqCst), 1);
1293 }
1294
1295 #[tokio::test]
1296 async fn open_transactions_starts_closed() {
1297 assert_eq!(open_transactions(), 0);
1298 assert!(!transaction_open());
1299 }
1300
1301 #[tokio::test]
1302 async fn open_transactions_tracks_nested_depth() {
1303 let db = setup_db().await;
1304
1305 assert_eq!(open_transactions(), 0);
1306
1307 transaction(&db, |txn| {
1308 let txn = txn.clone();
1309 async move {
1310 assert_eq!(open_transactions(), 1);
1311 transaction(&txn, |_| async move {
1312 assert_eq!(open_transactions(), 2);
1313 Ok(())
1314 })
1315 .await?;
1316 assert_eq!(open_transactions(), 1);
1317 Ok(())
1318 }
1319 })
1320 .await
1321 .expect("nested transaction should commit");
1322
1323 assert_eq!(open_transactions(), 0);
1324 }
1325
1326 #[tokio::test]
1327 async fn transaction_open_reflects_current_state() {
1328 let db = setup_db().await;
1329
1330 assert!(!transaction_open());
1331
1332 transaction(&db, |txn| {
1333 let txn = txn.clone();
1334 async move {
1335 assert!(transaction_open());
1336 transaction(&txn, |_| async move {
1337 assert!(transaction_open());
1338 Ok(())
1339 })
1340 .await?;
1341 assert!(transaction_open());
1342 Ok(())
1343 }
1344 })
1345 .await
1346 .expect("transaction should commit");
1347
1348 assert!(!transaction_open());
1349 }
1350
1351 #[tokio::test]
1352 async fn current_transaction_id_is_none_outside_transactions() {
1353 assert_eq!(current_transaction_id(), None);
1354
1355 let db = setup_db().await;
1356 transaction(&db, |_| async move { Ok(()) })
1357 .await
1358 .expect("transaction should commit");
1359
1360 assert_eq!(current_transaction_id(), None);
1361 }
1362
1363 #[tokio::test]
1364 async fn current_transaction_id_is_stable_across_nested_transactions() {
1365 let db = setup_db().await;
1366
1367 let (outer_id, inner_id, after_inner_id) = transaction(&db, |txn| {
1368 let txn = txn.clone();
1369 async move {
1370 let outer_id = current_transaction_id().expect("outer transaction id should exist");
1371 let inner_id = transaction(&txn, |_| async move {
1372 Ok::<String, RecordError>(
1373 current_transaction_id().expect("nested transaction id should exist"),
1374 )
1375 })
1376 .await?;
1377 let after_inner_id =
1378 current_transaction_id().expect("outer transaction id should still exist");
1379 Ok((outer_id, inner_id, after_inner_id))
1380 }
1381 })
1382 .await
1383 .expect("transaction should commit");
1384
1385 assert_eq!(outer_id, inner_id);
1386 assert_eq!(outer_id, after_inner_id);
1387 }
1388
1389 #[tokio::test]
1390 async fn current_transaction_id_changes_between_outer_transactions() {
1391 let db = setup_db().await;
1392
1393 let first = transaction(&db, |_| async move {
1394 Ok::<String, RecordError>(
1395 current_transaction_id().expect("transaction id should exist"),
1396 )
1397 })
1398 .await
1399 .expect("first transaction should commit");
1400
1401 let second = transaction(&db, |_| async move {
1402 Ok::<String, RecordError>(
1403 current_transaction_id().expect("transaction id should exist"),
1404 )
1405 })
1406 .await
1407 .expect("second transaction should commit");
1408
1409 assert_ne!(first, second);
1410 }
1411
1412 #[tokio::test]
1413 async fn transaction_state_clears_after_outer_rollback() {
1414 let db = setup_db().await;
1415
1416 let error = transaction(&db, |_| async move {
1417 assert_eq!(open_transactions(), 1);
1418 assert!(transaction_open());
1419 assert!(current_transaction_id().is_some());
1420 Err::<(), RecordError>(RecordError::Invalid("rollback outer".to_owned()))
1421 })
1422 .await
1423 .expect_err("transaction should roll back");
1424
1425 assert!(matches!(error, RecordError::Invalid(message) if message == "rollback outer"));
1426 assert_eq!(open_transactions(), 0);
1427 assert!(!transaction_open());
1428 assert_eq!(current_transaction_id(), None);
1429 }
1430
1431 #[tokio::test]
1432 async fn nested_rollback_restores_outer_transaction_state() {
1433 let db = setup_db().await;
1434
1435 transaction(&db, |txn| {
1436 let txn = txn.clone();
1437 async move {
1438 let outer_id = current_transaction_id().expect("outer transaction id should exist");
1439 let _ = transaction(&txn, |_| async move {
1440 assert_eq!(open_transactions(), 2);
1441 Err::<(), RecordError>(RecordError::Invalid("rollback inner".to_owned()))
1442 })
1443 .await
1444 .expect_err("inner transaction should roll back");
1445
1446 assert_eq!(open_transactions(), 1);
1447 assert!(transaction_open());
1448 assert_eq!(
1449 current_transaction_id().expect("outer transaction id should remain"),
1450 outer_id
1451 );
1452 Ok(())
1453 }
1454 })
1455 .await
1456 .expect("outer transaction should commit");
1457
1458 assert_eq!(open_transactions(), 0);
1459 assert_eq!(current_transaction_id(), None);
1460 }
1461}