sql_middleware/sqlite/typed/
tx.rs

1use std::sync::Arc;
2use std::sync::atomic::Ordering;
3
4use tokio::runtime::Handle;
5use tokio::task::block_in_place;
6
7use crate::middleware::SqlMiddlewareDbError;
8
9use super::SqliteTypedConnection;
10use super::core::{SKIP_DROP_ROLLBACK, begin_from_conn, run_blocking};
11use crate::sqlite::connection::{rollback_with_busy_retries, rollback_with_busy_retries_blocking};
12use crate::sqlite::config::SharedSqliteConnection;
13
14impl SqliteTypedConnection<super::core::Idle> {
15    /// Begin an explicit transaction.
16    ///
17    /// # Errors
18    /// Returns `SqlMiddlewareDbError` if transitioning into a transaction fails.
19    pub async fn begin(
20        mut self,
21    ) -> Result<SqliteTypedConnection<super::core::InTx>, SqlMiddlewareDbError> {
22        begin_from_conn(self.take_conn()?).await
23    }
24}
25
26impl SqliteTypedConnection<super::core::InTx> {
27    /// Commit and return to idle.
28    ///
29    /// # Errors
30    /// Returns `SqlMiddlewareDbError` if committing the transaction fails.
31    pub async fn commit(
32        mut self,
33    ) -> Result<SqliteTypedConnection<super::core::Idle>, SqlMiddlewareDbError> {
34        let conn_handle = self.conn_handle()?;
35        let commit_result = run_blocking(Arc::clone(&conn_handle), |guard| {
36            guard
37                .execute_batch("COMMIT")
38                .map_err(SqlMiddlewareDbError::SqliteError)
39        })
40        .await;
41
42        match commit_result {
43            Ok(()) => {
44                let conn = self.take_conn()?;
45                Ok(SqliteTypedConnection {
46                    conn: Some(conn),
47                    needs_rollback: false,
48                    _state: std::marker::PhantomData,
49                })
50            }
51            Err(err) => {
52                // Best-effort rollback; keep needs_rollback = true so Drop can retry if needed.
53                if rollback_with_busy_retries(&conn_handle).await.is_err() {
54                    conn_handle.mark_broken();
55                }
56                Err(err)
57            }
58        }
59    }
60
61    /// Rollback and return to idle.
62    ///
63    /// # Errors
64    /// Returns `SqlMiddlewareDbError` if rolling back the transaction fails.
65    pub async fn rollback(
66        mut self,
67    ) -> Result<SqliteTypedConnection<super::core::Idle>, SqlMiddlewareDbError> {
68        let conn_handle = self.conn_handle()?;
69        let rollback_result = rollback_with_busy_retries(&conn_handle).await;
70
71        match rollback_result {
72            Ok(()) => {
73                let conn = self.take_conn()?;
74                Ok(SqliteTypedConnection {
75                    conn: Some(conn),
76                    needs_rollback: false,
77                    _state: std::marker::PhantomData,
78                })
79            }
80            Err(err) => {
81                // Keep connection + needs_rollback so Drop can attempt cleanup.
82                conn_handle.mark_broken();
83                Err(err)
84            }
85        }
86    }
87}
88
89impl<State> Drop for SqliteTypedConnection<State> {
90    fn drop(&mut self) {
91        if self.needs_rollback
92            && !skip_drop_rollback()
93            && let Some(conn) = self.conn.take()
94        {
95            let conn_handle: SharedSqliteConnection = Arc::clone(&*conn);
96            // Rollback synchronously so the connection is clean before it
97            // goes back into the pool. Avoid async fire-and-forget, which
98            // could race with the next checkout.
99            let rollback = || rollback_with_busy_retries_blocking(&conn_handle);
100            let result = if Handle::try_current().is_ok() {
101                block_in_place(rollback)
102            } else {
103                rollback()
104            };
105
106            if result.is_err() {
107                conn_handle.mark_broken();
108            }
109        }
110    }
111}
112
113fn skip_drop_rollback() -> bool {
114    SKIP_DROP_ROLLBACK.load(Ordering::Relaxed)
115}
116
117/// Test-only escape hatch to simulate legacy behavior where dropping an in-flight transaction
118/// leaked the transaction back to the pool. Do not use outside tests.
119#[doc(hidden)]
120pub fn set_skip_drop_rollback_for_tests(skip: bool) {
121    SKIP_DROP_ROLLBACK.store(skip, Ordering::Relaxed);
122}