sql_middleware/sqlite/
config.rs

1use std::sync::Arc;
2use std::sync::atomic::{AtomicBool, Ordering};
3use std::thread;
4
5use bb8::{ManageConnection, Pool, PooledConnection};
6use crossbeam_channel::{Sender, unbounded};
7
8use crate::middleware::{ConfigAndPool, DatabaseType, MiddlewarePool, SqlMiddlewareDbError};
9
10/// Type alias for the pooled `SQLite` connection wrapper.
11pub type SqlitePooledConnection = PooledConnection<'static, SqliteManager>;
12
13/// Shared, worker-backed `SQLite` connection handle.
14pub type SharedSqliteConnection = Arc<SqliteWorker>;
15
16/// Test-only helper to rollback a connection from the pool.
17#[doc(hidden)]
18#[cfg(feature = "sqlite")]
19pub async fn rollback_for_tests(pool: &Pool<SqliteManager>) -> Result<(), SqlMiddlewareDbError> {
20    let conn = pool.get_owned().await.map_err(|e| {
21        SqlMiddlewareDbError::ConnectionError(format!("sqlite cleanup checkout error: {e}"))
22    })?;
23    let handle = Arc::clone(&*conn);
24    crate::sqlite::connection::run_blocking(handle, |c| {
25        c.execute_batch("ROLLBACK;")
26            .map_err(SqlMiddlewareDbError::SqliteError)
27    })
28    .await
29}
30
31enum SqliteWorkerMessage {
32    Execute(Box<dyn FnOnce(&mut rusqlite::Connection) + Send + 'static>),
33    Shutdown,
34}
35
36#[derive(Debug)]
37pub struct SqliteWorker {
38    sender: Sender<SqliteWorkerMessage>,
39    broken: Arc<AtomicBool>,
40    force_rollback_busy_for_tests: AtomicBool,
41}
42
43impl SqliteWorker {
44    pub(crate) fn start(conn: rusqlite::Connection) -> Arc<Self> {
45        let (sender, receiver) = unbounded::<SqliteWorkerMessage>();
46        let broken = Arc::new(AtomicBool::new(false));
47        let broken_flag = Arc::clone(&broken);
48        let mut conn = Some(conn);
49        // Dedicated worker thread to service requests for this pooled connection.
50        let _ = thread::Builder::new()
51            .name("sql-middleware-sqlite-worker".into())
52            .spawn(move || {
53                let mut conn = conn
54                    .take()
55                    .expect("sqlite worker missing connection at start");
56                for msg in &receiver {
57                    match msg {
58                        SqliteWorkerMessage::Execute(job) => {
59                            // If a job panics, mark the worker broken and exit to avoid
60                            // leaving the connection in an unknown state.
61                            let result =
62                                std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
63                                    job(&mut conn);
64                                }));
65                            if result.is_err() {
66                                broken_flag.store(true, Ordering::Relaxed);
67                                break;
68                            }
69                        }
70                        SqliteWorkerMessage::Shutdown => break,
71                    }
72                }
73                broken_flag.store(true, Ordering::Relaxed);
74            });
75
76        Arc::new(Self {
77            sender,
78            broken,
79            force_rollback_busy_for_tests: AtomicBool::new(false),
80        })
81    }
82
83    pub(crate) fn execute<F>(&self, func: F) -> Result<(), SqlMiddlewareDbError>
84    where
85        F: FnOnce(&mut rusqlite::Connection) + Send + 'static,
86    {
87        self.sender
88            .send(SqliteWorkerMessage::Execute(Box::new(func)))
89            .map_err(|_| {
90                SqlMiddlewareDbError::ExecutionError(
91                    "sqlite worker channel unexpectedly closed".into(),
92                )
93            })
94    }
95
96    pub(crate) fn execute_blocking<F, R>(&self, func: F) -> Result<R, SqlMiddlewareDbError>
97    where
98        F: FnOnce(&mut rusqlite::Connection) -> Result<R, SqlMiddlewareDbError> + Send + 'static,
99        R: Send + 'static,
100    {
101        let (resp_tx, resp_rx) = crossbeam_channel::bounded(1);
102        self.sender
103            .send(SqliteWorkerMessage::Execute(Box::new(move |conn| {
104                let _ = resp_tx.send(func(conn));
105            })))
106            .map_err(|_| {
107                SqlMiddlewareDbError::ExecutionError(
108                    "sqlite worker channel unexpectedly closed".into(),
109                )
110            })?;
111        resp_rx.recv().map_err(|_| {
112            SqlMiddlewareDbError::ExecutionError(
113                "sqlite worker response channel unexpectedly closed".into(),
114            )
115        })?
116    }
117
118    #[must_use]
119    pub(crate) fn is_broken(&self) -> bool {
120        self.broken.load(Ordering::Relaxed)
121    }
122
123    #[cfg(test)]
124    #[must_use]
125    pub fn is_broken_for_tests(&self) -> bool {
126        self.is_broken()
127    }
128
129    pub(crate) fn mark_broken(&self) {
130        self.broken.store(true, Ordering::Relaxed);
131    }
132
133    #[doc(hidden)]
134    pub fn set_force_rollback_busy_for_tests(&self, force: bool) {
135        self.force_rollback_busy_for_tests
136            .store(force, Ordering::Relaxed);
137    }
138
139    pub(crate) fn force_rollback_busy_for_tests(&self) -> bool {
140        self.force_rollback_busy_for_tests.load(Ordering::Relaxed)
141    }
142}
143
144impl Drop for SqliteWorker {
145    fn drop(&mut self) {
146        let _ = self.sender.send(SqliteWorkerMessage::Shutdown);
147    }
148}
149
150#[cfg(test)]
151mod tests {
152    use std::sync::Arc;
153
154    use bb8::Pool;
155
156    use super::SqliteManager;
157    use crate::middleware::SqlMiddlewareDbError;
158    use crate::sqlite::connection::run_blocking;
159
160    #[tokio::test]
161    async fn worker_panic_marks_connection_broken() -> Result<(), Box<dyn std::error::Error>> {
162        let pool = Pool::builder()
163            .max_size(1)
164            .build(SqliteManager::new("file::memory:?cache=shared".to_string()))
165            .await?;
166
167        let conn = pool.get_owned().await?;
168        let handle = Arc::clone(&*conn);
169        let err = run_blocking(handle, |_conn| -> Result<(), SqlMiddlewareDbError> {
170            panic!("boom");
171        })
172        .await
173        .expect_err("worker panic should surface as an error");
174        assert!(
175            err.to_string().contains("worker receive error"),
176            "unexpected error for worker panic: {err}"
177        );
178        assert!(conn.is_broken(), "connection should be marked broken");
179
180        drop(conn);
181
182        let conn = pool.get_owned().await?;
183        let handle = Arc::clone(&*conn);
184        run_blocking(handle, |c| {
185            c.query_row("SELECT 1", rusqlite::params![], |_row| Ok(()))
186                .map_err(SqlMiddlewareDbError::SqliteError)
187        })
188        .await?;
189        assert!(
190            !conn.is_broken(),
191            "replacement connection should be healthy"
192        );
193        Ok(())
194    }
195}
196
197/// Options for configuring a `SQLite` pool.
198#[derive(Debug, Clone)]
199pub struct SqliteOptions {
200    pub db_path: String,
201    pub translate_placeholders: bool,
202}
203
204impl SqliteOptions {
205    #[must_use]
206    pub fn new(db_path: String) -> Self {
207        Self {
208            db_path,
209            translate_placeholders: false,
210        }
211    }
212
213    #[must_use]
214    pub fn with_translation(mut self, translate_placeholders: bool) -> Self {
215        self.translate_placeholders = translate_placeholders;
216        self
217    }
218}
219
220/// Fluent builder for `SQLite` options.
221#[derive(Debug, Clone)]
222pub struct SqliteOptionsBuilder {
223    opts: SqliteOptions,
224}
225
226impl SqliteOptionsBuilder {
227    #[must_use]
228    pub fn new(db_path: String) -> Self {
229        Self {
230            opts: SqliteOptions::new(db_path),
231        }
232    }
233
234    #[must_use]
235    pub fn translation(mut self, translate_placeholders: bool) -> Self {
236        self.opts.translate_placeholders = translate_placeholders;
237        self
238    }
239
240    #[must_use]
241    pub fn finish(self) -> SqliteOptions {
242        self.opts
243    }
244
245    /// Build a `ConfigAndPool` for `SQLite`.
246    ///
247    /// # Errors
248    ///
249    /// Returns `SqlMiddlewareDbError` if pool creation or the initial smoke test fails.
250    pub async fn build(self) -> Result<ConfigAndPool, SqlMiddlewareDbError> {
251        ConfigAndPool::new_sqlite(self.finish()).await
252    }
253}
254
255impl ConfigAndPool {
256    #[must_use]
257    pub fn sqlite_builder(db_path: String) -> SqliteOptionsBuilder {
258        SqliteOptionsBuilder::new(db_path)
259    }
260
261    /// Asynchronous initializer for `ConfigAndPool` with Sqlite using a bb8-backed pool.
262    ///
263    /// # Errors
264    /// Returns `SqlMiddlewareDbError::ConnectionError` if pool creation or connection test fails.
265    pub async fn new_sqlite(opts: SqliteOptions) -> Result<Self, SqlMiddlewareDbError> {
266        let manager = SqliteManager::new(opts.db_path.clone());
267        let pool = manager.build_pool().await?;
268
269        // Initialize the database with WAL and a simple health check.
270        {
271            let mut conn = pool.get_owned().await.map_err(|e| {
272                SqlMiddlewareDbError::ConnectionError(format!("Failed to create SQLite pool: {e}"))
273            })?;
274
275            crate::sqlite::apply_wal_pragmas(&mut conn).await?;
276        }
277
278        Ok(ConfigAndPool {
279            pool: MiddlewarePool::Sqlite(pool),
280            db_type: DatabaseType::Sqlite,
281            translate_placeholders: opts.translate_placeholders,
282        })
283    }
284}
285
286/// bb8 manager for `SQLite` connections.
287pub struct SqliteManager {
288    db_path: String,
289}
290
291impl SqliteManager {
292    #[must_use]
293    pub fn new(db_path: String) -> Self {
294        Self { db_path }
295    }
296
297    /// Build a pool from this manager.
298    ///
299    /// # Errors
300    /// Returns `SqlMiddlewareDbError` if pool creation fails.
301    pub async fn build_pool(self) -> Result<Pool<SqliteManager>, SqlMiddlewareDbError> {
302        Pool::builder()
303            .build(self)
304            .await
305            .map_err(|e| SqlMiddlewareDbError::ConnectionError(format!("sqlite pool error: {e}")))
306    }
307}
308
309impl ManageConnection for SqliteManager {
310    type Connection = SharedSqliteConnection;
311    type Error = SqlMiddlewareDbError;
312
313    fn connect(
314        &self,
315    ) -> impl std::future::Future<Output = Result<Self::Connection, Self::Error>> + Send {
316        let path = self.db_path.clone();
317        async move {
318            let conn =
319                rusqlite::Connection::open(path).map_err(SqlMiddlewareDbError::SqliteError)?;
320            Ok(SqliteWorker::start(conn))
321        }
322    }
323
324    fn is_valid(
325        &self,
326        conn: &mut Self::Connection,
327    ) -> impl std::future::Future<Output = Result<(), Self::Error>> + Send {
328        let conn = Arc::clone(conn);
329        async move {
330            crate::sqlite::connection::run_blocking(conn, |guard| {
331                guard
332                    .query_row("SELECT 1", rusqlite::params![], |_row| Ok(()))
333                    .map_err(SqlMiddlewareDbError::SqliteError)
334            })
335            .await
336        }
337    }
338
339    fn has_broken(&self, conn: &mut Self::Connection) -> bool {
340        conn.is_broken()
341    }
342}