sql_middleware/sqlite/typed/
tx.rs

1use std::sync::Arc;
2use std::sync::atomic::Ordering;
3
4use tokio::runtime::Handle;
5use tokio::task::block_in_place;
6
7use crate::middleware::SqlMiddlewareDbError;
8
9use super::SqliteTypedConnection;
10use super::core::{SKIP_DROP_ROLLBACK, begin_from_conn, run_blocking};
11use crate::sqlite::config::SharedSqliteConnection;
12
13impl SqliteTypedConnection<super::core::Idle> {
14    /// Begin an explicit transaction.
15    ///
16    /// # Errors
17    /// Returns `SqlMiddlewareDbError` if transitioning into a transaction fails.
18    pub async fn begin(
19        mut self,
20    ) -> Result<SqliteTypedConnection<super::core::InTx>, SqlMiddlewareDbError> {
21        begin_from_conn(self.take_conn()?).await
22    }
23}
24
25impl SqliteTypedConnection<super::core::InTx> {
26    /// Commit and return to idle.
27    ///
28    /// # Errors
29    /// Returns `SqlMiddlewareDbError` if committing the transaction fails.
30    pub async fn commit(
31        mut self,
32    ) -> Result<SqliteTypedConnection<super::core::Idle>, SqlMiddlewareDbError> {
33        let conn_handle = self.conn_handle()?;
34        let commit_result = run_blocking(Arc::clone(&conn_handle), |guard| {
35            guard
36                .execute_batch("COMMIT")
37                .map_err(SqlMiddlewareDbError::SqliteError)
38        })
39        .await;
40
41        match commit_result {
42            Ok(()) => {
43                let conn = self.take_conn()?;
44                Ok(SqliteTypedConnection {
45                    conn: Some(conn),
46                    needs_rollback: false,
47                    _state: std::marker::PhantomData,
48                })
49            }
50            Err(err) => {
51                // Best-effort rollback; keep needs_rollback = true so Drop can retry if needed.
52                let _ = run_blocking(conn_handle, |guard| {
53                    guard
54                        .execute_batch("ROLLBACK")
55                        .map_err(SqlMiddlewareDbError::SqliteError)
56                })
57                .await;
58                Err(err)
59            }
60        }
61    }
62
63    /// Rollback and return to idle.
64    ///
65    /// # Errors
66    /// Returns `SqlMiddlewareDbError` if rolling back the transaction fails.
67    pub async fn rollback(
68        mut self,
69    ) -> Result<SqliteTypedConnection<super::core::Idle>, SqlMiddlewareDbError> {
70        let conn_handle = self.conn_handle()?;
71        let rollback_result = run_blocking(Arc::clone(&conn_handle), |guard| {
72            guard
73                .execute_batch("ROLLBACK")
74                .map_err(SqlMiddlewareDbError::SqliteError)
75        })
76        .await;
77
78        match rollback_result {
79            Ok(()) => {
80                let conn = self.take_conn()?;
81                Ok(SqliteTypedConnection {
82                    conn: Some(conn),
83                    needs_rollback: false,
84                    _state: std::marker::PhantomData,
85                })
86            }
87            Err(err) => {
88                // Keep connection + needs_rollback so Drop can attempt cleanup.
89                Err(err)
90            }
91        }
92    }
93}
94
95impl<State> Drop for SqliteTypedConnection<State> {
96    fn drop(&mut self) {
97        if self.needs_rollback
98            && !skip_drop_rollback()
99            && let Some(conn) = self.conn.take()
100        {
101            let conn_handle: SharedSqliteConnection = Arc::clone(&*conn);
102            // Rollback synchronously so the connection is clean before it
103            // goes back into the pool. Avoid async fire-and-forget, which
104            // could race with the next checkout.
105            if Handle::try_current().is_ok() {
106                block_in_place(|| {
107                    let _ = conn_handle.execute_blocking(|conn| {
108                        conn.execute_batch("ROLLBACK")
109                            .map_err(SqlMiddlewareDbError::SqliteError)
110                    });
111                });
112            } else {
113                let _ = conn_handle.execute_blocking(|conn| {
114                    conn.execute_batch("ROLLBACK")
115                        .map_err(SqlMiddlewareDbError::SqliteError)
116                });
117            }
118        }
119    }
120}
121
122fn skip_drop_rollback() -> bool {
123    SKIP_DROP_ROLLBACK.load(Ordering::Relaxed)
124}
125
126/// Test-only escape hatch to simulate legacy behavior where dropping an in-flight transaction
127/// leaked the transaction back to the pool. Do not use outside tests.
128#[doc(hidden)]
129pub fn set_skip_drop_rollback_for_tests(skip: bool) {
130    SKIP_DROP_ROLLBACK.store(skip, Ordering::Relaxed);
131}