sql_middleware/postgres/typed/
tx.rs

1use tokio::runtime::Handle;
2
3use crate::middleware::SqlMiddlewareDbError;
4
5use super::core::{Idle, InTx, PgConnection, SKIP_DROP_ROLLBACK};
6
7impl PgConnection<Idle> {
8    /// Begin an explicit transaction.
9    ///
10    /// # Errors
11    /// Returns `SqlMiddlewareDbError` if starting the transaction fails.
12    pub async fn begin(mut self) -> Result<PgConnection<InTx>, SqlMiddlewareDbError> {
13        let conn = self.take_conn()?;
14        conn.simple_query("BEGIN").await.map_err(|e| {
15            SqlMiddlewareDbError::ExecutionError(format!("postgres begin error: {e}"))
16        })?;
17        Ok(PgConnection::new(conn, true))
18    }
19}
20
21impl PgConnection<InTx> {
22    /// Commit and return to idle.
23    ///
24    /// # Errors
25    /// Returns `SqlMiddlewareDbError` if the commit fails.
26    pub async fn commit(self) -> Result<PgConnection<Idle>, SqlMiddlewareDbError> {
27        self.finish_tx("COMMIT", "commit").await
28    }
29
30    /// Rollback and return to idle.
31    ///
32    /// # Errors
33    /// Returns `SqlMiddlewareDbError` if the rollback fails.
34    pub async fn rollback(self) -> Result<PgConnection<Idle>, SqlMiddlewareDbError> {
35        self.finish_tx("ROLLBACK", "rollback").await
36    }
37
38    async fn finish_tx(
39        mut self,
40        sql: &str,
41        action: &str,
42    ) -> Result<PgConnection<Idle>, SqlMiddlewareDbError> {
43        let conn = self.take_conn()?;
44        match conn.simple_query(sql).await.map_err(|e| {
45            SqlMiddlewareDbError::ExecutionError(format!("postgres {action} error: {e}"))
46        }) {
47            Ok(_) => {
48                self.needs_rollback = false;
49                Ok(PgConnection::new(conn, false))
50            }
51            Err(err) => {
52                // Best-effort rollback; keep needs_rollback so Drop can retry.
53                let _ = conn.simple_query("ROLLBACK").await;
54                self.conn = Some(conn);
55                Err(err)
56            }
57        }
58    }
59}
60
61// NOTE: Cannot specialize Drop for PgConnection<InTx> in Rust.
62// Users must explicitly call commit() or rollback() to finalize transactions.
63// If dropped without finalizing, Postgres will auto-rollback when the connection
64// is returned to the pool (standard Postgres behavior for uncommitted transactions).
65fn skip_drop_rollback() -> bool {
66    SKIP_DROP_ROLLBACK.load(std::sync::atomic::Ordering::Relaxed)
67}
68
69impl<State> Drop for PgConnection<State> {
70    fn drop(&mut self) {
71        if self.needs_rollback
72            && !skip_drop_rollback()
73            && let Some(conn) = self.conn.take()
74            && let Ok(handle) = Handle::try_current()
75        {
76            handle.spawn(async move {
77                let _ = conn.simple_query("ROLLBACK").await;
78            });
79        }
80    }
81}
82
83/// Test-only escape hatch to simulate legacy behavior where dropping an in-flight transaction
84/// leaked the transaction back to the pool. Do not use outside tests.
85#[doc(hidden)]
86pub fn set_skip_drop_rollback_for_tests(skip: bool) {
87    SKIP_DROP_ROLLBACK.store(skip, std::sync::atomic::Ordering::Relaxed);
88}