trailbase_sqlite/
extension.rs1use crate::Error;
2use std::path::PathBuf;
3
4#[allow(unsafe_code)]
5#[no_mangle]
6extern "C" fn init_trailbase_extensions(
7 db: *mut rusqlite::ffi::sqlite3,
8 _pz_err_msg: *mut *mut std::os::raw::c_char,
9 _p_api: *const rusqlite::ffi::sqlite3_api_routines,
10) -> ::std::os::raw::c_int {
11 unsafe {
13 sqlite_vec::sqlite3_vec_init();
14 }
15
16 let status = unsafe { trailbase_sqlean::define_init(db as *mut trailbase_sqlean::sqlite3) };
19 if status != 0 {
20 log::error!("Failed to load sqlean::define",);
21 return status;
22 }
23
24 return status;
25}
26
27#[allow(unsafe_code)]
28pub fn connect_sqlite(
29 path: Option<PathBuf>,
30 extensions: Option<Vec<PathBuf>>,
31) -> Result<rusqlite::Connection, Error> {
32 crate::schema::try_init_schemas();
33
34 let status = unsafe { rusqlite::ffi::sqlite3_auto_extension(Some(init_trailbase_extensions)) };
35 if status != 0 {
36 return Err(Error::Other("Failed to load extensions".into()));
37 }
38
39 let conn = trailbase_extension::sqlite3_extension_init(if let Some(p) = path {
40 use rusqlite::OpenFlags;
41 let flags = OpenFlags::SQLITE_OPEN_READ_WRITE
42 | OpenFlags::SQLITE_OPEN_CREATE
43 | OpenFlags::SQLITE_OPEN_NO_MUTEX;
44
45 rusqlite::Connection::open_with_flags(p, flags)?
46 } else {
47 rusqlite::Connection::open_in_memory()?
48 })?;
49 conn.busy_timeout(std::time::Duration::from_secs(10))?;
50
51 const CONFIG: &[&str] = &[
52 "PRAGMA busy_timeout = 10000",
53 "PRAGMA journal_mode = WAL",
54 "PRAGMA journal_size_limit = 200000000",
55 "PRAGMA synchronous = NORMAL",
57 "PRAGMA foreign_keys = ON",
58 "PRAGMA temp_store = MEMORY",
59 "PRAGMA cache_size = -16000",
60 "PRAGMA trusted_schema = OFF",
68 ];
69
70 for pragma in CONFIG {
72 let mut stmt = conn.prepare(pragma)?;
73 let mut rows = stmt.query([])?;
74 rows.next()?;
75 }
76
77 if let Some(extensions) = extensions {
78 for path in extensions {
79 unsafe { conn.load_extension(path, None)? }
80 }
81 }
82
83 conn.execute("PRAGMA optimize = 0x10002", ())?;
85
86 return Ok(conn);
87}
88
89#[cfg(test)]
90mod test {
91 use super::*;
92 use uuid::Uuid;
93
94 #[test]
95 fn test_connect_and_extensions() {
96 let conn = connect_sqlite(None, None).unwrap();
97
98 let row = conn
99 .query_row(
100 "SELECT (uuid_v7())",
101 (),
102 |row| -> rusqlite::Result<[u8; 16]> { row.get(0) },
103 )
104 .unwrap();
105
106 let uuid = Uuid::from_bytes(row);
107 assert_eq!(uuid.get_version_num(), 7);
108 assert!(trailbase_extension::jsonschema::get_schema("std.FileUpload").is_some());
109
110 conn
112 .query_row("SELECT define('sumn', ':n * (:n + 1) / 2')", (), |_row| {
113 Ok(())
114 })
115 .unwrap();
116
117 let value: i64 = conn
118 .query_row("SELECT sumn(5)", (), |row| row.get(0))
119 .unwrap();
120 assert_eq!(value, 15);
121
122 conn
123 .query_row("SELECT undefine('sumn')", (), |_row| Ok(()))
124 .unwrap();
125
126 conn
128 .query_row("SELECT vec_f32('[0, 1, 2, 3]')", (), |_row| Ok(()))
129 .unwrap();
130 }
131}