sql_middleware/sqlite/connection/
core.rs

1use std::fmt;
2use std::sync::Arc;
3
4use crate::middleware::SqlMiddlewareDbError;
5
6use crate::sqlite::config::{SharedSqliteConnection, SqlitePooledConnection};
7use tokio::sync::oneshot;
8
9/// Connection wrapper backed by a bb8 pooled `SQLite` connection.
10pub struct SqliteConnection {
11    pub(crate) conn: SqlitePooledConnection,
12    pub(crate) in_transaction: bool,
13}
14
15impl SqliteConnection {
16    pub(crate) fn new(conn: SqlitePooledConnection) -> Self {
17        Self {
18            conn,
19            in_transaction: false,
20        }
21    }
22
23    /// Run `func` on the pooled rusqlite connection while no other transaction is in flight.
24    ///
25    /// # Errors
26    /// Returns `SqlMiddlewareDbError::ExecutionError` if the connection is in a transaction or the closure returns an error.
27    pub async fn with_connection<F, R>(&self, func: F) -> Result<R, SqlMiddlewareDbError>
28    where
29        F: FnOnce(&mut rusqlite::Connection) -> Result<R, SqlMiddlewareDbError> + Send + 'static,
30        R: Send + 'static,
31    {
32        if self.in_transaction {
33            return Err(SqlMiddlewareDbError::ExecutionError(
34                "SQLite transaction in progress; operation not permitted (with connection)".into(),
35            ));
36        }
37        run_blocking(self.conn_handle(), func).await
38    }
39
40    pub(crate) fn conn_handle(&self) -> SharedSqliteConnection {
41        Arc::clone(&*self.conn)
42    }
43
44    pub(crate) fn mark_broken(&self) {
45        self.conn_handle().mark_broken();
46    }
47
48    #[doc(hidden)]
49    pub fn set_force_rollback_busy_for_tests(&self, force: bool) {
50        self.conn_handle().set_force_rollback_busy_for_tests(force);
51    }
52
53    pub(crate) fn ensure_not_in_tx(&self, ctx: &str) -> Result<(), SqlMiddlewareDbError> {
54        if self.in_transaction {
55            Err(SqlMiddlewareDbError::ExecutionError(format!(
56                "SQLite transaction in progress; operation not permitted ({ctx})"
57            )))
58        } else {
59            Ok(())
60        }
61    }
62}
63
64impl fmt::Debug for SqliteConnection {
65    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66        f.debug_struct("SqliteConnection")
67            .field("conn", &self.conn)
68            .field("in_transaction", &self.in_transaction)
69            .finish()
70    }
71}
72
73pub(crate) async fn run_blocking<F, R>(
74    conn: SharedSqliteConnection,
75    func: F,
76) -> Result<R, SqlMiddlewareDbError>
77where
78    F: FnOnce(&mut rusqlite::Connection) -> Result<R, SqlMiddlewareDbError> + Send + 'static,
79    R: Send + 'static,
80{
81    let (tx, rx) = oneshot::channel();
82    conn.execute(move |conn| {
83        let _ = tx.send(func(conn));
84    })?;
85    rx.await.map_err(|e| {
86        SqlMiddlewareDbError::ExecutionError(format!("sqlite worker receive error: {e}"))
87    })?
88}
89
90/// Apply WAL pragmas to a pooled connection.
91///
92/// # Errors
93/// Returns `SqlMiddlewareDbError` if the PRAGMA statements cannot be executed.
94pub async fn apply_wal_pragmas(
95    conn: &mut SqlitePooledConnection,
96) -> Result<(), SqlMiddlewareDbError> {
97    let handle = Arc::clone(&*conn);
98    run_blocking(handle, |guard| {
99        guard
100            .execute_batch("PRAGMA journal_mode = WAL;")
101            .map_err(SqlMiddlewareDbError::SqliteError)
102    })
103    .await
104}