sqlite_vector_rs/
scalar.rs1use sqlite3_ext::function::FunctionOptions;
2use sqlite3_ext::query::ToParam;
3use sqlite3_ext::*;
4
5use crate::arrow_io;
6use crate::distance::{DistanceMetric, compute_distance};
7use crate::index::HnswIndex;
8use crate::json::{blob_to_json, json_to_blob};
9use crate::types::VectorType;
10use crate::vtab::shadow::ShadowOps;
11
12pub fn register_scalar_functions(db: &Connection) -> Result<()> {
14 db.create_scalar_function(
16 "vector_distance",
17 &FunctionOptions::default()
18 .set_n_args(4)
19 .set_deterministic(true),
20 |ctx, args| {
21 let metric_name = args[2].get_str()?.to_owned();
24 let type_name = args[3].get_str()?.to_owned();
25 let blob_a = args[0].get_blob()?.to_vec();
26 let blob_b = args[1].get_blob()?.to_vec();
27
28 let vtype =
29 VectorType::from_name(&type_name).map_err(|e| Error::Module(e.to_string()))?;
30 let metric = DistanceMetric::from_name(&metric_name)
31 .map_err(|e| Error::Module(e.to_string()))?;
32
33 let dim = blob_a.len() / vtype.element_size();
34 let dist = compute_distance(&blob_a, &blob_b, vtype, metric, dim)
35 .map_err(|e| Error::Module(e.to_string()))?;
36
37 ctx.set_result(dist)?;
38 Ok(())
39 },
40 )?;
41
42 db.create_scalar_function(
44 "vector_from_json",
45 &FunctionOptions::default()
46 .set_n_args(2)
47 .set_deterministic(true),
48 |ctx, args| {
49 let json_text = args[0].get_str()?.to_owned();
50 let type_name = args[1].get_str()?.to_owned();
51
52 let vtype =
53 VectorType::from_name(&type_name).map_err(|e| Error::Module(e.to_string()))?;
54 let blob = json_to_blob(&json_text, vtype).map_err(|e| Error::Module(e.to_string()))?;
55
56 ctx.set_result(&blob[..])?;
57 Ok(())
58 },
59 )?;
60
61 db.create_scalar_function(
63 "vector_to_json",
64 &FunctionOptions::default()
65 .set_n_args(2)
66 .set_deterministic(true),
67 |ctx, args| {
68 let type_name = args[1].get_str()?.to_owned();
69 let blob = args[0].get_blob()?.to_vec();
70
71 let vtype =
72 VectorType::from_name(&type_name).map_err(|e| Error::Module(e.to_string()))?;
73 let json = blob_to_json(&blob, vtype).map_err(|e| Error::Module(e.to_string()))?;
74
75 ctx.set_result(json)?;
77 Ok(())
78 },
79 )?;
80
81 db.create_scalar_function(
83 "vector_dims",
84 &FunctionOptions::default()
85 .set_n_args(2)
86 .set_deterministic(true),
87 |ctx, args| {
88 let type_name = args[1].get_str()?.to_owned();
89 let blob = args[0].get_blob()?;
90
91 let vtype =
92 VectorType::from_name(&type_name).map_err(|e| Error::Module(e.to_string()))?;
93 let dims = blob.len() / vtype.element_size();
94
95 ctx.set_result(dims as i64)?;
96 Ok(())
97 },
98 )?;
99
100 db.create_scalar_function(
104 "knn_match",
105 &FunctionOptions::default().set_n_args(2),
106 |ctx, _args| {
107 ctx.set_result(1i32)?;
108 Ok(())
109 },
110 )?;
111
112 db.create_scalar_function(
122 "vector_rebuild_index",
123 &FunctionOptions::default().set_n_args(3),
124 |ctx, args| {
125 let table_name = args[0].get_str()?.to_owned();
126 let type_name = args[1].get_str()?.to_owned();
127 let metric_name = args[2].get_str()?.to_owned();
128
129 let vtype =
130 VectorType::from_name(&type_name).map_err(|e| Error::Module(e.to_string()))?;
131 let metric = DistanceMetric::from_name(&metric_name)
132 .map_err(|e| Error::Module(e.to_string()))?;
133
134 let db = ctx.db();
135
136 let sql = ShadowOps::select_all_data_sql(&table_name);
138 let mut stmt = db.prepare(&sql)?;
139 stmt.query(())?;
140
141 let mut rows: Vec<(i64, Vec<u8>)> = Vec::new();
142 while let Some(row) = stmt.next()? {
143 let id = row[0].get_i64();
144 let blob = row[1].get_blob()?.to_vec();
145 rows.push((id, blob));
146 }
147
148 if rows.is_empty() {
149 ctx.set_result(0i64)?;
150 return Ok(());
151 }
152
153 let dim = rows[0].1.len() / vtype.element_size();
155
156 let index = HnswIndex::new(dim, vtype, metric, None)
158 .map_err(|e| Error::Module(e.to_string()))?;
159 for (id, blob) in &rows {
160 index
161 .add(*id as u64, blob)
162 .map_err(|e| Error::Module(e.to_string()))?;
163 }
164
165 let buf = index
167 .save_to_buffer()
168 .map_err(|e| Error::Module(e.to_string()))?;
169 let upsert_sql = ShadowOps::upsert_index_sql(&table_name);
170 db.insert(&upsert_sql, |stmt: &mut query::Statement| {
171 "hnsw_graph".bind_param(&mut *stmt, 1)?;
172 buf.as_slice().bind_param(&mut *stmt, 2)?;
173 Ok(())
174 })?;
175
176 ctx.set_result(rows.len() as i64)?;
177 Ok(())
178 },
179 )?;
180
181 db.create_scalar_function(
187 "vector_export_arrow",
188 &FunctionOptions::default().set_n_args(2),
189 |ctx, args| {
190 let table_name = args[0].get_str()?.to_owned();
191 let type_name = args[1].get_str()?.to_owned();
192
193 let vtype =
194 VectorType::from_name(&type_name).map_err(|e| Error::Module(e.to_string()))?;
195
196 let db = ctx.db();
197
198 let sql = ShadowOps::select_all_data_sql(&table_name);
200 let mut stmt = db.prepare(&sql)?;
201 stmt.query(())?;
202
203 let mut blobs: Vec<Vec<u8>> = Vec::new();
204 while let Some(row) = stmt.next()? {
205 blobs.push(row[1].get_blob()?.to_vec());
206 }
207
208 if blobs.is_empty() {
209 let empty: &[u8] = &[];
211 ctx.set_result(empty)?;
212 return Ok(());
213 }
214
215 let dim = blobs[0].len() / vtype.element_size();
216 let ipc = arrow_io::vectors_to_arrow_ipc(&blobs, vtype, dim)
217 .map_err(|e| Error::Module(e.to_string()))?;
218
219 ctx.set_result(&ipc[..])?;
220 Ok(())
221 },
222 )?;
223
224 db.create_scalar_function(
233 "vector_insert_arrow",
234 &FunctionOptions::default().set_n_args(3),
235 |ctx, args| {
236 let table_name = args[0].get_str()?.to_owned();
237 let type_name = args[1].get_str()?.to_owned();
238 let ipc_blob = args[2].get_blob()?.to_vec();
239
240 let vtype =
241 VectorType::from_name(&type_name).map_err(|e| Error::Module(e.to_string()))?;
242
243 if ipc_blob.is_empty() {
244 ctx.set_result(0i64)?;
245 return Ok(());
246 }
247
248 let blobs = arrow_io::arrow_ipc_to_vectors(&ipc_blob, vtype, 0)
251 .map_err(|e| Error::Module(e.to_string()))?;
252
253 if blobs.is_empty() {
254 ctx.set_result(0i64)?;
255 return Ok(());
256 }
257
258 let db = ctx.db();
259 let insert_sql = ShadowOps::insert_vector_only_sql(&table_name);
260 for blob in &blobs {
261 db.insert(&insert_sql, [blob.as_slice()])?;
262 }
263
264 ctx.set_result(blobs.len() as i64)?;
265 Ok(())
266 },
267 )?;
268
269 Ok(())
270}