sql_middleware/sqlite/
transaction.rs1use std::sync::Arc;
2
3use crate::middleware::{
4 ConversionMode, ParamConverter, ResultSet, RowValues, SqlMiddlewareDbError,
5};
6use crate::pool::MiddlewarePoolConnection;
7use crate::tx_outcome::TxOutcome;
8
9use super::connection::SqliteConnection;
10use super::params::Params;
11
12use std::sync::atomic::{AtomicBool, Ordering};
13
14static REWRAP_ON_ROLLBACK_FAILURE: AtomicBool = AtomicBool::new(false);
15
16#[doc(hidden)]
17pub fn set_rewrap_on_rollback_failure_for_tests(rewrap: bool) {
18 REWRAP_ON_ROLLBACK_FAILURE.store(rewrap, Ordering::Relaxed);
19}
20
21fn rewrap_on_rollback_failure_for_tests() -> bool {
22 REWRAP_ON_ROLLBACK_FAILURE.load(Ordering::Relaxed)
23}
24
25pub struct Tx<'a> {
27 conn: Option<SqliteConnection>,
28 conn_slot: &'a mut MiddlewarePoolConnection,
29}
30
31pub struct Prepared {
33 sql: Arc<String>,
34}
35
36pub async fn begin_transaction(
42 conn_slot: &mut MiddlewarePoolConnection,
43) -> Result<Tx<'_>, SqlMiddlewareDbError> {
44 #[cfg(any(feature = "postgres", feature = "mssql", feature = "turso"))]
45 let MiddlewarePoolConnection::Sqlite { conn, .. } = conn_slot else {
46 return Err(SqlMiddlewareDbError::Unimplemented(
47 "begin_transaction is only available for SQLite connections".into(),
48 ));
49 };
50 #[cfg(not(any(feature = "postgres", feature = "mssql", feature = "turso")))]
51 let MiddlewarePoolConnection::Sqlite { conn, .. } = conn_slot;
52
53 let mut conn = conn.take().ok_or_else(|| {
54 SqlMiddlewareDbError::ExecutionError(
55 "SQLite connection already taken from pool wrapper".into(),
56 )
57 })?;
58 conn.begin().await?;
59 Ok(Tx {
60 conn: Some(conn),
61 conn_slot,
62 })
63}
64
65impl Tx<'_> {
66 fn conn_mut(&mut self) -> Result<&mut SqliteConnection, SqlMiddlewareDbError> {
67 self.conn.as_mut().ok_or_else(|| {
68 SqlMiddlewareDbError::ExecutionError("SQLite transaction already completed".into())
69 })
70 }
71
72 pub fn prepare(&self, sql: &str) -> Result<Prepared, SqlMiddlewareDbError> {
77 if self.conn.is_none() {
78 return Err(SqlMiddlewareDbError::ExecutionError(
79 "SQLite transaction already completed".into(),
80 ));
81 }
82 Ok(Prepared {
83 sql: Arc::new(sql.to_owned()),
84 })
85 }
86
87 pub async fn execute_prepared(
92 &mut self,
93 prepared: &Prepared,
94 params: &[RowValues],
95 ) -> Result<usize, SqlMiddlewareDbError> {
96 let converted =
97 <Params as ParamConverter>::convert_sql_params(params, ConversionMode::Execute)?;
98 let conn = self.conn_mut()?;
99 conn.execute_dml_in_tx(prepared.sql.as_ref(), &converted.0)
100 .await
101 }
102
103 pub async fn query_prepared(
108 &mut self,
109 prepared: &Prepared,
110 params: &[RowValues],
111 ) -> Result<ResultSet, SqlMiddlewareDbError> {
112 let converted =
113 <Params as ParamConverter>::convert_sql_params(params, ConversionMode::Query)?;
114 let conn = self.conn_mut()?;
115 conn.execute_select_in_tx(
116 prepared.sql.as_ref(),
117 &converted.0,
118 super::query::build_result_set,
119 )
120 .await
121 }
122
123 pub async fn execute_batch(&mut self, sql: &str) -> Result<(), SqlMiddlewareDbError> {
128 let conn = self.conn_mut()?;
129 conn.execute_batch_in_tx(sql).await
130 }
131
132 pub async fn commit(mut self) -> Result<TxOutcome, SqlMiddlewareDbError> {
137 let mut conn = self.conn.take().ok_or_else(|| {
138 SqlMiddlewareDbError::ExecutionError("SQLite transaction already completed".into())
139 })?;
140 match conn.commit().await {
141 Ok(()) => {
142 self.rewrap(conn);
143 Ok(TxOutcome::without_restored_connection())
144 }
145 Err(err) => {
146 let handle = conn.conn_handle();
147 let rollback_result =
148 super::connection::rollback_with_busy_retries(&handle).await;
149 if rollback_result.is_ok() || rewrap_on_rollback_failure_for_tests() {
150 conn.in_transaction = false;
151 self.rewrap(conn);
152 }
153 if rollback_result.is_err() && !rewrap_on_rollback_failure_for_tests() {
154 handle.mark_broken();
155 }
156 Err(err)
157 }
158 }
159 }
160
161 pub async fn rollback(mut self) -> Result<TxOutcome, SqlMiddlewareDbError> {
166 let mut conn = self.conn.take().ok_or_else(|| {
167 SqlMiddlewareDbError::ExecutionError("SQLite transaction already completed".into())
168 })?;
169 let handle = conn.conn_handle();
170 match super::connection::rollback_with_busy_retries(&handle).await {
171 Ok(()) => {
172 conn.in_transaction = false;
173 self.rewrap(conn);
174 Ok(TxOutcome::without_restored_connection())
175 }
176 Err(err) => {
177 if rewrap_on_rollback_failure_for_tests() {
178 conn.in_transaction = false;
179 self.rewrap(conn);
180 }
181 if !rewrap_on_rollback_failure_for_tests() {
182 handle.mark_broken();
183 }
184 Err(err)
185 }
186 }
187 }
188
189 fn rewrap(&mut self, conn: SqliteConnection) {
190 #[cfg(any(feature = "postgres", feature = "mssql", feature = "turso"))]
191 let MiddlewarePoolConnection::Sqlite { conn: slot, .. } = self.conn_slot else {
192 return;
193 };
194 #[cfg(not(any(feature = "postgres", feature = "mssql", feature = "turso")))]
195 let MiddlewarePoolConnection::Sqlite { conn: slot, .. } = self.conn_slot;
196 debug_assert!(slot.is_none(), "sqlite conn slot should be empty during tx");
197 *slot = Some(conn);
198 }
199}
200
201impl Drop for Tx<'_> {
202 fn drop(&mut self) {
207 if let Some(mut conn) = self.conn.take() {
208 let handle = conn.conn_handle();
209 let rollback_result =
210 super::connection::rollback_with_busy_retries_blocking(&handle);
211 if rollback_result.is_ok() || rewrap_on_rollback_failure_for_tests() {
212 conn.in_transaction = false;
213 self.rewrap(conn);
214 } else {
215 handle.mark_broken();
218 }
219 }
220 }
221}