trailbase_extension/
lib.rs

1#![forbid(clippy::unwrap_used)]
2#![allow(clippy::needless_return)]
3
4use rusqlite::functions::FunctionFlags;
5use std::path::PathBuf;
6
7pub mod geoip;
8pub mod jsonschema;
9pub mod password;
10
11mod regex;
12mod uuid;
13mod validators;
14
15#[derive(thiserror::Error, Debug)]
16pub enum Error {
17  #[error("Rusqlite error: {0}")]
18  Rusqlite(#[from] rusqlite::Error),
19  #[error("Other error: {0}")]
20  Other(#[source] Box<dyn std::error::Error + Send + Sync + 'static>),
21}
22
23pub fn apply_default_pragmas(conn: &rusqlite::Connection) -> Result<(), rusqlite::Error> {
24  const CONFIG: &[&str] = &[
25    "PRAGMA busy_timeout       = 10000",
26    "PRAGMA journal_mode       = WAL",
27    "PRAGMA journal_size_limit = 200000000",
28    // Sync the file system less often.
29    "PRAGMA synchronous        = NORMAL",
30    "PRAGMA foreign_keys       = ON",
31    "PRAGMA temp_store         = MEMORY",
32    "PRAGMA cache_size         = -16000",
33    // TODO: Maybe worth exploring once we have a benchmark, based on
34    // https://phiresky.github.io/blog/2020/sqlite-performance-tuning/.
35    // "PRAGMA mmap_size          = 30000000000",
36    // "PRAGMA page_size          = 32768",
37
38    // Safety feature around application-defined functions recommended by
39    // https://sqlite.org/appfunc.html
40    "PRAGMA trusted_schema     = OFF",
41  ];
42
43  // NOTE: we're querying here since some pragmas return data.
44  for pragma in CONFIG {
45    // TODO: Use conn.pragma_update instead.
46    let mut stmt = conn.prepare(pragma)?;
47    let mut rows = stmt.query([])?;
48    let _maybe_row = rows.next()?;
49  }
50
51  return Ok(());
52}
53
54#[allow(unsafe_code)]
55pub fn connect_sqlite(
56  path: Option<PathBuf>,
57  extensions: Option<Vec<PathBuf>>,
58) -> Result<rusqlite::Connection, Error> {
59  // First load C extensions like sqlean and vector search.
60  let status =
61    unsafe { rusqlite::ffi::sqlite3_auto_extension(Some(init_sqlean_and_vector_search)) };
62  if status != 0 {
63    return Err(Error::Other("Failed to load extensions".into()));
64  }
65
66  // Then open database and load trailbase_extensions.
67  let conn = sqlite3_extension_init(if let Some(p) = path {
68    use rusqlite::OpenFlags;
69    let flags = OpenFlags::SQLITE_OPEN_READ_WRITE
70      | OpenFlags::SQLITE_OPEN_CREATE
71      | OpenFlags::SQLITE_OPEN_NO_MUTEX;
72
73    rusqlite::Connection::open_with_flags(p, flags)?
74  } else {
75    rusqlite::Connection::open_in_memory()?
76  })?;
77
78  // Load user-provided extensions.
79  if let Some(extensions) = extensions {
80    for path in extensions {
81      unsafe { conn.load_extension::<_, &str>(path, None)? }
82    }
83  }
84
85  apply_default_pragmas(&conn)?;
86
87  // Initial optimize.
88  conn.execute("PRAGMA optimize = 0x10002", ())?;
89
90  return Ok(conn);
91}
92
93pub fn sqlite3_extension_init(
94  db: rusqlite::Connection,
95) -> Result<rusqlite::Connection, rusqlite::Error> {
96  // WARN: Be careful with declaring INNOCUOUS. This allows these "app-defined functions" to run
97  // even when "trusted_schema=OFF", which means as part of: VIEWs, TRIGGERs, CHECK, DEFAULT,
98  // GENERATED cols, ... as opposed to just top-level SELECTs.
99
100  db.create_scalar_function(
101    "is_uuid",
102    1,
103    FunctionFlags::SQLITE_DETERMINISTIC | FunctionFlags::SQLITE_INNOCUOUS,
104    uuid::is_uuid,
105  )?;
106  db.create_scalar_function(
107    "is_uuid_v4",
108    1,
109    FunctionFlags::SQLITE_DETERMINISTIC | FunctionFlags::SQLITE_INNOCUOUS,
110    uuid::is_uuid_v4,
111  )?;
112  db.create_scalar_function("uuid_v4", 0, FunctionFlags::SQLITE_INNOCUOUS, uuid::uuid_v4)?;
113  db.create_scalar_function(
114    "is_uuid_v7",
115    1,
116    FunctionFlags::SQLITE_DETERMINISTIC | FunctionFlags::SQLITE_INNOCUOUS,
117    uuid::is_uuid_v7,
118  )?;
119  db.create_scalar_function("uuid_v7", 0, FunctionFlags::SQLITE_INNOCUOUS, uuid::uuid_v7)?;
120  db.create_scalar_function(
121    "uuid_text",
122    1,
123    FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_INNOCUOUS,
124    uuid::uuid_text,
125  )?;
126
127  db.create_scalar_function(
128    "uuid_parse",
129    1,
130    FunctionFlags::SQLITE_UTF8
131      | FunctionFlags::SQLITE_DETERMINISTIC
132      | FunctionFlags::SQLITE_INNOCUOUS,
133    uuid::uuid_parse,
134  )?;
135
136  // Used to create initial user credentials in migrations.
137  db.create_scalar_function(
138    "hash_password",
139    1,
140    FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_INNOCUOUS,
141    password::hash_password_sqlite,
142  )?;
143
144  // Match column against given JSON schema, e.g. jsonschema_matches(col, '<schema>').
145  db.create_scalar_function(
146    "jsonschema_matches",
147    2,
148    FunctionFlags::SQLITE_UTF8
149      | FunctionFlags::SQLITE_DETERMINISTIC
150      | FunctionFlags::SQLITE_INNOCUOUS,
151    jsonschema::jsonschema_matches,
152  )?;
153  // Match column against registered JSON schema by name, e.g. jsonschema(col, 'schema-name').
154  db.create_scalar_function(
155    "jsonschema",
156    2,
157    FunctionFlags::SQLITE_UTF8
158      | FunctionFlags::SQLITE_DETERMINISTIC
159      | FunctionFlags::SQLITE_INNOCUOUS,
160    jsonschema::jsonschema_by_name,
161  )?;
162  db.create_scalar_function(
163    "jsonschema",
164    3,
165    FunctionFlags::SQLITE_UTF8
166      | FunctionFlags::SQLITE_DETERMINISTIC
167      | FunctionFlags::SQLITE_INNOCUOUS,
168    jsonschema::jsonschema_by_name_with_extra_args,
169  )?;
170
171  // Validators for CHECK constraints.
172  db.create_scalar_function(
173    // NOTE: the name needs to be "regexp" to be picked up by sqlites REGEXP matcher:
174    // https://www.sqlite.org/lang_expr.html
175    "regexp",
176    2,
177    FunctionFlags::SQLITE_UTF8
178      | FunctionFlags::SQLITE_DETERMINISTIC
179      | FunctionFlags::SQLITE_INNOCUOUS,
180    regex::regexp,
181  )?;
182  db.create_scalar_function(
183    "is_email",
184    1,
185    FunctionFlags::SQLITE_UTF8
186      | FunctionFlags::SQLITE_DETERMINISTIC
187      | FunctionFlags::SQLITE_INNOCUOUS,
188    validators::is_email,
189  )?;
190  // NOTE: there's also https://sqlite.org/json1.html#jvalid
191  db.create_scalar_function(
192    "is_json",
193    1,
194    FunctionFlags::SQLITE_UTF8
195      | FunctionFlags::SQLITE_DETERMINISTIC
196      | FunctionFlags::SQLITE_INNOCUOUS,
197    validators::is_json,
198  )?;
199
200  db.create_scalar_function(
201    "geoip_country",
202    1,
203    FunctionFlags::SQLITE_UTF8
204      | FunctionFlags::SQLITE_DETERMINISTIC
205      | FunctionFlags::SQLITE_INNOCUOUS,
206    geoip::geoip_country,
207  )?;
208  db.create_scalar_function(
209    "geoip_city_name",
210    1,
211    FunctionFlags::SQLITE_UTF8
212      | FunctionFlags::SQLITE_DETERMINISTIC
213      | FunctionFlags::SQLITE_INNOCUOUS,
214    geoip::geoip_city_name,
215  )?;
216  db.create_scalar_function(
217    "geoip_city_json",
218    1,
219    FunctionFlags::SQLITE_UTF8
220      | FunctionFlags::SQLITE_DETERMINISTIC
221      | FunctionFlags::SQLITE_INNOCUOUS,
222    geoip::geoip_city_json,
223  )?;
224
225  return Ok(db);
226}
227
228#[allow(unsafe_code)]
229#[unsafe(no_mangle)]
230extern "C" fn init_sqlean_and_vector_search(
231  db: *mut rusqlite::ffi::sqlite3,
232  _pz_err_msg: *mut *mut std::os::raw::c_char,
233  _p_api: *const rusqlite::ffi::sqlite3_api_routines,
234) -> ::std::os::raw::c_int {
235  // Add sqlite-vec extension.
236  unsafe {
237    sqlite_vec::sqlite3_vec_init();
238  }
239
240  // Init sqlean's stored procedures: "define", see:
241  //   https://github.com/nalgeon/sqlean/blob/main/docs/define.md
242  let status = unsafe { trailbase_sqlean::define_init(db as *mut trailbase_sqlean::sqlite3) };
243  if status != 0 {
244    log::error!("Failed to load sqlean::define",);
245    return status;
246  }
247
248  return status;
249}
250
251#[cfg(test)]
252mod test {
253  use ::uuid::Uuid;
254  use rusqlite::Error;
255
256  use super::*;
257
258  #[test]
259  fn test_connect_and_extensions() {
260    let conn = connect_sqlite(None, None).unwrap();
261
262    let row = conn
263      .query_row("SELECT (uuid_v7())", (), |row| -> Result<[u8; 16], Error> {
264        row.get(0)
265      })
266      .unwrap();
267
268    let uuid = Uuid::from_bytes(row);
269    assert_eq!(uuid.get_version_num(), 7);
270
271    // sqlean: Define a stored procedure, use it, and remove it.
272    conn
273      .query_row("SELECT define('sumn', ':n * (:n + 1) / 2')", (), |_row| {
274        Ok(())
275      })
276      .unwrap();
277
278    let value: i64 = conn
279      .query_row("SELECT sumn(5)", (), |row| row.get(0))
280      .unwrap();
281    assert_eq!(value, 15);
282
283    conn
284      .query_row("SELECT undefine('sumn')", (), |_row| Ok(()))
285      .unwrap();
286
287    // sqlite-vec
288    conn
289      .query_row("SELECT vec_f32('[0, 1, 2, 3]')", (), |_row| Ok(()))
290      .unwrap();
291  }
292
293  #[test]
294  fn test_uuids() {
295    let conn = connect_sqlite(None, None).unwrap();
296
297    conn
298      .execute(
299        r#"CREATE TABLE test (
300        id    BLOB PRIMARY KEY NOT NULL CHECK(is_uuid_v7(id)) DEFAULT(uuid_v7()),
301        text  TEXT
302      )"#,
303        (),
304      )
305      .unwrap();
306
307    // V4 fails
308    assert!(
309      conn
310        .execute(
311          "INSERT INTO test (id) VALUES (?1) ",
312          rusqlite::params!(Uuid::new_v4().into_bytes())
313        )
314        .is_err()
315    );
316
317    // V7 succeeds
318    let id = Uuid::now_v7();
319    assert!(
320      conn
321        .execute(
322          "INSERT INTO test (id) VALUES (?1) ",
323          rusqlite::params!(id.into_bytes())
324        )
325        .is_ok()
326    );
327
328    let read_id: Uuid = conn
329      .query_row("SELECT id FROM test LIMIT 1", [], |row| {
330        Ok(Uuid::from_bytes(row.get::<_, [u8; 16]>(0)?))
331      })
332      .unwrap();
333
334    assert_eq!(id, read_id);
335  }
336}