spin_sqlite_connection/
rusqlite.rs

1use std::{marker::PhantomData, str::from_utf8};
2
3use rusqlite::{params_from_iter, types::ValueRef, Error, Params};
4use spin_sdk::sqlite::{QueryResult, Row, RowResult, Value};
5
6pub struct SqliteConnection<E>
7where
8    E: From<Error>,
9{
10    inner: rusqlite::Connection,
11    phantomdata: PhantomData<E>,
12}
13
14impl<E: From<Error>> SqliteConnection<E> {
15    pub fn try_open_default(migrations: Option<&str>) -> Result<Self, E> {
16        let connection = rusqlite::Connection::open_in_memory()?;
17        if let Some(m) = migrations {
18            connection.execute_batch(m)?;
19        }
20        Ok(Self {
21            inner: connection,
22            phantomdata: PhantomData,
23        })
24    }
25
26    pub fn query<T>(&self, sql: impl AsRef<str>, parameters: &[Value]) -> Result<Vec<T>, E>
27    where
28        T: for<'a> TryFrom<Row<'a>, Error = E>,
29    {
30        self.query_result(sql, parameters)
31            .and_then(|query_result| query_result.rows().map(T::try_from).collect())
32    }
33
34    fn query_result<S>(&self, sql: S, parameters: &[Value]) -> Result<QueryResult, E>
35    where
36        S: AsRef<str>,
37    {
38        let mut prepared = self.inner.prepare(sql.as_ref())?;
39        let columns = prepared
40            .column_names()
41            .into_iter()
42            .map(String::from)
43            .collect::<Vec<_>>();
44        let rows = prepared
45            .query_map(rusqlite_parameters(parameters), |row| {
46                (0..columns.len())
47                    .map(|i| row.get_ref(i).and_then(spin_sqlite_value))
48                    .collect::<Result<Vec<_>, _>>()
49                    .map(|values| RowResult { values })
50            })
51            .and_then(|mapped_rows| mapped_rows.collect::<Result<Vec<_>, Error>>())?;
52        Ok(QueryResult { columns, rows })
53    }
54
55    pub fn execute<S>(&self, sql: S, parameters: &[Value]) -> Result<i64, E>
56    where
57        S: AsRef<str>,
58    {
59        let count = self
60            .inner
61            .execute(sql.as_ref(), rusqlite_parameters(parameters))?;
62        Ok(count.try_into().unwrap())
63    }
64}
65
66fn spin_sqlite_value(value: ValueRef) -> Result<spin_sdk::sqlite::Value, rusqlite::Error> {
67    match value {
68        ValueRef::Blob(blob) => Ok(Value::Blob(blob.to_vec())),
69        ValueRef::Integer(integer) => Ok(Value::Integer(integer)),
70        ValueRef::Real(real) => Ok(Value::Real(real)),
71        ValueRef::Null => Ok(Value::Null),
72        ValueRef::Text(text) => from_utf8(text)
73            .map(String::from)
74            .map(Value::Text)
75            .map_err(Error::Utf8Error),
76    }
77}
78
79fn rusqlite_parameter(parameter: &Value) -> rusqlite::types::Value {
80    match parameter {
81        Value::Blob(blob) => rusqlite::types::Value::Blob(blob.clone()),
82        Value::Integer(integer) => rusqlite::types::Value::Integer(*integer),
83        Value::Null => rusqlite::types::Value::Null,
84        Value::Real(real) => rusqlite::types::Value::Real(*real),
85        Value::Text(text) => rusqlite::types::Value::Text(text.clone()),
86    }
87}
88
89fn rusqlite_parameters(parameters: &[Value]) -> impl Params + use<> {
90    params_from_iter(
91        parameters
92            .iter()
93            .map(rusqlite_parameter)
94            .collect::<Vec<_>>(),
95    )
96}
97
98#[cfg(test)]
99mod test {
100
101    #[test]
102    fn open_database() {
103        super::SqliteConnection::<Box<dyn std::error::Error>>::try_open_default(None).unwrap();
104    }
105}