Skip to main content

sqlite_vector_rs/
lib.rs

1pub mod arrow_io;
2pub mod distance;
3pub mod index;
4pub mod json;
5pub mod scalar;
6pub mod types;
7pub mod vtab;
8
9#[cfg(feature = "loadable_extension")]
10use sqlite3_ext::*;
11
12/// Entry point for the loadable SQLite extension.
13#[cfg(feature = "loadable_extension")]
14#[sqlite3_ext_main(persistent)]
15fn sqlite3_extension_init(db: &Connection) -> Result<()> {
16    use sqlite3_ext::vtab::Module;
17    let module = sqlite3_ext::vtab::StandardModule::<vtab::VectorTable<'_>>::new()
18        .with_update()
19        .with_transactions()
20        .with_find_function();
21    db.create_module("vector", module, ())?;
22    scalar::register_scalar_functions(db)?;
23    Ok(())
24}
25
26/// Register the vector extension on a rusqlite connection.
27///
28/// Loads the companion shared library (`libsqlite_vector_rs.so` / `.dylib` / `.dll`),
29/// which provides the `vector` virtual table module and all scalar functions.
30///
31/// The shared library is located by searching (in order):
32///
33/// 1. The `SQLITE_VECTOR_RS_LIB` environment variable (path with or without extension)
34/// 2. The directory containing the current executable
35/// 3. `target/debug/` and `target/release/` relative to the working directory
36///
37/// # Errors
38///
39/// Returns an error if the shared library cannot be found or loaded.
40///
41/// # Example
42///
43/// ```no_run
44/// let conn = rusqlite::Connection::open_in_memory().unwrap();
45/// sqlite_vector_rs::register(&conn).unwrap();
46/// conn.execute_batch("
47///     CREATE VIRTUAL TABLE embeddings USING vector(
48///         dim=3, type=float4, metric=cosine
49///     );
50/// ").unwrap();
51/// ```
52#[cfg(feature = "library")]
53pub fn register(conn: &rusqlite::Connection) -> std::result::Result<(), rusqlite::Error> {
54    let path = find_extension_path().ok_or_else(|| {
55        rusqlite::Error::ModuleError(
56            "Could not find the sqlite-vector-rs extension library. \
57             Build with `cargo build` first, or set SQLITE_VECTOR_RS_LIB to the library path."
58                .into(),
59        )
60    })?;
61
62    unsafe {
63        conn.load_extension_enable()?;
64    }
65
66    let result = unsafe { conn.load_extension(&path, None::<&str>) };
67
68    let _ = conn.load_extension_disable();
69
70    result
71}
72
73#[cfg(feature = "library")]
74fn find_extension_path() -> Option<String> {
75    use std::path::Path;
76
77    let stem = if cfg!(target_os = "windows") {
78        "sqlite_vector_rs"
79    } else {
80        "libsqlite_vector_rs"
81    };
82
83    let extensions: &[&str] = if cfg!(target_os = "macos") {
84        &[".dylib"]
85    } else if cfg!(target_os = "windows") {
86        &[".dll"]
87    } else {
88        &[".so"]
89    };
90
91    // 1. Explicit path via environment variable
92    if let Ok(val) = std::env::var("SQLITE_VECTOR_RS_LIB") {
93        if Path::new(&val).exists() {
94            return Some(val);
95        }
96        if extensions.iter().any(|ext| Path::new(&format!("{val}{ext}")).exists()) {
97            return Some(val);
98        }
99    }
100
101    // 2. Adjacent to the current executable
102    if let Ok(exe) = std::env::current_exe()
103        && let Some(dir) = exe.parent()
104    {
105        let base = dir.join(stem);
106        let base_str = base.to_string_lossy();
107        if extensions
108            .iter()
109            .any(|ext| Path::new(&format!("{base_str}{ext}")).exists())
110        {
111            return Some(base_str.into_owned());
112        }
113    }
114
115    // 3. Cargo build output directories (development convenience)
116    for profile in &["debug", "release"] {
117        let base = format!("target/{profile}/{stem}");
118        if extensions
119            .iter()
120            .any(|ext| Path::new(&format!("{base}{ext}")).exists())
121        {
122            return Some(base);
123        }
124    }
125
126    None
127}