sqlint/connector/
sqlite.rs

1mod conversion;
2mod error;
3
4pub use rusqlite::{params_from_iter, version as sqlite_version};
5
6use super::IsolationLevel;
7use crate::{
8    ast::{Query, Value},
9    connector::{metrics, queryable::*, ResultSet},
10    error::{Error, ErrorKind},
11    visitor::{self, Visitor},
12};
13use async_trait::async_trait;
14use std::{convert::TryFrom, path::Path, time::Duration};
15use tokio::sync::Mutex;
16
17pub(crate) const DEFAULT_SQLITE_SCHEMA_NAME: &str = "main";
18
19/// The underlying sqlite driver. Only available with the `expose-drivers` Cargo feature.
20#[cfg(feature = "expose-drivers")]
21pub use rusqlite;
22
23/// A connector interface for the SQLite database
24#[cfg_attr(feature = "docs", doc(cfg(feature = "sqlite")))]
25pub struct Sqlite {
26    pub(crate) client: Mutex<rusqlite::Connection>,
27}
28
29/// Wraps a connection url and exposes the parsing logic used by Sqlint,
30/// including default values.
31#[derive(Debug)]
32#[cfg_attr(feature = "docs", doc(cfg(feature = "sqlite")))]
33pub struct SqliteParams {
34    pub connection_limit: Option<usize>,
35    /// This is not a `PathBuf` because we need to `ATTACH` the database to the path, and this can
36    /// only be done with UTF-8 paths.
37    pub file_path: String,
38    pub db_name: String,
39    pub socket_timeout: Option<Duration>,
40    pub max_connection_lifetime: Option<Duration>,
41    pub max_idle_connection_lifetime: Option<Duration>,
42}
43
44impl TryFrom<&str> for SqliteParams {
45    type Error = Error;
46
47    fn try_from(path: &str) -> crate::Result<Self> {
48        let path = if path.starts_with("file:") {
49            path.trim_start_matches("file:")
50        } else {
51            path.trim_start_matches("sqlite:")
52        };
53
54        let path_parts: Vec<&str> = path.split('?').collect();
55        let path_str = path_parts[0];
56        let path = Path::new(path_str);
57
58        if path.is_dir() {
59            Err(Error::builder(ErrorKind::DatabaseUrlIsInvalid(path.to_str().unwrap().to_string())).build())
60        } else {
61            let mut connection_limit = None;
62            let mut socket_timeout = None;
63            let mut max_connection_lifetime = None;
64            let mut max_idle_connection_lifetime = None;
65
66            if path_parts.len() > 1 {
67                let params = path_parts.last().unwrap().split('&').map(|kv| {
68                    let splitted: Vec<&str> = kv.split('=').collect();
69                    (splitted[0], splitted[1])
70                });
71
72                for (k, v) in params {
73                    match k {
74                        "connection_limit" => {
75                            let as_int: usize =
76                                v.parse().map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?;
77
78                            connection_limit = Some(as_int);
79                        }
80                        "socket_timeout" => {
81                            let as_int =
82                                v.parse().map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?;
83
84                            socket_timeout = Some(Duration::from_secs(as_int));
85                        }
86                        "max_connection_lifetime" => {
87                            let as_int =
88                                v.parse().map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?;
89
90                            if as_int == 0 {
91                                max_connection_lifetime = None;
92                            } else {
93                                max_connection_lifetime = Some(Duration::from_secs(as_int));
94                            }
95                        }
96                        "max_idle_connection_lifetime" => {
97                            let as_int =
98                                v.parse().map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?;
99
100                            if as_int == 0 {
101                                max_idle_connection_lifetime = None;
102                            } else {
103                                max_idle_connection_lifetime = Some(Duration::from_secs(as_int));
104                            }
105                        }
106                        _ => {
107                            tracing::trace!(message = "Discarding connection string param", param = k);
108                        }
109                    };
110                }
111            }
112
113            Ok(Self {
114                connection_limit,
115                file_path: path_str.to_owned(),
116                db_name: DEFAULT_SQLITE_SCHEMA_NAME.to_owned(),
117                socket_timeout,
118                max_connection_lifetime,
119                max_idle_connection_lifetime,
120            })
121        }
122    }
123}
124
125impl TryFrom<&str> for Sqlite {
126    type Error = Error;
127
128    fn try_from(path: &str) -> crate::Result<Self> {
129        let params = SqliteParams::try_from(path)?;
130        let file_path = params.file_path;
131
132        let conn = rusqlite::Connection::open(file_path.as_str())?;
133
134        if let Some(timeout) = params.socket_timeout {
135            conn.busy_timeout(timeout)?;
136        };
137
138        let client = Mutex::new(conn);
139
140        Ok(Sqlite { client })
141    }
142}
143
144impl Sqlite {
145    pub fn new(file_path: &str) -> crate::Result<Sqlite> {
146        Self::try_from(file_path)
147    }
148
149    /// Open a new SQLite database in memory.
150    pub fn new_in_memory() -> crate::Result<Sqlite> {
151        let client = rusqlite::Connection::open_in_memory()?;
152
153        Ok(Sqlite { client: Mutex::new(client) })
154    }
155
156    /// The underlying rusqlite::Connection. Only available with the `expose-drivers` Cargo
157    /// feature. This is a lower level API when you need to get into database specific features.
158    #[cfg(feature = "expose-drivers")]
159    pub fn connection(&self) -> &Mutex<rusqlite::Connection> {
160        &self.client
161    }
162}
163
164impl TransactionCapable for Sqlite {}
165
166#[async_trait]
167impl Queryable for Sqlite {
168    async fn query(&self, q: Query<'_>) -> crate::Result<ResultSet> {
169        let (sql, params) = visitor::Sqlite::build(q)?;
170        self.query_raw(&sql, &params).await
171    }
172
173    async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result<ResultSet> {
174        metrics::query("sqlite.query_raw", sql, params, move || async move {
175            let client = self.client.lock().await;
176
177            let mut stmt = client.prepare_cached(sql)?;
178
179            let mut rows = stmt.query(params_from_iter(params.iter()))?;
180            let mut result = ResultSet::new(rows.to_column_names(), Vec::new());
181
182            while let Some(row) = rows.next()? {
183                result.rows.push(row.get_result_row()?);
184            }
185
186            result.set_last_insert_id(u64::try_from(client.last_insert_rowid()).unwrap_or(0));
187
188            Ok(result)
189        })
190        .await
191    }
192
193    async fn query_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result<ResultSet> {
194        self.query_raw(sql, params).await
195    }
196
197    async fn execute(&self, q: Query<'_>) -> crate::Result<u64> {
198        let (sql, params) = visitor::Sqlite::build(q)?;
199        self.execute_raw(&sql, &params).await
200    }
201
202    async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result<u64> {
203        metrics::query("sqlite.query_raw", sql, params, move || async move {
204            let client = self.client.lock().await;
205            let mut stmt = client.prepare_cached(sql)?;
206            let res = u64::try_from(stmt.execute(params_from_iter(params.iter()))?)?;
207
208            Ok(res)
209        })
210        .await
211    }
212
213    async fn execute_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result<u64> {
214        self.execute_raw(sql, params).await
215    }
216
217    async fn raw_cmd(&self, cmd: &str) -> crate::Result<()> {
218        metrics::query("sqlite.raw_cmd", cmd, &[], move || async move {
219            let client = self.client.lock().await;
220            client.execute_batch(cmd)?;
221            Ok(())
222        })
223        .await
224    }
225
226    async fn version(&self) -> crate::Result<Option<String>> {
227        Ok(Some(rusqlite::version().into()))
228    }
229
230    fn is_healthy(&self) -> bool {
231        true
232    }
233
234    async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> crate::Result<()> {
235        // SQLite is always "serializable", other modes involve pragmas
236        // and shared cache mode, which is out of scope for now and should be implemented
237        // as part of a separate effort.
238        if !matches!(isolation_level, IsolationLevel::Serializable) {
239            let kind = ErrorKind::invalid_isolation_level(&isolation_level);
240            return Err(Error::builder(kind).build());
241        }
242
243        Ok(())
244    }
245
246    fn requires_isolation_first(&self) -> bool {
247        false
248    }
249}
250
251#[cfg(test)]
252mod tests {
253    use super::*;
254    use crate::{
255        ast::*,
256        connector::Queryable,
257        error::{ErrorKind, Name},
258    };
259
260    #[test]
261    fn sqlite_params_from_str_should_resolve_path_correctly_with_file_scheme() {
262        let path = "file:dev.db";
263        let params = SqliteParams::try_from(path).unwrap();
264        assert_eq!(params.file_path, "dev.db");
265    }
266
267    #[test]
268    fn sqlite_params_from_str_should_resolve_path_correctly_with_sqlite_scheme() {
269        let path = "sqlite:dev.db";
270        let params = SqliteParams::try_from(path).unwrap();
271        assert_eq!(params.file_path, "dev.db");
272    }
273
274    #[test]
275    fn sqlite_params_from_str_should_resolve_path_correctly_with_no_scheme() {
276        let path = "dev.db";
277        let params = SqliteParams::try_from(path).unwrap();
278        assert_eq!(params.file_path, "dev.db");
279    }
280
281    #[tokio::test]
282    async fn unknown_table_should_give_a_good_error() {
283        let conn = Sqlite::try_from("file:db/test.db").unwrap();
284        let select = Select::from_table("not_there");
285
286        let err = conn.select(select).await.unwrap_err();
287
288        match err.kind() {
289            ErrorKind::TableDoesNotExist { table } => {
290                assert_eq!(&Name::available("not_there"), table);
291            }
292            e => panic!("Expected error TableDoesNotExist, got {:?}", e),
293        }
294    }
295
296    #[tokio::test]
297    async fn in_memory_sqlite_works() {
298        let conn = Sqlite::new_in_memory().unwrap();
299
300        conn.raw_cmd("CREATE TABLE test (id INTEGER PRIMARY KEY, txt TEXT NOT NULL);").await.unwrap();
301
302        let insert = Insert::single_into("test").value("txt", "henlo");
303        conn.insert(insert.into()).await.unwrap();
304
305        let select = Select::from_table("test").value(asterisk());
306        let result = conn.select(select.clone()).await.unwrap();
307        let result = result.into_single().unwrap();
308
309        assert_eq!(result.get("id").unwrap(), &Value::int32(1));
310        assert_eq!(result.get("txt").unwrap(), &Value::text("henlo"));
311
312        // Check that we do get a separate, new database.
313        let other_conn = Sqlite::new_in_memory().unwrap();
314
315        let err = other_conn.select(select).await.unwrap_err();
316        assert!(matches!(err.kind(), ErrorKind::TableDoesNotExist { .. }));
317    }
318
319    #[tokio::test]
320    async fn quoting_in_returning_in_sqlite_works() {
321        let conn = Sqlite::new_in_memory().unwrap();
322
323        conn.raw_cmd("CREATE TABLE test (id  INTEGER PRIMARY KEY, `txt space` TEXT NOT NULL);").await.unwrap();
324
325        let insert = Insert::single_into("test").value("txt space", "henlo");
326        conn.insert(insert.into()).await.unwrap();
327
328        let select = Select::from_table("test").value(asterisk());
329        let result = conn.select(select.clone()).await.unwrap();
330        let result = result.into_single().unwrap();
331
332        assert_eq!(result.get("id").unwrap(), &Value::int32(1));
333        assert_eq!(result.get("txt space").unwrap(), &Value::text("henlo"));
334
335        let insert = Insert::single_into("test").value("txt space", "henlo");
336        let insert: Insert = Insert::from(insert).returning(["txt space"]);
337
338        let result = conn.insert(insert).await.unwrap();
339        let result = result.into_single().unwrap();
340
341        assert_eq!(result.get("txt space").unwrap(), &Value::text("henlo"));
342    }
343}