vector_lite/
owned_index.rs

1use crate::{Node, Vector};
2use bincode::{Decode, Encode};
3use rand::Rng;
4use std::collections::{HashMap, HashSet};
5
6pub trait ANNIndexOwned<const N: usize> {
7    /// Insert a vector into the index.
8    fn insert(&mut self, vector: Vector<N>, id: String) {
9        self.insert_with_rng(vector, id, &mut rand::rng());
10    }
11
12    /// Insert a vector into the index with a custom rng.
13    fn insert_with_rng(&mut self, vector: Vector<N>, id: String, rng: &mut impl Rng);
14
15    /// Delete a vector from the index by id.
16    fn delete_by_id(&mut self, id: String);
17
18    /// Get a vector from the index by id.
19    fn get_by_id(&self, id: String) -> Option<&Vector<N>>;
20
21    /// Search for the top_k nearest neighbors of the query vector.
22    fn search(&self, query: &Vector<N>, top_k: usize) -> Vec<(String, f32)>;
23}
24
25/// A lightweight lsh-based ann index.
26#[derive(Encode, Decode)]
27pub struct VectorLite<const N: usize> {
28    vectors: Vec<Vector<N>>,
29    id_to_offset: HashMap<String, u32>,
30    offset_to_id: HashMap<u32, String>,
31    trees: Vec<Node<N>>,
32    max_leaf_size: usize,
33}
34
35impl<const N: usize> VectorLite<N> {
36    /// Create a new VectorLite index with the given number of trees and max leaf size.
37    /// More trees means more accuracy but slower search and larger memory usage.
38    /// Lower max_leaf_size means higher accuracy but larger memory usage.
39    pub fn new(num_trees: usize, max_leaf_size: usize) -> Self {
40        Self {
41            vectors: Vec::new(),
42            id_to_offset: HashMap::new(),
43            offset_to_id: HashMap::new(),
44            trees: (0..num_trees).map(|_| Node::new_empty()).collect(),
45            max_leaf_size,
46        }
47    }
48
49    /// Get the number of vectors in the index.
50    pub fn len(&self) -> usize {
51        self.vectors.len()
52    }
53
54    /// Check if the index is empty.
55    pub fn is_empty(&self) -> bool {
56        self.vectors.is_empty()
57    }
58
59    /// Serialize the index to a byte vector.
60    pub fn to_bytes(&self) -> Vec<u8> {
61        let config = bincode::config::standard();
62        bincode::encode_to_vec(self, config).unwrap()
63    }
64
65    /// Deserialize the index from a byte vector.
66    pub fn from_bytes(bytes: &[u8]) -> Self {
67        let config = bincode::config::standard();
68        let (index, _) = bincode::decode_from_slice(bytes, config).unwrap();
69        index
70    }
71}
72
73impl<const N: usize> ANNIndexOwned<N> for VectorLite<N> {
74    fn insert_with_rng(&mut self, vector: Vector<N>, id: String, rng: &mut impl Rng) {
75        self.vectors.push(vector);
76        let offset = self.vectors.len() as u32 - 1;
77        self.offset_to_id.insert(offset, id.clone());
78        self.id_to_offset.insert(id, offset);
79        let vector_fn = |idx: u32| &self.vectors[idx as usize];
80        for tree in &mut self.trees {
81            tree.insert(
82                &vector_fn,
83                self.vectors.len() as u32 - 1,
84                rng,
85                self.max_leaf_size,
86            );
87        }
88    }
89
90    fn delete_by_id(&mut self, id: String) {
91        let offset = self.id_to_offset[&id];
92        for tree in &mut self.trees {
93            tree.delete(&self.vectors[offset as usize], offset);
94        }
95        self.id_to_offset.remove(&id);
96        self.offset_to_id.remove(&offset);
97        // TODO: also remove from vectors.
98    }
99
100    fn get_by_id(&self, id: String) -> Option<&Vector<N>> {
101        self.id_to_offset
102            .get(&id)
103            .map(|offset| &self.vectors[*offset as usize])
104    }
105
106    fn search(&self, query: &Vector<N>, top_k: usize) -> Vec<(String, f32)> {
107        let mut candidates = HashSet::new();
108        for tree in self.trees.iter() {
109            tree.search(query, top_k, &mut candidates);
110        }
111
112        let mut results = candidates
113            .into_iter()
114            .map(|offset| (offset, self.vectors[offset as usize].sq_euc_dist(query)))
115            .collect::<Vec<_>>();
116        results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
117        results
118            .into_iter()
119            .take(top_k)
120            .map(|(offset, dist)| (self.offset_to_id[&offset].clone(), dist))
121            .collect()
122    }
123}
124
125#[cfg(test)]
126mod tests {
127    use super::*;
128    use rand::SeedableRng;
129    use rand::rngs::StdRng;
130
131    // Helper function to create a deterministic RNG for testing
132    fn create_test_rng() -> StdRng {
133        StdRng::seed_from_u64(42)
134    }
135
136    #[test]
137    fn test_basic_operations() {
138        let mut index = VectorLite::<3>::new(2, 2);
139        let mut rng = create_test_rng();
140
141        // Insert some vectors
142        index.insert_with_rng(Vector::from([1.0, 0.0, 0.0]), "101".to_string(), &mut rng);
143        index.insert_with_rng(Vector::from([0.0, 1.0, 0.0]), "102".to_string(), &mut rng);
144        index.insert_with_rng(Vector::from([0.0, 0.0, 1.0]), "103".to_string(), &mut rng);
145        index.insert_with_rng(Vector::from([1.0, 1.0, 0.0]), "104".to_string(), &mut rng);
146
147        // Search for nearest neighbors
148        let results = index.search(&Vector::from([0.9, 0.1, 0.0]), 2);
149
150        // Verify correct results
151        assert_eq!(results.len(), 2);
152        assert_eq!(results[0].0, "101"); // First result should be ID 101 ([1.0, 0.0, 0.0])
153
154        // Query close to second vector
155        let results = index.search(&Vector::from([0.1, 0.9, 0.0]), 1);
156        assert_eq!(results.len(), 1);
157        assert_eq!(results[0].0, "102"); // Should be ID 102 ([0.0, 1.0, 0.0])
158    }
159
160    #[test]
161    fn test_multi_tree_performance() {
162        // Compare performance with different numbers of trees
163        let mut single_tree = VectorLite::<2>::new(1, 2);
164        let mut multi_tree = VectorLite::<2>::new(5, 2);
165
166        let mut rng = create_test_rng();
167
168        // Create a grid of vectors
169        for x in 0..5 {
170            for y in 0..5 {
171                let id = x * 10 + y;
172                let vector = Vector::from([x as f32, y as f32]);
173
174                // Insert into both indexes
175                single_tree.insert_with_rng(vector.clone(), id.to_string(), &mut rng);
176                multi_tree.insert_with_rng(vector, id.to_string(), &mut rng);
177            }
178        }
179
180        // Query in a spot where approximation might occur
181        let query = Vector::from([2.3, 2.3]);
182
183        let single_results = single_tree.search(&query, 5);
184        let multi_results = multi_tree.search(&query, 5);
185
186        // Multi-tree should find at least as many results as single tree
187        assert!(multi_results.len() >= single_results.len());
188
189        assert_eq!(multi_results[0].0, "22");
190
191        // The results should be sorted by distance
192        for i in 1..multi_results.len() {
193            assert!(multi_results[i].1 >= multi_results[i - 1].1);
194        }
195    }
196
197    #[test]
198    fn test_deletion() {
199        let mut index = VectorLite::<2>::new(3, 2);
200        let mut rng = create_test_rng();
201
202        // Insert vectors with sequential IDs
203        for i in 0..10 {
204            let x = i as f32;
205            index.insert_with_rng(Vector::from([x, x]), i.to_string(), &mut rng);
206        }
207
208        // Search for a point and verify it exists
209        let results = index.search(&Vector::from([5.0, 5.0]), 1);
210        assert_eq!(results.len(), 1);
211        assert_eq!(results[0].0, "5");
212
213        // Delete that point
214        index.delete_by_id("5".to_string());
215
216        // Search again - should find a different point now
217        let results = index.search(&Vector::from([5.0, 5.0]), 1);
218        assert_eq!(results.len(), 1);
219        assert_ne!(results[0].0, "5"); // Should not find the deleted point
220
221        // The nearest should now be either 4 or 6
222        assert!(results[0].0 == "4" || results[0].0 == "6");
223
224        // Delete several points and verify none are found
225        index.delete_by_id("4".to_string());
226        index.delete_by_id("6".to_string());
227
228        let results = index.search(&Vector::from([5.0, 5.0]), 3);
229        for result in results {
230            assert!(result.0 != "4" && result.0 != "5" && result.0 != "6");
231        }
232    }
233
234    #[test]
235    fn test_file_operations() {
236        // Create a new index
237        let mut index = VectorLite::<3>::new(2, 2);
238        let mut rng = create_test_rng();
239
240        // Insert some test vectors
241        index.insert_with_rng(Vector::from([1.0, 0.0, 0.0]), "101".to_string(), &mut rng);
242        index.insert_with_rng(Vector::from([0.0, 1.0, 0.0]), "102".to_string(), &mut rng);
243        index.insert_with_rng(Vector::from([0.0, 0.0, 1.0]), "103".to_string(), &mut rng);
244
245        let serialized = index.to_bytes();
246        let loaded_index = VectorLite::<3>::from_bytes(&serialized);
247
248        // Verify search results match
249        let query = Vector::from([0.9, 0.1, 0.0]);
250        let original_results = index.search(&query, 2);
251        let loaded_results = loaded_index.search(&query, 2);
252
253        assert_eq!(original_results.len(), loaded_results.len());
254        for i in 0..original_results.len() {
255            assert_eq!(original_results[i].0, loaded_results[i].0);
256            assert!((original_results[i].1 - loaded_results[i].1).abs() < 1e-6);
257        }
258    }
259}