sql_middleware/sqlite/typed/
tx.rs1use std::sync::Arc;
2use std::sync::atomic::Ordering;
3
4use tokio::runtime::Handle;
5use tokio::task::block_in_place;
6
7use crate::middleware::SqlMiddlewareDbError;
8
9use super::SqliteTypedConnection;
10use super::core::{SKIP_DROP_ROLLBACK, begin_from_conn, run_blocking};
11use crate::sqlite::config::SharedSqliteConnection;
12
13impl SqliteTypedConnection<super::core::Idle> {
14 pub async fn begin(
19 mut self,
20 ) -> Result<SqliteTypedConnection<super::core::InTx>, SqlMiddlewareDbError> {
21 begin_from_conn(self.take_conn()?).await
22 }
23}
24
25impl SqliteTypedConnection<super::core::InTx> {
26 pub async fn commit(
31 mut self,
32 ) -> Result<SqliteTypedConnection<super::core::Idle>, SqlMiddlewareDbError> {
33 let conn_handle = self.conn_handle()?;
34 let commit_result = run_blocking(Arc::clone(&conn_handle), |guard| {
35 guard
36 .execute_batch("COMMIT")
37 .map_err(SqlMiddlewareDbError::SqliteError)
38 })
39 .await;
40
41 match commit_result {
42 Ok(()) => {
43 let conn = self.take_conn()?;
44 Ok(SqliteTypedConnection {
45 conn: Some(conn),
46 needs_rollback: false,
47 _state: std::marker::PhantomData,
48 })
49 }
50 Err(err) => {
51 let _ = run_blocking(conn_handle, |guard| {
53 guard
54 .execute_batch("ROLLBACK")
55 .map_err(SqlMiddlewareDbError::SqliteError)
56 })
57 .await;
58 Err(err)
59 }
60 }
61 }
62
63 pub async fn rollback(
68 mut self,
69 ) -> Result<SqliteTypedConnection<super::core::Idle>, SqlMiddlewareDbError> {
70 let conn_handle = self.conn_handle()?;
71 let rollback_result = run_blocking(Arc::clone(&conn_handle), |guard| {
72 guard
73 .execute_batch("ROLLBACK")
74 .map_err(SqlMiddlewareDbError::SqliteError)
75 })
76 .await;
77
78 match rollback_result {
79 Ok(()) => {
80 let conn = self.take_conn()?;
81 Ok(SqliteTypedConnection {
82 conn: Some(conn),
83 needs_rollback: false,
84 _state: std::marker::PhantomData,
85 })
86 }
87 Err(err) => {
88 Err(err)
90 }
91 }
92 }
93}
94
95impl<State> Drop for SqliteTypedConnection<State> {
96 fn drop(&mut self) {
97 if self.needs_rollback
98 && !skip_drop_rollback()
99 && let Some(conn) = self.conn.take()
100 {
101 let conn_handle: SharedSqliteConnection = Arc::clone(&*conn);
102 if Handle::try_current().is_ok() {
106 block_in_place(|| {
107 let _ = conn_handle.execute_blocking(|conn| {
108 conn.execute_batch("ROLLBACK")
109 .map_err(SqlMiddlewareDbError::SqliteError)
110 });
111 });
112 } else {
113 let _ = conn_handle.execute_blocking(|conn| {
114 conn.execute_batch("ROLLBACK")
115 .map_err(SqlMiddlewareDbError::SqliteError)
116 });
117 }
118 }
119 }
120}
121
122fn skip_drop_rollback() -> bool {
123 SKIP_DROP_ROLLBACK.load(Ordering::Relaxed)
124}
125
126#[doc(hidden)]
129pub fn set_skip_drop_rollback_for_tests(skip: bool) {
130 SKIP_DROP_ROLLBACK.store(skip, Ordering::Relaxed);
131}