1use kanal::{Receiver, Sender};
2use log::*;
3use parking_lot::RwLock;
4use rusqlite::fallible_iterator::FallibleIterator;
5use rusqlite::hooks::{Action, PreUpdateCase};
6use rusqlite::types::Value;
7use std::ops::{Deref, DerefMut};
8use std::{
9 fmt::{self, Debug},
10 sync::Arc,
11};
12use tokio::sync::oneshot;
13
14use crate::error::Error;
15pub use crate::params::Params;
16use crate::rows::{Column, columns};
17pub use crate::rows::{Row, Rows};
18
19#[macro_export]
20macro_rules! params {
21 () => {
22 [] as [$crate::params::ToSqlType]
23 };
24 ($($param:expr),+ $(,)?) => {
25 [$(Into::<$crate::params::ToSqlType>::into($param)),+]
26 };
27}
28
29#[macro_export]
30macro_rules! named_params {
31 () => {
32 [] as [(&str, $crate::params::ToSqlType)]
33 };
34 ($($param_name:literal: $param_val:expr),+ $(,)?) => {
35 [$(($param_name as &str, Into::<$crate::params::ToSqlType>::into($param_val))),+]
36 };
37}
38
39#[derive(Clone, Debug, PartialEq, serde::Deserialize)]
40pub struct Database {
41 pub seq: u8,
42 pub name: String,
43}
44
45struct LockedConnections(RwLock<Vec<rusqlite::Connection>>);
46
47unsafe impl Sync for LockedConnections {}
50
51pub type Result<T> = std::result::Result<T, Error>;
53
54enum Message {
55 RunMut(Box<dyn FnOnce(&mut rusqlite::Connection) + Send + 'static>),
56 RunConst(Box<dyn FnOnce(&rusqlite::Connection) + Send + 'static>),
57 Terminate,
58}
59
60#[derive(Clone)]
61pub struct Options {
62 pub busy_timeout: std::time::Duration,
63 pub n_read_threads: usize,
64}
65
66impl Default for Options {
67 fn default() -> Self {
68 return Self {
69 busy_timeout: std::time::Duration::from_secs(5),
70 n_read_threads: 0,
71 };
72 }
73}
74
75#[derive(Clone)]
77pub struct Connection {
78 reader: Sender<Message>,
79 writer: Sender<Message>,
80 conns: Arc<LockedConnections>,
81}
82
83impl Connection {
84 pub fn new<E>(
85 builder: impl Fn() -> std::result::Result<rusqlite::Connection, E>,
86 opt: Option<Options>,
87 ) -> std::result::Result<Self, E> {
88 let new_conn = || -> std::result::Result<rusqlite::Connection, E> {
89 let conn = builder()?;
90 if let Some(timeout) = opt.as_ref().map(|o| o.busy_timeout) {
91 conn.busy_timeout(timeout).expect("busy timeout failed");
92 }
93 return Ok(conn);
94 };
95
96 let conn = new_conn()?;
97 let name = conn.path().and_then(|s| {
98 if s.is_empty() {
100 None
101 } else {
102 Some(s.to_string())
103 }
104 });
105
106 let n_read_threads = if name.is_some() {
107 let n_read_threads = match opt.as_ref().map_or(0, |o| o.n_read_threads) {
108 1 => {
109 warn!(
110 "Using a single dedicated reader thread won't improve performance, falling back to 0."
111 );
112 0
113 }
114 n => n,
115 };
116
117 if let Ok(n) = std::thread::available_parallelism() {
118 if n_read_threads > n.get() {
119 debug!(
120 "Using {n_read_threads} exceeding hardware parallelism: {}",
121 n.get()
122 );
123 }
124 }
125
126 n_read_threads
127 } else {
128 0
130 };
131
132 let conns = {
133 let mut conns = vec![conn];
134 for _ in 0..n_read_threads {
135 conns.push(new_conn()?);
136 }
137
138 Arc::new(LockedConnections(RwLock::new(conns)))
139 };
140
141 let (shared_write_sender, shared_write_receiver) = kanal::unbounded::<Message>();
143 let conns_clone = conns.clone();
144 std::thread::spawn(move || event_loop(0, conns_clone, shared_write_receiver));
145
146 let shared_read_sender = if n_read_threads > 0 {
147 let (shared_read_sender, shared_read_receiver) = kanal::unbounded::<Message>();
148 for i in 0..n_read_threads {
149 let shared_read_receiver = shared_read_receiver.clone();
150 let conns_clone = conns.clone();
151 std::thread::spawn(move || event_loop(i, conns_clone, shared_read_receiver));
152 }
153 shared_read_sender
154 } else {
155 shared_write_sender.clone()
156 };
157
158 debug!(
159 "Opened SQLite DB '{name}' with {n_read_threads} dedicated reader threads",
160 name = name.as_deref().unwrap_or("<in-memory>")
161 );
162
163 return Ok(Self {
164 reader: shared_read_sender,
165 writer: shared_write_sender,
166 conns,
167 });
168 }
169
170 pub fn from_connection_test_only(conn: rusqlite::Connection) -> Self {
171 use parking_lot::lock_api::RwLock;
172
173 let (shared_write_sender, shared_write_receiver) = kanal::unbounded::<Message>();
174 let conns = Arc::new(LockedConnections(RwLock::new(vec![conn])));
175 let conns_clone = conns.clone();
176 std::thread::spawn(move || event_loop(0, conns_clone, shared_write_receiver));
177
178 return Self {
179 reader: shared_write_sender.clone(),
180 writer: shared_write_sender,
181 conns,
182 };
183 }
184
185 pub fn open_in_memory() -> Result<Self> {
191 return Self::new(|| Ok(rusqlite::Connection::open_in_memory()?), None);
192 }
193
194 #[inline]
195 pub fn write_lock(&self) -> LockGuard<'_> {
196 return LockGuard {
197 guard: self.conns.0.write(),
198 };
199 }
200
201 #[inline]
202 pub fn try_write_lock_for(&self, duration: tokio::time::Duration) -> Option<LockGuard<'_>> {
203 return self
204 .conns
205 .0
206 .try_write_for(duration)
207 .map(|guard| LockGuard { guard });
208 }
209
210 #[inline]
217 pub async fn call<F, R>(&self, function: F) -> Result<R>
218 where
219 F: FnOnce(&mut rusqlite::Connection) -> Result<R> + Send + 'static,
220 R: Send + 'static,
221 {
222 let (sender, receiver) = oneshot::channel::<Result<R>>();
224
225 self
226 .writer
227 .send(Message::RunMut(Box::new(move |conn| {
228 if !sender.is_closed() {
229 let _ = sender.send(function(conn));
230 }
231 })))
232 .map_err(|_| Error::ConnectionClosed)?;
233
234 receiver.await.map_err(|_| Error::ConnectionClosed)?
235 }
236
237 #[inline]
238 pub fn call_and_forget(&self, function: impl FnOnce(&rusqlite::Connection) + Send + 'static) {
239 let _ = self
240 .writer
241 .send(Message::RunMut(Box::new(move |conn| function(conn))));
242 }
243
244 #[inline]
245 async fn call_reader<F, R>(&self, function: F) -> Result<R>
246 where
247 F: FnOnce(&rusqlite::Connection) -> Result<R> + Send + 'static,
248 R: Send + 'static,
249 {
250 let (sender, receiver) = oneshot::channel::<Result<R>>();
251
252 self
253 .reader
254 .send(Message::RunConst(Box::new(move |conn| {
255 if !sender.is_closed() {
256 let _ = sender.send(function(conn));
257 }
258 })))
259 .map_err(|_| Error::ConnectionClosed)?;
260
261 receiver.await.map_err(|_| Error::ConnectionClosed)?
262 }
263
264 pub async fn read_query_rows(
266 &self,
267 sql: impl AsRef<str> + Send + 'static,
268 params: impl Params + Send + 'static,
269 ) -> Result<Rows> {
270 return self
271 .call_reader(move |conn: &rusqlite::Connection| {
272 let mut stmt = conn.prepare_cached(sql.as_ref())?;
273 assert!(stmt.readonly());
274
275 params.bind(&mut stmt)?;
276 let rows = stmt.raw_query();
277 Ok(Rows::from_rows(rows)?)
278 })
279 .await;
280 }
281
282 pub async fn write_query_rows(
283 &self,
284 sql: impl AsRef<str> + Send + 'static,
285 params: impl Params + Send + 'static,
286 ) -> Result<Rows> {
287 return self
288 .call(move |conn: &mut rusqlite::Connection| {
289 let mut stmt = conn.prepare_cached(sql.as_ref())?;
290
291 params.bind(&mut stmt)?;
292 let rows = stmt.raw_query();
293 Ok(Rows::from_rows(rows)?)
294 })
295 .await;
296 }
297
298 pub async fn read_query_row(
299 &self,
300 sql: impl AsRef<str> + Send + 'static,
301 params: impl Params + Send + 'static,
302 ) -> Result<Option<Row>> {
303 return self
304 .read_query_row_f(sql, params, |row| Row::from_row(row, None))
305 .await;
306 }
307
308 #[inline]
309 pub async fn query_row_f<T, E>(
310 &self,
311 sql: impl AsRef<str> + Send + 'static,
312 params: impl Params + Send + 'static,
313 f: impl (FnOnce(&rusqlite::Row<'_>) -> std::result::Result<T, E>) + Send + 'static,
314 ) -> Result<Option<T>>
315 where
316 T: Send + 'static,
317 crate::error::Error: From<E>,
318 {
319 return self
320 .call(move |conn: &mut rusqlite::Connection| {
321 let mut stmt = conn.prepare_cached(sql.as_ref())?;
322 params.bind(&mut stmt)?;
323
324 let mut rows = stmt.raw_query();
325
326 if let Some(row) = rows.next()? {
327 return Ok(Some(f(row)?));
328 }
329 Ok(None)
330 })
331 .await;
332 }
333
334 #[inline]
335 pub async fn read_query_row_f<T, E>(
336 &self,
337 sql: impl AsRef<str> + Send + 'static,
338 params: impl Params + Send + 'static,
339 f: impl (FnOnce(&rusqlite::Row<'_>) -> std::result::Result<T, E>) + Send + 'static,
340 ) -> Result<Option<T>>
341 where
342 T: Send + 'static,
343 crate::error::Error: From<E>,
344 {
345 return self
346 .call_reader(move |conn: &rusqlite::Connection| {
347 let mut stmt = conn.prepare_cached(sql.as_ref())?;
348 assert!(stmt.readonly());
349
350 params.bind(&mut stmt)?;
351
352 let mut rows = stmt.raw_query();
353
354 if let Some(row) = rows.next()? {
355 return Ok(Some(f(row)?));
356 }
357 Ok(None)
358 })
359 .await;
360 }
361
362 pub async fn read_query_value<T: serde::de::DeserializeOwned + Send + 'static>(
363 &self,
364 sql: impl AsRef<str> + Send + 'static,
365 params: impl Params + Send + 'static,
366 ) -> Result<Option<T>> {
367 return self
368 .read_query_row_f(sql, params, serde_rusqlite::from_row)
369 .await;
370 }
371
372 pub async fn write_query_value<T: serde::de::DeserializeOwned + Send + 'static>(
373 &self,
374 sql: impl AsRef<str> + Send + 'static,
375 params: impl Params + Send + 'static,
376 ) -> Result<Option<T>> {
377 return self
378 .query_row_f(sql, params, serde_rusqlite::from_row)
379 .await;
380 }
381
382 pub async fn read_query_values<T: serde::de::DeserializeOwned + Send + 'static>(
383 &self,
384 sql: impl AsRef<str> + Send + 'static,
385 params: impl Params + Send + 'static,
386 ) -> Result<Vec<T>> {
387 return self
388 .call_reader(move |conn: &rusqlite::Connection| {
389 let mut stmt = conn.prepare_cached(sql.as_ref())?;
390 assert!(stmt.readonly());
391
392 params.bind(&mut stmt)?;
393 let mut rows = stmt.raw_query();
394
395 let mut values = vec![];
396 while let Some(row) = rows.next()? {
397 values.push(serde_rusqlite::from_row(row)?);
398 }
399 return Ok(values);
400 })
401 .await;
402 }
403
404 pub async fn execute(
406 &self,
407 sql: impl AsRef<str> + Send + 'static,
408 params: impl Params + Send + 'static,
409 ) -> Result<usize> {
410 return self
411 .call(move |conn: &mut rusqlite::Connection| {
412 let mut stmt = conn.prepare_cached(sql.as_ref())?;
413 params.bind(&mut stmt)?;
414
415 let n = stmt.raw_execute()?;
416
417 return Ok(n);
418 })
419 .await;
420 }
421
422 pub async fn execute_batch(&self, sql: impl AsRef<str> + Send + 'static) -> Result<Option<Rows>> {
424 return self
425 .call(move |conn: &mut rusqlite::Connection| {
426 let batch = rusqlite::Batch::new(conn, sql.as_ref());
427
428 let mut p = batch.peekable();
429 while let Some(mut stmt) = p.next()? {
430 let mut rows = stmt.raw_query();
431 let row = rows.next()?;
432
433 match p.peek()? {
434 Some(_) => {}
435 None => {
436 if let Some(row) = row {
437 let cols: Arc<Vec<Column>> = Arc::new(columns(row.as_ref()));
438
439 let mut result = vec![Row::from_row(row, Some(cols.clone()))?];
440 while let Some(row) = rows.next()? {
441 result.push(Row::from_row(row, Some(cols.clone()))?);
442 }
443 return Ok(Some(Rows(result, cols)));
444 }
445
446 return Ok(None);
447 }
448 }
449 }
450
451 return Ok(None);
452 })
453 .await;
454 }
455
456 pub async fn add_preupdate_hook(
458 &self,
459 hook: Option<impl (Fn(Action, &str, &str, &PreUpdateCase)) + Send + Sync + 'static>,
460 ) -> Result<()> {
461 return self
462 .call(move |conn| {
463 conn.preupdate_hook(hook);
464 return Ok(());
465 })
466 .await;
467 }
468
469 pub fn attach(&self, path: &str, name: &str) -> Result<()> {
470 let lock = self.conns.0.write();
471 for conn in &*lock {
472 conn.execute(&format!("ATTACH DATABASE '{path}' AS {name} "), ())?;
473 }
474 return Ok(());
475 }
476
477 pub async fn list_databases(&self) -> Result<Vec<Database>> {
478 return self
479 .call(|conn| {
480 let mut stmt = conn.prepare("SELECT seq, name FROM pragma_database_list")?;
481 let mut rows = stmt.raw_query();
482
483 let mut databases: Vec<Database> = vec![];
484 while let Some(row) = rows.next()? {
485 databases.push(serde_rusqlite::from_row(row)?)
486 }
487 return Ok(databases);
488 })
489 .await;
490 }
491
492 pub async fn close(self) -> Result<()> {
504 let _ = self.writer.send(Message::Terminate);
505 while self.reader.send(Message::Terminate).is_ok() {
506 }
508
509 let mut errors = vec![];
510 let conns: Vec<_> = std::mem::take(&mut self.conns.0.write());
511 for conn in conns {
512 if let Err((_, err)) = conn.close() {
513 errors.push(err);
514 };
515 }
516
517 if !errors.is_empty() {
518 debug!("Closing connection: {errors:?}");
519 return Err(Error::Close(errors.swap_remove(0)));
520 }
521
522 return Ok(());
523 }
524}
525
526impl Debug for Connection {
527 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
528 f.debug_struct("Connection").finish()
529 }
530}
531
532fn event_loop(id: usize, conns: Arc<LockedConnections>, receiver: Receiver<Message>) {
533 while let Ok(message) = receiver.recv() {
534 match message {
535 Message::RunConst(f) => {
536 let lock = conns.0.read();
537 f(&lock[id])
538 }
539 Message::RunMut(f) => {
540 let mut lock = conns.0.write();
541 f(&mut lock[0])
542 }
543 Message::Terminate => {
544 return;
545 }
546 };
547 }
548}
549
550pub fn extract_row_id(case: &PreUpdateCase) -> Option<i64> {
551 return match case {
552 PreUpdateCase::Insert(accessor) => Some(accessor.get_new_row_id()),
553 PreUpdateCase::Delete(accessor) => Some(accessor.get_old_row_id()),
554 PreUpdateCase::Update {
555 new_value_accessor: accessor,
556 ..
557 } => Some(accessor.get_new_row_id()),
558 PreUpdateCase::Unknown => None,
559 };
560}
561
562pub fn extract_record_values(case: &PreUpdateCase) -> Option<Vec<Value>> {
563 return Some(match case {
564 PreUpdateCase::Insert(accessor) => (0..accessor.get_column_count())
565 .map(|idx| -> Value {
566 accessor
567 .get_new_column_value(idx)
568 .map_or(rusqlite::types::Value::Null, |v| v.into())
569 })
570 .collect(),
571 PreUpdateCase::Delete(accessor) => (0..accessor.get_column_count())
572 .map(|idx| -> rusqlite::types::Value {
573 accessor
574 .get_old_column_value(idx)
575 .map_or(rusqlite::types::Value::Null, |v| v.into())
576 })
577 .collect(),
578 PreUpdateCase::Update {
579 new_value_accessor: accessor,
580 ..
581 } => (0..accessor.get_column_count())
582 .map(|idx| -> rusqlite::types::Value {
583 accessor
584 .get_new_column_value(idx)
585 .map_or(rusqlite::types::Value::Null, |v| v.into())
586 })
587 .collect(),
588 PreUpdateCase::Unknown => {
589 return None;
590 }
591 });
592}
593
594pub struct LockGuard<'a> {
595 guard: parking_lot::RwLockWriteGuard<'a, Vec<rusqlite::Connection>>,
596}
597
598impl Deref for LockGuard<'_> {
599 type Target = rusqlite::Connection;
600 #[inline]
601 fn deref(&self) -> &rusqlite::Connection {
602 return &self.guard.deref()[0];
603 }
604}
605
606impl DerefMut for LockGuard<'_> {
607 #[inline]
608 fn deref_mut(&mut self) -> &mut rusqlite::Connection {
609 return &mut self.guard.deref_mut()[0];
610 }
611}
612
613#[cfg(test)]
614#[path = "tests.rs"]
615mod tests;