Skip to main content

sea_orm/database/
executor.rs

1use super::transaction::run_async_transaction_callback;
2use crate::{
3    AccessMode, ConnectionTrait, DatabaseConnection, DatabaseTransaction, DbBackend, DbErr,
4    ExecResult, IsolationLevel, QueryResult, Statement, TransactionError, TransactionOptions,
5    TransactionTrait,
6};
7use crate::{Schema, SchemaBuilder};
8use std::future::Future;
9use std::pin::Pin;
10
11/// A wrapper that holds either a reference to a [`DatabaseConnection`] or [`DatabaseTransaction`],
12/// or an owned [`DatabaseTransaction`].
13#[derive(Debug)]
14pub enum DatabaseExecutor<'c> {
15    /// A reference to a database connection
16    Connection(&'c DatabaseConnection),
17    /// A reference to a database transaction
18    Transaction(&'c DatabaseTransaction),
19    /// An owned database transaction (used by migration's `SchemaManager::begin()`)
20    OwnedTransaction(DatabaseTransaction),
21}
22
23impl<'c> From<&'c DatabaseConnection> for DatabaseExecutor<'c> {
24    fn from(conn: &'c DatabaseConnection) -> Self {
25        Self::Connection(conn)
26    }
27}
28
29impl<'c> From<&'c DatabaseTransaction> for DatabaseExecutor<'c> {
30    fn from(trans: &'c DatabaseTransaction) -> Self {
31        Self::Transaction(trans)
32    }
33}
34
35#[async_trait::async_trait]
36impl ConnectionTrait for DatabaseExecutor<'_> {
37    fn get_database_backend(&self) -> DbBackend {
38        match self {
39            DatabaseExecutor::Connection(conn) => conn.get_database_backend(),
40            DatabaseExecutor::Transaction(trans) => trans.get_database_backend(),
41            DatabaseExecutor::OwnedTransaction(trans) => trans.get_database_backend(),
42        }
43    }
44
45    async fn execute_raw(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
46        match self {
47            DatabaseExecutor::Connection(conn) => conn.execute_raw(stmt).await,
48            DatabaseExecutor::Transaction(trans) => trans.execute_raw(stmt).await,
49            DatabaseExecutor::OwnedTransaction(trans) => trans.execute_raw(stmt).await,
50        }
51    }
52
53    async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
54        match self {
55            DatabaseExecutor::Connection(conn) => conn.execute_unprepared(sql).await,
56            DatabaseExecutor::Transaction(trans) => trans.execute_unprepared(sql).await,
57            DatabaseExecutor::OwnedTransaction(trans) => trans.execute_unprepared(sql).await,
58        }
59    }
60
61    async fn query_one_raw(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
62        match self {
63            DatabaseExecutor::Connection(conn) => conn.query_one_raw(stmt).await,
64            DatabaseExecutor::Transaction(trans) => trans.query_one_raw(stmt).await,
65            DatabaseExecutor::OwnedTransaction(trans) => trans.query_one_raw(stmt).await,
66        }
67    }
68
69    async fn query_all_raw(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
70        match self {
71            DatabaseExecutor::Connection(conn) => conn.query_all_raw(stmt).await,
72            DatabaseExecutor::Transaction(trans) => trans.query_all_raw(stmt).await,
73            DatabaseExecutor::OwnedTransaction(trans) => trans.query_all_raw(stmt).await,
74        }
75    }
76}
77
78#[async_trait::async_trait]
79impl TransactionTrait for DatabaseExecutor<'_> {
80    type Transaction = DatabaseTransaction;
81
82    async fn begin(&self) -> Result<DatabaseTransaction, DbErr> {
83        match self {
84            DatabaseExecutor::Connection(conn) => conn.begin().await,
85            DatabaseExecutor::Transaction(trans) => trans.begin().await,
86            DatabaseExecutor::OwnedTransaction(trans) => trans.begin().await,
87        }
88    }
89
90    async fn begin_with_config(
91        &self,
92        isolation_level: Option<IsolationLevel>,
93        access_mode: Option<AccessMode>,
94    ) -> Result<DatabaseTransaction, DbErr> {
95        match self {
96            DatabaseExecutor::Connection(conn) => {
97                conn.begin_with_config(isolation_level, access_mode).await
98            }
99            DatabaseExecutor::Transaction(trans) => {
100                trans.begin_with_config(isolation_level, access_mode).await
101            }
102            DatabaseExecutor::OwnedTransaction(trans) => {
103                trans.begin_with_config(isolation_level, access_mode).await
104            }
105        }
106    }
107
108    async fn begin_with_options(
109        &self,
110        options: TransactionOptions,
111    ) -> Result<DatabaseTransaction, DbErr> {
112        match self {
113            DatabaseExecutor::Connection(conn) => conn.begin_with_options(options).await,
114            DatabaseExecutor::Transaction(trans) => trans.begin_with_options(options).await,
115            DatabaseExecutor::OwnedTransaction(trans) => trans.begin_with_options(options).await,
116        }
117    }
118
119    async fn transaction<F, T, E>(&self, callback: F) -> Result<T, TransactionError<E>>
120    where
121        F: for<'c> FnOnce(
122                &'c DatabaseTransaction,
123            ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
124            + Send,
125        T: Send,
126        E: std::fmt::Display + std::fmt::Debug + Send,
127    {
128        match self {
129            DatabaseExecutor::Connection(conn) => conn.transaction(callback).await,
130            DatabaseExecutor::Transaction(trans) => trans.transaction(callback).await,
131            DatabaseExecutor::OwnedTransaction(trans) => trans.transaction(callback).await,
132        }
133    }
134
135    async fn transaction_with_config<F, T, E>(
136        &self,
137        callback: F,
138        isolation_level: Option<IsolationLevel>,
139        access_mode: Option<AccessMode>,
140    ) -> Result<T, TransactionError<E>>
141    where
142        F: for<'c> FnOnce(
143                &'c DatabaseTransaction,
144            ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
145            + Send,
146        T: Send,
147        E: std::fmt::Display + std::fmt::Debug + Send,
148    {
149        match self {
150            DatabaseExecutor::Connection(conn) => {
151                conn.transaction_with_config(callback, isolation_level, access_mode)
152                    .await
153            }
154            DatabaseExecutor::Transaction(trans) => {
155                trans
156                    .transaction_with_config(callback, isolation_level, access_mode)
157                    .await
158            }
159            DatabaseExecutor::OwnedTransaction(trans) => {
160                trans
161                    .transaction_with_config(callback, isolation_level, access_mode)
162                    .await
163            }
164        }
165    }
166}
167
168/// A trait for converting into [`DatabaseExecutor`]
169pub trait IntoDatabaseExecutor<'c>: Send
170where
171    Self: 'c,
172{
173    /// Convert into a [`DatabaseExecutor`]
174    fn into_database_executor(self) -> DatabaseExecutor<'c>;
175}
176
177impl<'c> IntoDatabaseExecutor<'c> for DatabaseExecutor<'c> {
178    fn into_database_executor(self) -> DatabaseExecutor<'c> {
179        self
180    }
181}
182
183impl<'c> IntoDatabaseExecutor<'c> for &'c DatabaseConnection {
184    fn into_database_executor(self) -> DatabaseExecutor<'c> {
185        DatabaseExecutor::Connection(self)
186    }
187}
188
189impl<'c> IntoDatabaseExecutor<'c> for &'c DatabaseTransaction {
190    fn into_database_executor(self) -> DatabaseExecutor<'c> {
191        DatabaseExecutor::Transaction(self)
192    }
193}
194
195impl IntoDatabaseExecutor<'static> for DatabaseTransaction {
196    fn into_database_executor(self) -> DatabaseExecutor<'static> {
197        DatabaseExecutor::OwnedTransaction(self)
198    }
199}
200
201impl DatabaseExecutor<'_> {
202    /// Execute the function inside a transaction.
203    /// If the function returns an error, the transaction will be rolled back.
204    /// Otherwise, the transaction will be committed.
205    pub async fn transaction_async<F, T, E>(&self, callback: F) -> Result<T, TransactionError<E>>
206    where
207        F: for<'c> AsyncFnOnce(&'c DatabaseTransaction) -> Result<T, E> + Send,
208        T: Send,
209        E: std::fmt::Display + std::fmt::Debug + Send,
210    {
211        let transaction = self.begin().await.map_err(TransactionError::Connection)?;
212        run_async_transaction_callback(transaction, callback).await
213    }
214
215    /// Execute the function inside a transaction with isolation level and/or access mode.
216    /// If the function returns an error, the transaction will be rolled back.
217    /// Otherwise, the transaction will be committed.
218    pub async fn transaction_with_config_async<F, T, E>(
219        &self,
220        callback: F,
221        isolation_level: Option<IsolationLevel>,
222        access_mode: Option<AccessMode>,
223    ) -> Result<T, TransactionError<E>>
224    where
225        F: for<'c> AsyncFnOnce(&'c DatabaseTransaction) -> Result<T, E> + Send,
226        T: Send,
227        E: std::fmt::Display + std::fmt::Debug + Send,
228    {
229        let transaction = self
230            .begin_with_config(isolation_level, access_mode)
231            .await
232            .map_err(TransactionError::Connection)?;
233        run_async_transaction_callback(transaction, callback).await
234    }
235
236    /// Returns `true` if this executor is backed by a transaction (borrowed or owned).
237    pub fn is_transaction(&self) -> bool {
238        matches!(
239            self,
240            DatabaseExecutor::Transaction(_) | DatabaseExecutor::OwnedTransaction(_)
241        )
242    }
243
244    /// Creates a [`SchemaBuilder`] for this backend
245    pub fn get_schema_builder(&self) -> SchemaBuilder {
246        Schema::new(self.get_database_backend()).builder()
247    }
248
249    #[cfg(feature = "entity-registry")]
250    #[cfg_attr(docsrs, doc(cfg(feature = "entity-registry")))]
251    /// Builds a schema for all the entities in the given module
252    pub fn get_schema_registry(&self, prefix: &str) -> SchemaBuilder {
253        let schema = Schema::new(self.get_database_backend());
254        crate::EntityRegistry::build_schema(schema, prefix)
255    }
256}