spin_sqlite_connection/
rusqlite.rs1use std::{marker::PhantomData, str::from_utf8};
2
3use rusqlite::{params_from_iter, types::ValueRef, Error, Params};
4use spin_sdk::sqlite::{QueryResult, Row, RowResult, Value};
5
6pub struct SqliteConnection<E>
7where
8 E: From<Error>,
9{
10 inner: rusqlite::Connection,
11 phantomdata: PhantomData<E>,
12}
13
14impl<E: From<Error>> SqliteConnection<E> {
15 pub fn try_open_default(migrations: Option<&str>) -> Result<Self, E> {
16 let connection = rusqlite::Connection::open_in_memory()?;
17 if let Some(m) = migrations {
18 connection.execute_batch(m)?;
19 }
20 Ok(Self {
21 inner: connection,
22 phantomdata: PhantomData,
23 })
24 }
25
26 pub fn query<T>(&self, sql: impl AsRef<str>, parameters: &[Value]) -> Result<Vec<T>, E>
27 where
28 T: for<'a> TryFrom<Row<'a>, Error = E>,
29 {
30 self.query_result(sql, parameters)
31 .and_then(|query_result| query_result.rows().map(T::try_from).collect())
32 }
33
34 fn query_result<S>(&self, sql: S, parameters: &[Value]) -> Result<QueryResult, E>
35 where
36 S: AsRef<str>,
37 {
38 let mut prepared = self.inner.prepare(sql.as_ref())?;
39 let columns = prepared
40 .column_names()
41 .into_iter()
42 .map(String::from)
43 .collect::<Vec<_>>();
44 let rows = prepared
45 .query_map(rusqlite_parameters(parameters), |row| {
46 (0..columns.len())
47 .map(|i| row.get_ref(i).and_then(spin_sqlite_value))
48 .collect::<Result<Vec<_>, _>>()
49 .map(|values| RowResult { values })
50 })
51 .and_then(|mapped_rows| mapped_rows.collect::<Result<Vec<_>, Error>>())?;
52 Ok(QueryResult { columns, rows })
53 }
54
55 pub fn execute<S>(&self, sql: S, parameters: &[Value]) -> Result<i64, E>
56 where
57 S: AsRef<str>,
58 {
59 let count = self
60 .inner
61 .execute(sql.as_ref(), rusqlite_parameters(parameters))?;
62 Ok(count.try_into().unwrap())
63 }
64}
65
66fn spin_sqlite_value(value: ValueRef) -> Result<spin_sdk::sqlite::Value, rusqlite::Error> {
67 match value {
68 ValueRef::Blob(blob) => Ok(Value::Blob(blob.to_vec())),
69 ValueRef::Integer(integer) => Ok(Value::Integer(integer)),
70 ValueRef::Real(real) => Ok(Value::Real(real)),
71 ValueRef::Null => Ok(Value::Null),
72 ValueRef::Text(text) => from_utf8(text)
73 .map(String::from)
74 .map(Value::Text)
75 .map_err(Error::Utf8Error),
76 }
77}
78
79fn rusqlite_parameter(parameter: &Value) -> rusqlite::types::Value {
80 match parameter {
81 Value::Blob(blob) => rusqlite::types::Value::Blob(blob.clone()),
82 Value::Integer(integer) => rusqlite::types::Value::Integer(*integer),
83 Value::Null => rusqlite::types::Value::Null,
84 Value::Real(real) => rusqlite::types::Value::Real(*real),
85 Value::Text(text) => rusqlite::types::Value::Text(text.clone()),
86 }
87}
88
89fn rusqlite_parameters(parameters: &[Value]) -> impl Params + use<> {
90 params_from_iter(
91 parameters
92 .iter()
93 .map(rusqlite_parameter)
94 .collect::<Vec<_>>(),
95 )
96}
97
98#[cfg(test)]
99mod test {
100
101 #[test]
102 fn open_database() {
103 super::SqliteConnection::<Box<dyn std::error::Error>>::try_open_default(None).unwrap();
104 }
105}