vector_lite/
owned_index.rs

1use crate::{Node, Vector};
2use bincode::{Decode, Encode};
3use rand::Rng;
4use std::collections::{HashMap, HashSet};
5
6pub trait ANNIndexOwned<const N: usize> {
7    /// Insert a vector into the index.
8    fn insert(&mut self, vector: Vector<N>, id: String) {
9        self.insert_with_rng(vector, id, &mut rand::rng());
10    }
11
12    /// Insert a vector into the index with a custom rng.
13    fn insert_with_rng(&mut self, vector: Vector<N>, id: String, rng: &mut impl Rng);
14
15    /// Delete a vector from the index by id.
16    fn delete_by_id(&mut self, id: String);
17
18    /// Get a vector from the index by id.
19    fn get_by_id(&self, id: String) -> Option<&Vector<N>>;
20
21    /// Search for the top_k nearest neighbors of the query vector.
22    fn search(&self, query: &Vector<N>, top_k: usize) -> Vec<(String, f32)>;
23}
24
25/// A lightweight lsh-based ann index.
26#[derive(Encode, Decode)]
27pub struct VectorLite<const N: usize> {
28    vectors: Vec<Vector<N>>,
29    id_to_offset: HashMap<String, u32>,
30    offset_to_id: HashMap<u32, String>,
31    trees: Vec<Node<N>>,
32    max_leaf_size: usize,
33}
34
35impl<const N: usize> VectorLite<N> {
36    /// Create a new VectorLite index with the given number of trees and max leaf size.
37    /// More trees means more accuracy but slower search and larger memory usage.
38    /// Lower max_leaf_size means higher accuracy but larger memory usage.
39    pub fn new(num_trees: usize, max_leaf_size: usize) -> Self {
40        Self {
41            vectors: Vec::new(),
42            id_to_offset: HashMap::new(),
43            offset_to_id: HashMap::new(),
44            trees: (0..num_trees).map(|_| Node::new_empty()).collect(),
45            max_leaf_size,
46        }
47    }
48
49    pub fn to_bytes(&self) -> Vec<u8> {
50        let config = bincode::config::standard();
51        bincode::encode_to_vec(self, config).unwrap()
52    }
53
54    pub fn from_bytes(bytes: &[u8]) -> Self {
55        let config = bincode::config::standard();
56        let (index, _) = bincode::decode_from_slice(bytes, config).unwrap();
57        index
58    }
59}
60
61impl<const N: usize> ANNIndexOwned<N> for VectorLite<N> {
62    fn insert_with_rng(&mut self, vector: Vector<N>, id: String, rng: &mut impl Rng) {
63        self.vectors.push(vector);
64        let offset = self.vectors.len() as u32 - 1;
65        self.offset_to_id.insert(offset, id.clone());
66        self.id_to_offset.insert(id, offset);
67        let vector_fn = |idx: u32| &self.vectors[idx as usize];
68        for tree in &mut self.trees {
69            tree.insert(
70                &vector_fn,
71                self.vectors.len() as u32 - 1,
72                rng,
73                self.max_leaf_size,
74            );
75        }
76    }
77
78    fn delete_by_id(&mut self, id: String) {
79        let offset = self.id_to_offset[&id];
80        for tree in &mut self.trees {
81            tree.delete(&self.vectors[offset as usize], offset);
82        }
83        self.id_to_offset.remove(&id);
84        self.offset_to_id.remove(&offset);
85        // TODO: also remove from vectors.
86    }
87
88    fn get_by_id(&self, id: String) -> Option<&Vector<N>> {
89        self.id_to_offset
90            .get(&id)
91            .map(|offset| &self.vectors[*offset as usize])
92    }
93
94    fn search(&self, query: &Vector<N>, top_k: usize) -> Vec<(String, f32)> {
95        let mut candidates = HashSet::new();
96        for tree in self.trees.iter() {
97            tree.search(query, top_k, &mut candidates);
98        }
99
100        let mut results = candidates
101            .into_iter()
102            .map(|offset| {
103                (
104                    offset,
105                    self.vectors[offset as usize].cosine_similarity(query),
106                )
107            })
108            .collect::<Vec<_>>();
109        results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
110        results
111            .into_iter()
112            .take(top_k)
113            .map(|(offset, dist)| (self.offset_to_id[&offset].clone(), dist))
114            .collect()
115    }
116}
117
118#[cfg(test)]
119mod tests {
120    use super::*;
121    use rand::SeedableRng;
122    use rand::rngs::StdRng;
123
124    // Helper function to create a deterministic RNG for testing
125    fn create_test_rng() -> StdRng {
126        StdRng::seed_from_u64(42)
127    }
128
129    #[test]
130    fn test_basic_operations() {
131        let mut index = VectorLite::<3>::new(2, 2);
132        let mut rng = create_test_rng();
133
134        // Insert some vectors
135        index.insert_with_rng(Vector::from([1.0, 0.0, 0.0]), "101".to_string(), &mut rng);
136        index.insert_with_rng(Vector::from([0.0, 1.0, 0.0]), "102".to_string(), &mut rng);
137        index.insert_with_rng(Vector::from([0.0, 0.0, 1.0]), "103".to_string(), &mut rng);
138        index.insert_with_rng(Vector::from([1.0, 1.0, 0.0]), "104".to_string(), &mut rng);
139
140        // Search for nearest neighbors
141        let results = index.search(&Vector::from([0.9, 0.1, 0.0]), 2);
142
143        // Verify correct results
144        assert_eq!(results.len(), 2);
145        assert_eq!(results[0].0, "101"); // First result should be ID 101 ([1.0, 0.0, 0.0])
146
147        // Query close to second vector
148        let results = index.search(&Vector::from([0.1, 0.9, 0.0]), 1);
149        assert_eq!(results.len(), 1);
150        assert_eq!(results[0].0, "102"); // Should be ID 102 ([0.0, 1.0, 0.0])
151    }
152
153    #[test]
154    fn test_multi_tree_performance() {
155        // Compare performance with different numbers of trees
156        let mut single_tree = VectorLite::<2>::new(1, 2);
157        let mut multi_tree = VectorLite::<2>::new(5, 2);
158
159        let mut rng = create_test_rng();
160
161        // Create a grid of vectors
162        for x in 0..5 {
163            for y in 0..5 {
164                let id = x * 10 + y;
165                let vector = Vector::from([x as f32, y as f32]);
166
167                // Insert into both indexes
168                single_tree.insert_with_rng(vector.clone(), id.to_string(), &mut rng);
169                multi_tree.insert_with_rng(vector, id.to_string(), &mut rng);
170            }
171        }
172
173        // Query in a spot where approximation might occur
174        let query = Vector::from([2.3, 2.3]);
175
176        let single_results = single_tree.search(&query, 5);
177        let multi_results = multi_tree.search(&query, 5);
178
179        // Multi-tree should find at least as many results as single tree
180        assert!(multi_results.len() >= single_results.len());
181
182        assert_eq!(multi_results[0].0, "22");
183
184        // The results should be sorted by distance
185        for i in 1..multi_results.len() {
186            assert!(multi_results[i].1 >= multi_results[i - 1].1);
187        }
188    }
189
190    #[test]
191    fn test_deletion() {
192        let mut index = VectorLite::<2>::new(3, 2);
193        let mut rng = create_test_rng();
194
195        // Insert vectors with sequential IDs
196        for i in 0..10 {
197            let x = i as f32;
198            index.insert_with_rng(Vector::from([x, x]), i.to_string(), &mut rng);
199        }
200
201        // Search for a point and verify it exists
202        let results = index.search(&Vector::from([5.0, 5.0]), 1);
203        assert_eq!(results.len(), 1);
204        assert_eq!(results[0].0, "5");
205
206        // Delete that point
207        index.delete_by_id("5".to_string());
208
209        // Search again - should find a different point now
210        let results = index.search(&Vector::from([5.0, 5.0]), 1);
211        assert_eq!(results.len(), 1);
212        assert_ne!(results[0].0, "5"); // Should not find the deleted point
213
214        // The nearest should now be either 4 or 6
215        assert!(results[0].0 == "4" || results[0].0 == "6");
216
217        // Delete several points and verify none are found
218        index.delete_by_id("4".to_string());
219        index.delete_by_id("6".to_string());
220
221        let results = index.search(&Vector::from([5.0, 5.0]), 3);
222        for result in results {
223            assert!(result.0 != "4" && result.0 != "5" && result.0 != "6");
224        }
225    }
226
227    #[test]
228    fn test_file_operations() {
229        // Create a new index
230        let mut index = VectorLite::<3>::new(2, 2);
231        let mut rng = create_test_rng();
232
233        // Insert some test vectors
234        index.insert_with_rng(Vector::from([1.0, 0.0, 0.0]), "101".to_string(), &mut rng);
235        index.insert_with_rng(Vector::from([0.0, 1.0, 0.0]), "102".to_string(), &mut rng);
236        index.insert_with_rng(Vector::from([0.0, 0.0, 1.0]), "103".to_string(), &mut rng);
237
238        let serialized = index.to_bytes();
239        let loaded_index = VectorLite::<3>::from_bytes(&serialized);
240
241        // Verify search results match
242        let query = Vector::from([0.9, 0.1, 0.0]);
243        let original_results = index.search(&query, 2);
244        let loaded_results = loaded_index.search(&query, 2);
245
246        assert_eq!(original_results.len(), loaded_results.len());
247        for i in 0..original_results.len() {
248            assert_eq!(original_results[i].0, loaded_results[i].0);
249            assert!((original_results[i].1 - loaded_results[i].1).abs() < 1e-6);
250        }
251    }
252}