vector_lite/
owned_index.rs

1use crate::{Node, Vector};
2use bincode::{Decode, Encode};
3use rand::Rng;
4use std::{
5    collections::{HashMap, HashSet},
6    rc::Rc,
7};
8
9pub trait ANNIndexOwned<const N: usize> {
10    /// Insert a vector into the index.
11    fn insert(&mut self, vector: Vector<N>, id: String) {
12        self.insert_with_rng(vector, id, &mut rand::rng());
13    }
14
15    /// Insert a vector into the index with a custom rng.
16    fn insert_with_rng(&mut self, vector: Vector<N>, id: String, rng: &mut impl Rng);
17
18    /// Delete a vector from the index by id.
19    fn delete_by_id(&mut self, id: String);
20
21    /// Get a vector from the index by id.
22    fn get_by_id(&self, id: String) -> Option<&Vector<N>>;
23
24    /// Search for the top_k nearest neighbors of the query vector.
25    fn search(&self, query: &Vector<N>, top_k: usize) -> Vec<(String, f32)>;
26}
27
28#[derive(Encode, Decode)]
29pub struct VectorLiteIndex<const N: usize> {
30    trees: Vec<Node<N, Rc<String>>>,
31    max_leaf_size: usize,
32}
33
34impl<const N: usize> VectorLiteIndex<N> {
35    fn new(num_trees: usize, max_leaf_size: usize) -> Self {
36        Self {
37            trees: (0..num_trees).map(|_| Node::new_empty()).collect(),
38            max_leaf_size,
39        }
40    }
41
42    /// Search for the top_k nearest neighbors of the query vector.
43    /// Returns a vector of ids, user may use the ids to compute the distance themselves.
44    /// This is helpful when the vector is stored in a different place.
45    pub fn search(&self, query: &Vector<N>, top_k: usize) -> Vec<String> {
46        let mut candidates = HashSet::new();
47        for tree in &self.trees {
48            tree.search(query, top_k, &mut candidates);
49        }
50        candidates
51            .into_iter()
52            .map(|id| id.as_ref().clone())
53            .collect()
54    }
55
56    /// Serialize the index to a byte vector.
57    pub fn to_bytes(&self) -> Vec<u8> {
58        let config = bincode::config::standard();
59        bincode::encode_to_vec(self, config).unwrap()
60    }
61
62    /// Deserialize the index from a byte vector.
63    pub fn from_bytes(bytes: &[u8]) -> Self {
64        let config = bincode::config::standard();
65        let (index, _) = bincode::decode_from_slice(bytes, config).unwrap();
66        index
67    }
68}
69
70/// A lightweight lsh-based ann index.
71#[derive(Encode, Decode)]
72pub struct VectorLite<const N: usize> {
73    vectors: HashMap<Rc<String>, Vector<N>>,
74    index: VectorLiteIndex<N>,
75}
76
77impl<const N: usize> VectorLite<N> {
78    /// Create a new VectorLite index with the given number of trees and max leaf size.
79    /// More trees means more accuracy but slower search and larger memory usage.
80    /// Lower max_leaf_size means higher accuracy but larger memory usage.
81    pub fn new(num_trees: usize, max_leaf_size: usize) -> Self {
82        Self {
83            vectors: HashMap::new(),
84            index: VectorLiteIndex::new(num_trees, max_leaf_size),
85        }
86    }
87
88    /// Get the number of vectors in the index.
89    pub fn len(&self) -> usize {
90        self.vectors.len()
91    }
92
93    /// Check if the index is empty.
94    pub fn is_empty(&self) -> bool {
95        self.vectors.is_empty()
96    }
97
98    /// Get the index.
99    pub fn index(&self) -> &VectorLiteIndex<N> {
100        &self.index
101    }
102
103    /// Serialize the index to a byte vector.
104    pub fn to_bytes(&self) -> Vec<u8> {
105        let config = bincode::config::standard();
106        bincode::encode_to_vec(self, config).unwrap()
107    }
108
109    /// Deserialize the index from a byte vector.
110    pub fn from_bytes(bytes: &[u8]) -> Self {
111        let config = bincode::config::standard();
112        let (index, _) = bincode::decode_from_slice(bytes, config).unwrap();
113        index
114    }
115}
116
117impl<const N: usize> ANNIndexOwned<N> for VectorLite<N> {
118    fn insert_with_rng(&mut self, vector: Vector<N>, id: String, rng: &mut impl Rng) {
119        let id = Rc::new(id);
120        self.vectors.insert(id.clone(), vector);
121        let vector_fn = |id: &Rc<String>| &self.vectors[id];
122        for tree in &mut self.index.trees {
123            tree.insert(&vector_fn, id.clone(), rng, self.index.max_leaf_size);
124        }
125    }
126
127    fn delete_by_id(&mut self, id: String) {
128        let id = Rc::new(id);
129        for tree in &mut self.index.trees {
130            tree.delete(&self.vectors[&id], &id);
131        }
132        self.vectors.remove(&id);
133    }
134
135    fn get_by_id(&self, id: String) -> Option<&Vector<N>> {
136        self.vectors.get(&id)
137    }
138
139    fn search(&self, query: &Vector<N>, top_k: usize) -> Vec<(String, f32)> {
140        let candidates = self.index.search(query, top_k);
141
142        let mut results = candidates
143            .into_iter()
144            .map(|id| {
145                let dist = self.vectors[&id].sq_euc_dist(query);
146                (id, dist)
147            })
148            .collect::<Vec<_>>();
149        results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
150        results.into_iter().take(top_k).collect()
151    }
152}
153
154#[cfg(test)]
155mod tests {
156    use super::*;
157    use rand::SeedableRng;
158    use rand::rngs::StdRng;
159
160    // Helper function to create a deterministic RNG for testing
161    fn create_test_rng() -> StdRng {
162        StdRng::seed_from_u64(42)
163    }
164
165    #[test]
166    fn test_basic_operations() {
167        let mut index = VectorLite::<3>::new(2, 2);
168        let mut rng = create_test_rng();
169
170        // Insert some vectors
171        index.insert_with_rng(Vector::from([1.0, 0.0, 0.0]), "101".to_string(), &mut rng);
172        index.insert_with_rng(Vector::from([0.0, 1.0, 0.0]), "102".to_string(), &mut rng);
173        index.insert_with_rng(Vector::from([0.0, 0.0, 1.0]), "103".to_string(), &mut rng);
174        index.insert_with_rng(Vector::from([1.0, 1.0, 0.0]), "104".to_string(), &mut rng);
175
176        // Search for nearest neighbors
177        let results = index.search(&Vector::from([0.9, 0.1, 0.0]), 2);
178
179        // Verify correct results
180        assert_eq!(results.len(), 2);
181        assert_eq!(results[0].0, "101"); // First result should be ID 101 ([1.0, 0.0, 0.0])
182
183        // Query close to second vector
184        let results = index.search(&Vector::from([0.1, 0.9, 0.0]), 1);
185        assert_eq!(results.len(), 1);
186        assert_eq!(results[0].0, "102"); // Should be ID 102 ([0.0, 1.0, 0.0])
187    }
188
189    #[test]
190    fn test_multi_tree_performance() {
191        // Compare performance with different numbers of trees
192        let mut single_tree = VectorLite::<2>::new(1, 2);
193        let mut multi_tree = VectorLite::<2>::new(5, 2);
194
195        let mut rng = create_test_rng();
196
197        // Create a grid of vectors
198        for x in 0..5 {
199            for y in 0..5 {
200                let id = x * 10 + y;
201                let vector = Vector::from([x as f32, y as f32]);
202
203                // Insert into both indexes
204                single_tree.insert_with_rng(vector.clone(), id.to_string(), &mut rng);
205                multi_tree.insert_with_rng(vector, id.to_string(), &mut rng);
206            }
207        }
208
209        // Query in a spot where approximation might occur
210        let query = Vector::from([2.3, 2.3]);
211
212        let single_results = single_tree.search(&query, 5);
213        let multi_results = multi_tree.search(&query, 5);
214
215        // Multi-tree should find at least as many results as single tree
216        assert!(multi_results.len() >= single_results.len());
217
218        assert_eq!(multi_results[0].0, "22");
219
220        // The results should be sorted by distance
221        for i in 1..multi_results.len() {
222            assert!(multi_results[i].1 >= multi_results[i - 1].1);
223        }
224    }
225
226    #[test]
227    fn test_deletion() {
228        let mut index = VectorLite::<2>::new(3, 2);
229        let mut rng = create_test_rng();
230
231        // Insert vectors with sequential IDs
232        for i in 0..10 {
233            let x = i as f32;
234            index.insert_with_rng(Vector::from([x, x]), i.to_string(), &mut rng);
235        }
236
237        // Search for a point and verify it exists
238        let results = index.search(&Vector::from([5.0, 5.0]), 1);
239        assert_eq!(results.len(), 1);
240        assert_eq!(results[0].0, "5");
241
242        // Delete that point
243        index.delete_by_id("5".to_string());
244
245        // Search again - should find a different point now
246        let results = index.search(&Vector::from([5.0, 5.0]), 1);
247        assert_eq!(results.len(), 1);
248        assert_ne!(results[0].0, "5"); // Should not find the deleted point
249
250        // The nearest should now be either 4 or 6
251        assert!(results[0].0 == "4" || results[0].0 == "6");
252
253        // Delete several points and verify none are found
254        index.delete_by_id("4".to_string());
255        index.delete_by_id("6".to_string());
256
257        let results = index.search(&Vector::from([5.0, 5.0]), 3);
258        for result in results {
259            assert!(result.0 != "4" && result.0 != "5" && result.0 != "6");
260        }
261    }
262
263    #[test]
264    fn test_file_operations() {
265        // Create a new index
266        let mut index = VectorLite::<3>::new(2, 2);
267        let mut rng = create_test_rng();
268
269        // Insert some test vectors
270        index.insert_with_rng(Vector::from([1.0, 0.0, 0.0]), "101".to_string(), &mut rng);
271        index.insert_with_rng(Vector::from([0.0, 1.0, 0.0]), "102".to_string(), &mut rng);
272        index.insert_with_rng(Vector::from([0.0, 0.0, 1.0]), "103".to_string(), &mut rng);
273
274        let serialized = index.to_bytes();
275        let loaded_index = VectorLite::<3>::from_bytes(&serialized);
276
277        // Verify search results match
278        let query = Vector::from([0.9, 0.1, 0.0]);
279        let original_results = index.search(&query, 2);
280        let loaded_results = loaded_index.search(&query, 2);
281
282        assert_eq!(original_results.len(), loaded_results.len());
283        for i in 0..original_results.len() {
284            assert_eq!(original_results[i].0, loaded_results[i].0);
285            assert!((original_results[i].1 - loaded_results[i].1).abs() < 1e-6);
286        }
287    }
288}