Skip to main content

rustrails_record/
transactions.rs

1use std::{
2    cell::RefCell,
3    sync::atomic::{AtomicU32, AtomicU64, Ordering},
4};
5
6use rustrails_support::{database, runtime};
7use sea_orm::{ConnectionTrait, DatabaseConnection};
8
9use crate::base::{Record, RecordError};
10
11type TransactionCallback = Box<dyn FnOnce() + Send>;
12
13#[derive(Default)]
14struct TransactionLevel {
15    savepoint_name: Option<String>,
16    after_commit: Vec<TransactionCallback>,
17    after_rollback: Vec<TransactionCallback>,
18}
19
20impl TransactionLevel {
21    fn outermost() -> Self {
22        Self::default()
23    }
24
25    fn nested(savepoint_name: String) -> Self {
26        Self {
27            savepoint_name: Some(savepoint_name),
28            ..Self::default()
29        }
30    }
31
32    fn absorb(&mut self, nested: Self) {
33        self.after_commit.extend(nested.after_commit);
34        self.after_rollback.extend(nested.after_rollback);
35    }
36}
37
38enum FinalizeAction {
39    Commit,
40    Rollback,
41}
42
43enum TransactionBoundary {
44    Outermost,
45    Nested(String),
46}
47
48thread_local! {
49    static OPEN_TRANSACTION_COUNT: AtomicU32 = const { AtomicU32::new(0) };
50    static SAVEPOINT_SEQUENCE: AtomicU32 = const { AtomicU32::new(0) };
51    static CURRENT_TRANSACTION_ID: RefCell<Option<String>> = const { RefCell::new(None) };
52    static TRANSACTION_LEVELS: RefCell<Vec<TransactionLevel>> = const { RefCell::new(Vec::new()) };
53}
54
55static NEXT_TRANSACTION_ID: AtomicU64 = AtomicU64::new(1);
56
57/// Returns the number of currently open transaction scopes on this thread.
58#[must_use]
59pub fn open_transactions() -> u32 {
60    OPEN_TRANSACTION_COUNT.with(|count| count.load(Ordering::SeqCst))
61}
62
63/// Returns `true` when any transaction scope is currently open on this thread.
64#[must_use]
65pub fn transaction_open() -> bool {
66    open_transactions() > 0
67}
68
69/// Returns the current outermost transaction identifier, if one is active.
70#[must_use]
71pub fn current_transaction_id() -> Option<String> {
72    CURRENT_TRANSACTION_ID.with(|current| current.borrow().clone())
73}
74
75/// Registers a callback to run after the outermost transaction commits.
76///
77/// When no transaction is open, the callback runs immediately.
78pub fn after_commit<F>(callback: F)
79where
80    F: FnOnce() + Send + 'static,
81{
82    let mut callback = Some(Box::new(callback) as TransactionCallback);
83    let registered = TRANSACTION_LEVELS.with(|levels| {
84        let mut levels = levels.borrow_mut();
85        if let Some(level) = levels.last_mut() {
86            level
87                .after_commit
88                .push(callback.take().expect("after_commit callback should exist"));
89            true
90        } else {
91            false
92        }
93    });
94
95    if !registered {
96        callback.expect("after_commit callback should exist outside transactions")();
97    }
98}
99
100/// Registers a callback to run when the current transaction scope rolls back.
101pub fn after_rollback<F>(callback: F)
102where
103    F: FnOnce() + Send + 'static,
104{
105    let mut callback = Some(Box::new(callback) as TransactionCallback);
106    TRANSACTION_LEVELS.with(|levels| {
107        let mut levels = levels.borrow_mut();
108        if let Some(level) = levels.last_mut() {
109            level.after_rollback.push(
110                callback
111                    .take()
112                    .expect("after_rollback callback should exist"),
113            );
114        }
115    });
116}
117
118/// Executes a closure inside a database transaction.
119///
120/// The helper uses explicit `BEGIN` / `COMMIT` / `ROLLBACK` statements against the
121/// provided connection so the closure can keep working with a [`DatabaseConnection`].
122/// Nested calls use savepoints so inner failures can roll back without aborting the
123/// outermost transaction.
124pub async fn transaction<F, Fut, T>(db: &DatabaseConnection, f: F) -> Result<T, RecordError>
125where
126    F: FnOnce(&DatabaseConnection) -> Fut + Send,
127    Fut: std::future::Future<Output = Result<T, RecordError>> + Send,
128    T: Send,
129{
130    begin_transaction_scope(db).await?;
131
132    match f(db).await {
133        Ok(value) => {
134            commit_transaction_scope(db).await?;
135            Ok(value)
136        }
137        Err(error) => {
138            rollback_transaction_scope(db).await?;
139            Err(error)
140        }
141    }
142}
143
144/// Synchronous wrapper for [`transaction`].
145pub fn transaction_sync<F, Fut, T>(f: F) -> Result<T, RecordError>
146where
147    F: FnOnce(&DatabaseConnection) -> Fut + Send,
148    Fut: std::future::Future<Output = Result<T, RecordError>> + Send,
149    T: Send,
150{
151    database::with_db(|db| runtime::block_on(transaction(db, f)))
152}
153
154async fn begin_transaction_scope(db: &DatabaseConnection) -> Result<(), RecordError> {
155    if transaction_open() {
156        let savepoint_name = next_savepoint_name();
157        execute_transaction_control(db, &format!("SAVEPOINT {savepoint_name}")).await?;
158        OPEN_TRANSACTION_COUNT.with(|count| {
159            count.fetch_add(1, Ordering::SeqCst);
160        });
161        TRANSACTION_LEVELS.with(|levels| {
162            levels
163                .borrow_mut()
164                .push(TransactionLevel::nested(savepoint_name));
165        });
166    } else {
167        let transaction_id = next_transaction_id();
168        execute_transaction_control(db, "BEGIN").await?;
169        OPEN_TRANSACTION_COUNT.with(|count| count.store(1, Ordering::SeqCst));
170        SAVEPOINT_SEQUENCE.with(|sequence| sequence.store(0, Ordering::SeqCst));
171        CURRENT_TRANSACTION_ID.with(|current| {
172            current.replace(Some(transaction_id));
173        });
174        TRANSACTION_LEVELS.with(|levels| {
175            levels.borrow_mut().push(TransactionLevel::outermost());
176        });
177    }
178
179    Ok(())
180}
181
182async fn commit_transaction_scope(db: &DatabaseConnection) -> Result<(), RecordError> {
183    match current_transaction_boundary() {
184        Some(TransactionBoundary::Outermost) => {
185            execute_or_reset_state(db, "COMMIT").await?;
186            let callbacks = finish_outermost_transaction(FinalizeAction::Commit);
187            run_callbacks(callbacks);
188            Ok(())
189        }
190        Some(TransactionBoundary::Nested(savepoint_name)) => {
191            execute_or_reset_state(db, &format!("RELEASE SAVEPOINT {savepoint_name}")).await?;
192            merge_nested_transaction_into_parent();
193            Ok(())
194        }
195        None => Ok(()),
196    }
197}
198
199async fn rollback_transaction_scope(db: &DatabaseConnection) -> Result<(), RecordError> {
200    match current_transaction_boundary() {
201        Some(TransactionBoundary::Outermost) => {
202            execute_or_reset_state(db, "ROLLBACK").await?;
203            let callbacks = finish_outermost_transaction(FinalizeAction::Rollback);
204            run_callbacks(callbacks);
205            Ok(())
206        }
207        Some(TransactionBoundary::Nested(savepoint_name)) => {
208            execute_or_reset_state(db, &format!("ROLLBACK TO SAVEPOINT {savepoint_name}")).await?;
209            let callbacks = rollback_nested_transaction();
210            run_callbacks(callbacks);
211            Ok(())
212        }
213        None => Ok(()),
214    }
215}
216
217async fn execute_or_reset_state(db: &DatabaseConnection, sql: &str) -> Result<(), RecordError> {
218    if let Err(error) = execute_transaction_control(db, sql).await {
219        reset_transaction_state();
220        Err(error)
221    } else {
222        Ok(())
223    }
224}
225
226async fn execute_transaction_control(
227    db: &DatabaseConnection,
228    sql: &str,
229) -> Result<(), RecordError> {
230    db.execute_unprepared(sql).await?;
231    Ok(())
232}
233
234fn current_transaction_boundary() -> Option<TransactionBoundary> {
235    TRANSACTION_LEVELS.with(|levels| {
236        let levels = levels.borrow();
237        levels.last().map(|level| match &level.savepoint_name {
238            Some(savepoint_name) => TransactionBoundary::Nested(savepoint_name.clone()),
239            None => TransactionBoundary::Outermost,
240        })
241    })
242}
243
244fn merge_nested_transaction_into_parent() {
245    TRANSACTION_LEVELS.with(|levels| {
246        let mut levels = levels.borrow_mut();
247        let nested = levels
248            .pop()
249            .expect("nested transaction state should exist during commit");
250        let parent = levels
251            .last_mut()
252            .expect("parent transaction state should exist during nested commit");
253        parent.absorb(nested);
254    });
255    OPEN_TRANSACTION_COUNT.with(|count| {
256        count.fetch_sub(1, Ordering::SeqCst);
257    });
258}
259
260fn rollback_nested_transaction() -> Vec<TransactionCallback> {
261    let callbacks = TRANSACTION_LEVELS.with(|levels| {
262        let mut levels = levels.borrow_mut();
263        levels
264            .pop()
265            .expect("nested transaction state should exist during rollback")
266            .after_rollback
267    });
268    OPEN_TRANSACTION_COUNT.with(|count| {
269        count.fetch_sub(1, Ordering::SeqCst);
270    });
271    callbacks
272}
273
274fn finish_outermost_transaction(action: FinalizeAction) -> Vec<TransactionCallback> {
275    OPEN_TRANSACTION_COUNT.with(|count| count.store(0, Ordering::SeqCst));
276    SAVEPOINT_SEQUENCE.with(|sequence| sequence.store(0, Ordering::SeqCst));
277    CURRENT_TRANSACTION_ID.with(|current| {
278        current.replace(None);
279    });
280
281    TRANSACTION_LEVELS.with(|levels| {
282        let mut levels = levels.borrow_mut();
283        let outermost = levels
284            .pop()
285            .expect("outermost transaction state should exist during finalization");
286        levels.clear();
287        match action {
288            FinalizeAction::Commit => outermost.after_commit,
289            FinalizeAction::Rollback => outermost.after_rollback,
290        }
291    })
292}
293
294fn next_transaction_id() -> String {
295    format!("tx-{}", NEXT_TRANSACTION_ID.fetch_add(1, Ordering::Relaxed))
296}
297
298fn next_savepoint_name() -> String {
299    SAVEPOINT_SEQUENCE.with(|sequence| {
300        let next = sequence.fetch_add(1, Ordering::SeqCst) + 1;
301        format!("sp_{next}")
302    })
303}
304
305fn reset_transaction_state() {
306    OPEN_TRANSACTION_COUNT.with(|count| count.store(0, Ordering::SeqCst));
307    SAVEPOINT_SEQUENCE.with(|sequence| sequence.store(0, Ordering::SeqCst));
308    CURRENT_TRANSACTION_ID.with(|current| {
309        current.replace(None);
310    });
311    TRANSACTION_LEVELS.with(|levels| levels.borrow_mut().clear());
312}
313
314fn run_callbacks(callbacks: Vec<TransactionCallback>) {
315    for callback in callbacks {
316        callback();
317    }
318}
319
320/// Trait for records that support transactional operations.
321pub trait Transactional: Record {
322    /// Executes a closure inside a database transaction.
323    async fn transaction<F, Fut, T>(db: &DatabaseConnection, f: F) -> Result<T, RecordError>
324    where
325        F: FnOnce(&DatabaseConnection) -> Fut + Send,
326        Fut: std::future::Future<Output = Result<T, RecordError>> + Send,
327        T: Send,
328    {
329        crate::transactions::transaction(db, f).await
330    }
331}
332
333#[cfg(test)]
334mod tests {
335    use std::{
336        collections::HashMap,
337        sync::{
338            Arc, Mutex,
339            atomic::{AtomicUsize, Ordering as AtomicOrdering},
340        },
341    };
342
343    use sea_orm::{ConnectionTrait, Schema};
344    use serde_json::{Value, json};
345
346    use super::{
347        Transactional, after_commit, after_rollback, current_transaction_id, open_transactions,
348        transaction, transaction_open, transaction_sync,
349    };
350    use crate::{
351        Record, RecordError,
352        base::test_support::{TestUser, setup_db, test_user},
353        persistence::AsyncPersistence,
354        querying::AsyncQuerying,
355    };
356    use rustrails_support::{database, runtime};
357
358    fn run_sync_transaction_test(test: impl FnOnce() + Send + 'static) {
359        std::thread::spawn(move || {
360            let _rt = runtime::init_runtime();
361            database::establish("sqlite::memory:")
362                .expect("sqlite in-memory connection should succeed");
363            runtime::block_on(async {
364                let db = database::db();
365                let schema = Schema::new(db.get_database_backend());
366                db.execute(&schema.create_table_from_entity(test_user::Entity))
367                    .await
368                    .expect("test_users table should be created");
369            });
370            test();
371        })
372        .join()
373        .unwrap();
374    }
375
376    fn user_attrs(name: &str, email: &str) -> HashMap<String, Value> {
377        HashMap::from([
378            ("name".to_owned(), json!(name)),
379            ("email".to_owned(), json!(email)),
380        ])
381    }
382
383    impl Transactional for TestUser {}
384
385    #[tokio::test]
386    async fn transaction_commits_on_success() {
387        let db = setup_db().await;
388
389        transaction(&db, |txn| {
390            let txn = txn.clone();
391            async move {
392                TestUser::create(
393                    HashMap::from([
394                        ("name".to_owned(), json!("Alice")),
395                        ("email".to_owned(), json!("alice@example.com")),
396                    ]),
397                    &txn,
398                )
399                .await?;
400                Ok(())
401            }
402        })
403        .await
404        .expect("transaction should commit");
405
406        assert_eq!(TestUser::count(&db).await.expect("count should succeed"), 1);
407    }
408
409    #[tokio::test]
410    async fn transaction_rolls_back_on_error() {
411        let db = setup_db().await;
412
413        let error = transaction(&db, |txn| {
414            let txn = txn.clone();
415            async move {
416                TestUser::create(
417                    HashMap::from([
418                        ("name".to_owned(), json!("Alice")),
419                        ("email".to_owned(), json!("alice@example.com")),
420                    ]),
421                    &txn,
422                )
423                .await?;
424                Err::<(), RecordError>(RecordError::Invalid("force rollback".to_owned()))
425            }
426        })
427        .await
428        .expect_err("transaction should fail");
429
430        assert!(matches!(error, RecordError::Invalid(message) if message == "force rollback"));
431        assert_eq!(TestUser::count(&db).await.expect("count should succeed"), 0);
432    }
433
434    #[tokio::test]
435    async fn transaction_returns_closure_value() {
436        let db = setup_db().await;
437
438        let id = transaction(&db, |txn| {
439            let txn = txn.clone();
440            async move {
441                let user = TestUser::create(
442                    HashMap::from([
443                        ("name".to_owned(), json!("Alice")),
444                        ("email".to_owned(), json!("alice@example.com")),
445                    ]),
446                    &txn,
447                )
448                .await?;
449                user.id().ok_or(RecordError::NotSaved)
450            }
451        })
452        .await
453        .expect("transaction should return a value");
454
455        assert_eq!(id, 1);
456    }
457
458    #[tokio::test]
459    async fn transactional_trait_delegates_to_helper() {
460        let db = setup_db().await;
461
462        TestUser::transaction(&db, |txn| {
463            let txn = txn.clone();
464            async move {
465                TestUser::create(
466                    HashMap::from([
467                        ("name".to_owned(), json!("Bob")),
468                        ("email".to_owned(), json!("bob@example.com")),
469                    ]),
470                    &txn,
471                )
472                .await?;
473                Ok(())
474            }
475        })
476        .await
477        .expect("trait helper should commit");
478
479        let user = TestUser::find(1, &db).await.expect("user should exist");
480        assert_eq!(user.name, "Bob");
481    }
482
483    #[tokio::test]
484    async fn rollback_preserves_rows_outside_failed_transaction() {
485        let db = setup_db().await;
486
487        TestUser::create(
488            HashMap::from([
489                ("name".to_owned(), json!("Alice")),
490                ("email".to_owned(), json!("alice@example.com")),
491            ]),
492            &db,
493        )
494        .await
495        .expect("seed insert should succeed");
496
497        let _: Result<(), RecordError> = transaction(&db, |txn| {
498            let txn = txn.clone();
499            async move {
500                TestUser::create(
501                    HashMap::from([
502                        ("name".to_owned(), json!("Bob")),
503                        ("email".to_owned(), json!("bob@example.com")),
504                    ]),
505                    &txn,
506                )
507                .await?;
508                Err(RecordError::Invalid("rollback".to_owned()))
509            }
510        })
511        .await;
512
513        assert_eq!(TestUser::count(&db).await.expect("count should succeed"), 1);
514    }
515
516    #[tokio::test]
517    async fn transaction_commits_multiple_writes() {
518        let db = setup_db().await;
519
520        transaction(&db, |txn| {
521            let txn = txn.clone();
522            async move {
523                for (name, email) in [("Alice", "alice@example.com"), ("Bob", "bob@example.com")] {
524                    TestUser::create(
525                        HashMap::from([
526                            ("name".to_owned(), json!(name)),
527                            ("email".to_owned(), json!(email)),
528                        ]),
529                        &txn,
530                    )
531                    .await?;
532                }
533                Ok(())
534            }
535        })
536        .await
537        .expect("multi-write transaction should commit");
538
539        assert_eq!(TestUser::count(&db).await.expect("count should succeed"), 2);
540    }
541
542    #[tokio::test]
543    async fn transaction_rolls_back_multiple_writes() {
544        let db = setup_db().await;
545
546        let error = transaction(&db, |txn| {
547            let txn = txn.clone();
548            async move {
549                for (name, email) in [("Alice", "alice@example.com"), ("Bob", "bob@example.com")] {
550                    TestUser::create(
551                        HashMap::from([
552                            ("name".to_owned(), json!(name)),
553                            ("email".to_owned(), json!(email)),
554                        ]),
555                        &txn,
556                    )
557                    .await?;
558                }
559                Err::<(), RecordError>(RecordError::Invalid("rollback all writes".to_owned()))
560            }
561        })
562        .await
563        .expect_err("multi-write transaction should fail");
564
565        assert!(matches!(error, RecordError::Invalid(message) if message == "rollback all writes"));
566        assert_eq!(TestUser::count(&db).await.expect("count should succeed"), 0);
567    }
568
569    #[tokio::test]
570    async fn transaction_exposes_uncommitted_writes_inside_closure() {
571        let db = setup_db().await;
572
573        let visible_count = transaction(&db, |txn| {
574            let txn = txn.clone();
575            async move {
576                TestUser::create(
577                    HashMap::from([
578                        ("name".to_owned(), json!("Alice")),
579                        ("email".to_owned(), json!("alice@example.com")),
580                    ]),
581                    &txn,
582                )
583                .await?;
584                TestUser::count(&txn).await
585            }
586        })
587        .await
588        .expect("transaction should return the in-transaction count");
589
590        assert_eq!(visible_count, 1);
591        assert_eq!(TestUser::count(&db).await.expect("count should succeed"), 1);
592    }
593
594    #[tokio::test]
595    async fn transaction_rolls_back_writes_visible_inside_failed_closure() {
596        let db = setup_db().await;
597
598        let visible_count = transaction(&db, |txn| {
599            let txn = txn.clone();
600            async move {
601                TestUser::create(
602                    HashMap::from([
603                        ("name".to_owned(), json!("Alice")),
604                        ("email".to_owned(), json!("alice@example.com")),
605                    ]),
606                    &txn,
607                )
608                .await?;
609                let count = TestUser::count(&txn).await?;
610                Err::<u64, RecordError>(RecordError::Invalid(format!(
611                    "count before rollback: {count}"
612                )))
613            }
614        })
615        .await
616        .expect_err("transaction should fail");
617
618        assert!(
619            matches!(visible_count, RecordError::Invalid(message) if message == "count before rollback: 1")
620        );
621        assert_eq!(TestUser::count(&db).await.expect("count should succeed"), 0);
622    }
623
624    #[tokio::test]
625    async fn transaction_can_read_seeded_rows_and_insert_more() {
626        let db = setup_db().await;
627
628        TestUser::create(
629            HashMap::from([
630                ("name".to_owned(), json!("Alice")),
631                ("email".to_owned(), json!("alice@example.com")),
632            ]),
633            &db,
634        )
635        .await
636        .expect("seed insert should succeed");
637
638        let counts = transaction(&db, |txn| {
639            let txn = txn.clone();
640            async move {
641                let before = TestUser::count(&txn).await?;
642                TestUser::create(
643                    HashMap::from([
644                        ("name".to_owned(), json!("Bob")),
645                        ("email".to_owned(), json!("bob@example.com")),
646                    ]),
647                    &txn,
648                )
649                .await?;
650                let after = TestUser::count(&txn).await?;
651                Ok((before, after))
652            }
653        })
654        .await
655        .expect("transaction should commit");
656
657        assert_eq!(counts, (1, 2));
658        assert_eq!(TestUser::count(&db).await.expect("count should succeed"), 2);
659    }
660
661    #[tokio::test]
662    #[ignore = "Nested savepoint-backed transactions are supported on the same connection"]
663    async fn nested_transaction_on_same_connection_returns_database_error() {
664        let db = setup_db().await;
665
666        let error = transaction(&db, |txn| {
667            let txn = txn.clone();
668            async move {
669                TestUser::create(
670                    HashMap::from([
671                        ("name".to_owned(), json!("Outer")),
672                        ("email".to_owned(), json!("outer@example.com")),
673                    ]),
674                    &txn,
675                )
676                .await?;
677
678                let nested = transaction(&txn, |inner_txn| {
679                    let inner_txn = inner_txn.clone();
680                    async move {
681                        TestUser::create(
682                            HashMap::from([
683                                ("name".to_owned(), json!("Inner")),
684                                ("email".to_owned(), json!("inner@example.com")),
685                            ]),
686                            &inner_txn,
687                        )
688                        .await?;
689                        Ok(())
690                    }
691                })
692                .await;
693
694                assert!(nested.is_err());
695                nested
696            }
697        })
698        .await
699        .expect_err("nested transaction should fail on the same connection");
700
701        assert!(matches!(error, RecordError::Database(_)));
702        assert_eq!(TestUser::count(&db).await.expect("count should succeed"), 0);
703    }
704
705    #[tokio::test]
706    async fn transaction_rollback_preserves_seeded_rows_on_multiwrite_failure() {
707        let db = setup_db().await;
708
709        TestUser::create(
710            HashMap::from([
711                ("name".to_owned(), json!("Alice")),
712                ("email".to_owned(), json!("alice@example.com")),
713            ]),
714            &db,
715        )
716        .await
717        .expect("seed insert should succeed");
718
719        let _ = transaction(&db, |txn| {
720            let txn = txn.clone();
721            async move {
722                for (name, email) in [("Bob", "bob@example.com"), ("Carol", "carol@example.com")] {
723                    TestUser::create(
724                        HashMap::from([
725                            ("name".to_owned(), json!(name)),
726                            ("email".to_owned(), json!(email)),
727                        ]),
728                        &txn,
729                    )
730                    .await?;
731                }
732                Err::<(), RecordError>(RecordError::NotSaved)
733            }
734        })
735        .await;
736
737        assert_eq!(TestUser::count(&db).await.expect("count should succeed"), 1);
738        assert_eq!(
739            TestUser::find(1, &db)
740                .await
741                .expect("seed row should still exist")
742                .name,
743            "Alice"
744        );
745    }
746
747    #[tokio::test]
748    async fn manual_rollback_via_not_saved_error_rolls_back() {
749        let db = setup_db().await;
750
751        let error = transaction(&db, |txn| {
752            let txn = txn.clone();
753            async move {
754                TestUser::create(
755                    HashMap::from([
756                        ("name".to_owned(), json!("Alice")),
757                        ("email".to_owned(), json!("alice@example.com")),
758                    ]),
759                    &txn,
760                )
761                .await?;
762                Err::<(), RecordError>(RecordError::NotSaved)
763            }
764        })
765        .await
766        .expect_err("manual rollback should bubble the original error");
767
768        assert!(matches!(error, RecordError::NotSaved));
769        assert_eq!(TestUser::count(&db).await.expect("count should succeed"), 0);
770    }
771
772    #[tokio::test]
773    async fn transactional_trait_can_return_tuple_values() {
774        let db = setup_db().await;
775
776        let result = TestUser::transaction(&db, |txn| {
777            let txn = txn.clone();
778            async move {
779                let user = TestUser::create(
780                    HashMap::from([
781                        ("name".to_owned(), json!("Alice")),
782                        ("email".to_owned(), json!("alice@example.com")),
783                    ]),
784                    &txn,
785                )
786                .await?;
787                Ok((
788                    user.id().expect("id should be assigned"),
789                    TestUser::count(&txn).await?,
790                ))
791            }
792        })
793        .await
794        .expect("trait helper should return tuple values");
795
796        assert_eq!(result, (1, 1));
797    }
798
799    #[tokio::test]
800    async fn transaction_without_writes_can_return_existing_count() {
801        let db = setup_db().await;
802
803        TestUser::create(
804            HashMap::from([
805                ("name".to_owned(), json!("Alice")),
806                ("email".to_owned(), json!("alice@example.com")),
807            ]),
808            &db,
809        )
810        .await
811        .expect("seed insert should succeed");
812
813        let count = transaction(&db, |txn| {
814            let txn = txn.clone();
815            async move { TestUser::count(&txn).await }
816        })
817        .await
818        .expect("read-only transaction should commit");
819
820        assert_eq!(count, 1);
821    }
822
823    #[tokio::test]
824    async fn transaction_commits_updates_to_existing_rows() {
825        let db = setup_db().await;
826        let user = TestUser::create(
827            HashMap::from([
828                ("name".to_owned(), json!("Alice")),
829                ("email".to_owned(), json!("alice@example.com")),
830            ]),
831            &db,
832        )
833        .await
834        .expect("seed insert should succeed");
835        let id = user.id().expect("seed row should have an id");
836
837        transaction(&db, |txn| {
838            let txn = txn.clone();
839            async move {
840                let mut user = TestUser::find(id, &txn).await?;
841                user.update_attributes(
842                    HashMap::from([("name".to_owned(), json!("Updated Alice"))]),
843                    &txn,
844                )
845                .await?;
846                Ok(())
847            }
848        })
849        .await
850        .expect("update transaction should commit");
851
852        let reloaded = TestUser::find(id, &db)
853            .await
854            .expect("updated row should load after commit");
855        assert_eq!(reloaded.name, "Updated Alice");
856        assert_eq!(reloaded.email, "alice@example.com");
857    }
858
859    #[tokio::test]
860    async fn transaction_rolls_back_updates_to_existing_rows() {
861        let db = setup_db().await;
862        let user = TestUser::create(
863            HashMap::from([
864                ("name".to_owned(), json!("Alice")),
865                ("email".to_owned(), json!("alice@example.com")),
866            ]),
867            &db,
868        )
869        .await
870        .expect("seed insert should succeed");
871        let id = user.id().expect("seed row should have an id");
872
873        let error = transaction(&db, |txn| {
874            let txn = txn.clone();
875            async move {
876                let mut user = TestUser::find(id, &txn).await?;
877                user.update_attributes(
878                    HashMap::from([("name".to_owned(), json!("Updated Alice"))]),
879                    &txn,
880                )
881                .await?;
882                Err::<(), RecordError>(RecordError::Invalid("rollback update".to_owned()))
883            }
884        })
885        .await
886        .expect_err("update transaction should fail");
887
888        assert!(matches!(error, RecordError::Invalid(message) if message == "rollback update"));
889
890        let reloaded = TestUser::find(id, &db)
891            .await
892            .expect("seed row should still load after rollback");
893        assert_eq!(reloaded.name, "Alice");
894        assert_eq!(reloaded.email, "alice@example.com");
895    }
896
897    #[tokio::test]
898    async fn nested_transaction_commits_with_savepoint_release() {
899        let db = setup_db().await;
900
901        let (outer_id, inner_id) = transaction(&db, |txn| {
902            let txn = txn.clone();
903            async move {
904                TestUser::create(user_attrs("Outer", "outer@example.com"), &txn).await?;
905                let outer_id = current_transaction_id().expect("outer transaction id should exist");
906
907                let inner_id = transaction(&txn, |inner_txn| {
908                    let inner_txn = inner_txn.clone();
909                    async move {
910                        assert_eq!(open_transactions(), 2);
911                        TestUser::create(user_attrs("Inner", "inner@example.com"), &inner_txn)
912                            .await?;
913                        Ok::<String, RecordError>(
914                            current_transaction_id()
915                                .expect("nested transaction should reuse the outer id"),
916                        )
917                    }
918                })
919                .await?;
920
921                assert_eq!(TestUser::count(&txn).await?, 2);
922                Ok((outer_id, inner_id))
923            }
924        })
925        .await
926        .expect("nested transaction should commit");
927
928        assert_eq!(outer_id, inner_id);
929        assert_eq!(TestUser::count(&db).await.expect("count should succeed"), 2);
930    }
931
932    #[test]
933    fn transaction_sync_commits_on_success() {
934        run_sync_transaction_test(|| {
935            transaction_sync(|txn| {
936                let txn = txn.clone();
937                async move {
938                    TestUser::create(
939                        HashMap::from([
940                            ("name".to_owned(), json!("Alice")),
941                            ("email".to_owned(), json!("alice@example.com")),
942                        ]),
943                        &txn,
944                    )
945                    .await?;
946                    Ok(())
947                }
948            })
949            .expect("transaction should commit");
950
951            let count = runtime::block_on(async {
952                let db = database::db();
953                TestUser::count(&db).await.expect("count should succeed")
954            });
955            assert_eq!(count, 1);
956        });
957    }
958
959    #[tokio::test]
960    async fn nested_transaction_rollback_to_savepoint_preserves_outer_changes() {
961        let db = setup_db().await;
962
963        transaction(&db, |txn| {
964            let txn = txn.clone();
965            async move {
966                TestUser::create(user_attrs("Outer", "outer@example.com"), &txn).await?;
967
968                let error = transaction(&txn, |inner_txn| {
969                    let inner_txn = inner_txn.clone();
970                    async move {
971                        TestUser::create(user_attrs("Inner", "inner@example.com"), &inner_txn)
972                            .await?;
973                        Err::<(), RecordError>(RecordError::Invalid("rollback inner".to_owned()))
974                    }
975                })
976                .await
977                .expect_err("inner transaction should roll back to its savepoint");
978
979                assert!(
980                    matches!(error, RecordError::Invalid(message) if message == "rollback inner")
981                );
982                assert_eq!(open_transactions(), 1);
983                assert_eq!(TestUser::count(&txn).await?, 1);
984
985                TestUser::create(user_attrs("AfterInner", "after@example.com"), &txn).await?;
986                Ok(())
987            }
988        })
989        .await
990        .expect("outer transaction should still commit");
991
992        assert_eq!(TestUser::count(&db).await.expect("count should succeed"), 2);
993        assert_eq!(
994            TestUser::find(1, &db)
995                .await
996                .expect("outer row should persist")
997                .name,
998            "Outer"
999        );
1000        assert_eq!(
1001            TestUser::find(2, &db)
1002                .await
1003                .expect("post-rollback outer row should persist")
1004                .name,
1005            "AfterInner"
1006        );
1007    }
1008
1009    #[tokio::test]
1010    async fn after_commit_fires_after_outermost_commit_only() {
1011        let db = setup_db().await;
1012        let calls = Arc::new(AtomicUsize::new(0));
1013        let transaction_calls = Arc::clone(&calls);
1014
1015        transaction(&db, move |txn| {
1016            let txn = txn.clone();
1017            let calls = Arc::clone(&transaction_calls);
1018            let outer_calls = Arc::clone(&calls);
1019            async move {
1020                after_commit(move || {
1021                    outer_calls.fetch_add(1, AtomicOrdering::SeqCst);
1022                });
1023
1024                let nested_calls = Arc::clone(&calls);
1025                transaction(&txn, move |inner_txn| {
1026                    let inner_txn = inner_txn.clone();
1027                    let calls = Arc::clone(&nested_calls);
1028                    let inner_calls = Arc::clone(&calls);
1029                    async move {
1030                        TestUser::create(user_attrs("Inner", "inner@example.com"), &inner_txn)
1031                            .await?;
1032                        after_commit(move || {
1033                            inner_calls.fetch_add(1, AtomicOrdering::SeqCst);
1034                        });
1035                        assert_eq!(calls.load(AtomicOrdering::SeqCst), 0);
1036                        Ok(())
1037                    }
1038                })
1039                .await?;
1040
1041                assert_eq!(calls.load(AtomicOrdering::SeqCst), 0);
1042                Ok(())
1043            }
1044        })
1045        .await
1046        .expect("outer transaction should commit");
1047
1048        assert_eq!(calls.load(AtomicOrdering::SeqCst), 2);
1049    }
1050
1051    #[tokio::test]
1052    async fn after_commit_callbacks_fire_in_registration_order_across_nested_transactions() {
1053        let db = setup_db().await;
1054        let events = Arc::new(Mutex::new(Vec::new()));
1055        let transaction_events = Arc::clone(&events);
1056
1057        transaction(&db, move |txn| {
1058            let txn = txn.clone();
1059            let events = Arc::clone(&transaction_events);
1060            let outer_events = Arc::clone(&events);
1061            async move {
1062                after_commit(move || outer_events.lock().unwrap().push("outer-1".to_owned()));
1063
1064                let nested_events = Arc::clone(&events);
1065                transaction(&txn, move |inner_txn| {
1066                    let inner_txn = inner_txn.clone();
1067                    let inner_events = Arc::clone(&nested_events);
1068                    async move {
1069                        after_commit(move || {
1070                            inner_events.lock().unwrap().push("inner-1".to_owned());
1071                        });
1072                        TestUser::create(user_attrs("Inner", "inner@example.com"), &inner_txn)
1073                            .await?;
1074                        Ok(())
1075                    }
1076                })
1077                .await?;
1078
1079                let trailing_events = Arc::clone(&events);
1080                after_commit(move || {
1081                    trailing_events.lock().unwrap().push("outer-2".to_owned());
1082                });
1083                Ok(())
1084            }
1085        })
1086        .await
1087        .expect("callback ordering transaction should commit");
1088
1089        assert_eq!(
1090            *events.lock().unwrap(),
1091            vec![
1092                "outer-1".to_owned(),
1093                "inner-1".to_owned(),
1094                "outer-2".to_owned()
1095            ]
1096        );
1097    }
1098
1099    #[tokio::test]
1100    async fn after_commit_callbacks_do_not_fire_when_outer_transaction_rolls_back() {
1101        let db = setup_db().await;
1102        let calls = Arc::new(AtomicUsize::new(0));
1103        let transaction_calls = Arc::clone(&calls);
1104
1105        let error = transaction(&db, move |txn| {
1106            let txn = txn.clone();
1107            let calls = Arc::clone(&transaction_calls);
1108            let outer_calls = Arc::clone(&calls);
1109            async move {
1110                after_commit(move || {
1111                    outer_calls.fetch_add(1, AtomicOrdering::SeqCst);
1112                });
1113
1114                let nested_calls = Arc::clone(&calls);
1115                transaction(&txn, move |inner_txn| {
1116                    let inner_txn = inner_txn.clone();
1117                    let calls = Arc::clone(&nested_calls);
1118                    let inner_calls = Arc::clone(&calls);
1119                    async move {
1120                        after_commit(move || {
1121                            inner_calls.fetch_add(1, AtomicOrdering::SeqCst);
1122                        });
1123                        TestUser::create(user_attrs("Inner", "inner@example.com"), &inner_txn)
1124                            .await?;
1125                        Ok(())
1126                    }
1127                })
1128                .await?;
1129
1130                Err::<(), RecordError>(RecordError::Invalid("rollback outer".to_owned()))
1131            }
1132        })
1133        .await
1134        .expect_err("outer transaction should roll back");
1135
1136        assert!(matches!(error, RecordError::Invalid(message) if message == "rollback outer"));
1137        assert_eq!(calls.load(AtomicOrdering::SeqCst), 0);
1138    }
1139
1140    #[tokio::test]
1141    async fn after_commit_callbacks_are_cleared_after_commit() {
1142        let db = setup_db().await;
1143        let calls = Arc::new(AtomicUsize::new(0));
1144
1145        transaction(&db, |_| {
1146            let calls = Arc::clone(&calls);
1147            async move {
1148                after_commit(move || {
1149                    calls.fetch_add(1, AtomicOrdering::SeqCst);
1150                });
1151                Ok(())
1152            }
1153        })
1154        .await
1155        .expect("first transaction should commit");
1156
1157        transaction(&db, |_| async move { Ok(()) })
1158            .await
1159            .expect("second transaction should commit");
1160
1161        assert_eq!(calls.load(AtomicOrdering::SeqCst), 1);
1162    }
1163
1164    #[tokio::test]
1165    async fn after_rollback_fires_on_outermost_rollback() {
1166        let db = setup_db().await;
1167        let calls = Arc::new(AtomicUsize::new(0));
1168        let callback_calls = Arc::clone(&calls);
1169
1170        let error = transaction(&db, move |_txn| {
1171            let outer_calls = Arc::clone(&callback_calls);
1172            async move {
1173                after_rollback(move || {
1174                    outer_calls.fetch_add(1, AtomicOrdering::SeqCst);
1175                });
1176                Err::<(), RecordError>(RecordError::Invalid("rollback outer".to_owned()))
1177            }
1178        })
1179        .await
1180        .expect_err("transaction should roll back");
1181
1182        assert!(matches!(error, RecordError::Invalid(message) if message == "rollback outer"));
1183        assert_eq!(calls.load(AtomicOrdering::SeqCst), 1);
1184    }
1185
1186    #[tokio::test]
1187    async fn nested_rollback_fires_only_inner_after_rollback_callbacks() {
1188        let db = setup_db().await;
1189        let calls = Arc::new(Mutex::new(Vec::new()));
1190        let transaction_calls = Arc::clone(&calls);
1191
1192        transaction(&db, move |txn| {
1193            let txn = txn.clone();
1194            let calls = Arc::clone(&transaction_calls);
1195            let outer_calls = Arc::clone(&calls);
1196            async move {
1197                after_rollback(move || outer_calls.lock().unwrap().push("outer".to_owned()));
1198
1199                let nested_calls = Arc::clone(&calls);
1200                let error = transaction(&txn, move |inner_txn| {
1201                    let inner_txn = inner_txn.clone();
1202                    let calls = Arc::clone(&nested_calls);
1203                    let inner_calls = Arc::clone(&calls);
1204                    async move {
1205                        after_rollback(move || {
1206                            inner_calls.lock().unwrap().push("inner".to_owned())
1207                        });
1208                        TestUser::create(user_attrs("Inner", "inner@example.com"), &inner_txn)
1209                            .await?;
1210                        Err::<(), RecordError>(RecordError::Invalid("rollback inner".to_owned()))
1211                    }
1212                })
1213                .await
1214                .expect_err("inner transaction should roll back");
1215
1216                assert!(
1217                    matches!(error, RecordError::Invalid(message) if message == "rollback inner")
1218                );
1219                assert_eq!(*calls.lock().unwrap(), vec!["inner".to_owned()]);
1220                Ok(())
1221            }
1222        })
1223        .await
1224        .expect("outer transaction should commit");
1225
1226        assert_eq!(*calls.lock().unwrap(), vec!["inner".to_owned()]);
1227    }
1228
1229    #[tokio::test]
1230    async fn nested_successful_after_rollback_callbacks_fire_if_outer_transaction_rolls_back() {
1231        let db = setup_db().await;
1232        let calls = Arc::new(Mutex::new(Vec::new()));
1233        let transaction_calls = Arc::clone(&calls);
1234
1235        let error = transaction(&db, move |txn| {
1236            let txn = txn.clone();
1237            let calls = Arc::clone(&transaction_calls);
1238            let outer_calls = Arc::clone(&calls);
1239            async move {
1240                after_rollback(move || outer_calls.lock().unwrap().push("outer".to_owned()));
1241
1242                let nested_calls = Arc::clone(&calls);
1243                transaction(&txn, move |inner_txn| {
1244                    let inner_txn = inner_txn.clone();
1245                    let calls = Arc::clone(&nested_calls);
1246                    let inner_calls = Arc::clone(&calls);
1247                    async move {
1248                        after_rollback(move || {
1249                            inner_calls.lock().unwrap().push("inner".to_owned())
1250                        });
1251                        TestUser::create(user_attrs("Inner", "inner@example.com"), &inner_txn)
1252                            .await?;
1253                        Ok(())
1254                    }
1255                })
1256                .await?;
1257
1258                Err::<(), RecordError>(RecordError::Invalid("rollback outer".to_owned()))
1259            }
1260        })
1261        .await
1262        .expect_err("outer transaction should roll back");
1263
1264        assert!(matches!(error, RecordError::Invalid(message) if message == "rollback outer"));
1265        assert_eq!(
1266            *calls.lock().unwrap(),
1267            vec!["outer".to_owned(), "inner".to_owned()]
1268        );
1269    }
1270
1271    #[tokio::test]
1272    async fn after_rollback_callbacks_are_cleared_after_rollback() {
1273        let db = setup_db().await;
1274        let calls = Arc::new(AtomicUsize::new(0));
1275        let callback_calls = Arc::clone(&calls);
1276
1277        let _ = transaction(&db, move |_txn| {
1278            let outer_calls = Arc::clone(&callback_calls);
1279            async move {
1280                after_rollback(move || {
1281                    outer_calls.fetch_add(1, AtomicOrdering::SeqCst);
1282                });
1283                Err::<(), RecordError>(RecordError::Invalid("rollback once".to_owned()))
1284            }
1285        })
1286        .await;
1287
1288        transaction(&db, |_| async move { Ok(()) })
1289            .await
1290            .expect("later transaction should commit");
1291
1292        assert_eq!(calls.load(AtomicOrdering::SeqCst), 1);
1293    }
1294
1295    #[tokio::test]
1296    async fn open_transactions_starts_closed() {
1297        assert_eq!(open_transactions(), 0);
1298        assert!(!transaction_open());
1299    }
1300
1301    #[tokio::test]
1302    async fn open_transactions_tracks_nested_depth() {
1303        let db = setup_db().await;
1304
1305        assert_eq!(open_transactions(), 0);
1306
1307        transaction(&db, |txn| {
1308            let txn = txn.clone();
1309            async move {
1310                assert_eq!(open_transactions(), 1);
1311                transaction(&txn, |_| async move {
1312                    assert_eq!(open_transactions(), 2);
1313                    Ok(())
1314                })
1315                .await?;
1316                assert_eq!(open_transactions(), 1);
1317                Ok(())
1318            }
1319        })
1320        .await
1321        .expect("nested transaction should commit");
1322
1323        assert_eq!(open_transactions(), 0);
1324    }
1325
1326    #[tokio::test]
1327    async fn transaction_open_reflects_current_state() {
1328        let db = setup_db().await;
1329
1330        assert!(!transaction_open());
1331
1332        transaction(&db, |txn| {
1333            let txn = txn.clone();
1334            async move {
1335                assert!(transaction_open());
1336                transaction(&txn, |_| async move {
1337                    assert!(transaction_open());
1338                    Ok(())
1339                })
1340                .await?;
1341                assert!(transaction_open());
1342                Ok(())
1343            }
1344        })
1345        .await
1346        .expect("transaction should commit");
1347
1348        assert!(!transaction_open());
1349    }
1350
1351    #[tokio::test]
1352    async fn current_transaction_id_is_none_outside_transactions() {
1353        assert_eq!(current_transaction_id(), None);
1354
1355        let db = setup_db().await;
1356        transaction(&db, |_| async move { Ok(()) })
1357            .await
1358            .expect("transaction should commit");
1359
1360        assert_eq!(current_transaction_id(), None);
1361    }
1362
1363    #[tokio::test]
1364    async fn current_transaction_id_is_stable_across_nested_transactions() {
1365        let db = setup_db().await;
1366
1367        let (outer_id, inner_id, after_inner_id) = transaction(&db, |txn| {
1368            let txn = txn.clone();
1369            async move {
1370                let outer_id = current_transaction_id().expect("outer transaction id should exist");
1371                let inner_id = transaction(&txn, |_| async move {
1372                    Ok::<String, RecordError>(
1373                        current_transaction_id().expect("nested transaction id should exist"),
1374                    )
1375                })
1376                .await?;
1377                let after_inner_id =
1378                    current_transaction_id().expect("outer transaction id should still exist");
1379                Ok((outer_id, inner_id, after_inner_id))
1380            }
1381        })
1382        .await
1383        .expect("transaction should commit");
1384
1385        assert_eq!(outer_id, inner_id);
1386        assert_eq!(outer_id, after_inner_id);
1387    }
1388
1389    #[tokio::test]
1390    async fn current_transaction_id_changes_between_outer_transactions() {
1391        let db = setup_db().await;
1392
1393        let first = transaction(&db, |_| async move {
1394            Ok::<String, RecordError>(
1395                current_transaction_id().expect("transaction id should exist"),
1396            )
1397        })
1398        .await
1399        .expect("first transaction should commit");
1400
1401        let second = transaction(&db, |_| async move {
1402            Ok::<String, RecordError>(
1403                current_transaction_id().expect("transaction id should exist"),
1404            )
1405        })
1406        .await
1407        .expect("second transaction should commit");
1408
1409        assert_ne!(first, second);
1410    }
1411
1412    #[tokio::test]
1413    async fn transaction_state_clears_after_outer_rollback() {
1414        let db = setup_db().await;
1415
1416        let error = transaction(&db, |_| async move {
1417            assert_eq!(open_transactions(), 1);
1418            assert!(transaction_open());
1419            assert!(current_transaction_id().is_some());
1420            Err::<(), RecordError>(RecordError::Invalid("rollback outer".to_owned()))
1421        })
1422        .await
1423        .expect_err("transaction should roll back");
1424
1425        assert!(matches!(error, RecordError::Invalid(message) if message == "rollback outer"));
1426        assert_eq!(open_transactions(), 0);
1427        assert!(!transaction_open());
1428        assert_eq!(current_transaction_id(), None);
1429    }
1430
1431    #[tokio::test]
1432    async fn nested_rollback_restores_outer_transaction_state() {
1433        let db = setup_db().await;
1434
1435        transaction(&db, |txn| {
1436            let txn = txn.clone();
1437            async move {
1438                let outer_id = current_transaction_id().expect("outer transaction id should exist");
1439                let _ = transaction(&txn, |_| async move {
1440                    assert_eq!(open_transactions(), 2);
1441                    Err::<(), RecordError>(RecordError::Invalid("rollback inner".to_owned()))
1442                })
1443                .await
1444                .expect_err("inner transaction should roll back");
1445
1446                assert_eq!(open_transactions(), 1);
1447                assert!(transaction_open());
1448                assert_eq!(
1449                    current_transaction_id().expect("outer transaction id should remain"),
1450                    outer_id
1451                );
1452                Ok(())
1453            }
1454        })
1455        .await
1456        .expect("outer transaction should commit");
1457
1458        assert_eq!(open_transactions(), 0);
1459        assert_eq!(current_transaction_id(), None);
1460    }
1461}