trailbase_sqlite/
connection.rs1use crossbeam_channel::{Receiver, Sender};
2use rusqlite::fallible_iterator::FallibleIterator;
3use rusqlite::hooks::{Action, PreUpdateCase};
4use rusqlite::types::Value;
5use std::{
6 fmt::{self, Debug},
7 sync::Arc,
8};
9use tokio::sync::oneshot;
10
11use crate::error::Error;
12pub use crate::params::Params;
13use crate::rows::{columns, Column};
14pub use crate::rows::{Row, Rows};
15
16#[macro_export]
17macro_rules! params {
18 () => {
19 [] as [$crate::params::ToSqlType]
20 };
21 ($($param:expr),+ $(,)?) => {
22 [$(Into::<$crate::params::ToSqlType>::into($param)),+]
23 };
24}
25
26#[macro_export]
27macro_rules! named_params {
28 () => {
29 [] as [(&str, $crate::params::ToSqlType)]
30 };
31 ($($param_name:literal: $param_val:expr),+ $(,)?) => {
32 [$(($param_name as &str, Into::<$crate::params::ToSqlType>::into($param_val))),+]
33 };
34}
35
36pub type Result<T> = std::result::Result<T, Error>;
38
39type CallFn = Box<dyn FnOnce(&mut rusqlite::Connection) + Send + 'static>;
40
41enum Message {
42 Run(CallFn),
43 Close(oneshot::Sender<std::result::Result<(), rusqlite::Error>>),
44}
45
46#[derive(Clone)]
48pub struct Connection {
49 sender: Sender<Message>,
50}
51
52impl Connection {
53 pub fn from_conn(conn: rusqlite::Connection) -> Result<Self> {
54 let (sender, receiver) = crossbeam_channel::unbounded::<Message>();
55 std::thread::spawn(move || event_loop(conn, receiver));
56 return Ok(Self { sender });
57 }
58
59 pub fn open_in_memory() -> Result<Self> {
65 return Self::from_conn(rusqlite::Connection::open_in_memory()?);
66 }
67
68 pub async fn call<F, R>(&self, function: F) -> Result<R>
75 where
76 F: FnOnce(&mut rusqlite::Connection) -> Result<R> + Send + 'static,
77 R: Send + 'static,
78 {
79 let (sender, receiver) = oneshot::channel::<Result<R>>();
80
81 self
82 .sender
83 .send(Message::Run(Box::new(move |conn| {
84 let value = function(conn);
85 let _ = sender.send(value);
86 })))
87 .map_err(|_| Error::ConnectionClosed)?;
88
89 receiver.await.map_err(|_| Error::ConnectionClosed)?
90 }
91
92 pub fn call_and_forget(&self, function: impl FnOnce(&rusqlite::Connection) + Send + 'static) {
93 let _ = self
94 .sender
95 .send(Message::Run(Box::new(move |conn| function(conn))));
96 }
97
98 pub async fn query(&self, sql: &str, params: impl Params + Send + 'static) -> Result<Rows> {
100 let sql = sql.to_string();
101 return self
102 .call(move |conn: &mut rusqlite::Connection| {
103 let mut stmt = conn.prepare(&sql)?;
104 params.bind(&mut stmt)?;
105 let rows = stmt.raw_query();
106 Ok(Rows::from_rows(rows)?)
107 })
108 .await;
109 }
110
111 pub async fn query_row(
112 &self,
113 sql: &str,
114 params: impl Params + Send + 'static,
115 ) -> Result<Option<Row>> {
116 let sql = sql.to_string();
117 return self
118 .call(move |conn: &mut rusqlite::Connection| {
119 let mut stmt = conn.prepare(&sql)?;
120 params.bind(&mut stmt)?;
121 let mut rows = stmt.raw_query();
122 if let Some(row) = rows.next()? {
123 return Ok(Some(Row::from_row(row, None)?));
124 }
125 Ok(None)
126 })
127 .await;
128 }
129
130 pub async fn query_value<T: serde::de::DeserializeOwned + Send + 'static>(
131 &self,
132 sql: &str,
133 params: impl Params + Send + 'static,
134 ) -> Result<Option<T>> {
135 let sql = sql.to_string();
136 return self
137 .call(move |conn: &mut rusqlite::Connection| {
138 let mut stmt = conn.prepare(&sql)?;
139 params.bind(&mut stmt)?;
140 let mut rows = stmt.raw_query();
141 if let Some(row) = rows.next()? {
142 return Ok(Some(serde_rusqlite::from_row(row)?));
143 }
144 Ok(None)
145 })
146 .await;
147 }
148
149 pub async fn query_values<T: serde::de::DeserializeOwned + Send + 'static>(
150 &self,
151 sql: &str,
152 params: impl Params + Send + 'static,
153 ) -> Result<Vec<T>> {
154 let sql = sql.to_string();
155 return self
156 .call(move |conn: &mut rusqlite::Connection| {
157 let mut stmt = conn.prepare(&sql)?;
158 params.bind(&mut stmt)?;
159 let mut rows = stmt.raw_query();
160
161 let mut values = vec![];
162 while let Some(row) = rows.next()? {
163 values.push(serde_rusqlite::from_row(row)?);
164 }
165 return Ok(values);
166 })
167 .await;
168 }
169
170 pub async fn execute(&self, sql: &str, params: impl Params + Send + 'static) -> Result<usize> {
172 let sql = sql.to_string();
173 return self
174 .call(move |conn: &mut rusqlite::Connection| {
175 let mut stmt = conn.prepare(&sql)?;
176 params.bind(&mut stmt)?;
177 Ok(stmt.raw_execute()?)
178 })
179 .await;
180 }
181
182 pub async fn execute_batch(&self, sql: &str) -> Result<Option<Rows>> {
184 let sql = sql.to_string();
185 return self
186 .call(move |conn: &mut rusqlite::Connection| {
187 let batch = rusqlite::Batch::new(conn, &sql);
188
189 let mut p = batch.peekable();
190 while let Ok(Some(mut stmt)) = p.next() {
191 let mut rows = stmt.raw_query();
192 let row = rows.next()?;
193
194 match p.peek() {
195 Err(_) | Ok(None) => {
196 if let Some(row) = row {
197 let cols: Arc<Vec<Column>> = Arc::new(columns(row.as_ref()));
198
199 let mut result = vec![Row::from_row(row, Some(cols.clone()))?];
200 while let Some(row) = rows.next()? {
201 result.push(Row::from_row(row, Some(cols.clone()))?);
202 }
203 return Ok(Some(Rows(result, cols)));
204 }
205 return Ok(None);
206 }
207 _ => {}
208 }
209 }
210 return Ok(None);
211 })
212 .await;
213 }
214
215 pub async fn add_preupdate_hook(
217 &self,
218 hook: Option<impl (Fn(Action, &str, &str, &PreUpdateCase)) + Send + Sync + 'static>,
219 ) -> Result<()> {
220 return self
221 .call(|conn| {
222 conn.preupdate_hook(hook);
223 return Ok(());
224 })
225 .await;
226 }
227
228 pub async fn close(self) -> Result<()> {
245 let (sender, receiver) = oneshot::channel::<std::result::Result<(), rusqlite::Error>>();
246
247 if let Err(crossbeam_channel::SendError(_)) = self.sender.send(Message::Close(sender)) {
248 return Ok(());
251 }
252
253 let Ok(result) = receiver.await else {
254 return Ok(());
257 };
258
259 return result.map_err(|e| Error::Close(self, e));
260 }
261}
262
263impl Debug for Connection {
264 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
265 f.debug_struct("Connection").finish()
266 }
267}
268
269fn event_loop(mut conn: rusqlite::Connection, receiver: Receiver<Message>) {
270 const BUG_TEXT: &str = "bug in trailbase-sqlite, please report";
271
272 while let Ok(message) = receiver.recv() {
273 match message {
274 Message::Run(f) => f(&mut conn),
275 Message::Close(ch) => {
276 match conn.close() {
277 Ok(v) => ch.send(Ok(v)).expect(BUG_TEXT),
278 Err((_conn, e)) => ch.send(Err(e)).expect(BUG_TEXT),
279 };
280
281 return;
282 }
283 }
284 }
285}
286
287pub fn extract_row_id(case: &PreUpdateCase) -> Option<i64> {
288 return match case {
289 PreUpdateCase::Insert(accessor) => Some(accessor.get_new_row_id()),
290 PreUpdateCase::Delete(accessor) => Some(accessor.get_old_row_id()),
291 PreUpdateCase::Update {
292 new_value_accessor: accessor,
293 ..
294 } => Some(accessor.get_new_row_id()),
295 PreUpdateCase::Unknown => None,
296 };
297}
298
299pub fn extract_record_values(case: &PreUpdateCase) -> Option<Vec<Value>> {
300 return Some(match case {
301 PreUpdateCase::Insert(accessor) => (0..accessor.get_column_count())
302 .map(|idx| -> Value {
303 accessor
304 .get_new_column_value(idx)
305 .map_or(rusqlite::types::Value::Null, |v| v.into())
306 })
307 .collect(),
308 PreUpdateCase::Delete(accessor) => (0..accessor.get_column_count())
309 .map(|idx| -> rusqlite::types::Value {
310 accessor
311 .get_old_column_value(idx)
312 .map_or(rusqlite::types::Value::Null, |v| v.into())
313 })
314 .collect(),
315 PreUpdateCase::Update {
316 new_value_accessor: accessor,
317 ..
318 } => (0..accessor.get_column_count())
319 .map(|idx| -> rusqlite::types::Value {
320 accessor
321 .get_new_column_value(idx)
322 .map_or(rusqlite::types::Value::Null, |v| v.into())
323 })
324 .collect(),
325 PreUpdateCase::Unknown => {
326 return None;
327 }
328 });
329}
330
331#[cfg(test)]
332#[path = "tests.rs"]
333mod tests;