Skip to main content

sql_middleware/sqlite/
transaction.rs

1use std::sync::Arc;
2
3use crate::adapters::params::convert_params;
4use crate::middleware::{ConversionMode, ResultSet, RowValues, SqlMiddlewareDbError};
5use crate::pool::MiddlewarePoolConnection;
6use crate::tx_outcome::TxOutcome;
7
8use super::connection::SqliteConnection;
9use super::params::Params;
10
11use std::sync::atomic::{AtomicBool, Ordering};
12
13static REWRAP_ON_ROLLBACK_FAILURE: AtomicBool = AtomicBool::new(false);
14
15#[doc(hidden)]
16pub fn set_rewrap_on_rollback_failure_for_tests(rewrap: bool) {
17    REWRAP_ON_ROLLBACK_FAILURE.store(rewrap, Ordering::Relaxed);
18}
19
20fn rewrap_on_rollback_failure_for_tests() -> bool {
21    REWRAP_ON_ROLLBACK_FAILURE.load(Ordering::Relaxed)
22}
23
24/// Transaction handle that owns the `SQLite` connection until completion.
25pub struct Tx<'a> {
26    conn: Option<SqliteConnection>,
27    conn_slot: &'a mut MiddlewarePoolConnection,
28}
29
30/// Prepared statement tied to a `SQLite` transaction.
31pub struct Prepared {
32    sql: Arc<String>,
33}
34
35/// Begin a transaction, temporarily taking ownership of the pooled `SQLite` connection
36/// until commit/rollback (or drop) returns it to the wrapper.
37///
38/// # Errors
39/// Returns `SqlMiddlewareDbError` if the transaction cannot be started.
40pub async fn begin_transaction(
41    conn_slot: &mut MiddlewarePoolConnection,
42) -> Result<Tx<'_>, SqlMiddlewareDbError> {
43    #[cfg(any(feature = "postgres", feature = "mssql", feature = "turso"))]
44    let MiddlewarePoolConnection::Sqlite { conn, .. } = conn_slot else {
45        return Err(SqlMiddlewareDbError::Unimplemented(
46            "begin_transaction is only available for SQLite connections".into(),
47        ));
48    };
49    #[cfg(not(any(feature = "postgres", feature = "mssql", feature = "turso")))]
50    let MiddlewarePoolConnection::Sqlite { conn, .. } = conn_slot;
51
52    let mut conn = conn.take().ok_or_else(|| {
53        SqlMiddlewareDbError::ExecutionError(
54            "SQLite connection already taken from pool wrapper".into(),
55        )
56    })?;
57    conn.begin().await?;
58    Ok(Tx {
59        conn: Some(conn),
60        conn_slot,
61    })
62}
63
64impl Tx<'_> {
65    fn conn_mut(&mut self) -> Result<&mut SqliteConnection, SqlMiddlewareDbError> {
66        self.conn.as_mut().ok_or_else(|| {
67            SqlMiddlewareDbError::ExecutionError("SQLite transaction already completed".into())
68        })
69    }
70
71    /// Prepare a statement within this transaction.
72    ///
73    /// # Errors
74    /// Returns `SqlMiddlewareDbError` if the transaction has already completed.
75    pub fn prepare(&self, sql: &str) -> Result<Prepared, SqlMiddlewareDbError> {
76        if self.conn.is_none() {
77            return Err(SqlMiddlewareDbError::ExecutionError(
78                "SQLite transaction already completed".into(),
79            ));
80        }
81        Ok(Prepared {
82            sql: Arc::new(sql.to_owned()),
83        })
84    }
85
86    /// Execute a prepared statement as DML within this transaction.
87    ///
88    /// # Errors
89    /// Returns `SqlMiddlewareDbError` if parameter conversion or execution fails.
90    pub async fn execute_prepared(
91        &mut self,
92        prepared: &Prepared,
93        params: &[RowValues],
94    ) -> Result<usize, SqlMiddlewareDbError> {
95        let converted = convert_params::<Params>(params, ConversionMode::Execute)?;
96        let conn = self.conn_mut()?;
97        conn.execute_dml_in_tx(prepared.sql.as_ref(), &converted.0)
98            .await
99    }
100
101    /// Execute a prepared statement as a query within this transaction.
102    ///
103    /// # Errors
104    /// Returns `SqlMiddlewareDbError` if parameter conversion or execution fails.
105    pub async fn query_prepared(
106        &mut self,
107        prepared: &Prepared,
108        params: &[RowValues],
109    ) -> Result<ResultSet, SqlMiddlewareDbError> {
110        let converted = convert_params::<Params>(params, ConversionMode::Query)?;
111        let conn = self.conn_mut()?;
112        conn.execute_select_in_tx(
113            prepared.sql.as_ref(),
114            &converted.0,
115            super::query::build_result_set,
116        )
117        .await
118    }
119
120    /// Execute a batch inside the open transaction.
121    ///
122    /// # Errors
123    /// Returns `SqlMiddlewareDbError` if executing the batch fails.
124    pub async fn execute_batch(&mut self, sql: &str) -> Result<(), SqlMiddlewareDbError> {
125        let conn = self.conn_mut()?;
126        conn.execute_batch_in_tx(sql).await
127    }
128
129    /// Commit the transaction and rewrap the pooled connection.
130    ///
131    /// # Errors
132    /// Returns `SqlMiddlewareDbError` if committing the transaction fails.
133    pub async fn commit(mut self) -> Result<TxOutcome, SqlMiddlewareDbError> {
134        let mut conn = self.conn.take().ok_or_else(|| {
135            SqlMiddlewareDbError::ExecutionError("SQLite transaction already completed".into())
136        })?;
137        match conn.commit().await {
138            Ok(()) => {
139                self.rewrap(conn);
140                Ok(TxOutcome::without_restored_connection())
141            }
142            Err(err) => {
143                let handle = conn.conn_handle();
144                let rollback_result =
145                    super::connection::rollback_with_busy_retries(&handle).await;
146                if rollback_result.is_ok() || rewrap_on_rollback_failure_for_tests() {
147                    conn.in_transaction = false;
148                    self.rewrap(conn);
149                }
150                if rollback_result.is_err() && !rewrap_on_rollback_failure_for_tests() {
151                    handle.mark_broken();
152                }
153                Err(err)
154            }
155        }
156    }
157
158    /// Roll back the transaction and rewrap the pooled connection.
159    ///
160    /// # Errors
161    /// Returns `SqlMiddlewareDbError` if rolling back fails.
162    pub async fn rollback(mut self) -> Result<TxOutcome, SqlMiddlewareDbError> {
163        let mut conn = self.conn.take().ok_or_else(|| {
164            SqlMiddlewareDbError::ExecutionError("SQLite transaction already completed".into())
165        })?;
166        let handle = conn.conn_handle();
167        match super::connection::rollback_with_busy_retries(&handle).await {
168            Ok(()) => {
169                conn.in_transaction = false;
170                self.rewrap(conn);
171                Ok(TxOutcome::without_restored_connection())
172            }
173            Err(err) => {
174                if rewrap_on_rollback_failure_for_tests() {
175                    conn.in_transaction = false;
176                    self.rewrap(conn);
177                }
178                if !rewrap_on_rollback_failure_for_tests() {
179                    handle.mark_broken();
180                }
181                Err(err)
182            }
183        }
184    }
185
186    fn rewrap(&mut self, conn: SqliteConnection) {
187        #[cfg(any(feature = "postgres", feature = "mssql", feature = "turso"))]
188        let MiddlewarePoolConnection::Sqlite { conn: slot, .. } = self.conn_slot else {
189            return;
190        };
191        #[cfg(not(any(feature = "postgres", feature = "mssql", feature = "turso")))]
192        let MiddlewarePoolConnection::Sqlite { conn: slot, .. } = self.conn_slot;
193        debug_assert!(slot.is_none(), "sqlite conn slot should be empty during tx");
194        *slot = Some(conn);
195    }
196}
197
198impl Drop for Tx<'_> {
199    /// Rolls back on drop to avoid leaking open transactions; the rollback is best-effort and
200    /// `SQLite` may report "no transaction is active" if the transaction was already completed
201    /// by user code (e.g., via `execute_batch_in_tx`). Such errors are ignored because the goal
202    /// is simply to leave the connection in a clean state before returning it to the pool.
203    fn drop(&mut self) {
204        if let Some(mut conn) = self.conn.take() {
205            let handle = conn.conn_handle();
206            let rollback_result =
207                super::connection::rollback_with_busy_retries_blocking(&handle);
208            if rollback_result.is_ok() || rewrap_on_rollback_failure_for_tests() {
209                conn.in_transaction = false;
210                self.rewrap(conn);
211            } else {
212                // Mark broken so the pool will drop and replace this connection instead of
213                // handing out one that might still be mid-transaction.
214                handle.mark_broken();
215            }
216        }
217    }
218}