Skip to main content

sqlite_vector_rs/
index.rs

1use std::fmt;
2
3use bytemuck::cast_slice;
4use half::f16;
5
6use crate::distance::{DistanceMetric, vtype_to_scalar_kind};
7use crate::types::VectorType;
8
9/// Optional HNSW tuning parameters.
10#[derive(Debug, Clone, Copy)]
11pub struct HnswParams {
12    pub m: usize,
13    pub ef_construction: usize,
14    pub ef_search: usize,
15}
16
17impl Default for HnswParams {
18    fn default() -> Self {
19        Self {
20            m: 16,
21            ef_construction: 200,
22            ef_search: 64,
23        }
24    }
25}
26
27#[derive(Debug)]
28pub struct IndexError(pub String);
29
30impl fmt::Display for IndexError {
31    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
32        write!(f, "index error: {}", self.0)
33    }
34}
35
36impl std::error::Error for IndexError {}
37
38/// Wrapper around usearch::Index providing a typed interface.
39pub struct HnswIndex {
40    inner: usearch::Index,
41    _dim: usize,
42    vtype: VectorType,
43}
44
45impl HnswIndex {
46    /// Create a new empty HNSW index.
47    pub fn new(
48        dim: usize,
49        vtype: VectorType,
50        metric: DistanceMetric,
51        params: Option<HnswParams>,
52    ) -> Result<Self, IndexError> {
53        let p = params.unwrap_or_default();
54        let opts = usearch::IndexOptions {
55            dimensions: dim,
56            metric: metric.to_usearch(),
57            quantization: vtype_to_scalar_kind(vtype),
58            connectivity: p.m,
59            expansion_add: p.ef_construction,
60            expansion_search: p.ef_search,
61            multi: false,
62        };
63        let inner = usearch::Index::new(&opts).map_err(|e| IndexError(e.to_string()))?;
64        Ok(Self {
65            inner,
66            _dim: dim,
67            vtype,
68        })
69    }
70
71    /// Number of vectors in the index.
72    pub fn len(&self) -> usize {
73        self.inner.size()
74    }
75
76    pub fn is_empty(&self) -> bool {
77        self.len() == 0
78    }
79
80    /// Add a vector to the index. The blob must match the index's type and dimension.
81    pub fn add(&self, key: u64, blob: &[u8]) -> Result<(), IndexError> {
82        self.reserve_if_needed()?;
83        match self.vtype {
84            VectorType::Float4 => {
85                let v: &[f32] = cast_slice(blob);
86                self.inner
87                    .add(key, v)
88                    .map_err(|e| IndexError(e.to_string()))
89            }
90            VectorType::Float8 => {
91                let v: &[f64] = cast_slice(blob);
92                self.inner
93                    .add(key, v)
94                    .map_err(|e| IndexError(e.to_string()))
95            }
96            VectorType::Int1 => {
97                let v: &[i8] = cast_slice(blob);
98                self.inner
99                    .add(key, v)
100                    .map_err(|e| IndexError(e.to_string()))
101            }
102            // Float2 (f16), Int2 (i16), Int4 (i32) are not natively supported by the usearch
103            // generic VectorType trait as half::f16 — convert to f32 for index operations.
104            VectorType::Float2 => {
105                let v: &[f16] = cast_slice(blob);
106                let f: Vec<f32> = v.iter().map(|x| x.to_f32()).collect();
107                self.inner
108                    .add(key, &f)
109                    .map_err(|e| IndexError(e.to_string()))
110            }
111            VectorType::Int2 => {
112                let v: &[i16] = cast_slice(blob);
113                let f: Vec<f32> = v.iter().map(|x| *x as f32).collect();
114                self.inner
115                    .add(key, &f)
116                    .map_err(|e| IndexError(e.to_string()))
117            }
118            VectorType::Int4 => {
119                let v: &[i32] = cast_slice(blob);
120                let f: Vec<f32> = v.iter().map(|x| *x as f32).collect();
121                self.inner
122                    .add(key, &f)
123                    .map_err(|e| IndexError(e.to_string()))
124            }
125        }
126    }
127
128    /// Search for k nearest neighbors. Returns vec of (key, distance) pairs
129    /// sorted by distance ascending.
130    pub fn search(&self, query_blob: &[u8], k: usize) -> Result<Vec<(u64, f32)>, IndexError> {
131        if self.is_empty() {
132            return Ok(Vec::new());
133        }
134
135        let matches = match self.vtype {
136            VectorType::Float4 => {
137                let q: &[f32] = cast_slice(query_blob);
138                self.inner.search(q, k)
139            }
140            VectorType::Float8 => {
141                let q: &[f64] = cast_slice(query_blob);
142                self.inner.search(q, k)
143            }
144            VectorType::Int1 => {
145                let q: &[i8] = cast_slice(query_blob);
146                self.inner.search(q, k)
147            }
148            VectorType::Float2 => {
149                let q: &[f16] = cast_slice(query_blob);
150                let f: Vec<f32> = q.iter().map(|x| x.to_f32()).collect();
151                self.inner.search(&f, k)
152            }
153            VectorType::Int2 => {
154                let q: &[i16] = cast_slice(query_blob);
155                let f: Vec<f32> = q.iter().map(|x| *x as f32).collect();
156                self.inner.search(&f, k)
157            }
158            VectorType::Int4 => {
159                let q: &[i32] = cast_slice(query_blob);
160                let f: Vec<f32> = q.iter().map(|x| *x as f32).collect();
161                self.inner.search(&f, k)
162            }
163        }
164        .map_err(|e| IndexError(e.to_string()))?;
165
166        Ok(matches.keys.into_iter().zip(matches.distances).collect())
167    }
168
169    /// Remove a vector by key (soft delete).
170    pub fn remove(&self, key: u64) -> Result<(), IndexError> {
171        self.inner
172            .remove(key)
173            .map(|_| ())
174            .map_err(|e| IndexError(e.to_string()))
175    }
176
177    /// Serialize the index to a byte buffer.
178    pub fn save_to_buffer(&self) -> Result<Vec<u8>, IndexError> {
179        let len = self.inner.serialized_length();
180        let mut buf = vec![0u8; len];
181        self.inner
182            .save_to_buffer(&mut buf)
183            .map_err(|e| IndexError(e.to_string()))?;
184        Ok(buf)
185    }
186
187    /// Load index state from a byte buffer. Replaces current index contents.
188    pub fn load_from_buffer(&self, buf: &[u8]) -> Result<(), IndexError> {
189        self.inner
190            .load_from_buffer(buf)
191            .map_err(|e| IndexError(e.to_string()))
192    }
193
194    /// Reserve capacity if needed (doubles current capacity).
195    fn reserve_if_needed(&self) -> Result<(), IndexError> {
196        if self.inner.size() >= self.inner.capacity() {
197            let new_cap = (self.inner.capacity() * 2).max(64);
198            self.inner
199                .reserve(new_cap)
200                .map_err(|e| IndexError(e.to_string()))?;
201        }
202        Ok(())
203    }
204}
205
206#[cfg(test)]
207mod tests {
208    use super::*;
209    use bytemuck::cast_slice;
210
211    // ----------------------------------------------------------------
212    // Helper: build a Float4 blob from a &[f32]
213    // ----------------------------------------------------------------
214
215    fn f32_blob(values: &[f32]) -> Vec<u8> {
216        cast_slice(values).to_vec()
217    }
218
219    fn f64_blob(values: &[f64]) -> Vec<u8> {
220        cast_slice(values).to_vec()
221    }
222
223    // ----------------------------------------------------------------
224    // HnswParams::default
225    // ----------------------------------------------------------------
226
227    #[test]
228    fn hnsw_params_default_values() {
229        let p = HnswParams::default();
230        assert_eq!(p.m, 16);
231        assert_eq!(p.ef_construction, 200);
232        assert_eq!(p.ef_search, 64);
233    }
234
235    // ----------------------------------------------------------------
236    // HnswIndex::new
237    // ----------------------------------------------------------------
238
239    #[test]
240    fn new_float4_l2_does_not_error() {
241        let idx = HnswIndex::new(3, VectorType::Float4, DistanceMetric::L2, None);
242        assert!(idx.is_ok(), "expected Ok, got {:?}", idx.err());
243    }
244
245    #[test]
246    fn new_float8_cosine_does_not_error() {
247        let idx = HnswIndex::new(4, VectorType::Float8, DistanceMetric::Cosine, None);
248        assert!(idx.is_ok(), "expected Ok, got {:?}", idx.err());
249    }
250
251    #[test]
252    fn new_with_custom_params_does_not_error() {
253        let params = HnswParams {
254            m: 8,
255            ef_construction: 64,
256            ef_search: 32,
257        };
258        let idx = HnswIndex::new(3, VectorType::Float4, DistanceMetric::L2, Some(params));
259        assert!(idx.is_ok(), "expected Ok, got {:?}", idx.err());
260    }
261
262    // ----------------------------------------------------------------
263    // len / is_empty
264    // ----------------------------------------------------------------
265
266    #[test]
267    fn len_empty_index_is_zero() {
268        let idx = HnswIndex::new(3, VectorType::Float4, DistanceMetric::L2, None).unwrap();
269        assert_eq!(idx.len(), 0);
270        assert!(idx.is_empty());
271    }
272
273    #[test]
274    fn len_increases_after_add() {
275        let idx = HnswIndex::new(3, VectorType::Float4, DistanceMetric::L2, None).unwrap();
276
277        idx.add(1, &f32_blob(&[1.0, 0.0, 0.0])).unwrap();
278        assert_eq!(idx.len(), 1);
279        assert!(!idx.is_empty());
280
281        idx.add(2, &f32_blob(&[0.0, 1.0, 0.0])).unwrap();
282        assert_eq!(idx.len(), 2);
283
284        idx.add(3, &f32_blob(&[0.0, 0.0, 1.0])).unwrap();
285        assert_eq!(idx.len(), 3);
286    }
287
288    // ----------------------------------------------------------------
289    // add + search — orthogonal Float4 unit vectors
290    // ----------------------------------------------------------------
291
292    #[test]
293    fn search_nearest_orthogonal_float4() {
294        let idx = HnswIndex::new(3, VectorType::Float4, DistanceMetric::L2, None).unwrap();
295
296        idx.add(1, &f32_blob(&[1.0, 0.0, 0.0])).unwrap();
297        idx.add(2, &f32_blob(&[0.0, 1.0, 0.0])).unwrap();
298        idx.add(3, &f32_blob(&[0.0, 0.0, 1.0])).unwrap();
299
300        // Query is close to [1, 0, 0] — key 1 must be the nearest neighbour.
301        let results = idx.search(&f32_blob(&[0.9, 0.1, 0.0]), 1).unwrap();
302        assert_eq!(results.len(), 1);
303        assert_eq!(
304            results[0].0, 1,
305            "expected key 1 ([1,0,0]) as nearest, got key {}",
306            results[0].0
307        );
308    }
309
310    #[test]
311    fn search_returns_empty_on_empty_index() {
312        let idx = HnswIndex::new(3, VectorType::Float4, DistanceMetric::L2, None).unwrap();
313        let results = idx.search(&f32_blob(&[1.0, 0.0, 0.0]), 5).unwrap();
314        assert!(results.is_empty());
315    }
316
317    #[test]
318    fn search_k_larger_than_index_returns_all_vectors() {
319        let idx = HnswIndex::new(3, VectorType::Float4, DistanceMetric::L2, None).unwrap();
320        idx.add(1, &f32_blob(&[1.0, 0.0, 0.0])).unwrap();
321        idx.add(2, &f32_blob(&[0.0, 1.0, 0.0])).unwrap();
322
323        // k=10 > index size 2; usearch returns at most size() results.
324        let results = idx.search(&f32_blob(&[1.0, 0.0, 0.0]), 10).unwrap();
325        assert_eq!(results.len(), 2);
326    }
327
328    // ----------------------------------------------------------------
329    // remove
330    // ----------------------------------------------------------------
331
332    #[test]
333    fn remove_decreases_len() {
334        let idx = HnswIndex::new(3, VectorType::Float4, DistanceMetric::L2, None).unwrap();
335
336        idx.add(10, &f32_blob(&[1.0, 0.0, 0.0])).unwrap();
337        idx.add(20, &f32_blob(&[0.0, 1.0, 0.0])).unwrap();
338        idx.add(30, &f32_blob(&[0.0, 0.0, 1.0])).unwrap();
339        assert_eq!(idx.len(), 3);
340
341        idx.remove(20).unwrap();
342        assert_eq!(idx.len(), 2);
343    }
344
345    #[test]
346    fn remove_key_no_longer_returned_by_search() {
347        let idx = HnswIndex::new(3, VectorType::Float4, DistanceMetric::L2, None).unwrap();
348
349        idx.add(1, &f32_blob(&[1.0, 0.0, 0.0])).unwrap();
350        idx.add(2, &f32_blob(&[0.0, 1.0, 0.0])).unwrap();
351        idx.add(3, &f32_blob(&[0.0, 0.0, 1.0])).unwrap();
352
353        // Remove the vector that would otherwise be the nearest to [0, 1, 0].
354        idx.remove(2).unwrap();
355
356        let results = idx.search(&f32_blob(&[0.0, 1.0, 0.0]), 3).unwrap();
357        let returned_keys: Vec<u64> = results.iter().map(|(k, _)| *k).collect();
358        assert!(
359            !returned_keys.contains(&2),
360            "removed key 2 should not appear in search results, got {:?}",
361            returned_keys
362        );
363    }
364
365    // ----------------------------------------------------------------
366    // save_to_buffer / load_from_buffer — round-trip (Float4)
367    // ----------------------------------------------------------------
368
369    #[test]
370    fn save_load_roundtrip_float4() {
371        // Build and populate the source index.
372        let src = HnswIndex::new(3, VectorType::Float4, DistanceMetric::L2, None).unwrap();
373        src.add(1, &f32_blob(&[1.0, 0.0, 0.0])).unwrap();
374        src.add(2, &f32_blob(&[0.0, 1.0, 0.0])).unwrap();
375        src.add(3, &f32_blob(&[0.0, 0.0, 1.0])).unwrap();
376
377        let buf = src.save_to_buffer().unwrap();
378        assert!(!buf.is_empty(), "serialized buffer must not be empty");
379
380        // Load into a fresh index with identical configuration.
381        let dst = HnswIndex::new(3, VectorType::Float4, DistanceMetric::L2, None).unwrap();
382        dst.load_from_buffer(&buf).unwrap();
383
384        // The loaded index must contain the same number of vectors.
385        assert_eq!(dst.len(), src.len());
386
387        // Search must still return the correct nearest neighbour.
388        let results = dst.search(&f32_blob(&[0.9, 0.1, 0.0]), 1).unwrap();
389        assert_eq!(results.len(), 1);
390        assert_eq!(
391            results[0].0, 1,
392            "post-load search should return key 1, got {}",
393            results[0].0
394        );
395    }
396
397    // ----------------------------------------------------------------
398    // Float8 (f64) type — add + search
399    // ----------------------------------------------------------------
400
401    #[test]
402    fn add_search_float8() {
403        let idx = HnswIndex::new(3, VectorType::Float8, DistanceMetric::L2, None).unwrap();
404
405        idx.add(1, &f64_blob(&[1.0, 0.0, 0.0])).unwrap();
406        idx.add(2, &f64_blob(&[0.0, 1.0, 0.0])).unwrap();
407        idx.add(3, &f64_blob(&[0.0, 0.0, 1.0])).unwrap();
408
409        let results = idx.search(&f64_blob(&[0.1, 0.0, 0.9]), 1).unwrap();
410        assert_eq!(results.len(), 1);
411        assert_eq!(
412            results[0].0, 3,
413            "expected key 3 ([0,0,1]) as nearest, got key {}",
414            results[0].0
415        );
416    }
417
418    #[test]
419    fn save_load_roundtrip_float8() {
420        let src = HnswIndex::new(3, VectorType::Float8, DistanceMetric::L2, None).unwrap();
421        src.add(1, &f64_blob(&[1.0, 0.0, 0.0])).unwrap();
422        src.add(2, &f64_blob(&[0.0, 1.0, 0.0])).unwrap();
423        src.add(3, &f64_blob(&[0.0, 0.0, 1.0])).unwrap();
424
425        let buf = src.save_to_buffer().unwrap();
426
427        let dst = HnswIndex::new(3, VectorType::Float8, DistanceMetric::L2, None).unwrap();
428        dst.load_from_buffer(&buf).unwrap();
429
430        assert_eq!(dst.len(), 3);
431
432        let results = dst.search(&f64_blob(&[0.0, 0.9, 0.1]), 1).unwrap();
433        assert_eq!(results.len(), 1);
434        assert_eq!(
435            results[0].0, 2,
436            "post-load search should return key 2, got {}",
437            results[0].0
438        );
439    }
440
441    // ----------------------------------------------------------------
442    // Custom HnswParams
443    // ----------------------------------------------------------------
444
445    #[test]
446    fn custom_params_index_behaves_correctly() {
447        let params = HnswParams {
448            m: 4,
449            ef_construction: 32,
450            ef_search: 16,
451        };
452        let idx =
453            HnswIndex::new(3, VectorType::Float4, DistanceMetric::Cosine, Some(params)).unwrap();
454
455        idx.add(1, &f32_blob(&[1.0, 0.0, 0.0])).unwrap();
456        idx.add(2, &f32_blob(&[0.0, 1.0, 0.0])).unwrap();
457        idx.add(3, &f32_blob(&[0.0, 0.0, 1.0])).unwrap();
458
459        assert_eq!(idx.len(), 3);
460
461        let results = idx.search(&f32_blob(&[0.0, 0.1, 0.9]), 1).unwrap();
462        assert_eq!(results.len(), 1);
463        assert_eq!(
464            results[0].0, 3,
465            "expected key 3 ([0,0,1]) as nearest under cosine, got {}",
466            results[0].0
467        );
468    }
469}