testkit_core/handlers/
with_database.rs

1use crate::{
2    DatabaseBackend, DatabasePool, TestContext, TestDatabaseInstance, TransactionStarter,
3    transaction::DBTransactionManager,
4};
5
6use async_trait::async_trait;
7use std::fmt::Debug;
8
9/// Trait for handlers that can be executed in the context of a database
10#[async_trait]
11pub trait DatabaseHandler<DB>: Send + Sync
12where
13    DB: DatabaseBackend + Send + Sync + Debug + 'static,
14{
15    /// Execute the handler with the given context
16    async fn execute(&self, ctx: &mut TestContext<DB>) -> Result<(), DB::Error>;
17}
18
19// Implementation of DatabaseHandler for FnOnce closures
20#[async_trait]
21impl<DB, F, Fut> DatabaseHandler<DB> for F
22where
23    DB: DatabaseBackend + Send + Sync + Debug + 'static,
24    F: FnOnce(&mut TestContext<DB>) -> Fut + Send + Sync + Clone,
25    Fut: std::future::Future<Output = Result<(), DB::Error>> + Send + 'static,
26{
27    async fn execute(&self, ctx: &mut TestContext<DB>) -> Result<(), DB::Error> {
28        self.clone()(ctx).await
29    }
30}
31
32#[derive(Debug)]
33#[must_use]
34pub struct TransactionHandler<DB, S, TFn>
35where
36    DB: DatabaseBackend + Send + Sync + Debug + 'static,
37    S: Send + Sync + 'static,
38    TFn: Send + Sync + 'static,
39{
40    db: TestDatabaseInstance<DB>,
41    setup_fn: S,
42    transaction_fn: TFn,
43}
44
45impl<DB, S, TFn> TransactionHandler<DB, S, TFn>
46where
47    DB: DatabaseBackend + Send + Sync + Debug + 'static,
48    S: Send + Sync + 'static,
49    TFn: Send + Sync + 'static,
50{
51    #[allow(dead_code)]
52    #[inline]
53    pub fn new(db: TestDatabaseInstance<DB>, setup_fn: S, transaction_fn: TFn) -> Self {
54        Self {
55            db,
56            setup_fn,
57            transaction_fn,
58        }
59    }
60}
61
62impl<DB, S, Fut, TFn, TxFut> TransactionHandler<DB, S, TFn>
63where
64    DB: DatabaseBackend + Send + Sync + Debug + 'static,
65    S: FnOnce(&mut <DB::Pool as DatabasePool>::Connection) -> Fut + Send + Sync + 'static,
66    Fut: std::future::Future<Output = Result<(), DB::Error>> + Send + 'static,
67    TxFut: std::future::Future<Output = Result<(), DB::Error>> + Send + 'static,
68    for<'tx> TFn: FnOnce(&'tx mut <TestContext<DB> as TransactionStarter<DB>>::Transaction) -> TxFut
69        + Send
70        + Sync
71        + 'static,
72    TestContext<DB>: TransactionStarter<DB>
73        + DBTransactionManager<
74            <TestContext<DB> as TransactionStarter<DB>>::Transaction,
75            <TestContext<DB> as TransactionStarter<DB>>::Connection,
76            Error = DB::Error,
77            Tx = <TestContext<DB> as TransactionStarter<DB>>::Transaction,
78        >,
79{
80    /// Execute the entire operation chain
81    #[allow(dead_code)]
82    pub async fn execute(self) -> Result<TestContext<DB>, DB::Error> {
83        // Execute the setup function
84        self.db.setup(self.setup_fn).await?;
85
86        // Create the context
87        let mut ctx = TestContext::new(self.db);
88
89        // Begin the transaction using explicit types
90        let mut tx = <TestContext<DB> as DBTransactionManager<
91            <TestContext<DB> as TransactionStarter<DB>>::Transaction,
92            <TestContext<DB> as TransactionStarter<DB>>::Connection,
93        >>::begin_transaction(&mut ctx)
94        .await?;
95
96        // Execute the transaction function
97        match (self.transaction_fn)(&mut tx).await {
98            Ok(_) => {
99                // Commit the transaction
100                <TestContext<DB> as DBTransactionManager<
101                    <TestContext<DB> as TransactionStarter<DB>>::Transaction,
102                    <TestContext<DB> as TransactionStarter<DB>>::Connection,
103                >>::commit_transaction(&mut tx)
104                .await?;
105            }
106            Err(e) => {
107                // Rollback the transaction on error
108                let _ = <TestContext<DB> as DBTransactionManager<
109                    <TestContext<DB> as TransactionStarter<DB>>::Transaction,
110                    <TestContext<DB> as TransactionStarter<DB>>::Connection,
111                >>::rollback_transaction(&mut tx)
112                .await;
113                return Err(e);
114            }
115        }
116
117        Ok(ctx)
118    }
119}