trailbase_sqlite/
connection.rs

1use kanal::{Receiver, Sender};
2use log::*;
3use parking_lot::RwLock;
4use rusqlite::fallible_iterator::FallibleIterator;
5use rusqlite::hooks::{Action, PreUpdateCase};
6use rusqlite::types::Value;
7use std::ops::{Deref, DerefMut};
8use std::{
9  fmt::{self, Debug},
10  sync::Arc,
11};
12use tokio::sync::oneshot;
13
14use crate::error::Error;
15pub use crate::params::Params;
16use crate::rows::{Column, columns};
17pub use crate::rows::{Row, Rows};
18
19#[macro_export]
20macro_rules! params {
21    () => {
22        [] as [$crate::params::ToSqlType]
23    };
24    ($($param:expr),+ $(,)?) => {
25        [$(Into::<$crate::params::ToSqlType>::into($param)),+]
26    };
27}
28
29#[macro_export]
30macro_rules! named_params {
31    () => {
32        [] as [(&str, $crate::params::ToSqlType)]
33    };
34    ($($param_name:literal: $param_val:expr),+ $(,)?) => {
35        [$(($param_name as &str, Into::<$crate::params::ToSqlType>::into($param_val))),+]
36    };
37}
38
39#[derive(Clone, Debug, PartialEq, serde::Deserialize)]
40pub struct Database {
41  pub seq: u8,
42  pub name: String,
43}
44
45struct LockedConnections(RwLock<Vec<rusqlite::Connection>>);
46
47// NOTE: We must never access the same connection concurrently even as &Connection, due to
48// Statement cache. We can ensure this by uniquely assigning one connection to each thread.
49unsafe impl Sync for LockedConnections {}
50
51/// The result returned on method calls in this crate.
52pub type Result<T> = std::result::Result<T, Error>;
53
54enum Message {
55  RunMut(Box<dyn FnOnce(&mut rusqlite::Connection) + Send + 'static>),
56  RunConst(Box<dyn FnOnce(&rusqlite::Connection) + Send + 'static>),
57  Terminate,
58}
59
60#[derive(Clone)]
61pub struct Options {
62  pub busy_timeout: std::time::Duration,
63  pub n_read_threads: usize,
64}
65
66impl Default for Options {
67  fn default() -> Self {
68    return Self {
69      busy_timeout: std::time::Duration::from_secs(5),
70      n_read_threads: 0,
71    };
72  }
73}
74
75/// A handle to call functions in background thread.
76#[derive(Clone)]
77pub struct Connection {
78  reader: Sender<Message>,
79  writer: Sender<Message>,
80  conns: Arc<LockedConnections>,
81}
82
83impl Connection {
84  pub fn new<E>(
85    builder: impl Fn() -> std::result::Result<rusqlite::Connection, E>,
86    opt: Option<Options>,
87  ) -> std::result::Result<Self, E> {
88    let new_conn = || -> std::result::Result<rusqlite::Connection, E> {
89      let conn = builder()?;
90      if let Some(timeout) = opt.as_ref().map(|o| o.busy_timeout) {
91        conn.busy_timeout(timeout).expect("busy timeout failed");
92      }
93      return Ok(conn);
94    };
95
96    let conn = new_conn()?;
97    let name = conn.path().and_then(|s| {
98      // Returns empty string for in-memory databases.
99      if s.is_empty() {
100        None
101      } else {
102        Some(s.to_string())
103      }
104    });
105
106    let n_read_threads = if name.is_some() {
107      let n_read_threads = match opt.as_ref().map_or(0, |o| o.n_read_threads) {
108        1 => {
109          warn!(
110            "Using a single dedicated reader thread won't improve performance, falling back to 0."
111          );
112          0
113        }
114        n => n,
115      };
116
117      if let Ok(n) = std::thread::available_parallelism() {
118        if n_read_threads > n.get() {
119          debug!(
120            "Using {n_read_threads} exceeding hardware parallelism: {}",
121            n.get()
122          );
123        }
124      }
125
126      n_read_threads
127    } else {
128      // We cannot share an in-memory database across threads, they're all independent.
129      0
130    };
131
132    let conns = {
133      let mut conns = vec![conn];
134      for _ in 0..n_read_threads {
135        conns.push(new_conn()?);
136      }
137
138      Arc::new(LockedConnections(RwLock::new(conns)))
139    };
140
141    // Spawn writer.
142    let (shared_write_sender, shared_write_receiver) = kanal::unbounded::<Message>();
143    let conns_clone = conns.clone();
144    std::thread::spawn(move || event_loop(0, conns_clone, shared_write_receiver));
145
146    let shared_read_sender = if n_read_threads > 0 {
147      let (shared_read_sender, shared_read_receiver) = kanal::unbounded::<Message>();
148      for i in 0..n_read_threads {
149        let shared_read_receiver = shared_read_receiver.clone();
150        let conns_clone = conns.clone();
151        std::thread::spawn(move || event_loop(i, conns_clone, shared_read_receiver));
152      }
153      shared_read_sender
154    } else {
155      shared_write_sender.clone()
156    };
157
158    debug!(
159      "Opened SQLite DB '{name}' with {n_read_threads} dedicated reader threads",
160      name = name.as_deref().unwrap_or("<in-memory>")
161    );
162
163    return Ok(Self {
164      reader: shared_read_sender,
165      writer: shared_write_sender,
166      conns,
167    });
168  }
169
170  pub fn from_connection_test_only(conn: rusqlite::Connection) -> Self {
171    use parking_lot::lock_api::RwLock;
172
173    let (shared_write_sender, shared_write_receiver) = kanal::unbounded::<Message>();
174    let conns = Arc::new(LockedConnections(RwLock::new(vec![conn])));
175    let conns_clone = conns.clone();
176    std::thread::spawn(move || event_loop(0, conns_clone, shared_write_receiver));
177
178    return Self {
179      reader: shared_write_sender.clone(),
180      writer: shared_write_sender,
181      conns,
182    };
183  }
184
185  /// Open a new connection to an in-memory SQLite database.
186  ///
187  /// # Failure
188  ///
189  /// Will return `Err` if the underlying SQLite open call fails.
190  pub fn open_in_memory() -> Result<Self> {
191    return Self::new(|| Ok(rusqlite::Connection::open_in_memory()?), None);
192  }
193
194  #[inline]
195  pub fn write_lock(&self) -> LockGuard<'_> {
196    return LockGuard {
197      guard: self.conns.0.write(),
198    };
199  }
200
201  #[inline]
202  pub fn try_write_lock_for(&self, duration: tokio::time::Duration) -> Option<LockGuard<'_>> {
203    return self
204      .conns
205      .0
206      .try_write_for(duration)
207      .map(|guard| LockGuard { guard });
208  }
209
210  /// Call a function in background thread and get the result
211  /// asynchronously.
212  ///
213  /// # Failure
214  ///
215  /// Will return `Err` if the database connection has been closed.
216  #[inline]
217  pub async fn call<F, R>(&self, function: F) -> Result<R>
218  where
219    F: FnOnce(&mut rusqlite::Connection) -> Result<R> + Send + 'static,
220    R: Send + 'static,
221  {
222    // return call_impl(&self.writer, function).await;
223    let (sender, receiver) = oneshot::channel::<Result<R>>();
224
225    self
226      .writer
227      .send(Message::RunMut(Box::new(move |conn| {
228        if !sender.is_closed() {
229          let _ = sender.send(function(conn));
230        }
231      })))
232      .map_err(|_| Error::ConnectionClosed)?;
233
234    receiver.await.map_err(|_| Error::ConnectionClosed)?
235  }
236
237  #[inline]
238  pub fn call_and_forget(&self, function: impl FnOnce(&rusqlite::Connection) + Send + 'static) {
239    let _ = self
240      .writer
241      .send(Message::RunMut(Box::new(move |conn| function(conn))));
242  }
243
244  #[inline]
245  async fn call_reader<F, R>(&self, function: F) -> Result<R>
246  where
247    F: FnOnce(&rusqlite::Connection) -> Result<R> + Send + 'static,
248    R: Send + 'static,
249  {
250    let (sender, receiver) = oneshot::channel::<Result<R>>();
251
252    self
253      .reader
254      .send(Message::RunConst(Box::new(move |conn| {
255        if !sender.is_closed() {
256          let _ = sender.send(function(conn));
257        }
258      })))
259      .map_err(|_| Error::ConnectionClosed)?;
260
261    receiver.await.map_err(|_| Error::ConnectionClosed)?
262  }
263
264  /// Query SQL statement.
265  pub async fn read_query_rows(
266    &self,
267    sql: impl AsRef<str> + Send + 'static,
268    params: impl Params + Send + 'static,
269  ) -> Result<Rows> {
270    return self
271      .call_reader(move |conn: &rusqlite::Connection| {
272        let mut stmt = conn.prepare_cached(sql.as_ref())?;
273        assert!(stmt.readonly());
274
275        params.bind(&mut stmt)?;
276        let rows = stmt.raw_query();
277        Ok(Rows::from_rows(rows)?)
278      })
279      .await;
280  }
281
282  pub async fn write_query_rows(
283    &self,
284    sql: impl AsRef<str> + Send + 'static,
285    params: impl Params + Send + 'static,
286  ) -> Result<Rows> {
287    return self
288      .call(move |conn: &mut rusqlite::Connection| {
289        let mut stmt = conn.prepare_cached(sql.as_ref())?;
290
291        params.bind(&mut stmt)?;
292        let rows = stmt.raw_query();
293        Ok(Rows::from_rows(rows)?)
294      })
295      .await;
296  }
297
298  pub async fn read_query_row(
299    &self,
300    sql: impl AsRef<str> + Send + 'static,
301    params: impl Params + Send + 'static,
302  ) -> Result<Option<Row>> {
303    return self
304      .read_query_row_f(sql, params, |row| Row::from_row(row, None))
305      .await;
306  }
307
308  #[inline]
309  pub async fn query_row_f<T, E>(
310    &self,
311    sql: impl AsRef<str> + Send + 'static,
312    params: impl Params + Send + 'static,
313    f: impl (FnOnce(&rusqlite::Row<'_>) -> std::result::Result<T, E>) + Send + 'static,
314  ) -> Result<Option<T>>
315  where
316    T: Send + 'static,
317    crate::error::Error: From<E>,
318  {
319    return self
320      .call(move |conn: &mut rusqlite::Connection| {
321        let mut stmt = conn.prepare_cached(sql.as_ref())?;
322        params.bind(&mut stmt)?;
323
324        let mut rows = stmt.raw_query();
325
326        if let Some(row) = rows.next()? {
327          return Ok(Some(f(row)?));
328        }
329        Ok(None)
330      })
331      .await;
332  }
333
334  #[inline]
335  pub async fn read_query_row_f<T, E>(
336    &self,
337    sql: impl AsRef<str> + Send + 'static,
338    params: impl Params + Send + 'static,
339    f: impl (FnOnce(&rusqlite::Row<'_>) -> std::result::Result<T, E>) + Send + 'static,
340  ) -> Result<Option<T>>
341  where
342    T: Send + 'static,
343    crate::error::Error: From<E>,
344  {
345    return self
346      .call_reader(move |conn: &rusqlite::Connection| {
347        let mut stmt = conn.prepare_cached(sql.as_ref())?;
348        assert!(stmt.readonly());
349
350        params.bind(&mut stmt)?;
351
352        let mut rows = stmt.raw_query();
353
354        if let Some(row) = rows.next()? {
355          return Ok(Some(f(row)?));
356        }
357        Ok(None)
358      })
359      .await;
360  }
361
362  pub async fn read_query_value<T: serde::de::DeserializeOwned + Send + 'static>(
363    &self,
364    sql: impl AsRef<str> + Send + 'static,
365    params: impl Params + Send + 'static,
366  ) -> Result<Option<T>> {
367    return self
368      .read_query_row_f(sql, params, serde_rusqlite::from_row)
369      .await;
370  }
371
372  pub async fn write_query_value<T: serde::de::DeserializeOwned + Send + 'static>(
373    &self,
374    sql: impl AsRef<str> + Send + 'static,
375    params: impl Params + Send + 'static,
376  ) -> Result<Option<T>> {
377    return self
378      .query_row_f(sql, params, serde_rusqlite::from_row)
379      .await;
380  }
381
382  pub async fn read_query_values<T: serde::de::DeserializeOwned + Send + 'static>(
383    &self,
384    sql: impl AsRef<str> + Send + 'static,
385    params: impl Params + Send + 'static,
386  ) -> Result<Vec<T>> {
387    return self
388      .call_reader(move |conn: &rusqlite::Connection| {
389        let mut stmt = conn.prepare_cached(sql.as_ref())?;
390        assert!(stmt.readonly());
391
392        params.bind(&mut stmt)?;
393        let mut rows = stmt.raw_query();
394
395        let mut values = vec![];
396        while let Some(row) = rows.next()? {
397          values.push(serde_rusqlite::from_row(row)?);
398        }
399        return Ok(values);
400      })
401      .await;
402  }
403
404  /// Execute SQL statement.
405  pub async fn execute(
406    &self,
407    sql: impl AsRef<str> + Send + 'static,
408    params: impl Params + Send + 'static,
409  ) -> Result<usize> {
410    return self
411      .call(move |conn: &mut rusqlite::Connection| {
412        let mut stmt = conn.prepare_cached(sql.as_ref())?;
413        params.bind(&mut stmt)?;
414
415        let n = stmt.raw_execute()?;
416
417        return Ok(n);
418      })
419      .await;
420  }
421
422  /// Batch execute SQL statements and return rows of last statement.
423  pub async fn execute_batch(&self, sql: impl AsRef<str> + Send + 'static) -> Result<Option<Rows>> {
424    return self
425      .call(move |conn: &mut rusqlite::Connection| {
426        let batch = rusqlite::Batch::new(conn, sql.as_ref());
427
428        let mut p = batch.peekable();
429        while let Some(mut stmt) = p.next()? {
430          let mut rows = stmt.raw_query();
431          let row = rows.next()?;
432
433          match p.peek()? {
434            Some(_) => {}
435            None => {
436              if let Some(row) = row {
437                let cols: Arc<Vec<Column>> = Arc::new(columns(row.as_ref()));
438
439                let mut result = vec![Row::from_row(row, Some(cols.clone()))?];
440                while let Some(row) = rows.next()? {
441                  result.push(Row::from_row(row, Some(cols.clone()))?);
442                }
443                return Ok(Some(Rows(result, cols)));
444              }
445
446              return Ok(None);
447            }
448          }
449        }
450
451        return Ok(None);
452      })
453      .await;
454  }
455
456  /// Convenience API for (un)setting a new pre-update hook.
457  pub async fn add_preupdate_hook(
458    &self,
459    hook: Option<impl (Fn(Action, &str, &str, &PreUpdateCase)) + Send + Sync + 'static>,
460  ) -> Result<()> {
461    return self
462      .call(move |conn| {
463        conn.preupdate_hook(hook);
464        return Ok(());
465      })
466      .await;
467  }
468
469  pub fn attach(&self, path: &str, name: &str) -> Result<()> {
470    let lock = self.conns.0.write();
471    for conn in &*lock {
472      conn.execute(&format!("ATTACH DATABASE '{path}' AS {name} "), ())?;
473    }
474    return Ok(());
475  }
476
477  pub async fn list_databases(&self) -> Result<Vec<Database>> {
478    return self
479      .call(|conn| {
480        let mut stmt = conn.prepare("SELECT seq, name FROM pragma_database_list")?;
481        let mut rows = stmt.raw_query();
482
483        let mut databases: Vec<Database> = vec![];
484        while let Some(row) = rows.next()? {
485          databases.push(serde_rusqlite::from_row(row)?)
486        }
487        return Ok(databases);
488      })
489      .await;
490  }
491
492  /// Close the database connection.
493  ///
494  /// This is functionally equivalent to the `Drop` implementation for `Connection`. It consumes
495  /// the `Connection`, but on error returns it to the caller for retry purposes.
496  ///
497  /// If successful, any following `close` operations performed on `Connection` copies will succeed
498  /// immediately.
499  ///
500  /// # Failure
501  ///
502  /// Will return `Err` if the underlying SQLite close call fails.
503  pub async fn close(self) -> Result<()> {
504    let _ = self.writer.send(Message::Terminate);
505    while self.reader.send(Message::Terminate).is_ok() {
506      // Continue to close readers while the channel is alive.
507    }
508
509    let mut errors = vec![];
510    let conns: Vec<_> = std::mem::take(&mut self.conns.0.write());
511    for conn in conns {
512      if let Err((_, err)) = conn.close() {
513        errors.push(err);
514      };
515    }
516
517    if !errors.is_empty() {
518      debug!("Closing connection: {errors:?}");
519      return Err(Error::Close(errors.swap_remove(0)));
520    }
521
522    return Ok(());
523  }
524}
525
526impl Debug for Connection {
527  fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
528    f.debug_struct("Connection").finish()
529  }
530}
531
532fn event_loop(id: usize, conns: Arc<LockedConnections>, receiver: Receiver<Message>) {
533  while let Ok(message) = receiver.recv() {
534    match message {
535      Message::RunConst(f) => {
536        let lock = conns.0.read();
537        f(&lock[id])
538      }
539      Message::RunMut(f) => {
540        let mut lock = conns.0.write();
541        f(&mut lock[0])
542      }
543      Message::Terminate => {
544        return;
545      }
546    };
547  }
548}
549
550pub fn extract_row_id(case: &PreUpdateCase) -> Option<i64> {
551  return match case {
552    PreUpdateCase::Insert(accessor) => Some(accessor.get_new_row_id()),
553    PreUpdateCase::Delete(accessor) => Some(accessor.get_old_row_id()),
554    PreUpdateCase::Update {
555      new_value_accessor: accessor,
556      ..
557    } => Some(accessor.get_new_row_id()),
558    PreUpdateCase::Unknown => None,
559  };
560}
561
562pub fn extract_record_values(case: &PreUpdateCase) -> Option<Vec<Value>> {
563  return Some(match case {
564    PreUpdateCase::Insert(accessor) => (0..accessor.get_column_count())
565      .map(|idx| -> Value {
566        accessor
567          .get_new_column_value(idx)
568          .map_or(rusqlite::types::Value::Null, |v| v.into())
569      })
570      .collect(),
571    PreUpdateCase::Delete(accessor) => (0..accessor.get_column_count())
572      .map(|idx| -> rusqlite::types::Value {
573        accessor
574          .get_old_column_value(idx)
575          .map_or(rusqlite::types::Value::Null, |v| v.into())
576      })
577      .collect(),
578    PreUpdateCase::Update {
579      new_value_accessor: accessor,
580      ..
581    } => (0..accessor.get_column_count())
582      .map(|idx| -> rusqlite::types::Value {
583        accessor
584          .get_new_column_value(idx)
585          .map_or(rusqlite::types::Value::Null, |v| v.into())
586      })
587      .collect(),
588    PreUpdateCase::Unknown => {
589      return None;
590    }
591  });
592}
593
594pub struct LockGuard<'a> {
595  guard: parking_lot::RwLockWriteGuard<'a, Vec<rusqlite::Connection>>,
596}
597
598impl Deref for LockGuard<'_> {
599  type Target = rusqlite::Connection;
600  #[inline]
601  fn deref(&self) -> &rusqlite::Connection {
602    return &self.guard.deref()[0];
603  }
604}
605
606impl DerefMut for LockGuard<'_> {
607  #[inline]
608  fn deref_mut(&mut self) -> &mut rusqlite::Connection {
609    return &mut self.guard.deref_mut()[0];
610  }
611}
612
613#[cfg(test)]
614#[path = "tests.rs"]
615mod tests;