vector_lite/
owned_index.rs

1use crate::{Node, Vector};
2use rand::Rng;
3use serde::{Deserialize, Serialize};
4use std::collections::HashSet;
5
6pub trait ANNIndexOwned<const N: usize> {
7    /// Insert a vector into the index.
8    fn insert(&mut self, vector: Vector<N>, id: usize) {
9        self.insert_with_rng(vector, id, &mut rand::rng());
10    }
11
12    /// Insert a vector into the index with a custom rng.
13    fn insert_with_rng(&mut self, vector: Vector<N>, id: usize, rng: &mut impl Rng);
14
15    /// Delete a vector from the index by id.
16    fn delete_by_id(&mut self, id: usize);
17
18    /// Search for the top_k nearest neighbors of the query vector.
19    fn search(&self, query: &Vector<N>, top_k: usize) -> Vec<(usize, f32)>;
20}
21
22/// A lightweight lsh-based ann index.
23#[derive(Serialize, Deserialize)]
24pub struct VectorLite<const N: usize> {
25    vectors: Vec<Vector<N>>,
26    ids: Vec<usize>,
27    trees: Vec<Node<N>>,
28    max_leaf_size: usize,
29}
30
31impl<const N: usize> VectorLite<N> {
32    /// Create a new VectorLite index with the given number of trees and max leaf size.
33    /// More trees means more accuracy but slower search and larger memory usage.
34    /// Lower max_leaf_size means higher accuracy but larger memory usage.
35    pub fn new(num_trees: usize, max_leaf_size: usize) -> Self {
36        Self {
37            vectors: Vec::new(),
38            ids: Vec::new(),
39            trees: (0..num_trees).map(|_| Node::new_empty()).collect(),
40            max_leaf_size,
41        }
42    }
43
44    pub fn to_bytes(&self) -> Vec<u8> {
45        let mut s = flexbuffers::FlexbufferSerializer::new();
46        self.serialize(&mut s).unwrap();
47        s.take_buffer()
48    }
49
50    pub fn from_bytes(bytes: &[u8]) -> Self {
51        let s = flexbuffers::Reader::get_root(bytes).unwrap();
52        let index: Self = VectorLite::deserialize(s).unwrap();
53        index
54    }
55}
56
57impl<const N: usize> ANNIndexOwned<N> for VectorLite<N> {
58    fn insert_with_rng(&mut self, vector: Vector<N>, id: usize, rng: &mut impl Rng) {
59        self.vectors.push(vector);
60        self.ids.push(id);
61        let vector_fn = |idx: u32| &self.vectors[idx as usize];
62        for tree in &mut self.trees {
63            tree.insert(
64                &vector_fn,
65                self.vectors.len() as u32 - 1,
66                rng,
67                self.max_leaf_size,
68            );
69        }
70    }
71
72    fn delete_by_id(&mut self, id: usize) {
73        let vector_fn = |idx: u32| &self.vectors[idx as usize];
74        for tree in &mut self.trees {
75            tree.delete(&vector_fn, id as u32);
76        }
77        // TODO: also remove from vectors and ids.
78    }
79
80    fn search(&self, query: &Vector<N>, top_k: usize) -> Vec<(usize, f32)> {
81        let mut candidates = HashSet::new();
82        for tree in self.trees.iter() {
83            tree.search(query, top_k, &mut candidates);
84        }
85
86        let mut results = candidates
87            .into_iter()
88            .map(|idx| (idx as usize, self.vectors[idx as usize].sq_euc_dist(query)))
89            .collect::<Vec<_>>();
90        results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
91        results
92            .into_iter()
93            .take(top_k)
94            .map(|(id, dist)| (self.ids[id], dist))
95            .collect()
96    }
97}
98
99#[cfg(test)]
100mod tests {
101    use super::*;
102    use rand::SeedableRng;
103    use rand::rngs::StdRng;
104
105    // Helper function to create a deterministic RNG for testing
106    fn create_test_rng() -> StdRng {
107        StdRng::seed_from_u64(42)
108    }
109
110    #[test]
111    fn test_basic_operations() {
112        let mut index = VectorLite::<3>::new(2, 2);
113        let mut rng = create_test_rng();
114
115        // Insert some vectors
116        index.insert_with_rng(Vector::from([1.0, 0.0, 0.0]), 101, &mut rng);
117        index.insert_with_rng(Vector::from([0.0, 1.0, 0.0]), 102, &mut rng);
118        index.insert_with_rng(Vector::from([0.0, 0.0, 1.0]), 103, &mut rng);
119        index.insert_with_rng(Vector::from([1.0, 1.0, 0.0]), 104, &mut rng);
120
121        // Search for nearest neighbors
122        let results = index.search(&Vector::from([0.9, 0.1, 0.0]), 2);
123
124        // Verify correct results
125        assert_eq!(results.len(), 2);
126        assert_eq!(results[0].0, 101); // First result should be ID 101 ([1.0, 0.0, 0.0])
127
128        // Query close to second vector
129        let results = index.search(&Vector::from([0.1, 0.9, 0.0]), 1);
130        assert_eq!(results.len(), 1);
131        assert_eq!(results[0].0, 102); // Should be ID 102 ([0.0, 1.0, 0.0])
132    }
133
134    #[test]
135    fn test_multi_tree_performance() {
136        // Compare performance with different numbers of trees
137        let mut single_tree = VectorLite::<2>::new(1, 2);
138        let mut multi_tree = VectorLite::<2>::new(5, 2);
139
140        let mut rng = create_test_rng();
141
142        // Create a grid of vectors
143        for x in 0..5 {
144            for y in 0..5 {
145                let id = x * 10 + y;
146                let vector = Vector::from([x as f32, y as f32]);
147
148                // Insert into both indexes
149                single_tree.insert_with_rng(vector.clone(), id, &mut rng);
150                multi_tree.insert_with_rng(vector, id, &mut rng);
151            }
152        }
153
154        // Query in a spot where approximation might occur
155        let query = Vector::from([2.3, 2.3]);
156
157        let single_results = single_tree.search(&query, 5);
158        let multi_results = multi_tree.search(&query, 5);
159
160        // Multi-tree should find at least as many results as single tree
161        assert!(multi_results.len() >= single_results.len());
162
163        assert_eq!(multi_results[0].0, 22);
164
165        // The results should be sorted by distance
166        for i in 1..multi_results.len() {
167            assert!(multi_results[i].1 >= multi_results[i - 1].1);
168        }
169    }
170
171    #[test]
172    fn test_deletion() {
173        let mut index = VectorLite::<2>::new(3, 2);
174        let mut rng = create_test_rng();
175
176        // Insert vectors with sequential IDs
177        for i in 0..10 {
178            let x = i as f32;
179            index.insert_with_rng(Vector::from([x, x]), i, &mut rng);
180        }
181
182        // Search for a point and verify it exists
183        let results = index.search(&Vector::from([5.0, 5.0]), 1);
184        assert_eq!(results.len(), 1);
185        assert_eq!(results[0].0, 5);
186
187        // Delete that point
188        index.delete_by_id(5);
189
190        // Search again - should find a different point now
191        let results = index.search(&Vector::from([5.0, 5.0]), 1);
192        assert_eq!(results.len(), 1);
193        assert_ne!(results[0].0, 5); // Should not find the deleted point
194
195        // The nearest should now be either 4 or 6
196        assert!(results[0].0 == 4 || results[0].0 == 6);
197
198        // Delete several points and verify none are found
199        index.delete_by_id(4);
200        index.delete_by_id(6);
201
202        let results = index.search(&Vector::from([5.0, 5.0]), 3);
203        for result in results {
204            assert!(result.0 != 4 && result.0 != 5 && result.0 != 6);
205        }
206    }
207
208    #[test]
209    fn test_file_operations() {
210        // Create a new index
211        let mut index = VectorLite::<3>::new(2, 2);
212        let mut rng = create_test_rng();
213
214        // Insert some test vectors
215        index.insert_with_rng(Vector::from([1.0, 0.0, 0.0]), 101, &mut rng);
216        index.insert_with_rng(Vector::from([0.0, 1.0, 0.0]), 102, &mut rng);
217        index.insert_with_rng(Vector::from([0.0, 0.0, 1.0]), 103, &mut rng);
218
219        let serialized = index.to_bytes();
220        let loaded_index = VectorLite::<3>::from_bytes(&serialized);
221
222        // Verify search results match
223        let query = Vector::from([0.9, 0.1, 0.0]);
224        let original_results = index.search(&query, 2);
225        let loaded_results = loaded_index.search(&query, 2);
226
227        assert_eq!(original_results.len(), loaded_results.len());
228        for i in 0..original_results.len() {
229            assert_eq!(original_results[i].0, loaded_results[i].0);
230            assert!((original_results[i].1 - loaded_results[i].1).abs() < 1e-6);
231        }
232    }
233}