sql_middleware/postgres/typed/
tx.rs1use tokio::runtime::Handle;
2
3use crate::middleware::SqlMiddlewareDbError;
4
5use super::core::{Idle, InTx, PgConnection, SKIP_DROP_ROLLBACK};
6
7impl PgConnection<Idle> {
8 pub async fn begin(mut self) -> Result<PgConnection<InTx>, SqlMiddlewareDbError> {
13 let conn = self.take_conn()?;
14 conn.simple_query("BEGIN").await.map_err(|e| {
15 SqlMiddlewareDbError::ExecutionError(format!("postgres begin error: {e}"))
16 })?;
17 Ok(PgConnection::new(conn, true))
18 }
19}
20
21impl PgConnection<InTx> {
22 pub async fn commit(self) -> Result<PgConnection<Idle>, SqlMiddlewareDbError> {
27 self.finish_tx("COMMIT", "commit").await
28 }
29
30 pub async fn rollback(self) -> Result<PgConnection<Idle>, SqlMiddlewareDbError> {
35 self.finish_tx("ROLLBACK", "rollback").await
36 }
37
38 async fn finish_tx(
39 mut self,
40 sql: &str,
41 action: &str,
42 ) -> Result<PgConnection<Idle>, SqlMiddlewareDbError> {
43 let conn = self.take_conn()?;
44 match conn.simple_query(sql).await.map_err(|e| {
45 SqlMiddlewareDbError::ExecutionError(format!("postgres {action} error: {e}"))
46 }) {
47 Ok(_) => {
48 self.needs_rollback = false;
49 Ok(PgConnection::new(conn, false))
50 }
51 Err(err) => {
52 let _ = conn.simple_query("ROLLBACK").await;
54 self.conn = Some(conn);
55 Err(err)
56 }
57 }
58 }
59}
60
61fn skip_drop_rollback() -> bool {
66 SKIP_DROP_ROLLBACK.load(std::sync::atomic::Ordering::Relaxed)
67}
68
69impl<State> Drop for PgConnection<State> {
70 fn drop(&mut self) {
71 if self.needs_rollback
72 && !skip_drop_rollback()
73 && let Some(conn) = self.conn.take()
74 && let Ok(handle) = Handle::try_current()
75 {
76 handle.spawn(async move {
77 let _ = conn.simple_query("ROLLBACK").await;
78 });
79 }
80 }
81}
82
83#[doc(hidden)]
86pub fn set_skip_drop_rollback_for_tests(skip: bool) {
87 SKIP_DROP_ROLLBACK.store(skip, std::sync::atomic::Ordering::Relaxed);
88}