vector_lite/
owned_index.rs

1use crate::{Node, Vector};
2use bincode::{Decode, Encode};
3use rand::Rng;
4use std::{
5    collections::{HashMap, HashSet},
6    rc::Rc,
7};
8
9pub enum ScoreMetric {
10    Cosine,
11    L2,
12}
13
14pub trait ANNIndexOwned<const N: usize> {
15    /// Insert a vector into the index.
16    fn insert(&mut self, vector: Vector<N>, id: String) {
17        self.insert_with_rng(vector, id, &mut rand::rng());
18    }
19
20    /// Insert a vector into the index with a custom rng.
21    fn insert_with_rng(&mut self, vector: Vector<N>, id: String, rng: &mut impl Rng);
22
23    /// Delete a vector from the index by id.
24    /// Returns true if the vector was deleted, false if it was not found.
25    fn delete_by_id(&mut self, id: &str) -> bool;
26
27    /// Get a vector from the index by id.
28    fn get_by_id(&self, id: &str) -> Option<&Vector<N>>;
29
30    /// Search for the top_k nearest neighbors of the query vector.
31    /// Returns a array of (id, score) pairs, higher score means closer.
32    fn search(&self, query: &Vector<N>, top_k: usize) -> Vec<(String, f32)> {
33        self.search_with_metric(query, top_k, ScoreMetric::L2)
34    }
35
36    /// Search for the top_k nearest neighbors of the query vector.
37    /// Returns a array of (id, score) pairs, higher score means closer.
38    fn search_with_metric(
39        &self,
40        query: &Vector<N>,
41        top_k: usize,
42        metric: ScoreMetric,
43    ) -> Vec<(String, f32)>;
44}
45
46#[derive(Encode, Decode)]
47pub struct VectorLiteIndex<const N: usize> {
48    trees: Vec<Node<N, Rc<String>>>,
49    max_leaf_size: usize,
50}
51
52impl<const N: usize> VectorLiteIndex<N> {
53    fn new(num_trees: usize, max_leaf_size: usize) -> Self {
54        Self {
55            trees: (0..num_trees).map(|_| Node::new_empty()).collect(),
56            max_leaf_size,
57        }
58    }
59
60    /// Search for the top_k nearest neighbors of the query vector.
61    /// Returns a vector of ids, user may use the ids to compute the distance themselves.
62    /// This is helpful when the vector is stored in a different place.
63    pub fn search(&self, query: &Vector<N>, top_k: usize) -> Vec<&String> {
64        let mut candidates = HashSet::new();
65        for tree in &self.trees {
66            tree.search(query, top_k, &mut candidates);
67        }
68        candidates.into_iter().map(|id| id.as_ref()).collect()
69    }
70
71    /// Serialize the index to a byte vector.
72    pub fn to_bytes(&self) -> Vec<u8> {
73        let config = bincode::config::standard();
74        bincode::encode_to_vec(self, config).unwrap()
75    }
76
77    /// Deserialize the index from a byte vector.
78    pub fn from_bytes(bytes: &[u8]) -> Self {
79        let config = bincode::config::standard();
80        let (index, _) = bincode::decode_from_slice(bytes, config).unwrap();
81        index
82    }
83}
84
85/// A lightweight lsh-based ann index.
86#[derive(Encode, Decode)]
87pub struct VectorLite<const N: usize> {
88    vectors: HashMap<Rc<String>, Vector<N>>,
89    index: VectorLiteIndex<N>,
90}
91
92impl<const N: usize> VectorLite<N> {
93    /// Create a new VectorLite index with the given number of trees and max leaf size.
94    /// More trees means more accuracy but slower search and larger memory usage.
95    /// Lower max_leaf_size means higher accuracy but larger memory usage.
96    pub fn new(num_trees: usize, max_leaf_size: usize) -> Self {
97        Self {
98            vectors: HashMap::new(),
99            index: VectorLiteIndex::new(num_trees, max_leaf_size),
100        }
101    }
102
103    /// Get the number of vectors in the index.
104    pub fn len(&self) -> usize {
105        self.vectors.len()
106    }
107
108    /// Check if the index is empty.
109    pub fn is_empty(&self) -> bool {
110        self.vectors.is_empty()
111    }
112
113    /// Get the index.
114    pub fn index(&self) -> &VectorLiteIndex<N> {
115        &self.index
116    }
117
118    /// Serialize the index to a byte vector.
119    pub fn to_bytes(&self) -> Vec<u8> {
120        let config = bincode::config::standard();
121        bincode::encode_to_vec(self, config).unwrap()
122    }
123
124    /// Deserialize the index from a byte vector.
125    pub fn from_bytes(bytes: &[u8]) -> Self {
126        let config = bincode::config::standard();
127        let (index, _) = bincode::decode_from_slice(bytes, config).unwrap();
128        index
129    }
130}
131
132impl<const N: usize> ANNIndexOwned<N> for VectorLite<N> {
133    fn insert_with_rng(&mut self, vector: Vector<N>, id: String, rng: &mut impl Rng) {
134        let id = Rc::new(id);
135        self.vectors.insert(id.clone(), vector);
136        let vector_fn = |id: &Rc<String>| &self.vectors[id];
137        for tree in &mut self.index.trees {
138            tree.insert(&vector_fn, id.clone(), rng, self.index.max_leaf_size);
139        }
140    }
141
142    fn delete_by_id(&mut self, id: &str) -> bool {
143        let id = Rc::new(id.to_string());
144        let Some(vector) = self.vectors.remove(&id) else {
145            return false;
146        };
147        for tree in &mut self.index.trees {
148            tree.delete(&vector, &id);
149        }
150        true
151    }
152
153    fn get_by_id(&self, id: &str) -> Option<&Vector<N>> {
154        self.vectors.get(&Rc::new(id.to_string()))
155    }
156
157    fn search_with_metric(
158        &self,
159        query: &Vector<N>,
160        top_k: usize,
161        metric: ScoreMetric,
162    ) -> Vec<(String, f32)> {
163        let candidates = self.index.search(query, top_k);
164
165        let results = match metric {
166            ScoreMetric::L2 => {
167                let mut results = candidates
168                    .into_iter()
169                    .map(|id| {
170                        let dist = self.vectors[id].sq_euc_dist(query);
171                        (id, dist)
172                    })
173                    .collect::<Vec<_>>();
174                results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
175                results
176            }
177            ScoreMetric::Cosine => {
178                let mut results = candidates
179                    .into_iter()
180                    .map(|id| {
181                        let dist = self.vectors[id].cosine_similarity(query);
182                        (id, dist)
183                    })
184                    .collect::<Vec<_>>();
185                results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
186                results
187            }
188        };
189
190        results
191            .into_iter()
192            .take(top_k)
193            .map(|(id, dist)| (id.clone(), dist))
194            .collect()
195    }
196}
197
198#[cfg(test)]
199mod tests {
200    use super::*;
201    use rand::SeedableRng;
202    use rand::rngs::StdRng;
203
204    // Helper function to create a deterministic RNG for testing
205    fn create_test_rng() -> StdRng {
206        StdRng::seed_from_u64(42)
207    }
208
209    #[test]
210    fn test_basic_operations() {
211        let mut index = VectorLite::<3>::new(2, 2);
212        let mut rng = create_test_rng();
213
214        // Insert some vectors
215        index.insert_with_rng(Vector::from([1.0, 0.0, 0.0]), "101".to_string(), &mut rng);
216        index.insert_with_rng(Vector::from([0.0, 1.0, 0.0]), "102".to_string(), &mut rng);
217        index.insert_with_rng(Vector::from([0.0, 0.0, 1.0]), "103".to_string(), &mut rng);
218        index.insert_with_rng(Vector::from([1.0, 1.0, 0.0]), "104".to_string(), &mut rng);
219
220        // Search for nearest neighbors
221        let results = index.search(&Vector::from([0.9, 0.1, 0.0]), 2);
222
223        // Verify correct results
224        assert_eq!(results.len(), 2);
225        assert_eq!(results[0].0, "101"); // First result should be ID 101 ([1.0, 0.0, 0.0])
226
227        // Query close to second vector
228        let results = index.search(&Vector::from([0.1, 0.9, 0.0]), 1);
229        assert_eq!(results.len(), 1);
230        assert_eq!(results[0].0, "102"); // Should be ID 102 ([0.0, 1.0, 0.0])
231    }
232
233    #[test]
234    fn test_multi_tree_performance() {
235        // Compare performance with different numbers of trees
236        let mut single_tree = VectorLite::<2>::new(1, 2);
237        let mut multi_tree = VectorLite::<2>::new(5, 2);
238
239        let mut rng = create_test_rng();
240
241        // Create a grid of vectors
242        for x in 0..5 {
243            for y in 0..5 {
244                let id = x * 10 + y;
245                let vector = Vector::from([x as f32, y as f32]);
246
247                // Insert into both indexes
248                single_tree.insert_with_rng(vector.clone(), id.to_string(), &mut rng);
249                multi_tree.insert_with_rng(vector, id.to_string(), &mut rng);
250            }
251        }
252
253        // Query in a spot where approximation might occur
254        let query = Vector::from([2.3, 2.3]);
255
256        let single_results = single_tree.search(&query, 5);
257        let multi_results = multi_tree.search(&query, 5);
258
259        // Multi-tree should find at least as many results as single tree
260        assert!(multi_results.len() >= single_results.len());
261
262        assert_eq!(multi_results[0].0, "22");
263
264        // The results should be sorted by distance
265        for i in 1..multi_results.len() {
266            assert!(multi_results[i].1 >= multi_results[i - 1].1);
267        }
268    }
269
270    #[test]
271    fn test_deletion() {
272        let mut index = VectorLite::<2>::new(3, 2);
273        let mut rng = create_test_rng();
274
275        // Insert vectors with sequential IDs
276        for i in 0..10 {
277            let x = i as f32;
278            index.insert_with_rng(Vector::from([x, x]), i.to_string(), &mut rng);
279        }
280
281        // Search for a point and verify it exists
282        let results = index.search(&Vector::from([5.0, 5.0]), 1);
283        assert_eq!(results.len(), 1);
284        assert_eq!(results[0].0, "5");
285
286        // Delete that point
287        index.delete_by_id("5");
288
289        // Search again - should find a different point now
290        let results = index.search(&Vector::from([5.0, 5.0]), 1);
291        assert_eq!(results.len(), 1);
292        assert_ne!(results[0].0, "5"); // Should not find the deleted point
293
294        // The nearest should now be either 4 or 6
295        assert!(results[0].0 == "4" || results[0].0 == "6");
296
297        // Delete several points and verify none are found
298        index.delete_by_id("4");
299        index.delete_by_id("6");
300
301        let results = index.search(&Vector::from([5.0, 5.0]), 3);
302        for result in results {
303            assert!(result.0 != "4" && result.0 != "5" && result.0 != "6");
304        }
305    }
306
307    #[test]
308    fn test_file_operations() {
309        // Create a new index
310        let mut index = VectorLite::<3>::new(2, 2);
311        let mut rng = create_test_rng();
312
313        // Insert some test vectors
314        index.insert_with_rng(Vector::from([1.0, 0.0, 0.0]), "101".to_string(), &mut rng);
315        index.insert_with_rng(Vector::from([0.0, 1.0, 0.0]), "102".to_string(), &mut rng);
316        index.insert_with_rng(Vector::from([0.0, 0.0, 1.0]), "103".to_string(), &mut rng);
317
318        let serialized = index.to_bytes();
319        let loaded_index = VectorLite::<3>::from_bytes(&serialized);
320
321        // Verify search results match
322        let query = Vector::from([0.9, 0.1, 0.0]);
323        let original_results = index.search(&query, 2);
324        let loaded_results = loaded_index.search(&query, 2);
325
326        assert_eq!(original_results.len(), loaded_results.len());
327        for i in 0..original_results.len() {
328            assert_eq!(original_results[i].0, loaded_results[i].0);
329            assert!((original_results[i].1 - loaded_results[i].1).abs() < 1e-6);
330        }
331    }
332
333    #[test]
334    fn test_deleting_nonexistent_id() {
335        let mut index = VectorLite::<3>::new(2, 2);
336        let mut rng = create_test_rng();
337
338        index.insert_with_rng(Vector::from([1.0, 0.0, 0.0]), "101".to_string(), &mut rng);
339        index.insert_with_rng(Vector::from([0.0, 1.0, 0.0]), "102".to_string(), &mut rng);
340
341        let result = index.delete_by_id("non_existent_id");
342
343        assert_eq!(result, false);
344
345        assert_eq!(index.len(), 2);
346        assert!(index.get_by_id("101").is_some());
347        assert!(index.get_by_id("102").is_some());
348
349        let result = index.delete_by_id("101");
350        assert_eq!(result, true);
351        assert_eq!(index.len(), 1);
352    }
353}