sql_middleware/sqlite/
transaction.rs

1use std::sync::Arc;
2
3use crate::middleware::{
4    ConversionMode, ParamConverter, ResultSet, RowValues, SqlMiddlewareDbError,
5};
6use crate::pool::MiddlewarePoolConnection;
7use crate::tx_outcome::TxOutcome;
8
9use super::connection::SqliteConnection;
10use super::params::Params;
11
12use std::sync::atomic::{AtomicBool, Ordering};
13
14static REWRAP_ON_ROLLBACK_FAILURE: AtomicBool = AtomicBool::new(false);
15
16#[doc(hidden)]
17pub fn set_rewrap_on_rollback_failure_for_tests(rewrap: bool) {
18    REWRAP_ON_ROLLBACK_FAILURE.store(rewrap, Ordering::Relaxed);
19}
20
21fn rewrap_on_rollback_failure_for_tests() -> bool {
22    REWRAP_ON_ROLLBACK_FAILURE.load(Ordering::Relaxed)
23}
24
25/// Transaction handle that owns the `SQLite` connection until completion.
26pub struct Tx<'a> {
27    conn: Option<SqliteConnection>,
28    conn_slot: &'a mut MiddlewarePoolConnection,
29}
30
31/// Prepared statement tied to a `SQLite` transaction.
32pub struct Prepared {
33    sql: Arc<String>,
34}
35
36/// Begin a transaction, temporarily taking ownership of the pooled `SQLite` connection
37/// until commit/rollback (or drop) returns it to the wrapper.
38///
39/// # Errors
40/// Returns `SqlMiddlewareDbError` if the transaction cannot be started.
41pub async fn begin_transaction(
42    conn_slot: &mut MiddlewarePoolConnection,
43) -> Result<Tx<'_>, SqlMiddlewareDbError> {
44    #[cfg(any(feature = "postgres", feature = "mssql", feature = "turso"))]
45    let MiddlewarePoolConnection::Sqlite { conn, .. } = conn_slot else {
46        return Err(SqlMiddlewareDbError::Unimplemented(
47            "begin_transaction is only available for SQLite connections".into(),
48        ));
49    };
50    #[cfg(not(any(feature = "postgres", feature = "mssql", feature = "turso")))]
51    let MiddlewarePoolConnection::Sqlite { conn, .. } = conn_slot;
52
53    let mut conn = conn.take().ok_or_else(|| {
54        SqlMiddlewareDbError::ExecutionError(
55            "SQLite connection already taken from pool wrapper".into(),
56        )
57    })?;
58    conn.begin().await?;
59    Ok(Tx {
60        conn: Some(conn),
61        conn_slot,
62    })
63}
64
65impl Tx<'_> {
66    fn conn_mut(&mut self) -> Result<&mut SqliteConnection, SqlMiddlewareDbError> {
67        self.conn.as_mut().ok_or_else(|| {
68            SqlMiddlewareDbError::ExecutionError("SQLite transaction already completed".into())
69        })
70    }
71
72    /// Prepare a statement within this transaction.
73    ///
74    /// # Errors
75    /// Returns `SqlMiddlewareDbError` if the transaction has already completed.
76    pub fn prepare(&self, sql: &str) -> Result<Prepared, SqlMiddlewareDbError> {
77        if self.conn.is_none() {
78            return Err(SqlMiddlewareDbError::ExecutionError(
79                "SQLite transaction already completed".into(),
80            ));
81        }
82        Ok(Prepared {
83            sql: Arc::new(sql.to_owned()),
84        })
85    }
86
87    /// Execute a prepared statement as DML within this transaction.
88    ///
89    /// # Errors
90    /// Returns `SqlMiddlewareDbError` if parameter conversion or execution fails.
91    pub async fn execute_prepared(
92        &mut self,
93        prepared: &Prepared,
94        params: &[RowValues],
95    ) -> Result<usize, SqlMiddlewareDbError> {
96        let converted =
97            <Params as ParamConverter>::convert_sql_params(params, ConversionMode::Execute)?;
98        let conn = self.conn_mut()?;
99        conn.execute_dml_in_tx(prepared.sql.as_ref(), &converted.0)
100            .await
101    }
102
103    /// Execute a prepared statement as a query within this transaction.
104    ///
105    /// # Errors
106    /// Returns `SqlMiddlewareDbError` if parameter conversion or execution fails.
107    pub async fn query_prepared(
108        &mut self,
109        prepared: &Prepared,
110        params: &[RowValues],
111    ) -> Result<ResultSet, SqlMiddlewareDbError> {
112        let converted =
113            <Params as ParamConverter>::convert_sql_params(params, ConversionMode::Query)?;
114        let conn = self.conn_mut()?;
115        conn.execute_select_in_tx(
116            prepared.sql.as_ref(),
117            &converted.0,
118            super::query::build_result_set,
119        )
120        .await
121    }
122
123    /// Execute a batch inside the open transaction.
124    ///
125    /// # Errors
126    /// Returns `SqlMiddlewareDbError` if executing the batch fails.
127    pub async fn execute_batch(&mut self, sql: &str) -> Result<(), SqlMiddlewareDbError> {
128        let conn = self.conn_mut()?;
129        conn.execute_batch_in_tx(sql).await
130    }
131
132    /// Commit the transaction and rewrap the pooled connection.
133    ///
134    /// # Errors
135    /// Returns `SqlMiddlewareDbError` if committing the transaction fails.
136    pub async fn commit(mut self) -> Result<TxOutcome, SqlMiddlewareDbError> {
137        let mut conn = self.conn.take().ok_or_else(|| {
138            SqlMiddlewareDbError::ExecutionError("SQLite transaction already completed".into())
139        })?;
140        match conn.commit().await {
141            Ok(()) => {
142                self.rewrap(conn);
143                Ok(TxOutcome::without_restored_connection())
144            }
145            Err(err) => {
146                let handle = conn.conn_handle();
147                let rollback_result =
148                    super::connection::rollback_with_busy_retries(&handle).await;
149                if rollback_result.is_ok() || rewrap_on_rollback_failure_for_tests() {
150                    conn.in_transaction = false;
151                    self.rewrap(conn);
152                }
153                if rollback_result.is_err() && !rewrap_on_rollback_failure_for_tests() {
154                    handle.mark_broken();
155                }
156                Err(err)
157            }
158        }
159    }
160
161    /// Roll back the transaction and rewrap the pooled connection.
162    ///
163    /// # Errors
164    /// Returns `SqlMiddlewareDbError` if rolling back fails.
165    pub async fn rollback(mut self) -> Result<TxOutcome, SqlMiddlewareDbError> {
166        let mut conn = self.conn.take().ok_or_else(|| {
167            SqlMiddlewareDbError::ExecutionError("SQLite transaction already completed".into())
168        })?;
169        let handle = conn.conn_handle();
170        match super::connection::rollback_with_busy_retries(&handle).await {
171            Ok(()) => {
172                conn.in_transaction = false;
173                self.rewrap(conn);
174                Ok(TxOutcome::without_restored_connection())
175            }
176            Err(err) => {
177                if rewrap_on_rollback_failure_for_tests() {
178                    conn.in_transaction = false;
179                    self.rewrap(conn);
180                }
181                if !rewrap_on_rollback_failure_for_tests() {
182                    handle.mark_broken();
183                }
184                Err(err)
185            }
186        }
187    }
188
189    fn rewrap(&mut self, conn: SqliteConnection) {
190        #[cfg(any(feature = "postgres", feature = "mssql", feature = "turso"))]
191        let MiddlewarePoolConnection::Sqlite { conn: slot, .. } = self.conn_slot else {
192            return;
193        };
194        #[cfg(not(any(feature = "postgres", feature = "mssql", feature = "turso")))]
195        let MiddlewarePoolConnection::Sqlite { conn: slot, .. } = self.conn_slot;
196        debug_assert!(slot.is_none(), "sqlite conn slot should be empty during tx");
197        *slot = Some(conn);
198    }
199}
200
201impl Drop for Tx<'_> {
202    /// Rolls back on drop to avoid leaking open transactions; the rollback is best-effort and
203    /// `SQLite` may report "no transaction is active" if the transaction was already completed
204    /// by user code (e.g., via `execute_batch_in_tx`). Such errors are ignored because the goal
205    /// is simply to leave the connection in a clean state before returning it to the pool.
206    fn drop(&mut self) {
207        if let Some(mut conn) = self.conn.take() {
208            let handle = conn.conn_handle();
209            let rollback_result =
210                super::connection::rollback_with_busy_retries_blocking(&handle);
211            if rollback_result.is_ok() || rewrap_on_rollback_failure_for_tests() {
212                conn.in_transaction = false;
213                self.rewrap(conn);
214            } else {
215                // Mark broken so the pool will drop and replace this connection instead of
216                // handing out one that might still be mid-transaction.
217                handle.mark_broken();
218            }
219        }
220    }
221}