1use crate::{
2 AccessMode, ConnectionTrait, DbBackend, DbErr, ExecResult, InnerConnection, IsolationLevel,
3 QueryResult, Statement, StreamTrait, TransactionStream, TransactionTrait, debug_print,
4 error::*,
5};
6#[cfg(feature = "sqlx-dep")]
7use crate::{sqlx_error_to_exec_err, sqlx_error_to_query_err};
8use futures_util::lock::Mutex;
9#[cfg(feature = "sqlx-dep")]
10use sqlx::TransactionManager;
11use std::{future::Future, pin::Pin, sync::Arc};
12use tracing::instrument;
13
14pub struct DatabaseTransaction {
19 conn: Arc<Mutex<InnerConnection>>,
20 backend: DbBackend,
21 open: bool,
22 metric_callback: Option<crate::metric::Callback>,
23}
24
25impl std::fmt::Debug for DatabaseTransaction {
26 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27 write!(f, "DatabaseTransaction")
28 }
29}
30
31impl DatabaseTransaction {
32 #[instrument(level = "trace", skip(metric_callback))]
33 pub(crate) async fn begin(
34 conn: Arc<Mutex<InnerConnection>>,
35 backend: DbBackend,
36 metric_callback: Option<crate::metric::Callback>,
37 isolation_level: Option<IsolationLevel>,
38 access_mode: Option<AccessMode>,
39 ) -> Result<DatabaseTransaction, DbErr> {
40 let res = DatabaseTransaction {
41 conn,
42 backend,
43 open: true,
44 metric_callback,
45 };
46 match *res.conn.lock().await {
47 #[cfg(feature = "sqlx-mysql")]
48 InnerConnection::MySql(ref mut c) => {
49 crate::driver::sqlx_mysql::set_transaction_config(c, isolation_level, access_mode)
51 .await?;
52 <sqlx::MySql as sqlx::Database>::TransactionManager::begin(c, None)
53 .await
54 .map_err(sqlx_error_to_query_err)
55 }
56 #[cfg(feature = "sqlx-postgres")]
57 InnerConnection::Postgres(ref mut c) => {
58 <sqlx::Postgres as sqlx::Database>::TransactionManager::begin(c, None)
59 .await
60 .map_err(sqlx_error_to_query_err)?;
61 crate::driver::sqlx_postgres::set_transaction_config(
63 c,
64 isolation_level,
65 access_mode,
66 )
67 .await
68 }
69 #[cfg(feature = "sqlx-sqlite")]
70 InnerConnection::Sqlite(ref mut c) => {
71 crate::driver::sqlx_sqlite::set_transaction_config(c, isolation_level, access_mode)
73 .await?;
74 <sqlx::Sqlite as sqlx::Database>::TransactionManager::begin(c, None)
75 .await
76 .map_err(sqlx_error_to_query_err)
77 }
78 #[cfg(feature = "mock")]
79 InnerConnection::Mock(ref mut c) => {
80 c.begin();
81 Ok(())
82 }
83 #[cfg(feature = "proxy")]
84 InnerConnection::Proxy(ref mut c) => {
85 c.begin().await;
86 Ok(())
87 }
88 #[allow(unreachable_patterns)]
89 _ => Err(conn_err("Disconnected")),
90 }?;
91 Ok(res)
92 }
93
94 #[instrument(level = "trace", skip(callback))]
97 pub(crate) async fn run<F, T, E>(self, callback: F) -> Result<T, TransactionError<E>>
98 where
99 F: for<'b> FnOnce(
100 &'b DatabaseTransaction,
101 ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'b>>
102 + Send,
103 T: Send,
104 E: std::fmt::Display + std::fmt::Debug + Send,
105 {
106 let res = callback(&self).await.map_err(TransactionError::Transaction);
107 if res.is_ok() {
108 self.commit().await.map_err(TransactionError::Connection)?;
109 } else {
110 self.rollback()
111 .await
112 .map_err(TransactionError::Connection)?;
113 }
114 res
115 }
116
117 #[instrument(level = "trace")]
119 #[allow(unreachable_code, unused_mut)]
120 pub async fn commit(mut self) -> Result<(), DbErr> {
121 match *self.conn.lock().await {
122 #[cfg(feature = "sqlx-mysql")]
123 InnerConnection::MySql(ref mut c) => {
124 <sqlx::MySql as sqlx::Database>::TransactionManager::commit(c)
125 .await
126 .map_err(sqlx_error_to_query_err)
127 }
128 #[cfg(feature = "sqlx-postgres")]
129 InnerConnection::Postgres(ref mut c) => {
130 <sqlx::Postgres as sqlx::Database>::TransactionManager::commit(c)
131 .await
132 .map_err(sqlx_error_to_query_err)
133 }
134 #[cfg(feature = "sqlx-sqlite")]
135 InnerConnection::Sqlite(ref mut c) => {
136 <sqlx::Sqlite as sqlx::Database>::TransactionManager::commit(c)
137 .await
138 .map_err(sqlx_error_to_query_err)
139 }
140 #[cfg(feature = "mock")]
141 InnerConnection::Mock(ref mut c) => {
142 c.commit();
143 Ok(())
144 }
145 #[cfg(feature = "proxy")]
146 InnerConnection::Proxy(ref mut c) => {
147 c.commit().await;
148 Ok(())
149 }
150 #[allow(unreachable_patterns)]
151 _ => Err(conn_err("Disconnected")),
152 }?;
153 self.open = false;
154 Ok(())
155 }
156
157 #[instrument(level = "trace")]
159 #[allow(unreachable_code, unused_mut)]
160 pub async fn rollback(mut self) -> Result<(), DbErr> {
161 match *self.conn.lock().await {
162 #[cfg(feature = "sqlx-mysql")]
163 InnerConnection::MySql(ref mut c) => {
164 <sqlx::MySql as sqlx::Database>::TransactionManager::rollback(c)
165 .await
166 .map_err(sqlx_error_to_query_err)
167 }
168 #[cfg(feature = "sqlx-postgres")]
169 InnerConnection::Postgres(ref mut c) => {
170 <sqlx::Postgres as sqlx::Database>::TransactionManager::rollback(c)
171 .await
172 .map_err(sqlx_error_to_query_err)
173 }
174 #[cfg(feature = "sqlx-sqlite")]
175 InnerConnection::Sqlite(ref mut c) => {
176 <sqlx::Sqlite as sqlx::Database>::TransactionManager::rollback(c)
177 .await
178 .map_err(sqlx_error_to_query_err)
179 }
180 #[cfg(feature = "mock")]
181 InnerConnection::Mock(ref mut c) => {
182 c.rollback();
183 Ok(())
184 }
185 #[cfg(feature = "proxy")]
186 InnerConnection::Proxy(ref mut c) => {
187 c.rollback().await;
188 Ok(())
189 }
190 #[allow(unreachable_patterns)]
191 _ => Err(conn_err("Disconnected")),
192 }?;
193 self.open = false;
194 Ok(())
195 }
196
197 #[instrument(level = "trace")]
199 fn start_rollback(&mut self) -> Result<(), DbErr> {
200 if self.open {
201 if let Some(mut conn) = self.conn.try_lock() {
202 match &mut *conn {
203 #[cfg(feature = "sqlx-mysql")]
204 InnerConnection::MySql(c) => {
205 <sqlx::MySql as sqlx::Database>::TransactionManager::start_rollback(c);
206 }
207 #[cfg(feature = "sqlx-postgres")]
208 InnerConnection::Postgres(c) => {
209 <sqlx::Postgres as sqlx::Database>::TransactionManager::start_rollback(c);
210 }
211 #[cfg(feature = "sqlx-sqlite")]
212 InnerConnection::Sqlite(c) => {
213 <sqlx::Sqlite as sqlx::Database>::TransactionManager::start_rollback(c);
214 }
215 #[cfg(feature = "mock")]
216 InnerConnection::Mock(c) => {
217 c.rollback();
218 }
219 #[cfg(feature = "proxy")]
220 InnerConnection::Proxy(c) => {
221 c.start_rollback();
222 }
223 #[allow(unreachable_patterns)]
224 _ => return Err(conn_err("Disconnected")),
225 }
226 } else {
227 return Err(conn_err("Dropping a locked Transaction"));
229 }
230 }
231 Ok(())
232 }
233}
234
235impl Drop for DatabaseTransaction {
236 fn drop(&mut self) {
237 self.start_rollback().expect("Fail to rollback transaction");
238 }
239}
240
241#[async_trait::async_trait]
242impl ConnectionTrait for DatabaseTransaction {
243 fn get_database_backend(&self) -> DbBackend {
244 self.backend
246 }
247
248 #[instrument(level = "trace")]
249 #[allow(unused_variables)]
250 async fn execute_raw(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
251 debug_print!("{}", stmt);
252
253 match &mut *self.conn.lock().await {
254 #[cfg(feature = "sqlx-mysql")]
255 InnerConnection::MySql(conn) => {
256 let query = crate::driver::sqlx_mysql::sqlx_query(&stmt);
257 let conn: &mut sqlx::MySqlConnection = &mut *conn;
258 crate::metric::metric!(self.metric_callback, &stmt, {
259 query.execute(conn).await.map(Into::into)
260 })
261 .map_err(sqlx_error_to_exec_err)
262 }
263 #[cfg(feature = "sqlx-postgres")]
264 InnerConnection::Postgres(conn) => {
265 let query = crate::driver::sqlx_postgres::sqlx_query(&stmt);
266 let conn: &mut sqlx::PgConnection = &mut *conn;
267 crate::metric::metric!(self.metric_callback, &stmt, {
268 query.execute(conn).await.map(Into::into)
269 })
270 .map_err(sqlx_error_to_exec_err)
271 }
272 #[cfg(feature = "sqlx-sqlite")]
273 InnerConnection::Sqlite(conn) => {
274 let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt);
275 let conn: &mut sqlx::SqliteConnection = &mut *conn;
276 crate::metric::metric!(self.metric_callback, &stmt, {
277 query.execute(conn).await.map(Into::into)
278 })
279 .map_err(sqlx_error_to_exec_err)
280 }
281 #[cfg(feature = "mock")]
282 InnerConnection::Mock(conn) => return conn.execute(stmt),
283 #[cfg(feature = "proxy")]
284 InnerConnection::Proxy(conn) => return conn.execute(stmt).await,
285 #[allow(unreachable_patterns)]
286 _ => Err(conn_err("Disconnected")),
287 }
288 }
289
290 #[instrument(level = "trace")]
291 #[allow(unused_variables)]
292 async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
293 debug_print!("{}", sql);
294
295 match &mut *self.conn.lock().await {
296 #[cfg(feature = "sqlx-mysql")]
297 InnerConnection::MySql(conn) => {
298 let conn: &mut sqlx::MySqlConnection = &mut *conn;
299 sqlx::Executor::execute(conn, sql)
300 .await
301 .map(Into::into)
302 .map_err(sqlx_error_to_exec_err)
303 }
304 #[cfg(feature = "sqlx-postgres")]
305 InnerConnection::Postgres(conn) => {
306 let conn: &mut sqlx::PgConnection = &mut *conn;
307 sqlx::Executor::execute(conn, sql)
308 .await
309 .map(Into::into)
310 .map_err(sqlx_error_to_exec_err)
311 }
312 #[cfg(feature = "sqlx-sqlite")]
313 InnerConnection::Sqlite(conn) => {
314 let conn: &mut sqlx::SqliteConnection = &mut *conn;
315 sqlx::Executor::execute(conn, sql)
316 .await
317 .map(Into::into)
318 .map_err(sqlx_error_to_exec_err)
319 }
320 #[cfg(feature = "mock")]
321 InnerConnection::Mock(conn) => {
322 let db_backend = conn.get_database_backend();
323 let stmt = Statement::from_string(db_backend, sql);
324 conn.execute(stmt)
325 }
326 #[cfg(feature = "proxy")]
327 InnerConnection::Proxy(conn) => {
328 let db_backend = conn.get_database_backend();
329 let stmt = Statement::from_string(db_backend, sql);
330 conn.execute(stmt).await
331 }
332 #[allow(unreachable_patterns)]
333 _ => Err(conn_err("Disconnected")),
334 }
335 }
336
337 #[instrument(level = "trace")]
338 #[allow(unused_variables)]
339 async fn query_one_raw(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
340 debug_print!("{}", stmt);
341
342 match &mut *self.conn.lock().await {
343 #[cfg(feature = "sqlx-mysql")]
344 InnerConnection::MySql(conn) => {
345 let query = crate::driver::sqlx_mysql::sqlx_query(&stmt);
346 let conn: &mut sqlx::MySqlConnection = &mut *conn;
347 crate::metric::metric!(self.metric_callback, &stmt, {
348 crate::sqlx_map_err_ignore_not_found(
349 query.fetch_one(conn).await.map(|row| Some(row.into())),
350 )
351 })
352 }
353 #[cfg(feature = "sqlx-postgres")]
354 InnerConnection::Postgres(conn) => {
355 let query = crate::driver::sqlx_postgres::sqlx_query(&stmt);
356 let conn: &mut sqlx::PgConnection = &mut *conn;
357 crate::metric::metric!(self.metric_callback, &stmt, {
358 crate::sqlx_map_err_ignore_not_found(
359 query.fetch_one(conn).await.map(|row| Some(row.into())),
360 )
361 })
362 }
363 #[cfg(feature = "sqlx-sqlite")]
364 InnerConnection::Sqlite(conn) => {
365 let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt);
366 let conn: &mut sqlx::SqliteConnection = &mut *conn;
367 crate::metric::metric!(self.metric_callback, &stmt, {
368 crate::sqlx_map_err_ignore_not_found(
369 query.fetch_one(conn).await.map(|row| Some(row.into())),
370 )
371 })
372 }
373 #[cfg(feature = "mock")]
374 InnerConnection::Mock(conn) => return conn.query_one(stmt),
375 #[cfg(feature = "proxy")]
376 InnerConnection::Proxy(conn) => return conn.query_one(stmt).await,
377 #[allow(unreachable_patterns)]
378 _ => Err(conn_err("Disconnected")),
379 }
380 }
381
382 #[instrument(level = "trace")]
383 #[allow(unused_variables)]
384 async fn query_all_raw(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
385 debug_print!("{}", stmt);
386
387 match &mut *self.conn.lock().await {
388 #[cfg(feature = "sqlx-mysql")]
389 InnerConnection::MySql(conn) => {
390 let query = crate::driver::sqlx_mysql::sqlx_query(&stmt);
391 let conn: &mut sqlx::MySqlConnection = &mut *conn;
392 crate::metric::metric!(self.metric_callback, &stmt, {
393 query
394 .fetch_all(conn)
395 .await
396 .map(|rows| rows.into_iter().map(|r| r.into()).collect())
397 .map_err(sqlx_error_to_query_err)
398 })
399 }
400 #[cfg(feature = "sqlx-postgres")]
401 InnerConnection::Postgres(conn) => {
402 let query = crate::driver::sqlx_postgres::sqlx_query(&stmt);
403 let conn: &mut sqlx::PgConnection = &mut *conn;
404 crate::metric::metric!(self.metric_callback, &stmt, {
405 query
406 .fetch_all(conn)
407 .await
408 .map(|rows| rows.into_iter().map(|r| r.into()).collect())
409 .map_err(sqlx_error_to_query_err)
410 })
411 }
412 #[cfg(feature = "sqlx-sqlite")]
413 InnerConnection::Sqlite(conn) => {
414 let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt);
415 let conn: &mut sqlx::SqliteConnection = &mut *conn;
416 crate::metric::metric!(self.metric_callback, &stmt, {
417 query
418 .fetch_all(conn)
419 .await
420 .map(|rows| rows.into_iter().map(|r| r.into()).collect())
421 .map_err(sqlx_error_to_query_err)
422 })
423 }
424 #[cfg(feature = "mock")]
425 InnerConnection::Mock(conn) => return conn.query_all(stmt),
426 #[cfg(feature = "proxy")]
427 InnerConnection::Proxy(conn) => return conn.query_all(stmt).await,
428 #[allow(unreachable_patterns)]
429 _ => Err(conn_err("Disconnected")),
430 }
431 }
432}
433
434impl StreamTrait for DatabaseTransaction {
435 type Stream<'a> = TransactionStream<'a>;
436
437 fn get_database_backend(&self) -> DbBackend {
438 self.backend
439 }
440
441 #[instrument(level = "trace")]
442 fn stream_raw<'a>(
443 &'a self,
444 stmt: Statement,
445 ) -> Pin<Box<dyn Future<Output = Result<Self::Stream<'a>, DbErr>> + 'a + Send>> {
446 Box::pin(async move {
447 let conn = self.conn.lock().await;
448 Ok(crate::TransactionStream::build(
449 conn,
450 stmt,
451 self.metric_callback.clone(),
452 ))
453 })
454 }
455}
456
457#[async_trait::async_trait]
458impl TransactionTrait for DatabaseTransaction {
459 type Transaction = DatabaseTransaction;
460
461 #[instrument(level = "trace")]
462 async fn begin(&self) -> Result<DatabaseTransaction, DbErr> {
463 DatabaseTransaction::begin(
464 Arc::clone(&self.conn),
465 self.backend,
466 self.metric_callback.clone(),
467 None,
468 None,
469 )
470 .await
471 }
472
473 #[instrument(level = "trace")]
474 async fn begin_with_config(
475 &self,
476 isolation_level: Option<IsolationLevel>,
477 access_mode: Option<AccessMode>,
478 ) -> Result<DatabaseTransaction, DbErr> {
479 DatabaseTransaction::begin(
480 Arc::clone(&self.conn),
481 self.backend,
482 self.metric_callback.clone(),
483 isolation_level,
484 access_mode,
485 )
486 .await
487 }
488
489 #[instrument(level = "trace", skip(_callback))]
493 async fn transaction<F, T, E>(&self, _callback: F) -> Result<T, TransactionError<E>>
494 where
495 F: for<'c> FnOnce(
496 &'c DatabaseTransaction,
497 ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
498 + Send,
499 T: Send,
500 E: std::fmt::Display + std::fmt::Debug + Send,
501 {
502 let transaction = self.begin().await.map_err(TransactionError::Connection)?;
503 transaction.run(_callback).await
504 }
505
506 #[instrument(level = "trace", skip(_callback))]
510 async fn transaction_with_config<F, T, E>(
511 &self,
512 _callback: F,
513 isolation_level: Option<IsolationLevel>,
514 access_mode: Option<AccessMode>,
515 ) -> Result<T, TransactionError<E>>
516 where
517 F: for<'c> FnOnce(
518 &'c DatabaseTransaction,
519 ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
520 + Send,
521 T: Send,
522 E: std::fmt::Display + std::fmt::Debug + Send,
523 {
524 let transaction = self
525 .begin_with_config(isolation_level, access_mode)
526 .await
527 .map_err(TransactionError::Connection)?;
528 transaction.run(_callback).await
529 }
530}
531
532#[derive(Debug)]
534pub enum TransactionError<E> {
535 Connection(DbErr),
537 Transaction(E),
539}
540
541impl<E> std::fmt::Display for TransactionError<E>
542where
543 E: std::fmt::Display + std::fmt::Debug,
544{
545 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
546 match self {
547 TransactionError::Connection(e) => std::fmt::Display::fmt(e, f),
548 TransactionError::Transaction(e) => std::fmt::Display::fmt(e, f),
549 }
550 }
551}
552
553impl<E> std::error::Error for TransactionError<E> where E: std::fmt::Display + std::fmt::Debug {}
554
555impl<E> From<DbErr> for TransactionError<E>
556where
557 E: std::fmt::Display + std::fmt::Debug,
558{
559 fn from(e: DbErr) -> Self {
560 Self::Connection(e)
561 }
562}