1pub mod arrow_io;
2pub mod distance;
3pub mod index;
4pub mod json;
5pub mod scalar;
6pub mod types;
7pub mod vtab;
8
9#[cfg(feature = "loadable_extension")]
10use sqlite3_ext::*;
11
12#[cfg(feature = "loadable_extension")]
14#[sqlite3_ext_main(persistent)]
15fn sqlite3_extension_init(db: &Connection) -> Result<()> {
16 use sqlite3_ext::vtab::Module;
17 let module = sqlite3_ext::vtab::StandardModule::<vtab::VectorTable<'_>>::new()
18 .with_update()
19 .with_transactions()
20 .with_find_function();
21 db.create_module("vector", module, ())?;
22 scalar::register_scalar_functions(db)?;
23 Ok(())
24}
25
26#[cfg(feature = "library")]
53pub fn register(conn: &rusqlite::Connection) -> std::result::Result<(), rusqlite::Error> {
54 let path = find_extension_path().ok_or_else(|| {
55 rusqlite::Error::ModuleError(
56 "Could not find the sqlite-vector-rs extension library. \
57 Build with `cargo build` first, or set SQLITE_VECTOR_RS_LIB to the library path."
58 .into(),
59 )
60 })?;
61
62 unsafe {
63 conn.load_extension_enable()?;
64 }
65
66 let result = unsafe { conn.load_extension(&path, None::<&str>) };
67
68 let _ = conn.load_extension_disable();
69
70 result
71}
72
73#[cfg(feature = "library")]
74fn find_extension_path() -> Option<String> {
75 use std::path::Path;
76
77 let stem = if cfg!(target_os = "windows") {
78 "sqlite_vector_rs"
79 } else {
80 "libsqlite_vector_rs"
81 };
82
83 let extensions: &[&str] = if cfg!(target_os = "macos") {
84 &[".dylib"]
85 } else if cfg!(target_os = "windows") {
86 &[".dll"]
87 } else {
88 &[".so"]
89 };
90
91 if let Ok(val) = std::env::var("SQLITE_VECTOR_RS_LIB") {
93 if Path::new(&val).exists() {
94 return Some(val);
95 }
96 if extensions.iter().any(|ext| Path::new(&format!("{val}{ext}")).exists()) {
97 return Some(val);
98 }
99 }
100
101 if let Ok(exe) = std::env::current_exe()
103 && let Some(dir) = exe.parent()
104 {
105 let base = dir.join(stem);
106 let base_str = base.to_string_lossy();
107 if extensions
108 .iter()
109 .any(|ext| Path::new(&format!("{base_str}{ext}")).exists())
110 {
111 return Some(base_str.into_owned());
112 }
113 }
114
115 for profile in &["debug", "release"] {
117 let base = format!("target/{profile}/{stem}");
118 if extensions
119 .iter()
120 .any(|ext| Path::new(&format!("{base}{ext}")).exists())
121 {
122 return Some(base);
123 }
124 }
125
126 None
127}