Skip to main content

sql_middleware/sqlite/
transaction.rs

1use std::sync::Arc;
2
3use crate::adapters::params::convert_params;
4use crate::middleware::{ConversionMode, CustomDbRow, ResultSet, RowValues, SqlMiddlewareDbError};
5use crate::pool::MiddlewarePoolConnection;
6use crate::tx_outcome::TxOutcome;
7
8use super::connection::SqliteConnection;
9use super::params::Params;
10
11use std::sync::atomic::{AtomicBool, Ordering};
12
13static REWRAP_ON_ROLLBACK_FAILURE: AtomicBool = AtomicBool::new(false);
14
15#[doc(hidden)]
16pub fn set_rewrap_on_rollback_failure_for_tests(rewrap: bool) {
17    REWRAP_ON_ROLLBACK_FAILURE.store(rewrap, Ordering::Relaxed);
18}
19
20fn rewrap_on_rollback_failure_for_tests() -> bool {
21    REWRAP_ON_ROLLBACK_FAILURE.load(Ordering::Relaxed)
22}
23
24/// Transaction handle that owns the `SQLite` connection until completion.
25pub struct Tx<'a> {
26    conn: Option<SqliteConnection>,
27    conn_slot: &'a mut MiddlewarePoolConnection,
28}
29
30/// Prepared statement tied to a `SQLite` transaction.
31pub struct Prepared {
32    sql: Arc<String>,
33}
34
35/// Begin a transaction, temporarily taking ownership of the pooled `SQLite` connection
36/// until commit/rollback (or drop) returns it to the wrapper.
37///
38/// # Errors
39/// Returns `SqlMiddlewareDbError` if the transaction cannot be started.
40pub async fn begin_transaction(
41    conn_slot: &mut MiddlewarePoolConnection,
42) -> Result<Tx<'_>, SqlMiddlewareDbError> {
43    #[cfg(any(feature = "postgres", feature = "mssql", feature = "turso"))]
44    let MiddlewarePoolConnection::Sqlite { conn, .. } = conn_slot else {
45        return Err(SqlMiddlewareDbError::Unimplemented(
46            "begin_transaction is only available for SQLite connections".into(),
47        ));
48    };
49    #[cfg(not(any(feature = "postgres", feature = "mssql", feature = "turso")))]
50    let MiddlewarePoolConnection::Sqlite { conn, .. } = conn_slot;
51
52    let mut conn = conn.take().ok_or_else(|| {
53        SqlMiddlewareDbError::ExecutionError(
54            "SQLite connection already taken from pool wrapper".into(),
55        )
56    })?;
57    conn.begin().await?;
58    Ok(Tx {
59        conn: Some(conn),
60        conn_slot,
61    })
62}
63
64impl<'conn> Tx<'conn> {
65    fn conn_mut(&mut self) -> Result<&mut SqliteConnection, SqlMiddlewareDbError> {
66        self.conn.as_mut().ok_or_else(|| {
67            SqlMiddlewareDbError::ExecutionError("SQLite transaction already completed".into())
68        })
69    }
70
71    /// Prepare a statement within this transaction.
72    ///
73    /// # Errors
74    /// Returns `SqlMiddlewareDbError` if the transaction has already completed.
75    pub fn prepare(&self, sql: &str) -> Result<Prepared, SqlMiddlewareDbError> {
76        if self.conn.is_none() {
77            return Err(SqlMiddlewareDbError::ExecutionError(
78                "SQLite transaction already completed".into(),
79            ));
80        }
81        Ok(Prepared {
82            sql: Arc::new(sql.to_owned()),
83        })
84    }
85
86    /// Start configuring a prepared SELECT execution.
87    #[must_use]
88    pub fn select<'tx, 'prepared>(
89        &'tx mut self,
90        prepared: &'prepared Prepared,
91    ) -> PreparedSelect<'tx, 'prepared, 'static, 'conn> {
92        PreparedSelect {
93            tx: self,
94            prepared,
95            params: &[],
96        }
97    }
98
99    /// Start configuring a prepared DML execution.
100    #[must_use]
101    pub fn execute<'tx, 'prepared>(
102        &'tx mut self,
103        prepared: &'prepared Prepared,
104    ) -> PreparedExecute<'tx, 'prepared, 'static, 'conn> {
105        PreparedExecute {
106            tx: self,
107            prepared,
108            params: &[],
109        }
110    }
111
112    /// Execute a prepared statement as DML within this transaction.
113    ///
114    /// # Errors
115    /// Returns `SqlMiddlewareDbError` if parameter conversion or execution fails.
116    pub(crate) async fn execute_prepared(
117        &mut self,
118        prepared: &Prepared,
119        params: &[RowValues],
120    ) -> Result<usize, SqlMiddlewareDbError> {
121        let converted = convert_params::<Params>(params, ConversionMode::Execute)?;
122        let conn = self.conn_mut()?;
123        conn.execute_dml_in_tx(prepared.sql.as_ref(), &converted.0)
124            .await
125    }
126
127    /// Execute a prepared statement as a query within this transaction.
128    ///
129    /// # Errors
130    /// Returns `SqlMiddlewareDbError` if parameter conversion or execution fails.
131    pub(crate) async fn query_prepared(
132        &mut self,
133        prepared: &Prepared,
134        params: &[RowValues],
135    ) -> Result<ResultSet, SqlMiddlewareDbError> {
136        let converted = convert_params::<Params>(params, ConversionMode::Query)?;
137        let conn = self.conn_mut()?;
138        conn.execute_select_in_tx(
139            prepared.sql.as_ref(),
140            &converted.0,
141            super::query::build_result_set,
142        )
143        .await
144    }
145
146    /// Execute a batch inside the open transaction.
147    ///
148    /// # Errors
149    /// Returns `SqlMiddlewareDbError` if executing the batch fails.
150    pub async fn execute_batch(&mut self, sql: &str) -> Result<(), SqlMiddlewareDbError> {
151        let conn = self.conn_mut()?;
152        conn.execute_batch_in_tx(sql).await
153    }
154
155    /// Commit the transaction and rewrap the pooled connection.
156    ///
157    /// # Errors
158    /// Returns `SqlMiddlewareDbError` if committing the transaction fails.
159    pub async fn commit(mut self) -> Result<TxOutcome, SqlMiddlewareDbError> {
160        let mut conn = self.conn.take().ok_or_else(|| {
161            SqlMiddlewareDbError::ExecutionError("SQLite transaction already completed".into())
162        })?;
163        match conn.commit().await {
164            Ok(()) => {
165                self.rewrap(conn);
166                Ok(TxOutcome::without_restored_connection())
167            }
168            Err(err) => {
169                let handle = conn.conn_handle();
170                let rollback_result = super::connection::rollback_with_busy_retries(&handle).await;
171                if rollback_result.is_ok() || rewrap_on_rollback_failure_for_tests() {
172                    conn.in_transaction = false;
173                    self.rewrap(conn);
174                }
175                if rollback_result.is_err() && !rewrap_on_rollback_failure_for_tests() {
176                    handle.mark_broken();
177                }
178                Err(err)
179            }
180        }
181    }
182
183    /// Roll back the transaction and rewrap the pooled connection.
184    ///
185    /// # Errors
186    /// Returns `SqlMiddlewareDbError` if rolling back fails.
187    pub async fn rollback(mut self) -> Result<TxOutcome, SqlMiddlewareDbError> {
188        let mut conn = self.conn.take().ok_or_else(|| {
189            SqlMiddlewareDbError::ExecutionError("SQLite transaction already completed".into())
190        })?;
191        let handle = conn.conn_handle();
192        match super::connection::rollback_with_busy_retries(&handle).await {
193            Ok(()) => {
194                conn.in_transaction = false;
195                self.rewrap(conn);
196                Ok(TxOutcome::without_restored_connection())
197            }
198            Err(err) => {
199                if rewrap_on_rollback_failure_for_tests() {
200                    conn.in_transaction = false;
201                    self.rewrap(conn);
202                }
203                if !rewrap_on_rollback_failure_for_tests() {
204                    handle.mark_broken();
205                }
206                Err(err)
207            }
208        }
209    }
210
211    fn rewrap(&mut self, conn: SqliteConnection) {
212        #[cfg(any(feature = "postgres", feature = "mssql", feature = "turso"))]
213        let MiddlewarePoolConnection::Sqlite { conn: slot, .. } = self.conn_slot else {
214            return;
215        };
216        #[cfg(not(any(feature = "postgres", feature = "mssql", feature = "turso")))]
217        let MiddlewarePoolConnection::Sqlite { conn: slot, .. } = self.conn_slot;
218        debug_assert!(slot.is_none(), "sqlite conn slot should be empty during tx");
219        *slot = Some(conn);
220    }
221}
222
223/// Builder for executing a prepared SQLite DML statement inside a transaction.
224pub struct PreparedExecute<'tx, 'prepared, 'params, 'conn> {
225    tx: &'tx mut Tx<'conn>,
226    prepared: &'prepared Prepared,
227    params: &'params [RowValues],
228}
229
230impl<'tx, 'prepared, 'params, 'conn> PreparedExecute<'tx, 'prepared, 'params, 'conn> {
231    /// Use middleware `RowValues` parameters.
232    #[must_use]
233    pub fn params<'next>(
234        self,
235        params: &'next [RowValues],
236    ) -> PreparedExecute<'tx, 'prepared, 'next, 'conn> {
237        PreparedExecute {
238            tx: self.tx,
239            prepared: self.prepared,
240            params,
241        }
242    }
243
244    /// Execute the DML statement and return affected rows.
245    ///
246    /// # Errors
247    /// Returns `SqlMiddlewareDbError` if parameter conversion or execution fails.
248    pub async fn run(self) -> Result<usize, SqlMiddlewareDbError> {
249        self.tx.execute_prepared(self.prepared, self.params).await
250    }
251}
252
253/// Builder for executing a prepared SQLite SELECT inside a transaction.
254pub struct PreparedSelect<'tx, 'prepared, 'params, 'conn> {
255    tx: &'tx mut Tx<'conn>,
256    prepared: &'prepared Prepared,
257    params: &'params [RowValues],
258}
259
260impl<'tx, 'prepared, 'params, 'conn> PreparedSelect<'tx, 'prepared, 'params, 'conn> {
261    /// Use middleware `RowValues` parameters.
262    #[must_use]
263    pub fn params<'next>(
264        self,
265        params: &'next [RowValues],
266    ) -> PreparedSelect<'tx, 'prepared, 'next, 'conn> {
267        PreparedSelect {
268            tx: self.tx,
269            prepared: self.prepared,
270            params,
271        }
272    }
273
274    /// Execute and return all rows as a `ResultSet`.
275    ///
276    /// # Errors
277    /// Returns `SqlMiddlewareDbError` if parameter conversion or execution fails.
278    pub async fn all(self) -> Result<ResultSet, SqlMiddlewareDbError> {
279        self.tx.query_prepared(self.prepared, self.params).await
280    }
281
282    /// Execute and return the first row, if present.
283    ///
284    /// # Errors
285    /// Returns `SqlMiddlewareDbError` if parameter conversion or execution fails.
286    pub async fn optional(self) -> Result<Option<CustomDbRow>, SqlMiddlewareDbError> {
287        self.all().await.map(ResultSet::into_optional)
288    }
289
290    /// Execute and return exactly one row.
291    ///
292    /// # Errors
293    /// Returns `SqlMiddlewareDbError` if execution fails or no row is returned.
294    pub async fn one(self) -> Result<CustomDbRow, SqlMiddlewareDbError> {
295        self.all().await?.into_one()
296    }
297}
298
299impl Drop for Tx<'_> {
300    /// Rolls back on drop to avoid leaking open transactions; the rollback is best-effort and
301    /// `SQLite` may report "no transaction is active" if the transaction was already completed
302    /// by user code (e.g., via `execute_batch_in_tx`). Such errors are ignored because the goal
303    /// is simply to leave the connection in a clean state before returning it to the pool.
304    fn drop(&mut self) {
305        if let Some(mut conn) = self.conn.take() {
306            let handle = conn.conn_handle();
307            let rollback_result = super::connection::rollback_with_busy_retries_blocking(&handle);
308            if rollback_result.is_ok() || rewrap_on_rollback_failure_for_tests() {
309                conn.in_transaction = false;
310                self.rewrap(conn);
311            } else {
312                // Mark broken so the pool will drop and replace this connection instead of
313                // handing out one that might still be mid-transaction.
314                handle.mark_broken();
315            }
316        }
317    }
318}