1#![forbid(clippy::unwrap_used)]
2#![allow(clippy::needless_return)]
3
4use rusqlite::functions::FunctionFlags;
5use std::path::PathBuf;
6
7pub mod geoip;
8pub mod jsonschema;
9pub mod password;
10
11mod regex;
12mod uuid;
13mod validators;
14
15#[derive(thiserror::Error, Debug)]
16pub enum Error {
17 #[error("Rusqlite error: {0}")]
18 Rusqlite(#[from] rusqlite::Error),
19 #[error("Other error: {0}")]
20 Other(#[source] Box<dyn std::error::Error + Send + Sync + 'static>),
21}
22
23pub fn apply_default_pragmas(conn: &rusqlite::Connection) -> Result<(), rusqlite::Error> {
24 const CONFIG: &[&str] = &[
25 "PRAGMA busy_timeout = 10000",
26 "PRAGMA journal_mode = WAL",
27 "PRAGMA journal_size_limit = 200000000",
28 "PRAGMA synchronous = NORMAL",
30 "PRAGMA foreign_keys = ON",
31 "PRAGMA temp_store = MEMORY",
32 "PRAGMA cache_size = -16000",
33 "PRAGMA trusted_schema = OFF",
41 ];
42
43 for pragma in CONFIG {
45 let mut stmt = conn.prepare(pragma)?;
47 let mut rows = stmt.query([])?;
48 let _maybe_row = rows.next()?;
49 }
50
51 return Ok(());
52}
53
54#[allow(unsafe_code)]
55pub fn connect_sqlite(
56 path: Option<PathBuf>,
57 extensions: Option<Vec<PathBuf>>,
58) -> Result<rusqlite::Connection, Error> {
59 let status =
61 unsafe { rusqlite::ffi::sqlite3_auto_extension(Some(init_sqlean_and_vector_search)) };
62 if status != 0 {
63 return Err(Error::Other("Failed to load extensions".into()));
64 }
65
66 let conn = sqlite3_extension_init(if let Some(p) = path {
68 use rusqlite::OpenFlags;
69 let flags = OpenFlags::SQLITE_OPEN_READ_WRITE
70 | OpenFlags::SQLITE_OPEN_CREATE
71 | OpenFlags::SQLITE_OPEN_NO_MUTEX;
72
73 rusqlite::Connection::open_with_flags(p, flags)?
74 } else {
75 rusqlite::Connection::open_in_memory()?
76 })?;
77
78 if let Some(extensions) = extensions {
80 for path in extensions {
81 unsafe { conn.load_extension::<_, &str>(path, None)? }
82 }
83 }
84
85 apply_default_pragmas(&conn)?;
86
87 conn.execute("PRAGMA optimize = 0x10002", ())?;
89
90 return Ok(conn);
91}
92
93pub fn sqlite3_extension_init(
94 db: rusqlite::Connection,
95) -> Result<rusqlite::Connection, rusqlite::Error> {
96 db.create_scalar_function(
101 "is_uuid",
102 1,
103 FunctionFlags::SQLITE_DETERMINISTIC | FunctionFlags::SQLITE_INNOCUOUS,
104 uuid::is_uuid,
105 )?;
106 db.create_scalar_function(
107 "is_uuid_v4",
108 1,
109 FunctionFlags::SQLITE_DETERMINISTIC | FunctionFlags::SQLITE_INNOCUOUS,
110 uuid::is_uuid_v4,
111 )?;
112 db.create_scalar_function("uuid_v4", 0, FunctionFlags::SQLITE_INNOCUOUS, uuid::uuid_v4)?;
113 db.create_scalar_function(
114 "is_uuid_v7",
115 1,
116 FunctionFlags::SQLITE_DETERMINISTIC | FunctionFlags::SQLITE_INNOCUOUS,
117 uuid::is_uuid_v7,
118 )?;
119 db.create_scalar_function("uuid_v7", 0, FunctionFlags::SQLITE_INNOCUOUS, uuid::uuid_v7)?;
120 db.create_scalar_function(
121 "uuid_text",
122 1,
123 FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_INNOCUOUS,
124 uuid::uuid_text,
125 )?;
126
127 db.create_scalar_function(
128 "uuid_parse",
129 1,
130 FunctionFlags::SQLITE_UTF8
131 | FunctionFlags::SQLITE_DETERMINISTIC
132 | FunctionFlags::SQLITE_INNOCUOUS,
133 uuid::uuid_parse,
134 )?;
135
136 db.create_scalar_function(
138 "hash_password",
139 1,
140 FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_INNOCUOUS,
141 password::hash_password_sqlite,
142 )?;
143
144 db.create_scalar_function(
146 "jsonschema_matches",
147 2,
148 FunctionFlags::SQLITE_UTF8
149 | FunctionFlags::SQLITE_DETERMINISTIC
150 | FunctionFlags::SQLITE_INNOCUOUS,
151 jsonschema::jsonschema_matches,
152 )?;
153 db.create_scalar_function(
155 "jsonschema",
156 2,
157 FunctionFlags::SQLITE_UTF8
158 | FunctionFlags::SQLITE_DETERMINISTIC
159 | FunctionFlags::SQLITE_INNOCUOUS,
160 jsonschema::jsonschema_by_name,
161 )?;
162 db.create_scalar_function(
163 "jsonschema",
164 3,
165 FunctionFlags::SQLITE_UTF8
166 | FunctionFlags::SQLITE_DETERMINISTIC
167 | FunctionFlags::SQLITE_INNOCUOUS,
168 jsonschema::jsonschema_by_name_with_extra_args,
169 )?;
170
171 db.create_scalar_function(
173 "regexp",
176 2,
177 FunctionFlags::SQLITE_UTF8
178 | FunctionFlags::SQLITE_DETERMINISTIC
179 | FunctionFlags::SQLITE_INNOCUOUS,
180 regex::regexp,
181 )?;
182 db.create_scalar_function(
183 "is_email",
184 1,
185 FunctionFlags::SQLITE_UTF8
186 | FunctionFlags::SQLITE_DETERMINISTIC
187 | FunctionFlags::SQLITE_INNOCUOUS,
188 validators::is_email,
189 )?;
190 db.create_scalar_function(
192 "is_json",
193 1,
194 FunctionFlags::SQLITE_UTF8
195 | FunctionFlags::SQLITE_DETERMINISTIC
196 | FunctionFlags::SQLITE_INNOCUOUS,
197 validators::is_json,
198 )?;
199
200 db.create_scalar_function(
201 "geoip_country",
202 1,
203 FunctionFlags::SQLITE_UTF8
204 | FunctionFlags::SQLITE_DETERMINISTIC
205 | FunctionFlags::SQLITE_INNOCUOUS,
206 geoip::geoip_country,
207 )?;
208 db.create_scalar_function(
209 "geoip_city_name",
210 1,
211 FunctionFlags::SQLITE_UTF8
212 | FunctionFlags::SQLITE_DETERMINISTIC
213 | FunctionFlags::SQLITE_INNOCUOUS,
214 geoip::geoip_city_name,
215 )?;
216 db.create_scalar_function(
217 "geoip_city_json",
218 1,
219 FunctionFlags::SQLITE_UTF8
220 | FunctionFlags::SQLITE_DETERMINISTIC
221 | FunctionFlags::SQLITE_INNOCUOUS,
222 geoip::geoip_city_json,
223 )?;
224
225 return Ok(db);
226}
227
228#[allow(unsafe_code)]
229#[unsafe(no_mangle)]
230extern "C" fn init_sqlean_and_vector_search(
231 db: *mut rusqlite::ffi::sqlite3,
232 _pz_err_msg: *mut *mut std::os::raw::c_char,
233 _p_api: *const rusqlite::ffi::sqlite3_api_routines,
234) -> ::std::os::raw::c_int {
235 unsafe {
237 sqlite_vec::sqlite3_vec_init();
238 }
239
240 let status = unsafe { trailbase_sqlean::define_init(db as *mut trailbase_sqlean::sqlite3) };
243 if status != 0 {
244 log::error!("Failed to load sqlean::define",);
245 return status;
246 }
247
248 return status;
249}
250
251#[cfg(test)]
252mod test {
253 use ::uuid::Uuid;
254 use rusqlite::Error;
255
256 use super::*;
257
258 #[test]
259 fn test_connect_and_extensions() {
260 let conn = connect_sqlite(None, None).unwrap();
261
262 let row = conn
263 .query_row("SELECT (uuid_v7())", (), |row| -> Result<[u8; 16], Error> {
264 row.get(0)
265 })
266 .unwrap();
267
268 let uuid = Uuid::from_bytes(row);
269 assert_eq!(uuid.get_version_num(), 7);
270
271 conn
273 .query_row("SELECT define('sumn', ':n * (:n + 1) / 2')", (), |_row| {
274 Ok(())
275 })
276 .unwrap();
277
278 let value: i64 = conn
279 .query_row("SELECT sumn(5)", (), |row| row.get(0))
280 .unwrap();
281 assert_eq!(value, 15);
282
283 conn
284 .query_row("SELECT undefine('sumn')", (), |_row| Ok(()))
285 .unwrap();
286
287 conn
289 .query_row("SELECT vec_f32('[0, 1, 2, 3]')", (), |_row| Ok(()))
290 .unwrap();
291 }
292
293 #[test]
294 fn test_uuids() {
295 let conn = connect_sqlite(None, None).unwrap();
296
297 conn
298 .execute(
299 r#"CREATE TABLE test (
300 id BLOB PRIMARY KEY NOT NULL CHECK(is_uuid_v7(id)) DEFAULT(uuid_v7()),
301 text TEXT
302 )"#,
303 (),
304 )
305 .unwrap();
306
307 assert!(
309 conn
310 .execute(
311 "INSERT INTO test (id) VALUES (?1) ",
312 rusqlite::params!(Uuid::new_v4().into_bytes())
313 )
314 .is_err()
315 );
316
317 let id = Uuid::now_v7();
319 assert!(
320 conn
321 .execute(
322 "INSERT INTO test (id) VALUES (?1) ",
323 rusqlite::params!(id.into_bytes())
324 )
325 .is_ok()
326 );
327
328 let read_id: Uuid = conn
329 .query_row("SELECT id FROM test LIMIT 1", [], |row| {
330 Ok(Uuid::from_bytes(row.get::<_, [u8; 16]>(0)?))
331 })
332 .unwrap();
333
334 assert_eq!(id, read_id);
335 }
336}