trailbase_sqlite/
extension.rs

1use crate::Error;
2use std::path::PathBuf;
3
4#[allow(unsafe_code)]
5#[no_mangle]
6extern "C" fn init_trailbase_extensions(
7  db: *mut rusqlite::ffi::sqlite3,
8  _pz_err_msg: *mut *mut std::os::raw::c_char,
9  _p_api: *const rusqlite::ffi::sqlite3_api_routines,
10) -> ::std::os::raw::c_int {
11  // Add sqlite-vec extension.
12  unsafe {
13    sqlite_vec::sqlite3_vec_init();
14  }
15
16  // Init sqlean's stored procedures: "define", see:
17  //   https://github.com/nalgeon/sqlean/blob/main/docs/define.md
18  let status = unsafe { trailbase_sqlean::define_init(db as *mut trailbase_sqlean::sqlite3) };
19  if status != 0 {
20    log::error!("Failed to load sqlean::define",);
21    return status;
22  }
23
24  return status;
25}
26
27#[allow(unsafe_code)]
28pub fn connect_sqlite(
29  path: Option<PathBuf>,
30  extensions: Option<Vec<PathBuf>>,
31) -> Result<rusqlite::Connection, Error> {
32  crate::schema::try_init_schemas();
33
34  let status = unsafe { rusqlite::ffi::sqlite3_auto_extension(Some(init_trailbase_extensions)) };
35  if status != 0 {
36    return Err(Error::Other("Failed to load extensions".into()));
37  }
38
39  let conn = trailbase_extension::sqlite3_extension_init(if let Some(p) = path {
40    use rusqlite::OpenFlags;
41    let flags = OpenFlags::SQLITE_OPEN_READ_WRITE
42      | OpenFlags::SQLITE_OPEN_CREATE
43      | OpenFlags::SQLITE_OPEN_NO_MUTEX;
44
45    rusqlite::Connection::open_with_flags(p, flags)?
46  } else {
47    rusqlite::Connection::open_in_memory()?
48  })?;
49  conn.busy_timeout(std::time::Duration::from_secs(10))?;
50
51  const CONFIG: &[&str] = &[
52    "PRAGMA busy_timeout       = 10000",
53    "PRAGMA journal_mode       = WAL",
54    "PRAGMA journal_size_limit = 200000000",
55    // Sync the file system less often.
56    "PRAGMA synchronous        = NORMAL",
57    "PRAGMA foreign_keys       = ON",
58    "PRAGMA temp_store         = MEMORY",
59    "PRAGMA cache_size         = -16000",
60    // TODO: Maybe worth exploring once we have a benchmark, based on
61    // https://phiresky.github.io/blog/2020/sqlite-performance-tuning/.
62    // "PRAGMA mmap_size          = 30000000000",
63    // "PRAGMA page_size          = 32768",
64
65    // Safety feature around application-defined functions recommended by
66    // https://sqlite.org/appfunc.html
67    "PRAGMA trusted_schema     = OFF",
68  ];
69
70  // NOTE: we're querying here since some pragmas return data.
71  for pragma in CONFIG {
72    let mut stmt = conn.prepare(pragma)?;
73    let mut rows = stmt.query([])?;
74    rows.next()?;
75  }
76
77  if let Some(extensions) = extensions {
78    for path in extensions {
79      unsafe { conn.load_extension(path, None)? }
80    }
81  }
82
83  // Initial optimize.
84  conn.execute("PRAGMA optimize = 0x10002", ())?;
85
86  return Ok(conn);
87}
88
89#[cfg(test)]
90mod test {
91  use super::*;
92  use uuid::Uuid;
93
94  #[test]
95  fn test_connect_and_extensions() {
96    let conn = connect_sqlite(None, None).unwrap();
97
98    let row = conn
99      .query_row(
100        "SELECT (uuid_v7())",
101        (),
102        |row| -> rusqlite::Result<[u8; 16]> { row.get(0) },
103      )
104      .unwrap();
105
106    let uuid = Uuid::from_bytes(row);
107    assert_eq!(uuid.get_version_num(), 7);
108    assert!(trailbase_extension::jsonschema::get_schema("std.FileUpload").is_some());
109
110    // sqlean: Define a stored procedure, use it, and remove it.
111    conn
112      .query_row("SELECT define('sumn', ':n * (:n + 1) / 2')", (), |_row| {
113        Ok(())
114      })
115      .unwrap();
116
117    let value: i64 = conn
118      .query_row("SELECT sumn(5)", (), |row| row.get(0))
119      .unwrap();
120    assert_eq!(value, 15);
121
122    conn
123      .query_row("SELECT undefine('sumn')", (), |_row| Ok(()))
124      .unwrap();
125
126    // sqlite-vec
127    conn
128      .query_row("SELECT vec_f32('[0, 1, 2, 3]')", (), |_row| Ok(()))
129      .unwrap();
130  }
131}