sql_middleware/sqlite/
config.rs

1use std::sync::Arc;
2use std::sync::atomic::{AtomicBool, Ordering};
3use std::thread;
4
5use bb8::{ManageConnection, Pool, PooledConnection};
6use crossbeam_channel::{unbounded, Sender};
7
8use crate::middleware::{ConfigAndPool, DatabaseType, MiddlewarePool, SqlMiddlewareDbError};
9
10/// Type alias for the pooled `SQLite` connection wrapper.
11pub type SqlitePooledConnection = PooledConnection<'static, SqliteManager>;
12
13/// Shared, worker-backed `SQLite` connection handle.
14pub type SharedSqliteConnection = Arc<SqliteWorker>;
15
16/// Test-only helper to rollback a connection from the pool.
17#[doc(hidden)]
18#[cfg(feature = "sqlite")]
19pub async fn rollback_for_tests(
20    pool: &Pool<SqliteManager>,
21) -> Result<(), SqlMiddlewareDbError> {
22    let conn = pool.get_owned().await.map_err(|e| {
23        SqlMiddlewareDbError::ConnectionError(format!("sqlite cleanup checkout error: {e}"))
24    })?;
25    let handle = Arc::clone(&*conn);
26    crate::sqlite::connection::run_blocking(handle, |c| {
27        c.execute_batch("ROLLBACK;")
28            .map_err(SqlMiddlewareDbError::SqliteError)
29    })
30    .await
31}
32
33enum SqliteWorkerMessage {
34    Execute(Box<dyn FnOnce(&mut rusqlite::Connection) + Send + 'static>),
35    Shutdown,
36}
37
38#[derive(Debug)]
39pub struct SqliteWorker {
40    sender: Sender<SqliteWorkerMessage>,
41    broken: Arc<AtomicBool>,
42}
43
44impl SqliteWorker {
45    pub(crate) fn start(conn: rusqlite::Connection) -> Arc<Self> {
46        let (sender, receiver) = unbounded::<SqliteWorkerMessage>();
47        let broken = Arc::new(AtomicBool::new(false));
48        let broken_flag = Arc::clone(&broken);
49        let mut conn = Some(conn);
50        // Dedicated worker thread to service requests for this pooled connection.
51        let _ = thread::Builder::new()
52            .name("sql-middleware-sqlite-worker".into())
53            .spawn(move || {
54                let mut conn = conn
55                    .take()
56                    .expect("sqlite worker missing connection at start");
57                for msg in &receiver {
58                    match msg {
59                        SqliteWorkerMessage::Execute(job) => {
60                            // If a job panics, mark the worker broken and exit to avoid
61                            // leaving the connection in an unknown state.
62                            let result =
63                                std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
64                                    job(&mut conn);
65                                }));
66                            if result.is_err() {
67                                broken_flag.store(true, Ordering::Relaxed);
68                                break;
69                            }
70                        }
71                        SqliteWorkerMessage::Shutdown => break,
72                    }
73                }
74                broken_flag.store(true, Ordering::Relaxed);
75            });
76
77        Arc::new(Self { sender, broken })
78    }
79
80    pub(crate) fn execute<F>(&self, func: F) -> Result<(), SqlMiddlewareDbError>
81    where
82        F: FnOnce(&mut rusqlite::Connection) + Send + 'static,
83    {
84        self.sender
85            .send(SqliteWorkerMessage::Execute(Box::new(func)))
86            .map_err(|_| {
87                SqlMiddlewareDbError::ExecutionError(
88                    "sqlite worker channel unexpectedly closed".into(),
89                )
90            })
91    }
92
93    pub(crate) fn execute_blocking<F, R>(&self, func: F) -> Result<R, SqlMiddlewareDbError>
94    where
95        F: FnOnce(&mut rusqlite::Connection) -> Result<R, SqlMiddlewareDbError> + Send + 'static,
96        R: Send + 'static,
97    {
98        let (resp_tx, resp_rx) = crossbeam_channel::bounded(1);
99        self.sender
100            .send(SqliteWorkerMessage::Execute(Box::new(move |conn| {
101                let _ = resp_tx.send(func(conn));
102            })))
103            .map_err(|_| {
104                SqlMiddlewareDbError::ExecutionError(
105                    "sqlite worker channel unexpectedly closed".into(),
106                )
107            })?;
108        resp_rx
109            .recv()
110            .map_err(|_| {
111                SqlMiddlewareDbError::ExecutionError(
112                    "sqlite worker response channel unexpectedly closed".into(),
113                )
114            })?
115    }
116
117    #[must_use]
118    pub(crate) fn is_broken(&self) -> bool {
119        self.broken.load(Ordering::Relaxed)
120    }
121
122    #[cfg(test)]
123    #[must_use]
124    pub fn is_broken_for_tests(&self) -> bool {
125        self.is_broken()
126    }
127}
128
129impl Drop for SqliteWorker {
130    fn drop(&mut self) {
131        let _ = self.sender.send(SqliteWorkerMessage::Shutdown);
132    }
133}
134
135#[cfg(test)]
136mod tests {
137    use std::sync::Arc;
138
139    use bb8::Pool;
140
141    use super::SqliteManager;
142    use crate::middleware::SqlMiddlewareDbError;
143    use crate::sqlite::connection::run_blocking;
144
145    #[tokio::test]
146    async fn worker_panic_marks_connection_broken() -> Result<(), Box<dyn std::error::Error>> {
147        let pool = Pool::builder()
148            .max_size(1)
149            .build(SqliteManager::new("file::memory:?cache=shared".to_string()))
150            .await?;
151
152        let conn = pool.get_owned().await?;
153        let handle = Arc::clone(&*conn);
154        let err = run_blocking(handle, |_conn| -> Result<(), SqlMiddlewareDbError> {
155            panic!("boom");
156        })
157        .await
158        .expect_err("worker panic should surface as an error");
159        assert!(
160            err.to_string().contains("worker receive error"),
161            "unexpected error for worker panic: {err}"
162        );
163        assert!(conn.is_broken(), "connection should be marked broken");
164
165        drop(conn);
166
167        let conn = pool.get_owned().await?;
168        let handle = Arc::clone(&*conn);
169        run_blocking(handle, |c| {
170            c.query_row("SELECT 1", rusqlite::params![], |_row| Ok(()))
171                .map_err(SqlMiddlewareDbError::SqliteError)
172        })
173        .await?;
174        assert!(!conn.is_broken(), "replacement connection should be healthy");
175        Ok(())
176    }
177}
178
179/// Options for configuring a `SQLite` pool.
180#[derive(Debug, Clone)]
181pub struct SqliteOptions {
182    pub db_path: String,
183    pub translate_placeholders: bool,
184}
185
186impl SqliteOptions {
187    #[must_use]
188    pub fn new(db_path: String) -> Self {
189        Self {
190            db_path,
191            translate_placeholders: false,
192        }
193    }
194
195    #[must_use]
196    pub fn with_translation(mut self, translate_placeholders: bool) -> Self {
197        self.translate_placeholders = translate_placeholders;
198        self
199    }
200}
201
202/// Fluent builder for `SQLite` options.
203#[derive(Debug, Clone)]
204pub struct SqliteOptionsBuilder {
205    opts: SqliteOptions,
206}
207
208impl SqliteOptionsBuilder {
209    #[must_use]
210    pub fn new(db_path: String) -> Self {
211        Self {
212            opts: SqliteOptions::new(db_path),
213        }
214    }
215
216    #[must_use]
217    pub fn translation(mut self, translate_placeholders: bool) -> Self {
218        self.opts.translate_placeholders = translate_placeholders;
219        self
220    }
221
222    #[must_use]
223    pub fn finish(self) -> SqliteOptions {
224        self.opts
225    }
226
227    /// Build a `ConfigAndPool` for `SQLite`.
228    ///
229    /// # Errors
230    ///
231    /// Returns `SqlMiddlewareDbError` if pool creation or the initial smoke test fails.
232    pub async fn build(self) -> Result<ConfigAndPool, SqlMiddlewareDbError> {
233        ConfigAndPool::new_sqlite(self.finish()).await
234    }
235}
236
237impl ConfigAndPool {
238    #[must_use]
239    pub fn sqlite_builder(db_path: String) -> SqliteOptionsBuilder {
240        SqliteOptionsBuilder::new(db_path)
241    }
242
243    /// Asynchronous initializer for `ConfigAndPool` with Sqlite using a bb8-backed pool.
244    ///
245    /// # Errors
246    /// Returns `SqlMiddlewareDbError::ConnectionError` if pool creation or connection test fails.
247    pub async fn new_sqlite(opts: SqliteOptions) -> Result<Self, SqlMiddlewareDbError> {
248        let manager = SqliteManager::new(opts.db_path.clone());
249        let pool = manager.build_pool().await?;
250
251        // Initialize the database with WAL and a simple health check.
252        {
253            let mut conn = pool.get_owned().await.map_err(|e| {
254                SqlMiddlewareDbError::ConnectionError(format!("Failed to create SQLite pool: {e}"))
255            })?;
256
257            crate::sqlite::apply_wal_pragmas(&mut conn).await?;
258        }
259
260        Ok(ConfigAndPool {
261            pool: MiddlewarePool::Sqlite(pool),
262            db_type: DatabaseType::Sqlite,
263            translate_placeholders: opts.translate_placeholders,
264        })
265    }
266}
267
268/// bb8 manager for `SQLite` connections.
269pub struct SqliteManager {
270    db_path: String,
271}
272
273impl SqliteManager {
274    #[must_use]
275    pub fn new(db_path: String) -> Self {
276        Self { db_path }
277    }
278
279    /// Build a pool from this manager.
280    ///
281    /// # Errors
282    /// Returns `SqlMiddlewareDbError` if pool creation fails.
283    pub async fn build_pool(self) -> Result<Pool<SqliteManager>, SqlMiddlewareDbError> {
284        Pool::builder()
285            .build(self)
286            .await
287            .map_err(|e| SqlMiddlewareDbError::ConnectionError(format!("sqlite pool error: {e}")))
288    }
289}
290
291impl ManageConnection for SqliteManager {
292    type Connection = SharedSqliteConnection;
293    type Error = SqlMiddlewareDbError;
294
295    fn connect(
296        &self,
297    ) -> impl std::future::Future<Output = Result<Self::Connection, Self::Error>> + Send {
298        let path = self.db_path.clone();
299        async move {
300            let conn =
301                rusqlite::Connection::open(path).map_err(SqlMiddlewareDbError::SqliteError)?;
302            Ok(SqliteWorker::start(conn))
303        }
304    }
305
306    fn is_valid(
307        &self,
308        conn: &mut Self::Connection,
309    ) -> impl std::future::Future<Output = Result<(), Self::Error>> + Send {
310        let conn = Arc::clone(conn);
311        async move {
312            crate::sqlite::connection::run_blocking(conn, |guard| {
313                guard
314                    .query_row("SELECT 1", rusqlite::params![], |_row| Ok(()))
315                    .map_err(SqlMiddlewareDbError::SqliteError)
316            })
317            .await
318        }
319    }
320
321    fn has_broken(&self, conn: &mut Self::Connection) -> bool {
322        conn.is_broken()
323    }
324}