sql_middleware/sqlite/typed/
tx.rs1use std::sync::Arc;
2use std::sync::atomic::Ordering;
3
4use tokio::runtime::Handle;
5use tokio::task::block_in_place;
6
7use crate::middleware::SqlMiddlewareDbError;
8
9use super::SqliteTypedConnection;
10use super::core::{SKIP_DROP_ROLLBACK, begin_from_conn, run_blocking};
11use crate::sqlite::connection::{rollback_with_busy_retries, rollback_with_busy_retries_blocking};
12use crate::sqlite::config::SharedSqliteConnection;
13
14impl SqliteTypedConnection<super::core::Idle> {
15 pub async fn begin(
20 mut self,
21 ) -> Result<SqliteTypedConnection<super::core::InTx>, SqlMiddlewareDbError> {
22 begin_from_conn(self.take_conn()?).await
23 }
24}
25
26impl SqliteTypedConnection<super::core::InTx> {
27 pub async fn commit(
32 mut self,
33 ) -> Result<SqliteTypedConnection<super::core::Idle>, SqlMiddlewareDbError> {
34 let conn_handle = self.conn_handle()?;
35 let commit_result = run_blocking(Arc::clone(&conn_handle), |guard| {
36 guard
37 .execute_batch("COMMIT")
38 .map_err(SqlMiddlewareDbError::SqliteError)
39 })
40 .await;
41
42 match commit_result {
43 Ok(()) => {
44 let conn = self.take_conn()?;
45 Ok(SqliteTypedConnection {
46 conn: Some(conn),
47 needs_rollback: false,
48 _state: std::marker::PhantomData,
49 })
50 }
51 Err(err) => {
52 if rollback_with_busy_retries(&conn_handle).await.is_err() {
54 conn_handle.mark_broken();
55 }
56 Err(err)
57 }
58 }
59 }
60
61 pub async fn rollback(
66 mut self,
67 ) -> Result<SqliteTypedConnection<super::core::Idle>, SqlMiddlewareDbError> {
68 let conn_handle = self.conn_handle()?;
69 let rollback_result = rollback_with_busy_retries(&conn_handle).await;
70
71 match rollback_result {
72 Ok(()) => {
73 let conn = self.take_conn()?;
74 Ok(SqliteTypedConnection {
75 conn: Some(conn),
76 needs_rollback: false,
77 _state: std::marker::PhantomData,
78 })
79 }
80 Err(err) => {
81 conn_handle.mark_broken();
83 Err(err)
84 }
85 }
86 }
87}
88
89impl<State> Drop for SqliteTypedConnection<State> {
90 fn drop(&mut self) {
91 if self.needs_rollback
92 && !skip_drop_rollback()
93 && let Some(conn) = self.conn.take()
94 {
95 let conn_handle: SharedSqliteConnection = Arc::clone(&*conn);
96 let rollback = || rollback_with_busy_retries_blocking(&conn_handle);
100 let result = if Handle::try_current().is_ok() {
101 block_in_place(rollback)
102 } else {
103 rollback()
104 };
105
106 if result.is_err() {
107 conn_handle.mark_broken();
108 }
109 }
110 }
111}
112
113fn skip_drop_rollback() -> bool {
114 SKIP_DROP_ROLLBACK.load(Ordering::Relaxed)
115}
116
117#[doc(hidden)]
120pub fn set_skip_drop_rollback_for_tests(skip: bool) {
121 SKIP_DROP_ROLLBACK.store(skip, Ordering::Relaxed);
122}