sql_middleware/sqlite/
config.rs1use std::sync::Arc;
2use std::sync::atomic::{AtomicBool, Ordering};
3use std::thread;
4
5use bb8::{ManageConnection, Pool, PooledConnection};
6use crossbeam_channel::{unbounded, Sender};
7
8use crate::middleware::{ConfigAndPool, DatabaseType, MiddlewarePool, SqlMiddlewareDbError};
9
10pub type SqlitePooledConnection = PooledConnection<'static, SqliteManager>;
12
13pub type SharedSqliteConnection = Arc<SqliteWorker>;
15
16#[doc(hidden)]
18#[cfg(feature = "sqlite")]
19pub async fn rollback_for_tests(
20 pool: &Pool<SqliteManager>,
21) -> Result<(), SqlMiddlewareDbError> {
22 let conn = pool.get_owned().await.map_err(|e| {
23 SqlMiddlewareDbError::ConnectionError(format!("sqlite cleanup checkout error: {e}"))
24 })?;
25 let handle = Arc::clone(&*conn);
26 crate::sqlite::connection::run_blocking(handle, |c| {
27 c.execute_batch("ROLLBACK;")
28 .map_err(SqlMiddlewareDbError::SqliteError)
29 })
30 .await
31}
32
33enum SqliteWorkerMessage {
34 Execute(Box<dyn FnOnce(&mut rusqlite::Connection) + Send + 'static>),
35 Shutdown,
36}
37
38#[derive(Debug)]
39pub struct SqliteWorker {
40 sender: Sender<SqliteWorkerMessage>,
41 broken: Arc<AtomicBool>,
42}
43
44impl SqliteWorker {
45 pub(crate) fn start(conn: rusqlite::Connection) -> Arc<Self> {
46 let (sender, receiver) = unbounded::<SqliteWorkerMessage>();
47 let broken = Arc::new(AtomicBool::new(false));
48 let broken_flag = Arc::clone(&broken);
49 let mut conn = Some(conn);
50 let _ = thread::Builder::new()
52 .name("sql-middleware-sqlite-worker".into())
53 .spawn(move || {
54 let mut conn = conn
55 .take()
56 .expect("sqlite worker missing connection at start");
57 for msg in &receiver {
58 match msg {
59 SqliteWorkerMessage::Execute(job) => {
60 let result =
63 std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
64 job(&mut conn);
65 }));
66 if result.is_err() {
67 broken_flag.store(true, Ordering::Relaxed);
68 break;
69 }
70 }
71 SqliteWorkerMessage::Shutdown => break,
72 }
73 }
74 broken_flag.store(true, Ordering::Relaxed);
75 });
76
77 Arc::new(Self { sender, broken })
78 }
79
80 pub(crate) fn execute<F>(&self, func: F) -> Result<(), SqlMiddlewareDbError>
81 where
82 F: FnOnce(&mut rusqlite::Connection) + Send + 'static,
83 {
84 self.sender
85 .send(SqliteWorkerMessage::Execute(Box::new(func)))
86 .map_err(|_| {
87 SqlMiddlewareDbError::ExecutionError(
88 "sqlite worker channel unexpectedly closed".into(),
89 )
90 })
91 }
92
93 pub(crate) fn execute_blocking<F, R>(&self, func: F) -> Result<R, SqlMiddlewareDbError>
94 where
95 F: FnOnce(&mut rusqlite::Connection) -> Result<R, SqlMiddlewareDbError> + Send + 'static,
96 R: Send + 'static,
97 {
98 let (resp_tx, resp_rx) = crossbeam_channel::bounded(1);
99 self.sender
100 .send(SqliteWorkerMessage::Execute(Box::new(move |conn| {
101 let _ = resp_tx.send(func(conn));
102 })))
103 .map_err(|_| {
104 SqlMiddlewareDbError::ExecutionError(
105 "sqlite worker channel unexpectedly closed".into(),
106 )
107 })?;
108 resp_rx
109 .recv()
110 .map_err(|_| {
111 SqlMiddlewareDbError::ExecutionError(
112 "sqlite worker response channel unexpectedly closed".into(),
113 )
114 })?
115 }
116
117 #[must_use]
118 pub(crate) fn is_broken(&self) -> bool {
119 self.broken.load(Ordering::Relaxed)
120 }
121
122 #[cfg(test)]
123 #[must_use]
124 pub fn is_broken_for_tests(&self) -> bool {
125 self.is_broken()
126 }
127}
128
129impl Drop for SqliteWorker {
130 fn drop(&mut self) {
131 let _ = self.sender.send(SqliteWorkerMessage::Shutdown);
132 }
133}
134
135#[cfg(test)]
136mod tests {
137 use std::sync::Arc;
138
139 use bb8::Pool;
140
141 use super::SqliteManager;
142 use crate::middleware::SqlMiddlewareDbError;
143 use crate::sqlite::connection::run_blocking;
144
145 #[tokio::test]
146 async fn worker_panic_marks_connection_broken() -> Result<(), Box<dyn std::error::Error>> {
147 let pool = Pool::builder()
148 .max_size(1)
149 .build(SqliteManager::new("file::memory:?cache=shared".to_string()))
150 .await?;
151
152 let conn = pool.get_owned().await?;
153 let handle = Arc::clone(&*conn);
154 let err = run_blocking(handle, |_conn| -> Result<(), SqlMiddlewareDbError> {
155 panic!("boom");
156 })
157 .await
158 .expect_err("worker panic should surface as an error");
159 assert!(
160 err.to_string().contains("worker receive error"),
161 "unexpected error for worker panic: {err}"
162 );
163 assert!(conn.is_broken(), "connection should be marked broken");
164
165 drop(conn);
166
167 let conn = pool.get_owned().await?;
168 let handle = Arc::clone(&*conn);
169 run_blocking(handle, |c| {
170 c.query_row("SELECT 1", rusqlite::params![], |_row| Ok(()))
171 .map_err(SqlMiddlewareDbError::SqliteError)
172 })
173 .await?;
174 assert!(!conn.is_broken(), "replacement connection should be healthy");
175 Ok(())
176 }
177}
178
179#[derive(Debug, Clone)]
181pub struct SqliteOptions {
182 pub db_path: String,
183 pub translate_placeholders: bool,
184}
185
186impl SqliteOptions {
187 #[must_use]
188 pub fn new(db_path: String) -> Self {
189 Self {
190 db_path,
191 translate_placeholders: false,
192 }
193 }
194
195 #[must_use]
196 pub fn with_translation(mut self, translate_placeholders: bool) -> Self {
197 self.translate_placeholders = translate_placeholders;
198 self
199 }
200}
201
202#[derive(Debug, Clone)]
204pub struct SqliteOptionsBuilder {
205 opts: SqliteOptions,
206}
207
208impl SqliteOptionsBuilder {
209 #[must_use]
210 pub fn new(db_path: String) -> Self {
211 Self {
212 opts: SqliteOptions::new(db_path),
213 }
214 }
215
216 #[must_use]
217 pub fn translation(mut self, translate_placeholders: bool) -> Self {
218 self.opts.translate_placeholders = translate_placeholders;
219 self
220 }
221
222 #[must_use]
223 pub fn finish(self) -> SqliteOptions {
224 self.opts
225 }
226
227 pub async fn build(self) -> Result<ConfigAndPool, SqlMiddlewareDbError> {
233 ConfigAndPool::new_sqlite(self.finish()).await
234 }
235}
236
237impl ConfigAndPool {
238 #[must_use]
239 pub fn sqlite_builder(db_path: String) -> SqliteOptionsBuilder {
240 SqliteOptionsBuilder::new(db_path)
241 }
242
243 pub async fn new_sqlite(opts: SqliteOptions) -> Result<Self, SqlMiddlewareDbError> {
248 let manager = SqliteManager::new(opts.db_path.clone());
249 let pool = manager.build_pool().await?;
250
251 {
253 let mut conn = pool.get_owned().await.map_err(|e| {
254 SqlMiddlewareDbError::ConnectionError(format!("Failed to create SQLite pool: {e}"))
255 })?;
256
257 crate::sqlite::apply_wal_pragmas(&mut conn).await?;
258 }
259
260 Ok(ConfigAndPool {
261 pool: MiddlewarePool::Sqlite(pool),
262 db_type: DatabaseType::Sqlite,
263 translate_placeholders: opts.translate_placeholders,
264 })
265 }
266}
267
268pub struct SqliteManager {
270 db_path: String,
271}
272
273impl SqliteManager {
274 #[must_use]
275 pub fn new(db_path: String) -> Self {
276 Self { db_path }
277 }
278
279 pub async fn build_pool(self) -> Result<Pool<SqliteManager>, SqlMiddlewareDbError> {
284 Pool::builder()
285 .build(self)
286 .await
287 .map_err(|e| SqlMiddlewareDbError::ConnectionError(format!("sqlite pool error: {e}")))
288 }
289}
290
291impl ManageConnection for SqliteManager {
292 type Connection = SharedSqliteConnection;
293 type Error = SqlMiddlewareDbError;
294
295 fn connect(
296 &self,
297 ) -> impl std::future::Future<Output = Result<Self::Connection, Self::Error>> + Send {
298 let path = self.db_path.clone();
299 async move {
300 let conn =
301 rusqlite::Connection::open(path).map_err(SqlMiddlewareDbError::SqliteError)?;
302 Ok(SqliteWorker::start(conn))
303 }
304 }
305
306 fn is_valid(
307 &self,
308 conn: &mut Self::Connection,
309 ) -> impl std::future::Future<Output = Result<(), Self::Error>> + Send {
310 let conn = Arc::clone(conn);
311 async move {
312 crate::sqlite::connection::run_blocking(conn, |guard| {
313 guard
314 .query_row("SELECT 1", rusqlite::params![], |_row| Ok(()))
315 .map_err(SqlMiddlewareDbError::SqliteError)
316 })
317 .await
318 }
319 }
320
321 fn has_broken(&self, conn: &mut Self::Connection) -> bool {
322 conn.is_broken()
323 }
324}