sql_middleware/sqlite/
config.rs1use std::sync::Arc;
2use std::sync::atomic::{AtomicBool, Ordering};
3use std::thread;
4
5use bb8::{ManageConnection, Pool, PooledConnection};
6use crossbeam_channel::{Sender, unbounded};
7
8use crate::middleware::{ConfigAndPool, DatabaseType, MiddlewarePool, SqlMiddlewareDbError};
9
10pub type SqlitePooledConnection = PooledConnection<'static, SqliteManager>;
12
13pub type SharedSqliteConnection = Arc<SqliteWorker>;
15
16#[doc(hidden)]
18#[cfg(feature = "sqlite")]
19pub async fn rollback_for_tests(pool: &Pool<SqliteManager>) -> Result<(), SqlMiddlewareDbError> {
20 let conn = pool.get_owned().await.map_err(|e| {
21 SqlMiddlewareDbError::ConnectionError(format!("sqlite cleanup checkout error: {e}"))
22 })?;
23 let handle = Arc::clone(&*conn);
24 crate::sqlite::connection::run_blocking(handle, |c| {
25 c.execute_batch("ROLLBACK;")
26 .map_err(SqlMiddlewareDbError::SqliteError)
27 })
28 .await
29}
30
31enum SqliteWorkerMessage {
32 Execute(Box<dyn FnOnce(&mut rusqlite::Connection) + Send + 'static>),
33 Shutdown,
34}
35
36#[derive(Debug)]
37pub struct SqliteWorker {
38 sender: Sender<SqliteWorkerMessage>,
39 broken: Arc<AtomicBool>,
40 force_rollback_busy_for_tests: AtomicBool,
41}
42
43impl SqliteWorker {
44 pub(crate) fn start(conn: rusqlite::Connection) -> Arc<Self> {
45 let (sender, receiver) = unbounded::<SqliteWorkerMessage>();
46 let broken = Arc::new(AtomicBool::new(false));
47 let broken_flag = Arc::clone(&broken);
48 let mut conn = Some(conn);
49 let _ = thread::Builder::new()
51 .name("sql-middleware-sqlite-worker".into())
52 .spawn(move || {
53 let mut conn = conn
54 .take()
55 .expect("sqlite worker missing connection at start");
56 for msg in &receiver {
57 match msg {
58 SqliteWorkerMessage::Execute(job) => {
59 let result =
62 std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
63 job(&mut conn);
64 }));
65 if result.is_err() {
66 broken_flag.store(true, Ordering::Relaxed);
67 break;
68 }
69 }
70 SqliteWorkerMessage::Shutdown => break,
71 }
72 }
73 broken_flag.store(true, Ordering::Relaxed);
74 });
75
76 Arc::new(Self {
77 sender,
78 broken,
79 force_rollback_busy_for_tests: AtomicBool::new(false),
80 })
81 }
82
83 pub(crate) fn execute<F>(&self, func: F) -> Result<(), SqlMiddlewareDbError>
84 where
85 F: FnOnce(&mut rusqlite::Connection) + Send + 'static,
86 {
87 self.sender
88 .send(SqliteWorkerMessage::Execute(Box::new(func)))
89 .map_err(|_| {
90 SqlMiddlewareDbError::ExecutionError(
91 "sqlite worker channel unexpectedly closed".into(),
92 )
93 })
94 }
95
96 pub(crate) fn execute_blocking<F, R>(&self, func: F) -> Result<R, SqlMiddlewareDbError>
97 where
98 F: FnOnce(&mut rusqlite::Connection) -> Result<R, SqlMiddlewareDbError> + Send + 'static,
99 R: Send + 'static,
100 {
101 let (resp_tx, resp_rx) = crossbeam_channel::bounded(1);
102 self.sender
103 .send(SqliteWorkerMessage::Execute(Box::new(move |conn| {
104 let _ = resp_tx.send(func(conn));
105 })))
106 .map_err(|_| {
107 SqlMiddlewareDbError::ExecutionError(
108 "sqlite worker channel unexpectedly closed".into(),
109 )
110 })?;
111 resp_rx.recv().map_err(|_| {
112 SqlMiddlewareDbError::ExecutionError(
113 "sqlite worker response channel unexpectedly closed".into(),
114 )
115 })?
116 }
117
118 #[must_use]
119 pub(crate) fn is_broken(&self) -> bool {
120 self.broken.load(Ordering::Relaxed)
121 }
122
123 #[cfg(test)]
124 #[must_use]
125 pub fn is_broken_for_tests(&self) -> bool {
126 self.is_broken()
127 }
128
129 pub(crate) fn mark_broken(&self) {
130 self.broken.store(true, Ordering::Relaxed);
131 }
132
133 #[doc(hidden)]
134 pub fn set_force_rollback_busy_for_tests(&self, force: bool) {
135 self.force_rollback_busy_for_tests
136 .store(force, Ordering::Relaxed);
137 }
138
139 pub(crate) fn force_rollback_busy_for_tests(&self) -> bool {
140 self.force_rollback_busy_for_tests.load(Ordering::Relaxed)
141 }
142}
143
144impl Drop for SqliteWorker {
145 fn drop(&mut self) {
146 let _ = self.sender.send(SqliteWorkerMessage::Shutdown);
147 }
148}
149
150#[cfg(test)]
151mod tests {
152 use std::sync::Arc;
153
154 use bb8::Pool;
155
156 use super::SqliteManager;
157 use crate::middleware::SqlMiddlewareDbError;
158 use crate::sqlite::connection::run_blocking;
159
160 #[tokio::test]
161 async fn worker_panic_marks_connection_broken() -> Result<(), Box<dyn std::error::Error>> {
162 let pool = Pool::builder()
163 .max_size(1)
164 .build(SqliteManager::new("file::memory:?cache=shared".to_string()))
165 .await?;
166
167 let conn = pool.get_owned().await?;
168 let handle = Arc::clone(&*conn);
169 let err = run_blocking(handle, |_conn| -> Result<(), SqlMiddlewareDbError> {
170 panic!("boom");
171 })
172 .await
173 .expect_err("worker panic should surface as an error");
174 assert!(
175 err.to_string().contains("worker receive error"),
176 "unexpected error for worker panic: {err}"
177 );
178 assert!(conn.is_broken(), "connection should be marked broken");
179
180 drop(conn);
181
182 let conn = pool.get_owned().await?;
183 let handle = Arc::clone(&*conn);
184 run_blocking(handle, |c| {
185 c.query_row("SELECT 1", rusqlite::params![], |_row| Ok(()))
186 .map_err(SqlMiddlewareDbError::SqliteError)
187 })
188 .await?;
189 assert!(
190 !conn.is_broken(),
191 "replacement connection should be healthy"
192 );
193 Ok(())
194 }
195}
196
197#[derive(Debug, Clone)]
199pub struct SqliteOptions {
200 pub db_path: String,
201 pub translate_placeholders: bool,
202}
203
204impl SqliteOptions {
205 #[must_use]
206 pub fn new(db_path: String) -> Self {
207 Self {
208 db_path,
209 translate_placeholders: false,
210 }
211 }
212
213 #[must_use]
214 pub fn with_translation(mut self, translate_placeholders: bool) -> Self {
215 self.translate_placeholders = translate_placeholders;
216 self
217 }
218}
219
220#[derive(Debug, Clone)]
222pub struct SqliteOptionsBuilder {
223 opts: SqliteOptions,
224}
225
226impl SqliteOptionsBuilder {
227 #[must_use]
228 pub fn new(db_path: String) -> Self {
229 Self {
230 opts: SqliteOptions::new(db_path),
231 }
232 }
233
234 #[must_use]
235 pub fn translation(mut self, translate_placeholders: bool) -> Self {
236 self.opts.translate_placeholders = translate_placeholders;
237 self
238 }
239
240 #[must_use]
241 pub fn finish(self) -> SqliteOptions {
242 self.opts
243 }
244
245 pub async fn build(self) -> Result<ConfigAndPool, SqlMiddlewareDbError> {
251 ConfigAndPool::new_sqlite(self.finish()).await
252 }
253}
254
255impl ConfigAndPool {
256 #[must_use]
257 pub fn sqlite_builder(db_path: String) -> SqliteOptionsBuilder {
258 SqliteOptionsBuilder::new(db_path)
259 }
260
261 pub async fn new_sqlite(opts: SqliteOptions) -> Result<Self, SqlMiddlewareDbError> {
266 let manager = SqliteManager::new(opts.db_path.clone());
267 let pool = manager.build_pool().await?;
268
269 {
271 let mut conn = pool.get_owned().await.map_err(|e| {
272 SqlMiddlewareDbError::ConnectionError(format!("Failed to create SQLite pool: {e}"))
273 })?;
274
275 crate::sqlite::apply_wal_pragmas(&mut conn).await?;
276 }
277
278 Ok(ConfigAndPool {
279 pool: MiddlewarePool::Sqlite(pool),
280 db_type: DatabaseType::Sqlite,
281 translate_placeholders: opts.translate_placeholders,
282 })
283 }
284}
285
286pub struct SqliteManager {
288 db_path: String,
289}
290
291impl SqliteManager {
292 #[must_use]
293 pub fn new(db_path: String) -> Self {
294 Self { db_path }
295 }
296
297 pub async fn build_pool(self) -> Result<Pool<SqliteManager>, SqlMiddlewareDbError> {
302 Pool::builder()
303 .build(self)
304 .await
305 .map_err(|e| SqlMiddlewareDbError::ConnectionError(format!("sqlite pool error: {e}")))
306 }
307}
308
309impl ManageConnection for SqliteManager {
310 type Connection = SharedSqliteConnection;
311 type Error = SqlMiddlewareDbError;
312
313 fn connect(
314 &self,
315 ) -> impl std::future::Future<Output = Result<Self::Connection, Self::Error>> + Send {
316 let path = self.db_path.clone();
317 async move {
318 let conn =
319 rusqlite::Connection::open(path).map_err(SqlMiddlewareDbError::SqliteError)?;
320 Ok(SqliteWorker::start(conn))
321 }
322 }
323
324 fn is_valid(
325 &self,
326 conn: &mut Self::Connection,
327 ) -> impl std::future::Future<Output = Result<(), Self::Error>> + Send {
328 let conn = Arc::clone(conn);
329 async move {
330 crate::sqlite::connection::run_blocking(conn, |guard| {
331 guard
332 .query_row("SELECT 1", rusqlite::params![], |_row| Ok(()))
333 .map_err(SqlMiddlewareDbError::SqliteError)
334 })
335 .await
336 }
337 }
338
339 fn has_broken(&self, conn: &mut Self::Connection) -> bool {
340 conn.is_broken()
341 }
342}