1use crate::{
2 debug_print, error::*, AccessMode, ConnectionTrait, DbBackend, DbErr, ExecResult,
3 InnerConnection, IsolationLevel, QueryResult, Statement, StreamTrait, TransactionStream,
4 TransactionTrait,
5};
6#[cfg(feature = "sqlx-dep")]
7use crate::{sqlx_error_to_exec_err, sqlx_error_to_query_err};
8use futures_util::lock::Mutex;
9#[cfg(feature = "sqlx-dep")]
10use sqlx::TransactionManager;
11use std::{future::Future, pin::Pin, sync::Arc};
12use tracing::instrument;
13
14pub struct DatabaseTransaction {
18 conn: Arc<Mutex<InnerConnection>>,
19 backend: DbBackend,
20 open: bool,
21 metric_callback: Option<crate::metric::Callback>,
22}
23
24impl std::fmt::Debug for DatabaseTransaction {
25 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26 write!(f, "DatabaseTransaction")
27 }
28}
29
30impl DatabaseTransaction {
31 #[instrument(level = "trace", skip(metric_callback))]
32 pub(crate) async fn begin(
33 conn: Arc<Mutex<InnerConnection>>,
34 backend: DbBackend,
35 metric_callback: Option<crate::metric::Callback>,
36 isolation_level: Option<IsolationLevel>,
37 access_mode: Option<AccessMode>,
38 ) -> Result<DatabaseTransaction, DbErr> {
39 let res = DatabaseTransaction {
40 conn,
41 backend,
42 open: true,
43 metric_callback,
44 };
45 match *res.conn.lock().await {
46 #[cfg(feature = "sqlx-mysql")]
47 InnerConnection::MySql(ref mut c) => {
48 crate::driver::sqlx_mysql::set_transaction_config(c, isolation_level, access_mode)
50 .await?;
51 <sqlx::MySql as sqlx::Database>::TransactionManager::begin(c)
52 .await
53 .map_err(sqlx_error_to_query_err)
54 }
55 #[cfg(feature = "sqlx-postgres")]
56 InnerConnection::Postgres(ref mut c) => {
57 <sqlx::Postgres as sqlx::Database>::TransactionManager::begin(c)
58 .await
59 .map_err(sqlx_error_to_query_err)?;
60 crate::driver::sqlx_postgres::set_transaction_config(
62 c,
63 isolation_level,
64 access_mode,
65 )
66 .await
67 }
68 #[cfg(feature = "sqlx-sqlite")]
69 InnerConnection::Sqlite(ref mut c) => {
70 crate::driver::sqlx_sqlite::set_transaction_config(c, isolation_level, access_mode)
72 .await?;
73 <sqlx::Sqlite as sqlx::Database>::TransactionManager::begin(c)
74 .await
75 .map_err(sqlx_error_to_query_err)
76 }
77 #[cfg(feature = "mock")]
78 InnerConnection::Mock(ref mut c) => {
79 c.begin();
80 Ok(())
81 }
82 #[allow(unreachable_patterns)]
83 _ => Err(conn_err("Disconnected")),
84 }?;
85 Ok(res)
86 }
87
88 #[instrument(level = "trace", skip(callback))]
91 pub(crate) async fn run<F, T, E>(self, callback: F) -> Result<T, TransactionError<E>>
92 where
93 F: for<'b> FnOnce(
94 &'b DatabaseTransaction,
95 ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'b>>
96 + Send,
97 T: Send,
98 E: std::error::Error + Send,
99 {
100 let res = callback(&self).await.map_err(TransactionError::Transaction);
101 if res.is_ok() {
102 self.commit().await.map_err(TransactionError::Connection)?;
103 } else {
104 self.rollback()
105 .await
106 .map_err(TransactionError::Connection)?;
107 }
108 res
109 }
110
111 #[instrument(level = "trace")]
113 #[allow(unreachable_code, unused_mut)]
114 pub async fn commit(mut self) -> Result<(), DbErr> {
115 match *self.conn.lock().await {
116 #[cfg(feature = "sqlx-mysql")]
117 InnerConnection::MySql(ref mut c) => {
118 <sqlx::MySql as sqlx::Database>::TransactionManager::commit(c)
119 .await
120 .map_err(sqlx_error_to_query_err)
121 }
122 #[cfg(feature = "sqlx-postgres")]
123 InnerConnection::Postgres(ref mut c) => {
124 <sqlx::Postgres as sqlx::Database>::TransactionManager::commit(c)
125 .await
126 .map_err(sqlx_error_to_query_err)
127 }
128 #[cfg(feature = "sqlx-sqlite")]
129 InnerConnection::Sqlite(ref mut c) => {
130 <sqlx::Sqlite as sqlx::Database>::TransactionManager::commit(c)
131 .await
132 .map_err(sqlx_error_to_query_err)
133 }
134 #[cfg(feature = "mock")]
135 InnerConnection::Mock(ref mut c) => {
136 c.commit();
137 Ok(())
138 }
139 #[allow(unreachable_patterns)]
140 _ => Err(conn_err("Disconnected")),
141 }?;
142 self.open = false;
143 Ok(())
144 }
145
146 #[instrument(level = "trace")]
148 #[allow(unreachable_code, unused_mut)]
149 pub async fn rollback(mut self) -> Result<(), DbErr> {
150 match *self.conn.lock().await {
151 #[cfg(feature = "sqlx-mysql")]
152 InnerConnection::MySql(ref mut c) => {
153 <sqlx::MySql as sqlx::Database>::TransactionManager::rollback(c)
154 .await
155 .map_err(sqlx_error_to_query_err)
156 }
157 #[cfg(feature = "sqlx-postgres")]
158 InnerConnection::Postgres(ref mut c) => {
159 <sqlx::Postgres as sqlx::Database>::TransactionManager::rollback(c)
160 .await
161 .map_err(sqlx_error_to_query_err)
162 }
163 #[cfg(feature = "sqlx-sqlite")]
164 InnerConnection::Sqlite(ref mut c) => {
165 <sqlx::Sqlite as sqlx::Database>::TransactionManager::rollback(c)
166 .await
167 .map_err(sqlx_error_to_query_err)
168 }
169 #[cfg(feature = "mock")]
170 InnerConnection::Mock(ref mut c) => {
171 c.rollback();
172 Ok(())
173 }
174 #[allow(unreachable_patterns)]
175 _ => Err(conn_err("Disconnected")),
176 }?;
177 self.open = false;
178 Ok(())
179 }
180
181 #[instrument(level = "trace")]
183 fn start_rollback(&mut self) -> Result<(), DbErr> {
184 if self.open {
185 if let Some(mut conn) = self.conn.try_lock() {
186 match &mut *conn {
187 #[cfg(feature = "sqlx-mysql")]
188 InnerConnection::MySql(c) => {
189 <sqlx::MySql as sqlx::Database>::TransactionManager::start_rollback(c);
190 }
191 #[cfg(feature = "sqlx-postgres")]
192 InnerConnection::Postgres(c) => {
193 <sqlx::Postgres as sqlx::Database>::TransactionManager::start_rollback(c);
194 }
195 #[cfg(feature = "sqlx-sqlite")]
196 InnerConnection::Sqlite(c) => {
197 <sqlx::Sqlite as sqlx::Database>::TransactionManager::start_rollback(c);
198 }
199 #[cfg(feature = "mock")]
200 InnerConnection::Mock(c) => {
201 c.rollback();
202 }
203 #[allow(unreachable_patterns)]
204 _ => return Err(conn_err("Disconnected")),
205 }
206 } else {
207 return Err(conn_err("Dropping a locked Transaction"));
209 }
210 }
211 Ok(())
212 }
213}
214
215impl Drop for DatabaseTransaction {
216 fn drop(&mut self) {
217 self.start_rollback().expect("Fail to rollback transaction");
218 }
219}
220
221#[async_trait::async_trait]
222impl ConnectionTrait for DatabaseTransaction {
223 fn get_database_backend(&self) -> DbBackend {
224 self.backend
226 }
227
228 #[instrument(level = "trace")]
229 #[allow(unused_variables)]
230 async fn execute(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
231 debug_print!("{}", stmt);
232
233 match &mut *self.conn.lock().await {
234 #[cfg(feature = "sqlx-mysql")]
235 InnerConnection::MySql(conn) => {
236 let query = crate::driver::sqlx_mysql::sqlx_query(&stmt);
237 let conn: &mut sqlx::MySqlConnection = &mut *conn;
238 crate::metric::metric!(self.metric_callback, &stmt, {
239 query.execute(conn).await.map(Into::into)
240 })
241 .map_err(sqlx_error_to_exec_err)
242 }
243 #[cfg(feature = "sqlx-postgres")]
244 InnerConnection::Postgres(conn) => {
245 let query = crate::driver::sqlx_postgres::sqlx_query(&stmt);
246 let conn: &mut sqlx::PgConnection = &mut *conn;
247 crate::metric::metric!(self.metric_callback, &stmt, {
248 query.execute(conn).await.map(Into::into)
249 })
250 .map_err(sqlx_error_to_exec_err)
251 }
252 #[cfg(feature = "sqlx-sqlite")]
253 InnerConnection::Sqlite(conn) => {
254 let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt);
255 let conn: &mut sqlx::SqliteConnection = &mut *conn;
256 crate::metric::metric!(self.metric_callback, &stmt, {
257 query.execute(conn).await.map(Into::into)
258 })
259 .map_err(sqlx_error_to_exec_err)
260 }
261 #[cfg(feature = "mock")]
262 InnerConnection::Mock(conn) => return conn.execute(stmt),
263 #[allow(unreachable_patterns)]
264 _ => Err(conn_err("Disconnected")),
265 }
266 }
267
268 #[instrument(level = "trace")]
269 #[allow(unused_variables)]
270 async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
271 debug_print!("{}", sql);
272
273 match &mut *self.conn.lock().await {
274 #[cfg(feature = "sqlx-mysql")]
275 InnerConnection::MySql(conn) => {
276 let conn: &mut sqlx::MySqlConnection = &mut *conn;
277 sqlx::Executor::execute(conn, sql)
278 .await
279 .map(Into::into)
280 .map_err(sqlx_error_to_exec_err)
281 }
282 #[cfg(feature = "sqlx-postgres")]
283 InnerConnection::Postgres(conn) => {
284 let conn: &mut sqlx::PgConnection = &mut *conn;
285 sqlx::Executor::execute(conn, sql)
286 .await
287 .map(Into::into)
288 .map_err(sqlx_error_to_exec_err)
289 }
290 #[cfg(feature = "sqlx-sqlite")]
291 InnerConnection::Sqlite(conn) => {
292 let conn: &mut sqlx::SqliteConnection = &mut *conn;
293 sqlx::Executor::execute(conn, sql)
294 .await
295 .map(Into::into)
296 .map_err(sqlx_error_to_exec_err)
297 }
298 #[cfg(feature = "mock")]
299 InnerConnection::Mock(conn) => {
300 let db_backend = conn.get_database_backend();
301 let stmt = Statement::from_string(db_backend, sql);
302 conn.execute(stmt)
303 }
304 #[allow(unreachable_patterns)]
305 _ => Err(conn_err("Disconnected")),
306 }
307 }
308
309 #[instrument(level = "trace")]
310 #[allow(unused_variables)]
311 async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
312 debug_print!("{}", stmt);
313
314 match &mut *self.conn.lock().await {
315 #[cfg(feature = "sqlx-mysql")]
316 InnerConnection::MySql(conn) => {
317 let query = crate::driver::sqlx_mysql::sqlx_query(&stmt);
318 let conn: &mut sqlx::MySqlConnection = &mut *conn;
319 crate::metric::metric!(self.metric_callback, &stmt, {
320 crate::sqlx_map_err_ignore_not_found(
321 query.fetch_one(conn).await.map(|row| Some(row.into())),
322 )
323 })
324 }
325 #[cfg(feature = "sqlx-postgres")]
326 InnerConnection::Postgres(conn) => {
327 let query = crate::driver::sqlx_postgres::sqlx_query(&stmt);
328 let conn: &mut sqlx::PgConnection = &mut *conn;
329 crate::metric::metric!(self.metric_callback, &stmt, {
330 crate::sqlx_map_err_ignore_not_found(
331 query.fetch_one(conn).await.map(|row| Some(row.into())),
332 )
333 })
334 }
335 #[cfg(feature = "sqlx-sqlite")]
336 InnerConnection::Sqlite(conn) => {
337 let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt);
338 let conn: &mut sqlx::SqliteConnection = &mut *conn;
339 crate::metric::metric!(self.metric_callback, &stmt, {
340 crate::sqlx_map_err_ignore_not_found(
341 query.fetch_one(conn).await.map(|row| Some(row.into())),
342 )
343 })
344 }
345 #[cfg(feature = "mock")]
346 InnerConnection::Mock(conn) => return conn.query_one(stmt),
347 #[allow(unreachable_patterns)]
348 _ => Err(conn_err("Disconnected")),
349 }
350 }
351
352 #[instrument(level = "trace")]
353 #[allow(unused_variables)]
354 async fn query_all(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
355 debug_print!("{}", stmt);
356
357 match &mut *self.conn.lock().await {
358 #[cfg(feature = "sqlx-mysql")]
359 InnerConnection::MySql(conn) => {
360 let query = crate::driver::sqlx_mysql::sqlx_query(&stmt);
361 let conn: &mut sqlx::MySqlConnection = &mut *conn;
362 crate::metric::metric!(self.metric_callback, &stmt, {
363 query
364 .fetch_all(conn)
365 .await
366 .map(|rows| rows.into_iter().map(|r| r.into()).collect())
367 .map_err(sqlx_error_to_query_err)
368 })
369 }
370 #[cfg(feature = "sqlx-postgres")]
371 InnerConnection::Postgres(conn) => {
372 let query = crate::driver::sqlx_postgres::sqlx_query(&stmt);
373 let conn: &mut sqlx::PgConnection = &mut *conn;
374 crate::metric::metric!(self.metric_callback, &stmt, {
375 query
376 .fetch_all(conn)
377 .await
378 .map(|rows| rows.into_iter().map(|r| r.into()).collect())
379 .map_err(sqlx_error_to_query_err)
380 })
381 }
382 #[cfg(feature = "sqlx-sqlite")]
383 InnerConnection::Sqlite(conn) => {
384 let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt);
385 let conn: &mut sqlx::SqliteConnection = &mut *conn;
386 crate::metric::metric!(self.metric_callback, &stmt, {
387 query
388 .fetch_all(conn)
389 .await
390 .map(|rows| rows.into_iter().map(|r| r.into()).collect())
391 .map_err(sqlx_error_to_query_err)
392 })
393 }
394 #[cfg(feature = "mock")]
395 InnerConnection::Mock(conn) => return conn.query_all(stmt),
396 #[allow(unreachable_patterns)]
397 _ => Err(conn_err("Disconnected")),
398 }
399 }
400}
401
402impl StreamTrait for DatabaseTransaction {
403 type Stream<'a> = TransactionStream<'a>;
404
405 #[instrument(level = "trace")]
406 fn stream<'a>(
407 &'a self,
408 stmt: Statement,
409 ) -> Pin<Box<dyn Future<Output = Result<Self::Stream<'a>, DbErr>> + 'a + Send>> {
410 Box::pin(async move {
411 let conn = self.conn.lock().await;
412 Ok(crate::TransactionStream::build(
413 conn,
414 stmt,
415 self.metric_callback.clone(),
416 ))
417 })
418 }
419}
420
421#[async_trait::async_trait]
422impl TransactionTrait for DatabaseTransaction {
423 #[instrument(level = "trace")]
424 async fn begin(&self) -> Result<DatabaseTransaction, DbErr> {
425 DatabaseTransaction::begin(
426 Arc::clone(&self.conn),
427 self.backend,
428 self.metric_callback.clone(),
429 None,
430 None,
431 )
432 .await
433 }
434
435 #[instrument(level = "trace")]
436 async fn begin_with_config(
437 &self,
438 isolation_level: Option<IsolationLevel>,
439 access_mode: Option<AccessMode>,
440 ) -> Result<DatabaseTransaction, DbErr> {
441 DatabaseTransaction::begin(
442 Arc::clone(&self.conn),
443 self.backend,
444 self.metric_callback.clone(),
445 isolation_level,
446 access_mode,
447 )
448 .await
449 }
450
451 #[instrument(level = "trace", skip(_callback))]
454 async fn transaction<F, T, E>(&self, _callback: F) -> Result<T, TransactionError<E>>
455 where
456 F: for<'c> FnOnce(
457 &'c DatabaseTransaction,
458 ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
459 + Send,
460 T: Send,
461 E: std::error::Error + Send,
462 {
463 let transaction = self.begin().await.map_err(TransactionError::Connection)?;
464 transaction.run(_callback).await
465 }
466
467 #[instrument(level = "trace", skip(_callback))]
470 async fn transaction_with_config<F, T, E>(
471 &self,
472 _callback: F,
473 isolation_level: Option<IsolationLevel>,
474 access_mode: Option<AccessMode>,
475 ) -> Result<T, TransactionError<E>>
476 where
477 F: for<'c> FnOnce(
478 &'c DatabaseTransaction,
479 ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
480 + Send,
481 T: Send,
482 E: std::error::Error + Send,
483 {
484 let transaction = self
485 .begin_with_config(isolation_level, access_mode)
486 .await
487 .map_err(TransactionError::Connection)?;
488 transaction.run(_callback).await
489 }
490}
491
492#[derive(Debug)]
494pub enum TransactionError<E>
495where
496 E: std::error::Error,
497{
498 Connection(DbErr),
500 Transaction(E),
502}
503
504impl<E> std::fmt::Display for TransactionError<E>
505where
506 E: std::error::Error,
507{
508 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
509 match self {
510 TransactionError::Connection(e) => std::fmt::Display::fmt(e, f),
511 TransactionError::Transaction(e) => std::fmt::Display::fmt(e, f),
512 }
513 }
514}
515
516impl<E> std::error::Error for TransactionError<E> where E: std::error::Error {}
517
518impl<E> From<DbErr> for TransactionError<E>
519where
520 E: std::error::Error,
521{
522 fn from(e: DbErr) -> Self {
523 Self::Connection(e)
524 }
525}