trailbase_sqlite/
connection.rs

1use crossbeam_channel::{Receiver, Sender};
2use rusqlite::fallible_iterator::FallibleIterator;
3use rusqlite::hooks::{Action, PreUpdateCase};
4use rusqlite::types::Value;
5use std::{
6  fmt::{self, Debug},
7  sync::Arc,
8};
9use tokio::sync::oneshot;
10
11use crate::error::Error;
12pub use crate::params::Params;
13use crate::rows::{columns, Column};
14pub use crate::rows::{Row, Rows};
15
16#[macro_export]
17macro_rules! params {
18    () => {
19        [] as [$crate::params::ToSqlType]
20    };
21    ($($param:expr),+ $(,)?) => {
22        [$(Into::<$crate::params::ToSqlType>::into($param)),+]
23    };
24}
25
26#[macro_export]
27macro_rules! named_params {
28    () => {
29        [] as [(&str, $crate::params::ToSqlType)]
30    };
31    ($($param_name:literal: $param_val:expr),+ $(,)?) => {
32        [$(($param_name as &str, Into::<$crate::params::ToSqlType>::into($param_val))),+]
33    };
34}
35
36/// The result returned on method calls in this crate.
37pub type Result<T> = std::result::Result<T, Error>;
38
39type CallFn = Box<dyn FnOnce(&mut rusqlite::Connection) + Send + 'static>;
40
41enum Message {
42  Run(CallFn),
43  Close(oneshot::Sender<std::result::Result<(), rusqlite::Error>>),
44}
45
46/// A handle to call functions in background thread.
47#[derive(Clone)]
48pub struct Connection {
49  sender: Sender<Message>,
50}
51
52impl Connection {
53  pub fn from_conn(conn: rusqlite::Connection) -> Result<Self> {
54    let (sender, receiver) = crossbeam_channel::unbounded::<Message>();
55    std::thread::spawn(move || event_loop(conn, receiver));
56    return Ok(Self { sender });
57  }
58
59  /// Open a new connection to an in-memory SQLite database.
60  ///
61  /// # Failure
62  ///
63  /// Will return `Err` if the underlying SQLite open call fails.
64  pub fn open_in_memory() -> Result<Self> {
65    return Self::from_conn(rusqlite::Connection::open_in_memory()?);
66  }
67
68  /// Call a function in background thread and get the result
69  /// asynchronously.
70  ///
71  /// # Failure
72  ///
73  /// Will return `Err` if the database connection has been closed.
74  pub async fn call<F, R>(&self, function: F) -> Result<R>
75  where
76    F: FnOnce(&mut rusqlite::Connection) -> Result<R> + Send + 'static,
77    R: Send + 'static,
78  {
79    let (sender, receiver) = oneshot::channel::<Result<R>>();
80
81    self
82      .sender
83      .send(Message::Run(Box::new(move |conn| {
84        let value = function(conn);
85        let _ = sender.send(value);
86      })))
87      .map_err(|_| Error::ConnectionClosed)?;
88
89    receiver.await.map_err(|_| Error::ConnectionClosed)?
90  }
91
92  pub fn call_and_forget(&self, function: impl FnOnce(&rusqlite::Connection) + Send + 'static) {
93    let _ = self
94      .sender
95      .send(Message::Run(Box::new(move |conn| function(conn))));
96  }
97
98  /// Query SQL statement.
99  pub async fn query(&self, sql: &str, params: impl Params + Send + 'static) -> Result<Rows> {
100    let sql = sql.to_string();
101    return self
102      .call(move |conn: &mut rusqlite::Connection| {
103        let mut stmt = conn.prepare(&sql)?;
104        params.bind(&mut stmt)?;
105        let rows = stmt.raw_query();
106        Ok(Rows::from_rows(rows)?)
107      })
108      .await;
109  }
110
111  pub async fn query_row(
112    &self,
113    sql: &str,
114    params: impl Params + Send + 'static,
115  ) -> Result<Option<Row>> {
116    let sql = sql.to_string();
117    return self
118      .call(move |conn: &mut rusqlite::Connection| {
119        let mut stmt = conn.prepare(&sql)?;
120        params.bind(&mut stmt)?;
121        let mut rows = stmt.raw_query();
122        if let Some(row) = rows.next()? {
123          return Ok(Some(Row::from_row(row, None)?));
124        }
125        Ok(None)
126      })
127      .await;
128  }
129
130  pub async fn query_value<T: serde::de::DeserializeOwned + Send + 'static>(
131    &self,
132    sql: &str,
133    params: impl Params + Send + 'static,
134  ) -> Result<Option<T>> {
135    let sql = sql.to_string();
136    return self
137      .call(move |conn: &mut rusqlite::Connection| {
138        let mut stmt = conn.prepare(&sql)?;
139        params.bind(&mut stmt)?;
140        let mut rows = stmt.raw_query();
141        if let Some(row) = rows.next()? {
142          return Ok(Some(serde_rusqlite::from_row(row)?));
143        }
144        Ok(None)
145      })
146      .await;
147  }
148
149  pub async fn query_values<T: serde::de::DeserializeOwned + Send + 'static>(
150    &self,
151    sql: &str,
152    params: impl Params + Send + 'static,
153  ) -> Result<Vec<T>> {
154    let sql = sql.to_string();
155    return self
156      .call(move |conn: &mut rusqlite::Connection| {
157        let mut stmt = conn.prepare(&sql)?;
158        params.bind(&mut stmt)?;
159        let mut rows = stmt.raw_query();
160
161        let mut values = vec![];
162        while let Some(row) = rows.next()? {
163          values.push(serde_rusqlite::from_row(row)?);
164        }
165        return Ok(values);
166      })
167      .await;
168  }
169
170  /// Execute SQL statement.
171  pub async fn execute(&self, sql: &str, params: impl Params + Send + 'static) -> Result<usize> {
172    let sql = sql.to_string();
173    return self
174      .call(move |conn: &mut rusqlite::Connection| {
175        let mut stmt = conn.prepare(&sql)?;
176        params.bind(&mut stmt)?;
177        Ok(stmt.raw_execute()?)
178      })
179      .await;
180  }
181
182  /// Batch execute SQL statements and return rows of last statement.
183  pub async fn execute_batch(&self, sql: &str) -> Result<Option<Rows>> {
184    let sql = sql.to_string();
185    return self
186      .call(move |conn: &mut rusqlite::Connection| {
187        let batch = rusqlite::Batch::new(conn, &sql);
188
189        let mut p = batch.peekable();
190        while let Ok(Some(mut stmt)) = p.next() {
191          let mut rows = stmt.raw_query();
192          let row = rows.next()?;
193
194          match p.peek() {
195            Err(_) | Ok(None) => {
196              if let Some(row) = row {
197                let cols: Arc<Vec<Column>> = Arc::new(columns(row.as_ref()));
198
199                let mut result = vec![Row::from_row(row, Some(cols.clone()))?];
200                while let Some(row) = rows.next()? {
201                  result.push(Row::from_row(row, Some(cols.clone()))?);
202                }
203                return Ok(Some(Rows(result, cols)));
204              }
205              return Ok(None);
206            }
207            _ => {}
208          }
209        }
210        return Ok(None);
211      })
212      .await;
213  }
214
215  /// Convenience API for (un)setting a new pre-update hook.
216  pub async fn add_preupdate_hook(
217    &self,
218    hook: Option<impl (Fn(Action, &str, &str, &PreUpdateCase)) + Send + Sync + 'static>,
219  ) -> Result<()> {
220    return self
221      .call(|conn| {
222        conn.preupdate_hook(hook);
223        return Ok(());
224      })
225      .await;
226  }
227
228  /// Close the database connection.
229  ///
230  /// This is functionally equivalent to the `Drop` implementation for
231  /// `Connection`. It consumes the `Connection`, but on error returns it
232  /// to the caller for retry purposes.
233  ///
234  /// If successful, any following `close` operations performed
235  /// on `Connection` copies will succeed immediately.
236  ///
237  /// On the other hand, any calls to [`Connection::call`] will return a
238  /// [`Error::ConnectionClosed`], and any calls to [`Connection::call_unwrap`] will cause a
239  /// `panic`.
240  ///
241  /// # Failure
242  ///
243  /// Will return `Err` if the underlying SQLite close call fails.
244  pub async fn close(self) -> Result<()> {
245    let (sender, receiver) = oneshot::channel::<std::result::Result<(), rusqlite::Error>>();
246
247    if let Err(crossbeam_channel::SendError(_)) = self.sender.send(Message::Close(sender)) {
248      // If the channel is closed on the other side, it means the connection closed successfully
249      // This is a safeguard against calling close on a `Copy` of the connection
250      return Ok(());
251    }
252
253    let Ok(result) = receiver.await else {
254      // If we get a RecvError at this point, it also means the channel closed in the meantime
255      // we can assume the connection is closed
256      return Ok(());
257    };
258
259    return result.map_err(|e| Error::Close(self, e));
260  }
261}
262
263impl Debug for Connection {
264  fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
265    f.debug_struct("Connection").finish()
266  }
267}
268
269fn event_loop(mut conn: rusqlite::Connection, receiver: Receiver<Message>) {
270  const BUG_TEXT: &str = "bug in trailbase-sqlite, please report";
271
272  while let Ok(message) = receiver.recv() {
273    match message {
274      Message::Run(f) => f(&mut conn),
275      Message::Close(ch) => {
276        match conn.close() {
277          Ok(v) => ch.send(Ok(v)).expect(BUG_TEXT),
278          Err((_conn, e)) => ch.send(Err(e)).expect(BUG_TEXT),
279        };
280
281        return;
282      }
283    }
284  }
285}
286
287pub fn extract_row_id(case: &PreUpdateCase) -> Option<i64> {
288  return match case {
289    PreUpdateCase::Insert(accessor) => Some(accessor.get_new_row_id()),
290    PreUpdateCase::Delete(accessor) => Some(accessor.get_old_row_id()),
291    PreUpdateCase::Update {
292      new_value_accessor: accessor,
293      ..
294    } => Some(accessor.get_new_row_id()),
295    PreUpdateCase::Unknown => None,
296  };
297}
298
299pub fn extract_record_values(case: &PreUpdateCase) -> Option<Vec<Value>> {
300  return Some(match case {
301    PreUpdateCase::Insert(accessor) => (0..accessor.get_column_count())
302      .map(|idx| -> Value {
303        accessor
304          .get_new_column_value(idx)
305          .map_or(rusqlite::types::Value::Null, |v| v.into())
306      })
307      .collect(),
308    PreUpdateCase::Delete(accessor) => (0..accessor.get_column_count())
309      .map(|idx| -> rusqlite::types::Value {
310        accessor
311          .get_old_column_value(idx)
312          .map_or(rusqlite::types::Value::Null, |v| v.into())
313      })
314      .collect(),
315    PreUpdateCase::Update {
316      new_value_accessor: accessor,
317      ..
318    } => (0..accessor.get_column_count())
319      .map(|idx| -> rusqlite::types::Value {
320        accessor
321          .get_new_column_value(idx)
322          .map_or(rusqlite::types::Value::Null, |v| v.into())
323      })
324      .collect(),
325    PreUpdateCase::Unknown => {
326      return None;
327    }
328  });
329}
330
331#[cfg(test)]
332#[path = "tests.rs"]
333mod tests;