Skip to main content

sql_middleware/sqlite/
config.rs

1use std::path::PathBuf;
2use std::sync::Arc;
3use std::sync::atomic::{AtomicBool, Ordering};
4use std::thread;
5
6use bb8::{ManageConnection, Pool, PooledConnection};
7use crossbeam_channel::{Sender, unbounded};
8
9use crate::middleware::SqlMiddlewareDbError;
10
11/// Type alias for the pooled `SQLite` connection wrapper.
12pub type SqlitePooledConnection = PooledConnection<'static, SqliteManager>;
13
14/// Shared, worker-backed `SQLite` connection handle.
15pub type SharedSqliteConnection = Arc<SqliteWorker>;
16
17/// Test-only helper to rollback a connection from the pool.
18#[doc(hidden)]
19#[cfg(feature = "sqlite")]
20pub async fn rollback_for_tests(pool: &Pool<SqliteManager>) -> Result<(), SqlMiddlewareDbError> {
21    let conn = pool.get_owned().await.map_err(|e| {
22        SqlMiddlewareDbError::ConnectionError(format!("sqlite cleanup checkout error: {e}"))
23    })?;
24    let handle = Arc::clone(&*conn);
25    crate::sqlite::connection::run_blocking(handle, |c| {
26        c.execute_batch("ROLLBACK;")
27            .map_err(SqlMiddlewareDbError::SqliteError)
28    })
29    .await
30}
31
32enum SqliteWorkerMessage {
33    Execute(Box<dyn FnOnce(&mut rusqlite::Connection) + Send + 'static>),
34    Shutdown,
35}
36
37#[derive(Debug)]
38pub struct SqliteWorker {
39    sender: Sender<SqliteWorkerMessage>,
40    broken: Arc<AtomicBool>,
41    force_rollback_busy_for_tests: AtomicBool,
42}
43
44impl SqliteWorker {
45    pub(crate) fn start(conn: rusqlite::Connection) -> Arc<Self> {
46        let (sender, receiver) = unbounded::<SqliteWorkerMessage>();
47        let broken = Arc::new(AtomicBool::new(false));
48        let broken_flag = Arc::clone(&broken);
49        let mut conn = Some(conn);
50        // Dedicated worker thread to service requests for this pooled connection.
51        let _ = thread::Builder::new()
52            .name("sql-middleware-sqlite-worker".into())
53            .spawn(move || {
54                let mut conn = conn
55                    .take()
56                    .expect("sqlite worker missing connection at start");
57                for msg in &receiver {
58                    match msg {
59                        SqliteWorkerMessage::Execute(job) => {
60                            // If a job panics, mark the worker broken and exit to avoid
61                            // leaving the connection in an unknown state.
62                            let result =
63                                std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
64                                    job(&mut conn);
65                                }));
66                            if result.is_err() {
67                                broken_flag.store(true, Ordering::Relaxed);
68                                break;
69                            }
70                        }
71                        SqliteWorkerMessage::Shutdown => break,
72                    }
73                }
74                broken_flag.store(true, Ordering::Relaxed);
75            });
76
77        Arc::new(Self {
78            sender,
79            broken,
80            force_rollback_busy_for_tests: AtomicBool::new(false),
81        })
82    }
83
84    pub(crate) fn execute<F>(&self, func: F) -> Result<(), SqlMiddlewareDbError>
85    where
86        F: FnOnce(&mut rusqlite::Connection) + Send + 'static,
87    {
88        self.sender
89            .send(SqliteWorkerMessage::Execute(Box::new(func)))
90            .map_err(|_| {
91                SqlMiddlewareDbError::ExecutionError(
92                    "sqlite worker channel unexpectedly closed".into(),
93                )
94            })
95    }
96
97    pub(crate) fn execute_blocking<F, R>(&self, func: F) -> Result<R, SqlMiddlewareDbError>
98    where
99        F: FnOnce(&mut rusqlite::Connection) -> Result<R, SqlMiddlewareDbError> + Send + 'static,
100        R: Send + 'static,
101    {
102        let (resp_tx, resp_rx) = crossbeam_channel::bounded(1);
103        self.sender
104            .send(SqliteWorkerMessage::Execute(Box::new(move |conn| {
105                let _ = resp_tx.send(func(conn));
106            })))
107            .map_err(|_| {
108                SqlMiddlewareDbError::ExecutionError(
109                    "sqlite worker channel unexpectedly closed".into(),
110                )
111            })?;
112        resp_rx.recv().map_err(|_| {
113            SqlMiddlewareDbError::ExecutionError(
114                "sqlite worker response channel unexpectedly closed".into(),
115            )
116        })?
117    }
118
119    #[must_use]
120    pub(crate) fn is_broken(&self) -> bool {
121        self.broken.load(Ordering::Relaxed)
122    }
123
124    #[cfg(test)]
125    #[must_use]
126    pub fn is_broken_for_tests(&self) -> bool {
127        self.is_broken()
128    }
129
130    pub(crate) fn mark_broken(&self) {
131        self.broken.store(true, Ordering::Relaxed);
132    }
133
134    #[doc(hidden)]
135    pub fn set_force_rollback_busy_for_tests(&self, force: bool) {
136        self.force_rollback_busy_for_tests
137            .store(force, Ordering::Relaxed);
138    }
139
140    pub(crate) fn force_rollback_busy_for_tests(&self) -> bool {
141        self.force_rollback_busy_for_tests.load(Ordering::Relaxed)
142    }
143}
144
145impl Drop for SqliteWorker {
146    fn drop(&mut self) {
147        let _ = self.sender.send(SqliteWorkerMessage::Shutdown);
148    }
149}
150
151/// bb8 manager for `SQLite` connections.
152pub struct SqliteManager {
153    db_path: PathBuf,
154}
155
156impl SqliteManager {
157    #[must_use]
158    pub fn new(db_path: String) -> Self {
159        Self {
160            db_path: db_path.into(),
161        }
162    }
163
164    #[must_use]
165    pub fn from_path(db_path: impl Into<PathBuf>) -> Self {
166        Self {
167            db_path: db_path.into(),
168        }
169    }
170
171    /// Build a pool from this manager.
172    ///
173    /// # Errors
174    /// Returns `SqlMiddlewareDbError` if pool creation fails.
175    pub async fn build_pool(self) -> Result<Pool<SqliteManager>, SqlMiddlewareDbError> {
176        Pool::builder()
177            .build(self)
178            .await
179            .map_err(|e| SqlMiddlewareDbError::ConnectionError(format!("sqlite pool error: {e}")))
180    }
181}
182
183impl ManageConnection for SqliteManager {
184    type Connection = SharedSqliteConnection;
185    type Error = SqlMiddlewareDbError;
186
187    fn connect(
188        &self,
189    ) -> impl std::future::Future<Output = Result<Self::Connection, Self::Error>> + Send {
190        let path = self.db_path.clone();
191        async move {
192            let conn =
193                rusqlite::Connection::open(path).map_err(SqlMiddlewareDbError::SqliteError)?;
194            Ok(SqliteWorker::start(conn))
195        }
196    }
197
198    fn is_valid(
199        &self,
200        conn: &mut Self::Connection,
201    ) -> impl std::future::Future<Output = Result<(), Self::Error>> + Send {
202        let conn = Arc::clone(conn);
203        async move {
204            crate::sqlite::connection::run_blocking(conn, |guard| {
205                guard
206                    .query_row("SELECT 1", rusqlite::params![], |_row| Ok(()))
207                    .map_err(SqlMiddlewareDbError::SqliteError)
208            })
209            .await
210        }
211    }
212
213    fn has_broken(&self, conn: &mut Self::Connection) -> bool {
214        conn.is_broken()
215    }
216}