sql_middleware/sqlite/
transaction.rs1use std::sync::Arc;
2
3use crate::adapters::params::convert_params;
4use crate::middleware::{ConversionMode, ResultSet, RowValues, SqlMiddlewareDbError};
5use crate::pool::MiddlewarePoolConnection;
6use crate::tx_outcome::TxOutcome;
7
8use super::connection::SqliteConnection;
9use super::params::Params;
10
11use std::sync::atomic::{AtomicBool, Ordering};
12
13static REWRAP_ON_ROLLBACK_FAILURE: AtomicBool = AtomicBool::new(false);
14
15#[doc(hidden)]
16pub fn set_rewrap_on_rollback_failure_for_tests(rewrap: bool) {
17 REWRAP_ON_ROLLBACK_FAILURE.store(rewrap, Ordering::Relaxed);
18}
19
20fn rewrap_on_rollback_failure_for_tests() -> bool {
21 REWRAP_ON_ROLLBACK_FAILURE.load(Ordering::Relaxed)
22}
23
24pub struct Tx<'a> {
26 conn: Option<SqliteConnection>,
27 conn_slot: &'a mut MiddlewarePoolConnection,
28}
29
30pub struct Prepared {
32 sql: Arc<String>,
33}
34
35pub async fn begin_transaction(
41 conn_slot: &mut MiddlewarePoolConnection,
42) -> Result<Tx<'_>, SqlMiddlewareDbError> {
43 #[cfg(any(feature = "postgres", feature = "mssql", feature = "turso"))]
44 let MiddlewarePoolConnection::Sqlite { conn, .. } = conn_slot else {
45 return Err(SqlMiddlewareDbError::Unimplemented(
46 "begin_transaction is only available for SQLite connections".into(),
47 ));
48 };
49 #[cfg(not(any(feature = "postgres", feature = "mssql", feature = "turso")))]
50 let MiddlewarePoolConnection::Sqlite { conn, .. } = conn_slot;
51
52 let mut conn = conn.take().ok_or_else(|| {
53 SqlMiddlewareDbError::ExecutionError(
54 "SQLite connection already taken from pool wrapper".into(),
55 )
56 })?;
57 conn.begin().await?;
58 Ok(Tx {
59 conn: Some(conn),
60 conn_slot,
61 })
62}
63
64impl Tx<'_> {
65 fn conn_mut(&mut self) -> Result<&mut SqliteConnection, SqlMiddlewareDbError> {
66 self.conn.as_mut().ok_or_else(|| {
67 SqlMiddlewareDbError::ExecutionError("SQLite transaction already completed".into())
68 })
69 }
70
71 pub fn prepare(&self, sql: &str) -> Result<Prepared, SqlMiddlewareDbError> {
76 if self.conn.is_none() {
77 return Err(SqlMiddlewareDbError::ExecutionError(
78 "SQLite transaction already completed".into(),
79 ));
80 }
81 Ok(Prepared {
82 sql: Arc::new(sql.to_owned()),
83 })
84 }
85
86 pub async fn execute_prepared(
91 &mut self,
92 prepared: &Prepared,
93 params: &[RowValues],
94 ) -> Result<usize, SqlMiddlewareDbError> {
95 let converted = convert_params::<Params>(params, ConversionMode::Execute)?;
96 let conn = self.conn_mut()?;
97 conn.execute_dml_in_tx(prepared.sql.as_ref(), &converted.0)
98 .await
99 }
100
101 pub async fn query_prepared(
106 &mut self,
107 prepared: &Prepared,
108 params: &[RowValues],
109 ) -> Result<ResultSet, SqlMiddlewareDbError> {
110 let converted = convert_params::<Params>(params, ConversionMode::Query)?;
111 let conn = self.conn_mut()?;
112 conn.execute_select_in_tx(
113 prepared.sql.as_ref(),
114 &converted.0,
115 super::query::build_result_set,
116 )
117 .await
118 }
119
120 pub async fn execute_batch(&mut self, sql: &str) -> Result<(), SqlMiddlewareDbError> {
125 let conn = self.conn_mut()?;
126 conn.execute_batch_in_tx(sql).await
127 }
128
129 pub async fn commit(mut self) -> Result<TxOutcome, SqlMiddlewareDbError> {
134 let mut conn = self.conn.take().ok_or_else(|| {
135 SqlMiddlewareDbError::ExecutionError("SQLite transaction already completed".into())
136 })?;
137 match conn.commit().await {
138 Ok(()) => {
139 self.rewrap(conn);
140 Ok(TxOutcome::without_restored_connection())
141 }
142 Err(err) => {
143 let handle = conn.conn_handle();
144 let rollback_result =
145 super::connection::rollback_with_busy_retries(&handle).await;
146 if rollback_result.is_ok() || rewrap_on_rollback_failure_for_tests() {
147 conn.in_transaction = false;
148 self.rewrap(conn);
149 }
150 if rollback_result.is_err() && !rewrap_on_rollback_failure_for_tests() {
151 handle.mark_broken();
152 }
153 Err(err)
154 }
155 }
156 }
157
158 pub async fn rollback(mut self) -> Result<TxOutcome, SqlMiddlewareDbError> {
163 let mut conn = self.conn.take().ok_or_else(|| {
164 SqlMiddlewareDbError::ExecutionError("SQLite transaction already completed".into())
165 })?;
166 let handle = conn.conn_handle();
167 match super::connection::rollback_with_busy_retries(&handle).await {
168 Ok(()) => {
169 conn.in_transaction = false;
170 self.rewrap(conn);
171 Ok(TxOutcome::without_restored_connection())
172 }
173 Err(err) => {
174 if rewrap_on_rollback_failure_for_tests() {
175 conn.in_transaction = false;
176 self.rewrap(conn);
177 }
178 if !rewrap_on_rollback_failure_for_tests() {
179 handle.mark_broken();
180 }
181 Err(err)
182 }
183 }
184 }
185
186 fn rewrap(&mut self, conn: SqliteConnection) {
187 #[cfg(any(feature = "postgres", feature = "mssql", feature = "turso"))]
188 let MiddlewarePoolConnection::Sqlite { conn: slot, .. } = self.conn_slot else {
189 return;
190 };
191 #[cfg(not(any(feature = "postgres", feature = "mssql", feature = "turso")))]
192 let MiddlewarePoolConnection::Sqlite { conn: slot, .. } = self.conn_slot;
193 debug_assert!(slot.is_none(), "sqlite conn slot should be empty during tx");
194 *slot = Some(conn);
195 }
196}
197
198impl Drop for Tx<'_> {
199 fn drop(&mut self) {
204 if let Some(mut conn) = self.conn.take() {
205 let handle = conn.conn_handle();
206 let rollback_result =
207 super::connection::rollback_with_busy_retries_blocking(&handle);
208 if rollback_result.is_ok() || rewrap_on_rollback_failure_for_tests() {
209 conn.in_transaction = false;
210 self.rewrap(conn);
211 } else {
212 handle.mark_broken();
215 }
216 }
217 }
218}