1use std::sync::Arc;
2
3use crate::adapters::params::convert_params;
4use crate::middleware::{ConversionMode, CustomDbRow, ResultSet, RowValues, SqlMiddlewareDbError};
5use crate::pool::MiddlewarePoolConnection;
6use crate::tx_outcome::TxOutcome;
7
8use super::connection::SqliteConnection;
9use super::params::Params;
10
11use std::sync::atomic::{AtomicBool, Ordering};
12
13static REWRAP_ON_ROLLBACK_FAILURE: AtomicBool = AtomicBool::new(false);
14
15#[doc(hidden)]
16pub fn set_rewrap_on_rollback_failure_for_tests(rewrap: bool) {
17 REWRAP_ON_ROLLBACK_FAILURE.store(rewrap, Ordering::Relaxed);
18}
19
20fn rewrap_on_rollback_failure_for_tests() -> bool {
21 REWRAP_ON_ROLLBACK_FAILURE.load(Ordering::Relaxed)
22}
23
24pub struct Tx<'a> {
26 conn: Option<SqliteConnection>,
27 conn_slot: &'a mut MiddlewarePoolConnection,
28}
29
30pub struct Prepared {
32 sql: Arc<String>,
33}
34
35pub async fn begin_transaction(
41 conn_slot: &mut MiddlewarePoolConnection,
42) -> Result<Tx<'_>, SqlMiddlewareDbError> {
43 #[cfg(any(feature = "postgres", feature = "mssql", feature = "turso"))]
44 let MiddlewarePoolConnection::Sqlite { conn, .. } = conn_slot else {
45 return Err(SqlMiddlewareDbError::Unimplemented(
46 "begin_transaction is only available for SQLite connections".into(),
47 ));
48 };
49 #[cfg(not(any(feature = "postgres", feature = "mssql", feature = "turso")))]
50 let MiddlewarePoolConnection::Sqlite { conn, .. } = conn_slot;
51
52 let mut conn = conn.take().ok_or_else(|| {
53 SqlMiddlewareDbError::ExecutionError(
54 "SQLite connection already taken from pool wrapper".into(),
55 )
56 })?;
57 conn.begin().await?;
58 Ok(Tx {
59 conn: Some(conn),
60 conn_slot,
61 })
62}
63
64impl<'conn> Tx<'conn> {
65 fn conn_mut(&mut self) -> Result<&mut SqliteConnection, SqlMiddlewareDbError> {
66 self.conn.as_mut().ok_or_else(|| {
67 SqlMiddlewareDbError::ExecutionError("SQLite transaction already completed".into())
68 })
69 }
70
71 pub fn prepare(&self, sql: &str) -> Result<Prepared, SqlMiddlewareDbError> {
76 if self.conn.is_none() {
77 return Err(SqlMiddlewareDbError::ExecutionError(
78 "SQLite transaction already completed".into(),
79 ));
80 }
81 Ok(Prepared {
82 sql: Arc::new(sql.to_owned()),
83 })
84 }
85
86 #[must_use]
88 pub fn select<'tx, 'prepared>(
89 &'tx mut self,
90 prepared: &'prepared Prepared,
91 ) -> PreparedSelect<'tx, 'prepared, 'static, 'conn> {
92 PreparedSelect {
93 tx: self,
94 prepared,
95 params: &[],
96 }
97 }
98
99 #[must_use]
101 pub fn execute<'tx, 'prepared>(
102 &'tx mut self,
103 prepared: &'prepared Prepared,
104 ) -> PreparedExecute<'tx, 'prepared, 'static, 'conn> {
105 PreparedExecute {
106 tx: self,
107 prepared,
108 params: &[],
109 }
110 }
111
112 pub(crate) async fn execute_prepared(
117 &mut self,
118 prepared: &Prepared,
119 params: &[RowValues],
120 ) -> Result<usize, SqlMiddlewareDbError> {
121 let converted = convert_params::<Params>(params, ConversionMode::Execute)?;
122 let conn = self.conn_mut()?;
123 conn.execute_dml_in_tx(prepared.sql.as_ref(), &converted.0)
124 .await
125 }
126
127 pub(crate) async fn query_prepared(
132 &mut self,
133 prepared: &Prepared,
134 params: &[RowValues],
135 ) -> Result<ResultSet, SqlMiddlewareDbError> {
136 let converted = convert_params::<Params>(params, ConversionMode::Query)?;
137 let conn = self.conn_mut()?;
138 conn.execute_select_in_tx(
139 prepared.sql.as_ref(),
140 &converted.0,
141 super::query::build_result_set,
142 )
143 .await
144 }
145
146 pub async fn execute_batch(&mut self, sql: &str) -> Result<(), SqlMiddlewareDbError> {
151 let conn = self.conn_mut()?;
152 conn.execute_batch_in_tx(sql).await
153 }
154
155 pub async fn commit(mut self) -> Result<TxOutcome, SqlMiddlewareDbError> {
160 let mut conn = self.conn.take().ok_or_else(|| {
161 SqlMiddlewareDbError::ExecutionError("SQLite transaction already completed".into())
162 })?;
163 match conn.commit().await {
164 Ok(()) => {
165 self.rewrap(conn);
166 Ok(TxOutcome::without_restored_connection())
167 }
168 Err(err) => {
169 let handle = conn.conn_handle();
170 let rollback_result = super::connection::rollback_with_busy_retries(&handle).await;
171 if rollback_result.is_ok() || rewrap_on_rollback_failure_for_tests() {
172 conn.in_transaction = false;
173 self.rewrap(conn);
174 }
175 if rollback_result.is_err() && !rewrap_on_rollback_failure_for_tests() {
176 handle.mark_broken();
177 }
178 Err(err)
179 }
180 }
181 }
182
183 pub async fn rollback(mut self) -> Result<TxOutcome, SqlMiddlewareDbError> {
188 let mut conn = self.conn.take().ok_or_else(|| {
189 SqlMiddlewareDbError::ExecutionError("SQLite transaction already completed".into())
190 })?;
191 let handle = conn.conn_handle();
192 match super::connection::rollback_with_busy_retries(&handle).await {
193 Ok(()) => {
194 conn.in_transaction = false;
195 self.rewrap(conn);
196 Ok(TxOutcome::without_restored_connection())
197 }
198 Err(err) => {
199 if rewrap_on_rollback_failure_for_tests() {
200 conn.in_transaction = false;
201 self.rewrap(conn);
202 }
203 if !rewrap_on_rollback_failure_for_tests() {
204 handle.mark_broken();
205 }
206 Err(err)
207 }
208 }
209 }
210
211 fn rewrap(&mut self, conn: SqliteConnection) {
212 #[cfg(any(feature = "postgres", feature = "mssql", feature = "turso"))]
213 let MiddlewarePoolConnection::Sqlite { conn: slot, .. } = self.conn_slot else {
214 return;
215 };
216 #[cfg(not(any(feature = "postgres", feature = "mssql", feature = "turso")))]
217 let MiddlewarePoolConnection::Sqlite { conn: slot, .. } = self.conn_slot;
218 debug_assert!(slot.is_none(), "sqlite conn slot should be empty during tx");
219 *slot = Some(conn);
220 }
221}
222
223pub struct PreparedExecute<'tx, 'prepared, 'params, 'conn> {
225 tx: &'tx mut Tx<'conn>,
226 prepared: &'prepared Prepared,
227 params: &'params [RowValues],
228}
229
230impl<'tx, 'prepared, 'params, 'conn> PreparedExecute<'tx, 'prepared, 'params, 'conn> {
231 #[must_use]
233 pub fn params<'next>(
234 self,
235 params: &'next [RowValues],
236 ) -> PreparedExecute<'tx, 'prepared, 'next, 'conn> {
237 PreparedExecute {
238 tx: self.tx,
239 prepared: self.prepared,
240 params,
241 }
242 }
243
244 pub async fn run(self) -> Result<usize, SqlMiddlewareDbError> {
249 self.tx.execute_prepared(self.prepared, self.params).await
250 }
251}
252
253pub struct PreparedSelect<'tx, 'prepared, 'params, 'conn> {
255 tx: &'tx mut Tx<'conn>,
256 prepared: &'prepared Prepared,
257 params: &'params [RowValues],
258}
259
260impl<'tx, 'prepared, 'params, 'conn> PreparedSelect<'tx, 'prepared, 'params, 'conn> {
261 #[must_use]
263 pub fn params<'next>(
264 self,
265 params: &'next [RowValues],
266 ) -> PreparedSelect<'tx, 'prepared, 'next, 'conn> {
267 PreparedSelect {
268 tx: self.tx,
269 prepared: self.prepared,
270 params,
271 }
272 }
273
274 pub async fn all(self) -> Result<ResultSet, SqlMiddlewareDbError> {
279 self.tx.query_prepared(self.prepared, self.params).await
280 }
281
282 pub async fn optional(self) -> Result<Option<CustomDbRow>, SqlMiddlewareDbError> {
287 self.all().await.map(ResultSet::into_optional)
288 }
289
290 pub async fn one(self) -> Result<CustomDbRow, SqlMiddlewareDbError> {
295 self.all().await?.into_one()
296 }
297}
298
299impl Drop for Tx<'_> {
300 fn drop(&mut self) {
305 if let Some(mut conn) = self.conn.take() {
306 let handle = conn.conn_handle();
307 let rollback_result = super::connection::rollback_with_busy_retries_blocking(&handle);
308 if rollback_result.is_ok() || rewrap_on_rollback_failure_for_tests() {
309 conn.in_transaction = false;
310 self.rewrap(conn);
311 } else {
312 handle.mark_broken();
315 }
316 }
317 }
318}