Skip to main content

sqlite_vector_rs/
scalar.rs

1use sqlite3_ext::function::FunctionOptions;
2use sqlite3_ext::query::ToParam;
3use sqlite3_ext::*;
4
5use crate::arrow_io;
6use crate::distance::{DistanceMetric, compute_distance};
7use crate::index::HnswIndex;
8use crate::json::{blob_to_json, json_to_blob};
9use crate::types::VectorType;
10use crate::vtab::shadow::ShadowOps;
11
12/// Register all standalone scalar functions on a connection.
13pub fn register_scalar_functions(db: &Connection) -> Result<()> {
14    // vector_distance(blob_a, blob_b, metric, type) -> REAL
15    db.create_scalar_function(
16        "vector_distance",
17        &FunctionOptions::default()
18            .set_n_args(4)
19            .set_deterministic(true),
20        |ctx, args| {
21            // Collect string args as owned values first to avoid borrow conflicts
22            // with the blob borrows that follow.
23            let metric_name = args[2].get_str()?.to_owned();
24            let type_name = args[3].get_str()?.to_owned();
25            let blob_a = args[0].get_blob()?.to_vec();
26            let blob_b = args[1].get_blob()?.to_vec();
27
28            let vtype =
29                VectorType::from_name(&type_name).map_err(|e| Error::Module(e.to_string()))?;
30            let metric = DistanceMetric::from_name(&metric_name)
31                .map_err(|e| Error::Module(e.to_string()))?;
32
33            let dim = blob_a.len() / vtype.element_size();
34            let dist = compute_distance(&blob_a, &blob_b, vtype, metric, dim)
35                .map_err(|e| Error::Module(e.to_string()))?;
36
37            ctx.set_result(dist)?;
38            Ok(())
39        },
40    )?;
41
42    // vector_from_json(json_text, type) -> BLOB
43    db.create_scalar_function(
44        "vector_from_json",
45        &FunctionOptions::default()
46            .set_n_args(2)
47            .set_deterministic(true),
48        |ctx, args| {
49            let json_text = args[0].get_str()?.to_owned();
50            let type_name = args[1].get_str()?.to_owned();
51
52            let vtype =
53                VectorType::from_name(&type_name).map_err(|e| Error::Module(e.to_string()))?;
54            let blob = json_to_blob(&json_text, vtype).map_err(|e| Error::Module(e.to_string()))?;
55
56            ctx.set_result(&blob[..])?;
57            Ok(())
58        },
59    )?;
60
61    // vector_to_json(blob, type) -> TEXT
62    db.create_scalar_function(
63        "vector_to_json",
64        &FunctionOptions::default()
65            .set_n_args(2)
66            .set_deterministic(true),
67        |ctx, args| {
68            let type_name = args[1].get_str()?.to_owned();
69            let blob = args[0].get_blob()?.to_vec();
70
71            let vtype =
72                VectorType::from_name(&type_name).map_err(|e| Error::Module(e.to_string()))?;
73            let json = blob_to_json(&blob, vtype).map_err(|e| Error::Module(e.to_string()))?;
74
75            // Pass owned String — ToContextResult is implemented for String
76            ctx.set_result(json)?;
77            Ok(())
78        },
79    )?;
80
81    // vector_dims(blob, type) -> INTEGER
82    db.create_scalar_function(
83        "vector_dims",
84        &FunctionOptions::default()
85            .set_n_args(2)
86            .set_deterministic(true),
87        |ctx, args| {
88            let type_name = args[1].get_str()?.to_owned();
89            let blob = args[0].get_blob()?;
90
91            let vtype =
92                VectorType::from_name(&type_name).map_err(|e| Error::Module(e.to_string()))?;
93            let dims = blob.len() / vtype.element_size();
94
95            ctx.set_result(dims as i64)?;
96            Ok(())
97        },
98    )?;
99
100    // knn_match(col, query_blob) — placeholder for xFindFunction override.
101    // The global version is a no-op; the vtab's FindFunctionVTab replaces it
102    // when the first argument is a virtual table column.
103    db.create_scalar_function(
104        "knn_match",
105        &FunctionOptions::default().set_n_args(2),
106        |ctx, _args| {
107            ctx.set_result(1i32)?;
108            Ok(())
109        },
110    )?;
111
112    // vector_rebuild_index(table_name, type, metric) -> INTEGER (row count)
113    //
114    // Reads all vectors from the shadow data table, builds a fresh HNSW index,
115    // and serializes it back to the shadow index table. Returns the number of
116    // vectors indexed.
117    //
118    // NOTE: This writes directly to shadow tables, bypassing the vtab's
119    // in-memory index. A running vtab won't see the rebuilt index until
120    // reconnect. Intended for offline maintenance, not live use.
121    db.create_scalar_function(
122        "vector_rebuild_index",
123        &FunctionOptions::default().set_n_args(3),
124        |ctx, args| {
125            let table_name = args[0].get_str()?.to_owned();
126            let type_name = args[1].get_str()?.to_owned();
127            let metric_name = args[2].get_str()?.to_owned();
128
129            let vtype =
130                VectorType::from_name(&type_name).map_err(|e| Error::Module(e.to_string()))?;
131            let metric = DistanceMetric::from_name(&metric_name)
132                .map_err(|e| Error::Module(e.to_string()))?;
133
134            let db = ctx.db();
135
136            // Read all (rowid, vector_blob) pairs from the data shadow table.
137            let sql = ShadowOps::select_all_data_sql(&table_name);
138            let mut stmt = db.prepare(&sql)?;
139            stmt.query(())?;
140
141            let mut rows: Vec<(i64, Vec<u8>)> = Vec::new();
142            while let Some(row) = stmt.next()? {
143                let id = row[0].get_i64();
144                let blob = row[1].get_blob()?.to_vec();
145                rows.push((id, blob));
146            }
147
148            if rows.is_empty() {
149                ctx.set_result(0i64)?;
150                return Ok(());
151            }
152
153            // Infer dimension from the first vector blob.
154            let dim = rows[0].1.len() / vtype.element_size();
155
156            // Build a fresh index and insert every vector.
157            let index = HnswIndex::new(dim, vtype, metric, None)
158                .map_err(|e| Error::Module(e.to_string()))?;
159            for (id, blob) in &rows {
160                index
161                    .add(*id as u64, blob)
162                    .map_err(|e| Error::Module(e.to_string()))?;
163            }
164
165            // Serialize and persist to the _index shadow table.
166            let buf = index
167                .save_to_buffer()
168                .map_err(|e| Error::Module(e.to_string()))?;
169            let upsert_sql = ShadowOps::upsert_index_sql(&table_name);
170            db.insert(&upsert_sql, |stmt: &mut query::Statement| {
171                "hnsw_graph".bind_param(&mut *stmt, 1)?;
172                buf.as_slice().bind_param(&mut *stmt, 2)?;
173                Ok(())
174            })?;
175
176            ctx.set_result(rows.len() as i64)?;
177            Ok(())
178        },
179    )?;
180
181    // vector_export_arrow(table_name, type) -> BLOB (Arrow IPC stream)
182    //
183    // Exports all vectors from the shadow data table as an Arrow IPC byte
184    // buffer. The caller must supply the element type so blobs are decoded
185    // correctly.
186    db.create_scalar_function(
187        "vector_export_arrow",
188        &FunctionOptions::default().set_n_args(2),
189        |ctx, args| {
190            let table_name = args[0].get_str()?.to_owned();
191            let type_name = args[1].get_str()?.to_owned();
192
193            let vtype =
194                VectorType::from_name(&type_name).map_err(|e| Error::Module(e.to_string()))?;
195
196            let db = ctx.db();
197
198            // Collect all vector blobs from the data shadow table.
199            let sql = ShadowOps::select_all_data_sql(&table_name);
200            let mut stmt = db.prepare(&sql)?;
201            stmt.query(())?;
202
203            let mut blobs: Vec<Vec<u8>> = Vec::new();
204            while let Some(row) = stmt.next()? {
205                blobs.push(row[1].get_blob()?.to_vec());
206            }
207
208            if blobs.is_empty() {
209                // Return an empty blob for an empty table.
210                let empty: &[u8] = &[];
211                ctx.set_result(empty)?;
212                return Ok(());
213            }
214
215            let dim = blobs[0].len() / vtype.element_size();
216            let ipc = arrow_io::vectors_to_arrow_ipc(&blobs, vtype, dim)
217                .map_err(|e| Error::Module(e.to_string()))?;
218
219            ctx.set_result(&ipc[..])?;
220            Ok(())
221        },
222    )?;
223
224    // vector_insert_arrow(table_name, type, arrow_ipc_blob) -> INTEGER (row count)
225    //
226    // Imports vectors from an Arrow IPC blob into the shadow data table,
227    // adding one row per vector. Returns the number of rows inserted.
228    // Only inserts the vector column; metadata columns get NULL defaults.
229    //
230    // NOTE: Inserts directly into the shadow table, bypassing the in-memory
231    // HNSW index. Call vector_rebuild_index afterwards to sync the index.
232    db.create_scalar_function(
233        "vector_insert_arrow",
234        &FunctionOptions::default().set_n_args(3),
235        |ctx, args| {
236            let table_name = args[0].get_str()?.to_owned();
237            let type_name = args[1].get_str()?.to_owned();
238            let ipc_blob = args[2].get_blob()?.to_vec();
239
240            let vtype =
241                VectorType::from_name(&type_name).map_err(|e| Error::Module(e.to_string()))?;
242
243            if ipc_blob.is_empty() {
244                ctx.set_result(0i64)?;
245                return Ok(());
246            }
247
248            // Decode the Arrow IPC stream. We need the dimension, which we
249            // infer from the first decoded vector.
250            let blobs = arrow_io::arrow_ipc_to_vectors(&ipc_blob, vtype, 0)
251                .map_err(|e| Error::Module(e.to_string()))?;
252
253            if blobs.is_empty() {
254                ctx.set_result(0i64)?;
255                return Ok(());
256            }
257
258            let db = ctx.db();
259            let insert_sql = ShadowOps::insert_vector_only_sql(&table_name);
260            for blob in &blobs {
261                db.insert(&insert_sql, [blob.as_slice()])?;
262            }
263
264            ctx.set_result(blobs.len() as i64)?;
265            Ok(())
266        },
267    )?;
268
269    Ok(())
270}