1#![allow(unused_assignments)]
2use crate::{
3 AccessMode, ConnectionTrait, DbBackend, DbErr, ExecResult, InnerConnection, IsolationLevel,
4 QueryResult, Statement, StreamTrait, TransactionSession, TransactionStream, TransactionTrait,
5 debug_print, error::*,
6};
7#[cfg(feature = "sqlx-dep")]
8use crate::{sqlx_error_to_exec_err, sqlx_error_to_query_err};
9#[cfg(feature = "sqlx-dep")]
10use sqlx::TransactionManager;
11use std::sync::Arc;
12use std::sync::Mutex;
13use tracing::instrument;
14
15pub struct DatabaseTransaction {
20 conn: Arc<Mutex<InnerConnection>>,
21 backend: DbBackend,
22 open: bool,
23 metric_callback: Option<crate::metric::Callback>,
24}
25
26impl std::fmt::Debug for DatabaseTransaction {
27 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28 write!(f, "DatabaseTransaction")
29 }
30}
31
32impl DatabaseTransaction {
33 #[instrument(level = "trace", skip(metric_callback))]
34 pub(crate) fn begin(
35 conn: Arc<Mutex<InnerConnection>>,
36 backend: DbBackend,
37 metric_callback: Option<crate::metric::Callback>,
38 isolation_level: Option<IsolationLevel>,
39 access_mode: Option<AccessMode>,
40 ) -> Result<DatabaseTransaction, DbErr> {
41 let res = DatabaseTransaction {
42 conn,
43 backend,
44 open: true,
45 metric_callback,
46 };
47
48 let begin_result: Result<(), DbErr> = super::tracing_spans::with_db_span!(
49 "sea_orm.begin",
50 backend,
51 "BEGIN",
52 record_stmt = false,
53 {
54 #[cfg(not(feature = "sync"))]
55 let conn = &mut *res.conn.lock();
56 #[cfg(feature = "sync")]
57 let conn = &mut *res.conn.lock().map_err(|_| DbErr::MutexPoisonError)?;
58
59 match conn {
60 #[cfg(feature = "sqlx-mysql")]
61 InnerConnection::MySql(c) => {
62 crate::driver::sqlx_mysql::set_transaction_config(
64 c,
65 isolation_level,
66 access_mode,
67 )?;
68 <sqlx::MySql as sqlx::Database>::TransactionManager::begin(c, None)
69 .map_err(sqlx_error_to_query_err)
70 }
71 #[cfg(feature = "sqlx-postgres")]
72 InnerConnection::Postgres(c) => {
73 <sqlx::Postgres as sqlx::Database>::TransactionManager::begin(c, None)
74 .map_err(sqlx_error_to_query_err)?;
75 crate::driver::sqlx_postgres::set_transaction_config(
77 c,
78 isolation_level,
79 access_mode,
80 )
81 }
82 #[cfg(feature = "sqlx-sqlite")]
83 InnerConnection::Sqlite(c) => {
84 crate::driver::sqlx_sqlite::set_transaction_config(
86 c,
87 isolation_level,
88 access_mode,
89 )?;
90 <sqlx::Sqlite as sqlx::Database>::TransactionManager::begin(c, None)
91 .map_err(sqlx_error_to_query_err)
92 }
93 #[cfg(feature = "rusqlite")]
94 InnerConnection::Rusqlite(c) => c.begin(),
95 #[cfg(feature = "mock")]
96 InnerConnection::Mock(c) => {
97 c.begin();
98 Ok(())
99 }
100 #[cfg(feature = "proxy")]
101 InnerConnection::Proxy(c) => {
102 c.begin();
103 Ok(())
104 }
105 #[allow(unreachable_patterns)]
106 _ => Err(conn_err("Disconnected")),
107 }
108 }
109 );
110
111 begin_result?;
112 Ok(res)
113 }
114
115 #[instrument(level = "trace", skip(callback))]
118 pub(crate) fn run<F, T, E>(self, callback: F) -> Result<T, TransactionError<E>>
119 where
120 F: for<'b> FnOnce(&'b DatabaseTransaction) -> Result<T, E>,
121 E: std::fmt::Display + std::fmt::Debug,
122 {
123 let res = callback(&self).map_err(TransactionError::Transaction);
124 if res.is_ok() {
125 self.commit().map_err(TransactionError::Connection)?;
126 } else {
127 self.rollback().map_err(TransactionError::Connection)?;
128 }
129 res
130 }
131
132 #[instrument(level = "trace")]
134 #[allow(unreachable_code, unused_mut)]
135 pub fn commit(mut self) -> Result<(), DbErr> {
136 let result: Result<(), DbErr> = super::tracing_spans::with_db_span!(
137 "sea_orm.commit",
138 self.backend,
139 "COMMIT",
140 record_stmt = false,
141 {
142 #[cfg(not(feature = "sync"))]
143 let conn = &mut *self.conn.lock();
144 #[cfg(feature = "sync")]
145 let conn = &mut *self.conn.lock().map_err(|_| DbErr::MutexPoisonError)?;
146
147 match conn {
148 #[cfg(feature = "sqlx-mysql")]
149 InnerConnection::MySql(c) => {
150 <sqlx::MySql as sqlx::Database>::TransactionManager::commit(c)
151 .map_err(sqlx_error_to_query_err)
152 }
153 #[cfg(feature = "sqlx-postgres")]
154 InnerConnection::Postgres(c) => {
155 <sqlx::Postgres as sqlx::Database>::TransactionManager::commit(c)
156 .map_err(sqlx_error_to_query_err)
157 }
158 #[cfg(feature = "sqlx-sqlite")]
159 InnerConnection::Sqlite(c) => {
160 <sqlx::Sqlite as sqlx::Database>::TransactionManager::commit(c)
161 .map_err(sqlx_error_to_query_err)
162 }
163 #[cfg(feature = "rusqlite")]
164 InnerConnection::Rusqlite(c) => c.commit(),
165 #[cfg(feature = "mock")]
166 InnerConnection::Mock(c) => {
167 c.commit();
168 Ok(())
169 }
170 #[cfg(feature = "proxy")]
171 InnerConnection::Proxy(c) => {
172 c.commit();
173 Ok(())
174 }
175 #[allow(unreachable_patterns)]
176 _ => Err(conn_err("Disconnected")),
177 }
178 }
179 );
180
181 result?;
182 self.open = false; Ok(())
184 }
185
186 #[instrument(level = "trace")]
188 #[allow(unreachable_code, unused_mut)]
189 pub fn rollback(mut self) -> Result<(), DbErr> {
190 let result: Result<(), DbErr> = super::tracing_spans::with_db_span!(
191 "sea_orm.rollback",
192 self.backend,
193 "ROLLBACK",
194 record_stmt = false,
195 {
196 #[cfg(not(feature = "sync"))]
197 let conn = &mut *self.conn.lock();
198 #[cfg(feature = "sync")]
199 let conn = &mut *self.conn.lock().map_err(|_| DbErr::MutexPoisonError)?;
200
201 match conn {
202 #[cfg(feature = "sqlx-mysql")]
203 InnerConnection::MySql(c) => {
204 <sqlx::MySql as sqlx::Database>::TransactionManager::rollback(c)
205 .map_err(sqlx_error_to_query_err)
206 }
207 #[cfg(feature = "sqlx-postgres")]
208 InnerConnection::Postgres(c) => {
209 <sqlx::Postgres as sqlx::Database>::TransactionManager::rollback(c)
210 .map_err(sqlx_error_to_query_err)
211 }
212 #[cfg(feature = "sqlx-sqlite")]
213 InnerConnection::Sqlite(c) => {
214 <sqlx::Sqlite as sqlx::Database>::TransactionManager::rollback(c)
215 .map_err(sqlx_error_to_query_err)
216 }
217 #[cfg(feature = "rusqlite")]
218 InnerConnection::Rusqlite(c) => c.rollback(),
219 #[cfg(feature = "mock")]
220 InnerConnection::Mock(c) => {
221 c.rollback();
222 Ok(())
223 }
224 #[cfg(feature = "proxy")]
225 InnerConnection::Proxy(c) => {
226 c.rollback();
227 Ok(())
228 }
229 #[allow(unreachable_patterns)]
230 _ => Err(conn_err("Disconnected")),
231 }
232 }
233 );
234
235 result?;
236 self.open = false; Ok(())
238 }
239
240 #[instrument(level = "trace")]
242 fn start_rollback(&mut self) -> Result<(), DbErr> {
243 if self.open {
244 if let Some(mut conn) = self.conn.try_lock().ok() {
245 match &mut *conn {
246 #[cfg(feature = "sqlx-mysql")]
247 InnerConnection::MySql(c) => {
248 <sqlx::MySql as sqlx::Database>::TransactionManager::start_rollback(c);
249 }
250 #[cfg(feature = "sqlx-postgres")]
251 InnerConnection::Postgres(c) => {
252 <sqlx::Postgres as sqlx::Database>::TransactionManager::start_rollback(c);
253 }
254 #[cfg(feature = "sqlx-sqlite")]
255 InnerConnection::Sqlite(c) => {
256 <sqlx::Sqlite as sqlx::Database>::TransactionManager::start_rollback(c);
257 }
258 #[cfg(feature = "rusqlite")]
259 InnerConnection::Rusqlite(c) => {
260 c.start_rollback()?;
261 }
262 #[cfg(feature = "mock")]
263 InnerConnection::Mock(c) => {
264 c.rollback();
265 }
266 #[cfg(feature = "proxy")]
267 InnerConnection::Proxy(c) => {
268 c.start_rollback();
269 }
270 #[allow(unreachable_patterns)]
271 _ => return Err(conn_err("Disconnected")),
272 }
273 } else {
274 return Err(conn_err("Dropping a locked Transaction"));
276 }
277 }
278 Ok(())
279 }
280}
281
282impl TransactionSession for DatabaseTransaction {
283 fn commit(self) -> Result<(), DbErr> {
284 self.commit()
285 }
286
287 fn rollback(self) -> Result<(), DbErr> {
288 self.rollback()
289 }
290}
291
292impl Drop for DatabaseTransaction {
293 fn drop(&mut self) {
294 self.start_rollback().expect("Fail to rollback transaction");
295 }
296}
297
298impl ConnectionTrait for DatabaseTransaction {
299 fn get_database_backend(&self) -> DbBackend {
300 self.backend
302 }
303
304 #[instrument(level = "trace")]
305 #[allow(unused_variables)]
306 fn execute_raw(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
307 debug_print!("{}", stmt);
308
309 super::tracing_spans::with_db_span!(
310 "sea_orm.execute",
311 self.backend,
312 stmt.sql.as_str(),
313 record_stmt = true,
314 {
315 #[cfg(not(feature = "sync"))]
316 let conn = &mut *self.conn.lock();
317 #[cfg(feature = "sync")]
318 let conn = &mut *self.conn.lock().map_err(|_| DbErr::MutexPoisonError)?;
319
320 match conn {
321 #[cfg(feature = "sqlx-mysql")]
322 InnerConnection::MySql(conn) => {
323 let query = crate::driver::sqlx_mysql::sqlx_query(&stmt);
324 let conn: &mut sqlx::MySqlConnection = &mut *conn;
325 crate::metric::metric!(self.metric_callback, &stmt, {
326 query.execute(conn).map(Into::into)
327 })
328 .map_err(sqlx_error_to_exec_err)
329 }
330 #[cfg(feature = "sqlx-postgres")]
331 InnerConnection::Postgres(conn) => {
332 let query = crate::driver::sqlx_postgres::sqlx_query(&stmt);
333 let conn: &mut sqlx::PgConnection = &mut *conn;
334 crate::metric::metric!(self.metric_callback, &stmt, {
335 query.execute(conn).map(Into::into)
336 })
337 .map_err(sqlx_error_to_exec_err)
338 }
339 #[cfg(feature = "sqlx-sqlite")]
340 InnerConnection::Sqlite(conn) => {
341 let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt);
342 let conn: &mut sqlx::SqliteConnection = &mut *conn;
343 crate::metric::metric!(self.metric_callback, &stmt, {
344 query.execute(conn).map(Into::into)
345 })
346 .map_err(sqlx_error_to_exec_err)
347 }
348 #[cfg(feature = "rusqlite")]
349 InnerConnection::Rusqlite(conn) => conn.execute(stmt, &self.metric_callback),
350 #[cfg(feature = "mock")]
351 InnerConnection::Mock(conn) => conn.execute(stmt),
352 #[cfg(feature = "proxy")]
353 InnerConnection::Proxy(conn) => conn.execute(stmt),
354 #[allow(unreachable_patterns)]
355 _ => Err(conn_err("Disconnected")),
356 }
357 }
358 )
359 }
360
361 #[instrument(level = "trace")]
362 #[allow(unused_variables)]
363 fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
364 debug_print!("{}", sql);
365
366 super::tracing_spans::with_db_span!(
367 "sea_orm.execute_unprepared",
368 self.backend,
369 sql,
370 record_stmt = false,
371 {
372 #[cfg(not(feature = "sync"))]
373 let conn = &mut *self.conn.lock();
374 #[cfg(feature = "sync")]
375 let conn = &mut *self.conn.lock().map_err(|_| DbErr::MutexPoisonError)?;
376
377 match conn {
378 #[cfg(feature = "sqlx-mysql")]
379 InnerConnection::MySql(conn) => {
380 let conn: &mut sqlx::MySqlConnection = &mut *conn;
381 sqlx::Executor::execute(conn, sql)
382 .map(Into::into)
383 .map_err(sqlx_error_to_exec_err)
384 }
385 #[cfg(feature = "sqlx-postgres")]
386 InnerConnection::Postgres(conn) => {
387 let conn: &mut sqlx::PgConnection = &mut *conn;
388 sqlx::Executor::execute(conn, sql)
389 .map(Into::into)
390 .map_err(sqlx_error_to_exec_err)
391 }
392 #[cfg(feature = "sqlx-sqlite")]
393 InnerConnection::Sqlite(conn) => {
394 let conn: &mut sqlx::SqliteConnection = &mut *conn;
395 sqlx::Executor::execute(conn, sql)
396 .map(Into::into)
397 .map_err(sqlx_error_to_exec_err)
398 }
399 #[cfg(feature = "rusqlite")]
400 InnerConnection::Rusqlite(conn) => conn.execute_unprepared(sql),
401 #[cfg(feature = "mock")]
402 InnerConnection::Mock(conn) => {
403 let db_backend = conn.get_database_backend();
404 let stmt = Statement::from_string(db_backend, sql);
405 conn.execute(stmt)
406 }
407 #[cfg(feature = "proxy")]
408 InnerConnection::Proxy(conn) => {
409 let db_backend = conn.get_database_backend();
410 let stmt = Statement::from_string(db_backend, sql);
411 conn.execute(stmt)
412 }
413 #[allow(unreachable_patterns)]
414 _ => Err(conn_err("Disconnected")),
415 }
416 }
417 )
418 }
419
420 #[instrument(level = "trace")]
421 #[allow(unused_variables)]
422 fn query_one_raw(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
423 debug_print!("{}", stmt);
424
425 super::tracing_spans::with_db_span!(
426 "sea_orm.query_one",
427 self.backend,
428 stmt.sql.as_str(),
429 record_stmt = true,
430 {
431 #[cfg(not(feature = "sync"))]
432 let conn = &mut *self.conn.lock();
433 #[cfg(feature = "sync")]
434 let conn = &mut *self.conn.lock().map_err(|_| DbErr::MutexPoisonError)?;
435
436 match conn {
437 #[cfg(feature = "sqlx-mysql")]
438 InnerConnection::MySql(conn) => {
439 let query = crate::driver::sqlx_mysql::sqlx_query(&stmt);
440 let conn: &mut sqlx::MySqlConnection = &mut *conn;
441 crate::metric::metric!(self.metric_callback, &stmt, {
442 crate::sqlx_map_err_ignore_not_found(
443 query.fetch_one(conn).map(|row| Some(row.into())),
444 )
445 })
446 }
447 #[cfg(feature = "sqlx-postgres")]
448 InnerConnection::Postgres(conn) => {
449 let query = crate::driver::sqlx_postgres::sqlx_query(&stmt);
450 let conn: &mut sqlx::PgConnection = &mut *conn;
451 crate::metric::metric!(self.metric_callback, &stmt, {
452 crate::sqlx_map_err_ignore_not_found(
453 query.fetch_one(conn).map(|row| Some(row.into())),
454 )
455 })
456 }
457 #[cfg(feature = "sqlx-sqlite")]
458 InnerConnection::Sqlite(conn) => {
459 let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt);
460 let conn: &mut sqlx::SqliteConnection = &mut *conn;
461 crate::metric::metric!(self.metric_callback, &stmt, {
462 crate::sqlx_map_err_ignore_not_found(
463 query.fetch_one(conn).map(|row| Some(row.into())),
464 )
465 })
466 }
467 #[cfg(feature = "rusqlite")]
468 InnerConnection::Rusqlite(conn) => conn.query_one(stmt, &self.metric_callback),
469 #[cfg(feature = "mock")]
470 InnerConnection::Mock(conn) => conn.query_one(stmt),
471 #[cfg(feature = "proxy")]
472 InnerConnection::Proxy(conn) => conn.query_one(stmt),
473 #[allow(unreachable_patterns)]
474 _ => Err(conn_err("Disconnected")),
475 }
476 }
477 )
478 }
479
480 #[instrument(level = "trace")]
481 #[allow(unused_variables)]
482 fn query_all_raw(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
483 debug_print!("{}", stmt);
484
485 super::tracing_spans::with_db_span!(
486 "sea_orm.query_all",
487 self.backend,
488 stmt.sql.as_str(),
489 record_stmt = true,
490 {
491 #[cfg(not(feature = "sync"))]
492 let conn = &mut *self.conn.lock();
493 #[cfg(feature = "sync")]
494 let conn = &mut *self.conn.lock().map_err(|_| DbErr::MutexPoisonError)?;
495
496 match conn {
497 #[cfg(feature = "sqlx-mysql")]
498 InnerConnection::MySql(conn) => {
499 let query = crate::driver::sqlx_mysql::sqlx_query(&stmt);
500 let conn: &mut sqlx::MySqlConnection = &mut *conn;
501 crate::metric::metric!(self.metric_callback, &stmt, {
502 query
503 .fetch_all(conn)
504 .map(|rows| rows.into_iter().map(|r| r.into()).collect())
505 .map_err(sqlx_error_to_query_err)
506 })
507 }
508 #[cfg(feature = "sqlx-postgres")]
509 InnerConnection::Postgres(conn) => {
510 let query = crate::driver::sqlx_postgres::sqlx_query(&stmt);
511 let conn: &mut sqlx::PgConnection = &mut *conn;
512 crate::metric::metric!(self.metric_callback, &stmt, {
513 query
514 .fetch_all(conn)
515 .map(|rows| rows.into_iter().map(|r| r.into()).collect())
516 .map_err(sqlx_error_to_query_err)
517 })
518 }
519 #[cfg(feature = "sqlx-sqlite")]
520 InnerConnection::Sqlite(conn) => {
521 let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt);
522 let conn: &mut sqlx::SqliteConnection = &mut *conn;
523 crate::metric::metric!(self.metric_callback, &stmt, {
524 query
525 .fetch_all(conn)
526 .map(|rows| rows.into_iter().map(|r| r.into()).collect())
527 .map_err(sqlx_error_to_query_err)
528 })
529 }
530 #[cfg(feature = "rusqlite")]
531 InnerConnection::Rusqlite(conn) => conn.query_all(stmt, &self.metric_callback),
532 #[cfg(feature = "mock")]
533 InnerConnection::Mock(conn) => conn.query_all(stmt),
534 #[cfg(feature = "proxy")]
535 InnerConnection::Proxy(conn) => conn.query_all(stmt),
536 #[allow(unreachable_patterns)]
537 _ => Err(conn_err("Disconnected")),
538 }
539 }
540 )
541 }
542}
543
544impl StreamTrait for DatabaseTransaction {
545 type Stream<'a> = TransactionStream<'a>;
546
547 fn get_database_backend(&self) -> DbBackend {
548 self.backend
549 }
550
551 #[instrument(level = "trace")]
552 fn stream_raw<'a>(&'a self, stmt: Statement) -> Result<Self::Stream<'a>, DbErr> {
553 ({
554 #[cfg(not(feature = "sync"))]
555 let conn = self.conn.lock();
556 #[cfg(feature = "sync")]
557 let conn = self.conn.lock().map_err(|_| DbErr::MutexPoisonError)?;
558 Ok(crate::TransactionStream::build(
559 conn,
560 stmt,
561 self.metric_callback.clone(),
562 ))
563 })
564 }
565}
566
567impl TransactionTrait for DatabaseTransaction {
568 type Transaction = DatabaseTransaction;
569
570 #[instrument(level = "trace")]
571 fn begin(&self) -> Result<DatabaseTransaction, DbErr> {
572 DatabaseTransaction::begin(
573 Arc::clone(&self.conn),
574 self.backend,
575 self.metric_callback.clone(),
576 None,
577 None,
578 )
579 }
580
581 #[instrument(level = "trace")]
582 fn begin_with_config(
583 &self,
584 isolation_level: Option<IsolationLevel>,
585 access_mode: Option<AccessMode>,
586 ) -> Result<DatabaseTransaction, DbErr> {
587 DatabaseTransaction::begin(
588 Arc::clone(&self.conn),
589 self.backend,
590 self.metric_callback.clone(),
591 isolation_level,
592 access_mode,
593 )
594 }
595
596 #[instrument(level = "trace", skip(_callback))]
600 fn transaction<F, T, E>(&self, _callback: F) -> Result<T, TransactionError<E>>
601 where
602 F: for<'c> FnOnce(&'c DatabaseTransaction) -> Result<T, E>,
603 E: std::fmt::Display + std::fmt::Debug,
604 {
605 let transaction = self.begin().map_err(TransactionError::Connection)?;
606 transaction.run(_callback)
607 }
608
609 #[instrument(level = "trace", skip(_callback))]
613 fn transaction_with_config<F, T, E>(
614 &self,
615 _callback: F,
616 isolation_level: Option<IsolationLevel>,
617 access_mode: Option<AccessMode>,
618 ) -> Result<T, TransactionError<E>>
619 where
620 F: for<'c> FnOnce(&'c DatabaseTransaction) -> Result<T, E>,
621 E: std::fmt::Display + std::fmt::Debug,
622 {
623 let transaction = self
624 .begin_with_config(isolation_level, access_mode)
625 .map_err(TransactionError::Connection)?;
626 transaction.run(_callback)
627 }
628}
629
630#[derive(Debug)]
632pub enum TransactionError<E> {
633 Connection(DbErr),
635 Transaction(E),
637}
638
639impl<E> std::fmt::Display for TransactionError<E>
640where
641 E: std::fmt::Display + std::fmt::Debug,
642{
643 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
644 match self {
645 TransactionError::Connection(e) => std::fmt::Display::fmt(e, f),
646 TransactionError::Transaction(e) => std::fmt::Display::fmt(e, f),
647 }
648 }
649}
650
651impl<E> std::error::Error for TransactionError<E> where E: std::fmt::Display + std::fmt::Debug {}
652
653impl<E> From<DbErr> for TransactionError<E>
654where
655 E: std::fmt::Display + std::fmt::Debug,
656{
657 fn from(e: DbErr) -> Self {
658 Self::Connection(e)
659 }
660}