1#![allow(unused_assignments)]
2use crate::{
3 AccessMode, ConnectionTrait, DbBackend, DbErr, ExecResult, InnerConnection, IsolationLevel,
4 QueryResult, Statement, StreamTrait, TransactionSession, TransactionStream, TransactionTrait,
5 debug_print, error::*,
6};
7#[cfg(feature = "sqlx-dep")]
8use crate::{sqlx_error_to_exec_err, sqlx_error_to_query_err};
9#[cfg(feature = "sqlx-dep")]
10use sqlx::TransactionManager;
11use std::sync::Arc;
12use std::sync::Mutex;
13use tracing::instrument;
14
15pub struct DatabaseTransaction {
20 conn: Arc<Mutex<InnerConnection>>,
21 backend: DbBackend,
22 open: bool,
23 metric_callback: Option<crate::metric::Callback>,
24}
25
26impl std::fmt::Debug for DatabaseTransaction {
27 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28 write!(f, "DatabaseTransaction")
29 }
30}
31
32impl DatabaseTransaction {
33 #[instrument(level = "trace", skip(metric_callback))]
34 pub(crate) fn begin(
35 conn: Arc<Mutex<InnerConnection>>,
36 backend: DbBackend,
37 metric_callback: Option<crate::metric::Callback>,
38 isolation_level: Option<IsolationLevel>,
39 access_mode: Option<AccessMode>,
40 ) -> Result<DatabaseTransaction, DbErr> {
41 let res = DatabaseTransaction {
42 conn,
43 backend,
44 open: true,
45 metric_callback,
46 };
47 {
48 #[cfg(not(feature = "sync"))]
49 let conn = &mut *res.conn.lock();
50 #[cfg(feature = "sync")]
51 let conn = &mut *res.conn.lock().map_err(|_| DbErr::MutexPoisonError)?;
52
53 match conn {
54 #[cfg(feature = "sqlx-mysql")]
55 InnerConnection::MySql(c) => {
56 crate::driver::sqlx_mysql::set_transaction_config(
58 c,
59 isolation_level,
60 access_mode,
61 )?;
62 <sqlx::MySql as sqlx::Database>::TransactionManager::begin(c, None)
63 .map_err(sqlx_error_to_query_err)
64 }
65 #[cfg(feature = "sqlx-postgres")]
66 InnerConnection::Postgres(c) => {
67 <sqlx::Postgres as sqlx::Database>::TransactionManager::begin(c, None)
68 .map_err(sqlx_error_to_query_err)?;
69 crate::driver::sqlx_postgres::set_transaction_config(
71 c,
72 isolation_level,
73 access_mode,
74 )
75 }
76 #[cfg(feature = "sqlx-sqlite")]
77 InnerConnection::Sqlite(c) => {
78 crate::driver::sqlx_sqlite::set_transaction_config(
80 c,
81 isolation_level,
82 access_mode,
83 )?;
84 <sqlx::Sqlite as sqlx::Database>::TransactionManager::begin(c, None)
85 .map_err(sqlx_error_to_query_err)
86 }
87 #[cfg(feature = "rusqlite")]
88 InnerConnection::Rusqlite(c) => c.begin(),
89 #[cfg(feature = "mock")]
90 InnerConnection::Mock(c) => {
91 c.begin();
92 Ok(())
93 }
94 #[cfg(feature = "proxy")]
95 InnerConnection::Proxy(c) => {
96 c.begin();
97 Ok(())
98 }
99 #[allow(unreachable_patterns)]
100 _ => Err(conn_err("Disconnected")),
101 }?
102 };
103
104 Ok(res)
105 }
106
107 #[instrument(level = "trace", skip(callback))]
110 pub(crate) fn run<F, T, E>(self, callback: F) -> Result<T, TransactionError<E>>
111 where
112 F: for<'b> FnOnce(&'b DatabaseTransaction) -> Result<T, E>,
113 E: std::fmt::Display + std::fmt::Debug,
114 {
115 let res = callback(&self).map_err(TransactionError::Transaction);
116 if res.is_ok() {
117 self.commit().map_err(TransactionError::Connection)?;
118 } else {
119 self.rollback().map_err(TransactionError::Connection)?;
120 }
121 res
122 }
123
124 #[instrument(level = "trace")]
126 #[allow(unreachable_code, unused_mut)]
127 pub fn commit(mut self) -> Result<(), DbErr> {
128 #[cfg(not(feature = "sync"))]
129 let conn = &mut *self.conn.lock();
130 #[cfg(feature = "sync")]
131 let conn = &mut *self.conn.lock().map_err(|_| DbErr::MutexPoisonError)?;
132
133 match conn {
134 #[cfg(feature = "sqlx-mysql")]
135 InnerConnection::MySql(c) => {
136 <sqlx::MySql as sqlx::Database>::TransactionManager::commit(c)
137 .map_err(sqlx_error_to_query_err)
138 }
139 #[cfg(feature = "sqlx-postgres")]
140 InnerConnection::Postgres(c) => {
141 <sqlx::Postgres as sqlx::Database>::TransactionManager::commit(c)
142 .map_err(sqlx_error_to_query_err)
143 }
144 #[cfg(feature = "sqlx-sqlite")]
145 InnerConnection::Sqlite(c) => {
146 <sqlx::Sqlite as sqlx::Database>::TransactionManager::commit(c)
147 .map_err(sqlx_error_to_query_err)
148 }
149 #[cfg(feature = "rusqlite")]
150 InnerConnection::Rusqlite(c) => c.commit(),
151 #[cfg(feature = "mock")]
152 InnerConnection::Mock(c) => {
153 c.commit();
154 Ok(())
155 }
156 #[cfg(feature = "proxy")]
157 InnerConnection::Proxy(c) => {
158 c.commit();
159 Ok(())
160 }
161 #[allow(unreachable_patterns)]
162 _ => Err(conn_err("Disconnected")),
163 }?;
164 self.open = false; Ok(())
166 }
167
168 #[instrument(level = "trace")]
170 #[allow(unreachable_code, unused_mut)]
171 pub fn rollback(mut self) -> Result<(), DbErr> {
172 #[cfg(not(feature = "sync"))]
173 let conn = &mut *self.conn.lock();
174 #[cfg(feature = "sync")]
175 let conn = &mut *self.conn.lock().map_err(|_| DbErr::MutexPoisonError)?;
176
177 match conn {
178 #[cfg(feature = "sqlx-mysql")]
179 InnerConnection::MySql(c) => {
180 <sqlx::MySql as sqlx::Database>::TransactionManager::rollback(c)
181 .map_err(sqlx_error_to_query_err)
182 }
183 #[cfg(feature = "sqlx-postgres")]
184 InnerConnection::Postgres(c) => {
185 <sqlx::Postgres as sqlx::Database>::TransactionManager::rollback(c)
186 .map_err(sqlx_error_to_query_err)
187 }
188 #[cfg(feature = "sqlx-sqlite")]
189 InnerConnection::Sqlite(c) => {
190 <sqlx::Sqlite as sqlx::Database>::TransactionManager::rollback(c)
191 .map_err(sqlx_error_to_query_err)
192 }
193 #[cfg(feature = "rusqlite")]
194 InnerConnection::Rusqlite(c) => c.rollback(),
195 #[cfg(feature = "mock")]
196 InnerConnection::Mock(c) => {
197 c.rollback();
198 Ok(())
199 }
200 #[cfg(feature = "proxy")]
201 InnerConnection::Proxy(c) => {
202 c.rollback();
203 Ok(())
204 }
205 #[allow(unreachable_patterns)]
206 _ => Err(conn_err("Disconnected")),
207 }?;
208 self.open = false; Ok(())
210 }
211
212 #[instrument(level = "trace")]
214 fn start_rollback(&mut self) -> Result<(), DbErr> {
215 if self.open {
216 if let Some(mut conn) = self.conn.try_lock().ok() {
217 match &mut *conn {
218 #[cfg(feature = "sqlx-mysql")]
219 InnerConnection::MySql(c) => {
220 <sqlx::MySql as sqlx::Database>::TransactionManager::start_rollback(c);
221 }
222 #[cfg(feature = "sqlx-postgres")]
223 InnerConnection::Postgres(c) => {
224 <sqlx::Postgres as sqlx::Database>::TransactionManager::start_rollback(c);
225 }
226 #[cfg(feature = "sqlx-sqlite")]
227 InnerConnection::Sqlite(c) => {
228 <sqlx::Sqlite as sqlx::Database>::TransactionManager::start_rollback(c);
229 }
230 #[cfg(feature = "rusqlite")]
231 InnerConnection::Rusqlite(c) => {
232 c.start_rollback()?;
233 }
234 #[cfg(feature = "mock")]
235 InnerConnection::Mock(c) => {
236 c.rollback();
237 }
238 #[cfg(feature = "proxy")]
239 InnerConnection::Proxy(c) => {
240 c.start_rollback();
241 }
242 #[allow(unreachable_patterns)]
243 _ => return Err(conn_err("Disconnected")),
244 }
245 } else {
246 return Err(conn_err("Dropping a locked Transaction"));
248 }
249 }
250 Ok(())
251 }
252}
253
254impl TransactionSession for DatabaseTransaction {
255 fn commit(self) -> Result<(), DbErr> {
256 self.commit()
257 }
258
259 fn rollback(self) -> Result<(), DbErr> {
260 self.rollback()
261 }
262}
263
264impl Drop for DatabaseTransaction {
265 fn drop(&mut self) {
266 self.start_rollback().expect("Fail to rollback transaction");
267 }
268}
269
270impl ConnectionTrait for DatabaseTransaction {
271 fn get_database_backend(&self) -> DbBackend {
272 self.backend
274 }
275
276 #[instrument(level = "trace")]
277 #[allow(unused_variables)]
278 fn execute_raw(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
279 debug_print!("{}", stmt);
280
281 #[cfg(not(feature = "sync"))]
282 let conn = &mut *self.conn.lock();
283 #[cfg(feature = "sync")]
284 let conn = &mut *self.conn.lock().map_err(|_| DbErr::MutexPoisonError)?;
285
286 match conn {
287 #[cfg(feature = "sqlx-mysql")]
288 InnerConnection::MySql(conn) => {
289 let query = crate::driver::sqlx_mysql::sqlx_query(&stmt);
290 let conn: &mut sqlx::MySqlConnection = &mut *conn;
291 crate::metric::metric!(self.metric_callback, &stmt, {
292 query.execute(conn).map(Into::into)
293 })
294 .map_err(sqlx_error_to_exec_err)
295 }
296 #[cfg(feature = "sqlx-postgres")]
297 InnerConnection::Postgres(conn) => {
298 let query = crate::driver::sqlx_postgres::sqlx_query(&stmt);
299 let conn: &mut sqlx::PgConnection = &mut *conn;
300 crate::metric::metric!(self.metric_callback, &stmt, {
301 query.execute(conn).map(Into::into)
302 })
303 .map_err(sqlx_error_to_exec_err)
304 }
305 #[cfg(feature = "sqlx-sqlite")]
306 InnerConnection::Sqlite(conn) => {
307 let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt);
308 let conn: &mut sqlx::SqliteConnection = &mut *conn;
309 crate::metric::metric!(self.metric_callback, &stmt, {
310 query.execute(conn).map(Into::into)
311 })
312 .map_err(sqlx_error_to_exec_err)
313 }
314 #[cfg(feature = "rusqlite")]
315 InnerConnection::Rusqlite(conn) => conn.execute(stmt, &self.metric_callback),
316 #[cfg(feature = "mock")]
317 InnerConnection::Mock(conn) => return conn.execute(stmt),
318 #[cfg(feature = "proxy")]
319 InnerConnection::Proxy(conn) => return conn.execute(stmt),
320 #[allow(unreachable_patterns)]
321 _ => Err(conn_err("Disconnected")),
322 }
323 }
324
325 #[instrument(level = "trace")]
326 #[allow(unused_variables)]
327 fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
328 debug_print!("{}", sql);
329
330 #[cfg(not(feature = "sync"))]
331 let conn = &mut *self.conn.lock();
332 #[cfg(feature = "sync")]
333 let conn = &mut *self.conn.lock().map_err(|_| DbErr::MutexPoisonError)?;
334
335 match conn {
336 #[cfg(feature = "sqlx-mysql")]
337 InnerConnection::MySql(conn) => {
338 let conn: &mut sqlx::MySqlConnection = &mut *conn;
339 sqlx::Executor::execute(conn, sql)
340 .map(Into::into)
341 .map_err(sqlx_error_to_exec_err)
342 }
343 #[cfg(feature = "sqlx-postgres")]
344 InnerConnection::Postgres(conn) => {
345 let conn: &mut sqlx::PgConnection = &mut *conn;
346 sqlx::Executor::execute(conn, sql)
347 .map(Into::into)
348 .map_err(sqlx_error_to_exec_err)
349 }
350 #[cfg(feature = "sqlx-sqlite")]
351 InnerConnection::Sqlite(conn) => {
352 let conn: &mut sqlx::SqliteConnection = &mut *conn;
353 sqlx::Executor::execute(conn, sql)
354 .map(Into::into)
355 .map_err(sqlx_error_to_exec_err)
356 }
357 #[cfg(feature = "rusqlite")]
358 InnerConnection::Rusqlite(conn) => conn.execute_unprepared(sql),
359 #[cfg(feature = "mock")]
360 InnerConnection::Mock(conn) => {
361 let db_backend = conn.get_database_backend();
362 let stmt = Statement::from_string(db_backend, sql);
363 conn.execute(stmt)
364 }
365 #[cfg(feature = "proxy")]
366 InnerConnection::Proxy(conn) => {
367 let db_backend = conn.get_database_backend();
368 let stmt = Statement::from_string(db_backend, sql);
369 conn.execute(stmt)
370 }
371 #[allow(unreachable_patterns)]
372 _ => Err(conn_err("Disconnected")),
373 }
374 }
375
376 #[instrument(level = "trace")]
377 #[allow(unused_variables)]
378 fn query_one_raw(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
379 debug_print!("{}", stmt);
380
381 #[cfg(not(feature = "sync"))]
382 let conn = &mut *self.conn.lock();
383 #[cfg(feature = "sync")]
384 let conn = &mut *self.conn.lock().map_err(|_| DbErr::MutexPoisonError)?;
385
386 match conn {
387 #[cfg(feature = "sqlx-mysql")]
388 InnerConnection::MySql(conn) => {
389 let query = crate::driver::sqlx_mysql::sqlx_query(&stmt);
390 let conn: &mut sqlx::MySqlConnection = &mut *conn;
391 crate::metric::metric!(self.metric_callback, &stmt, {
392 crate::sqlx_map_err_ignore_not_found(
393 query.fetch_one(conn).map(|row| Some(row.into())),
394 )
395 })
396 }
397 #[cfg(feature = "sqlx-postgres")]
398 InnerConnection::Postgres(conn) => {
399 let query = crate::driver::sqlx_postgres::sqlx_query(&stmt);
400 let conn: &mut sqlx::PgConnection = &mut *conn;
401 crate::metric::metric!(self.metric_callback, &stmt, {
402 crate::sqlx_map_err_ignore_not_found(
403 query.fetch_one(conn).map(|row| Some(row.into())),
404 )
405 })
406 }
407 #[cfg(feature = "sqlx-sqlite")]
408 InnerConnection::Sqlite(conn) => {
409 let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt);
410 let conn: &mut sqlx::SqliteConnection = &mut *conn;
411 crate::metric::metric!(self.metric_callback, &stmt, {
412 crate::sqlx_map_err_ignore_not_found(
413 query.fetch_one(conn).map(|row| Some(row.into())),
414 )
415 })
416 }
417 #[cfg(feature = "rusqlite")]
418 InnerConnection::Rusqlite(conn) => conn.query_one(stmt, &self.metric_callback),
419 #[cfg(feature = "mock")]
420 InnerConnection::Mock(conn) => return conn.query_one(stmt),
421 #[cfg(feature = "proxy")]
422 InnerConnection::Proxy(conn) => return conn.query_one(stmt),
423 #[allow(unreachable_patterns)]
424 _ => Err(conn_err("Disconnected")),
425 }
426 }
427
428 #[instrument(level = "trace")]
429 #[allow(unused_variables)]
430 fn query_all_raw(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
431 debug_print!("{}", stmt);
432
433 #[cfg(not(feature = "sync"))]
434 let conn = &mut *self.conn.lock();
435 #[cfg(feature = "sync")]
436 let conn = &mut *self.conn.lock().map_err(|_| DbErr::MutexPoisonError)?;
437
438 match conn {
439 #[cfg(feature = "sqlx-mysql")]
440 InnerConnection::MySql(conn) => {
441 let query = crate::driver::sqlx_mysql::sqlx_query(&stmt);
442 let conn: &mut sqlx::MySqlConnection = &mut *conn;
443 crate::metric::metric!(self.metric_callback, &stmt, {
444 query
445 .fetch_all(conn)
446 .map(|rows| rows.into_iter().map(|r| r.into()).collect())
447 .map_err(sqlx_error_to_query_err)
448 })
449 }
450 #[cfg(feature = "sqlx-postgres")]
451 InnerConnection::Postgres(conn) => {
452 let query = crate::driver::sqlx_postgres::sqlx_query(&stmt);
453 let conn: &mut sqlx::PgConnection = &mut *conn;
454 crate::metric::metric!(self.metric_callback, &stmt, {
455 query
456 .fetch_all(conn)
457 .map(|rows| rows.into_iter().map(|r| r.into()).collect())
458 .map_err(sqlx_error_to_query_err)
459 })
460 }
461 #[cfg(feature = "sqlx-sqlite")]
462 InnerConnection::Sqlite(conn) => {
463 let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt);
464 let conn: &mut sqlx::SqliteConnection = &mut *conn;
465 crate::metric::metric!(self.metric_callback, &stmt, {
466 query
467 .fetch_all(conn)
468 .map(|rows| rows.into_iter().map(|r| r.into()).collect())
469 .map_err(sqlx_error_to_query_err)
470 })
471 }
472 #[cfg(feature = "rusqlite")]
473 InnerConnection::Rusqlite(conn) => conn.query_all(stmt, &self.metric_callback),
474 #[cfg(feature = "mock")]
475 InnerConnection::Mock(conn) => return conn.query_all(stmt),
476 #[cfg(feature = "proxy")]
477 InnerConnection::Proxy(conn) => return conn.query_all(stmt),
478 #[allow(unreachable_patterns)]
479 _ => Err(conn_err("Disconnected")),
480 }
481 }
482}
483
484impl StreamTrait for DatabaseTransaction {
485 type Stream<'a> = TransactionStream<'a>;
486
487 fn get_database_backend(&self) -> DbBackend {
488 self.backend
489 }
490
491 #[instrument(level = "trace")]
492 fn stream_raw<'a>(&'a self, stmt: Statement) -> Result<Self::Stream<'a>, DbErr> {
493 ({
494 #[cfg(not(feature = "sync"))]
495 let conn = self.conn.lock();
496 #[cfg(feature = "sync")]
497 let conn = self.conn.lock().map_err(|_| DbErr::MutexPoisonError)?;
498 Ok(crate::TransactionStream::build(
499 conn,
500 stmt,
501 self.metric_callback.clone(),
502 ))
503 })
504 }
505}
506
507impl TransactionTrait for DatabaseTransaction {
508 type Transaction = DatabaseTransaction;
509
510 #[instrument(level = "trace")]
511 fn begin(&self) -> Result<DatabaseTransaction, DbErr> {
512 DatabaseTransaction::begin(
513 Arc::clone(&self.conn),
514 self.backend,
515 self.metric_callback.clone(),
516 None,
517 None,
518 )
519 }
520
521 #[instrument(level = "trace")]
522 fn begin_with_config(
523 &self,
524 isolation_level: Option<IsolationLevel>,
525 access_mode: Option<AccessMode>,
526 ) -> Result<DatabaseTransaction, DbErr> {
527 DatabaseTransaction::begin(
528 Arc::clone(&self.conn),
529 self.backend,
530 self.metric_callback.clone(),
531 isolation_level,
532 access_mode,
533 )
534 }
535
536 #[instrument(level = "trace", skip(_callback))]
540 fn transaction<F, T, E>(&self, _callback: F) -> Result<T, TransactionError<E>>
541 where
542 F: for<'c> FnOnce(&'c DatabaseTransaction) -> Result<T, E>,
543 E: std::fmt::Display + std::fmt::Debug,
544 {
545 let transaction = self.begin().map_err(TransactionError::Connection)?;
546 transaction.run(_callback)
547 }
548
549 #[instrument(level = "trace", skip(_callback))]
553 fn transaction_with_config<F, T, E>(
554 &self,
555 _callback: F,
556 isolation_level: Option<IsolationLevel>,
557 access_mode: Option<AccessMode>,
558 ) -> Result<T, TransactionError<E>>
559 where
560 F: for<'c> FnOnce(&'c DatabaseTransaction) -> Result<T, E>,
561 E: std::fmt::Display + std::fmt::Debug,
562 {
563 let transaction = self
564 .begin_with_config(isolation_level, access_mode)
565 .map_err(TransactionError::Connection)?;
566 transaction.run(_callback)
567 }
568}
569
570#[derive(Debug)]
572pub enum TransactionError<E> {
573 Connection(DbErr),
575 Transaction(E),
577}
578
579impl<E> std::fmt::Display for TransactionError<E>
580where
581 E: std::fmt::Display + std::fmt::Debug,
582{
583 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
584 match self {
585 TransactionError::Connection(e) => std::fmt::Display::fmt(e, f),
586 TransactionError::Transaction(e) => std::fmt::Display::fmt(e, f),
587 }
588 }
589}
590
591impl<E> std::error::Error for TransactionError<E> where E: std::fmt::Display + std::fmt::Debug {}
592
593impl<E> From<DbErr> for TransactionError<E>
594where
595 E: std::fmt::Display + std::fmt::Debug,
596{
597 fn from(e: DbErr) -> Self {
598 Self::Connection(e)
599 }
600}