vector_lite/
owned_index.rs

1use crate::{Node, Vector};
2use bincode::{Decode, Encode};
3use rand::Rng;
4use std::collections::{HashMap, HashSet};
5
6pub trait ANNIndexOwned<const N: usize> {
7    /// Insert a vector into the index.
8    fn insert(&mut self, vector: Vector<N>, id: String) {
9        self.insert_with_rng(vector, id, &mut rand::rng());
10    }
11
12    /// Insert a vector into the index with a custom rng.
13    fn insert_with_rng(&mut self, vector: Vector<N>, id: String, rng: &mut impl Rng);
14
15    /// Delete a vector from the index by id.
16    fn delete_by_id(&mut self, id: String);
17
18    /// Search for the top_k nearest neighbors of the query vector.
19    fn search(&self, query: &Vector<N>, top_k: usize) -> Vec<(String, f32)>;
20}
21
22/// A lightweight lsh-based ann index.
23#[derive(Encode, Decode)]
24pub struct VectorLite<const N: usize> {
25    vectors: Vec<Vector<N>>,
26    id_to_offset: HashMap<String, u32>,
27    offset_to_id: HashMap<u32, String>,
28    trees: Vec<Node<N>>,
29    max_leaf_size: usize,
30}
31
32impl<const N: usize> VectorLite<N> {
33    /// Create a new VectorLite index with the given number of trees and max leaf size.
34    /// More trees means more accuracy but slower search and larger memory usage.
35    /// Lower max_leaf_size means higher accuracy but larger memory usage.
36    pub fn new(num_trees: usize, max_leaf_size: usize) -> Self {
37        Self {
38            vectors: Vec::new(),
39            id_to_offset: HashMap::new(),
40            offset_to_id: HashMap::new(),
41            trees: (0..num_trees).map(|_| Node::new_empty()).collect(),
42            max_leaf_size,
43        }
44    }
45
46    pub fn to_bytes(&self) -> Vec<u8> {
47        let config = bincode::config::standard();
48        bincode::encode_to_vec(self, config).unwrap()
49    }
50
51    pub fn from_bytes(bytes: &[u8]) -> Self {
52        let config = bincode::config::standard();
53        let (index, _) = bincode::decode_from_slice(bytes, config).unwrap();
54        index
55    }
56}
57
58impl<const N: usize> ANNIndexOwned<N> for VectorLite<N> {
59    fn insert_with_rng(&mut self, vector: Vector<N>, id: String, rng: &mut impl Rng) {
60        self.vectors.push(vector);
61        let offset = self.vectors.len() as u32 - 1;
62        self.offset_to_id.insert(offset, id.clone());
63        self.id_to_offset.insert(id, offset);
64        let vector_fn = |idx: u32| &self.vectors[idx as usize];
65        for tree in &mut self.trees {
66            tree.insert(
67                &vector_fn,
68                self.vectors.len() as u32 - 1,
69                rng,
70                self.max_leaf_size,
71            );
72        }
73    }
74
75    fn delete_by_id(&mut self, id: String) {
76        let offset = self.id_to_offset[&id];
77        for tree in &mut self.trees {
78            tree.delete(&self.vectors[offset as usize], offset);
79        }
80        self.id_to_offset.remove(&id);
81        self.offset_to_id.remove(&offset);
82        // TODO: also remove from vectors.
83    }
84
85    fn search(&self, query: &Vector<N>, top_k: usize) -> Vec<(String, f32)> {
86        let mut candidates = HashSet::new();
87        for tree in self.trees.iter() {
88            tree.search(query, top_k, &mut candidates);
89        }
90
91        let mut results = candidates
92            .into_iter()
93            .map(|offset| (offset, self.vectors[offset as usize].cosine_dist(query)))
94            .collect::<Vec<_>>();
95        results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
96        results
97            .into_iter()
98            .take(top_k)
99            .map(|(offset, dist)| (self.offset_to_id[&offset].clone(), dist))
100            .collect()
101    }
102}
103
104#[cfg(test)]
105mod tests {
106    use super::*;
107    use rand::SeedableRng;
108    use rand::rngs::StdRng;
109
110    // Helper function to create a deterministic RNG for testing
111    fn create_test_rng() -> StdRng {
112        StdRng::seed_from_u64(42)
113    }
114
115    #[test]
116    fn test_basic_operations() {
117        let mut index = VectorLite::<3>::new(2, 2);
118        let mut rng = create_test_rng();
119
120        // Insert some vectors
121        index.insert_with_rng(Vector::from([1.0, 0.0, 0.0]), "101".to_string(), &mut rng);
122        index.insert_with_rng(Vector::from([0.0, 1.0, 0.0]), "102".to_string(), &mut rng);
123        index.insert_with_rng(Vector::from([0.0, 0.0, 1.0]), "103".to_string(), &mut rng);
124        index.insert_with_rng(Vector::from([1.0, 1.0, 0.0]), "104".to_string(), &mut rng);
125
126        // Search for nearest neighbors
127        let results = index.search(&Vector::from([0.9, 0.1, 0.0]), 2);
128
129        // Verify correct results
130        assert_eq!(results.len(), 2);
131        assert_eq!(results[0].0, "101"); // First result should be ID 101 ([1.0, 0.0, 0.0])
132
133        // Query close to second vector
134        let results = index.search(&Vector::from([0.1, 0.9, 0.0]), 1);
135        assert_eq!(results.len(), 1);
136        assert_eq!(results[0].0, "102"); // Should be ID 102 ([0.0, 1.0, 0.0])
137    }
138
139    #[test]
140    fn test_multi_tree_performance() {
141        // Compare performance with different numbers of trees
142        let mut single_tree = VectorLite::<2>::new(1, 2);
143        let mut multi_tree = VectorLite::<2>::new(5, 2);
144
145        let mut rng = create_test_rng();
146
147        // Create a grid of vectors
148        for x in 0..5 {
149            for y in 0..5 {
150                let id = x * 10 + y;
151                let vector = Vector::from([x as f32, y as f32]);
152
153                // Insert into both indexes
154                single_tree.insert_with_rng(vector.clone(), id.to_string(), &mut rng);
155                multi_tree.insert_with_rng(vector, id.to_string(), &mut rng);
156            }
157        }
158
159        // Query in a spot where approximation might occur
160        let query = Vector::from([2.3, 2.3]);
161
162        let single_results = single_tree.search(&query, 5);
163        let multi_results = multi_tree.search(&query, 5);
164
165        // Multi-tree should find at least as many results as single tree
166        assert!(multi_results.len() >= single_results.len());
167
168        assert_eq!(multi_results[0].0, "22");
169
170        // The results should be sorted by distance
171        for i in 1..multi_results.len() {
172            assert!(multi_results[i].1 >= multi_results[i - 1].1);
173        }
174    }
175
176    #[test]
177    fn test_deletion() {
178        let mut index = VectorLite::<2>::new(3, 2);
179        let mut rng = create_test_rng();
180
181        // Insert vectors with sequential IDs
182        for i in 0..10 {
183            let x = i as f32;
184            index.insert_with_rng(Vector::from([x, x]), i.to_string(), &mut rng);
185        }
186
187        // Search for a point and verify it exists
188        let results = index.search(&Vector::from([5.0, 5.0]), 1);
189        assert_eq!(results.len(), 1);
190        assert_eq!(results[0].0, "5");
191
192        // Delete that point
193        index.delete_by_id("5".to_string());
194
195        // Search again - should find a different point now
196        let results = index.search(&Vector::from([5.0, 5.0]), 1);
197        assert_eq!(results.len(), 1);
198        assert_ne!(results[0].0, "5"); // Should not find the deleted point
199
200        // The nearest should now be either 4 or 6
201        assert!(results[0].0 == "4" || results[0].0 == "6");
202
203        // Delete several points and verify none are found
204        index.delete_by_id("4".to_string());
205        index.delete_by_id("6".to_string());
206
207        let results = index.search(&Vector::from([5.0, 5.0]), 3);
208        for result in results {
209            assert!(result.0 != "4" && result.0 != "5" && result.0 != "6");
210        }
211    }
212
213    #[test]
214    fn test_file_operations() {
215        // Create a new index
216        let mut index = VectorLite::<3>::new(2, 2);
217        let mut rng = create_test_rng();
218
219        // Insert some test vectors
220        index.insert_with_rng(Vector::from([1.0, 0.0, 0.0]), "101".to_string(), &mut rng);
221        index.insert_with_rng(Vector::from([0.0, 1.0, 0.0]), "102".to_string(), &mut rng);
222        index.insert_with_rng(Vector::from([0.0, 0.0, 1.0]), "103".to_string(), &mut rng);
223
224        let serialized = index.to_bytes();
225        let loaded_index = VectorLite::<3>::from_bytes(&serialized);
226
227        // Verify search results match
228        let query = Vector::from([0.9, 0.1, 0.0]);
229        let original_results = index.search(&query, 2);
230        let loaded_results = loaded_index.search(&query, 2);
231
232        assert_eq!(original_results.len(), loaded_results.len());
233        for i in 0..original_results.len() {
234            assert_eq!(original_results[i].0, loaded_results[i].0);
235            assert!((original_results[i].1 - loaded_results[i].1).abs() < 1e-6);
236        }
237    }
238}