1use crate::{Node, Vector};
2use bincode::{Decode, Encode};
3use rand::Rng;
4use std::collections::{HashMap, HashSet};
5
6pub trait ANNIndexOwned<const N: usize> {
7 fn insert(&mut self, vector: Vector<N>, id: String) {
9 self.insert_with_rng(vector, id, &mut rand::rng());
10 }
11
12 fn insert_with_rng(&mut self, vector: Vector<N>, id: String, rng: &mut impl Rng);
14
15 fn delete_by_id(&mut self, id: String);
17
18 fn get_by_id(&self, id: String) -> Option<&Vector<N>>;
20
21 fn search(&self, query: &Vector<N>, top_k: usize) -> Vec<(String, f32)>;
23}
24
25#[derive(Encode, Decode)]
27pub struct VectorLite<const N: usize> {
28 vectors: Vec<Vector<N>>,
29 id_to_offset: HashMap<String, u32>,
30 offset_to_id: HashMap<u32, String>,
31 trees: Vec<Node<N>>,
32 max_leaf_size: usize,
33}
34
35impl<const N: usize> VectorLite<N> {
36 pub fn new(num_trees: usize, max_leaf_size: usize) -> Self {
40 Self {
41 vectors: Vec::new(),
42 id_to_offset: HashMap::new(),
43 offset_to_id: HashMap::new(),
44 trees: (0..num_trees).map(|_| Node::new_empty()).collect(),
45 max_leaf_size,
46 }
47 }
48
49 pub fn to_bytes(&self) -> Vec<u8> {
50 let config = bincode::config::standard();
51 bincode::encode_to_vec(self, config).unwrap()
52 }
53
54 pub fn from_bytes(bytes: &[u8]) -> Self {
55 let config = bincode::config::standard();
56 let (index, _) = bincode::decode_from_slice(bytes, config).unwrap();
57 index
58 }
59}
60
61impl<const N: usize> ANNIndexOwned<N> for VectorLite<N> {
62 fn insert_with_rng(&mut self, vector: Vector<N>, id: String, rng: &mut impl Rng) {
63 self.vectors.push(vector);
64 let offset = self.vectors.len() as u32 - 1;
65 self.offset_to_id.insert(offset, id.clone());
66 self.id_to_offset.insert(id, offset);
67 let vector_fn = |idx: u32| &self.vectors[idx as usize];
68 for tree in &mut self.trees {
69 tree.insert(
70 &vector_fn,
71 self.vectors.len() as u32 - 1,
72 rng,
73 self.max_leaf_size,
74 );
75 }
76 }
77
78 fn delete_by_id(&mut self, id: String) {
79 let offset = self.id_to_offset[&id];
80 for tree in &mut self.trees {
81 tree.delete(&self.vectors[offset as usize], offset);
82 }
83 self.id_to_offset.remove(&id);
84 self.offset_to_id.remove(&offset);
85 }
87
88 fn get_by_id(&self, id: String) -> Option<&Vector<N>> {
89 self.id_to_offset
90 .get(&id)
91 .map(|offset| &self.vectors[*offset as usize])
92 }
93
94 fn search(&self, query: &Vector<N>, top_k: usize) -> Vec<(String, f32)> {
95 let mut candidates = HashSet::new();
96 for tree in self.trees.iter() {
97 tree.search(query, top_k, &mut candidates);
98 }
99
100 let mut results = candidates
101 .into_iter()
102 .map(|offset| {
103 (
104 offset,
105 self.vectors[offset as usize].cosine_similarity(query),
106 )
107 })
108 .collect::<Vec<_>>();
109 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
110 results
111 .into_iter()
112 .take(top_k)
113 .map(|(offset, dist)| (self.offset_to_id[&offset].clone(), dist))
114 .collect()
115 }
116}
117
118#[cfg(test)]
119mod tests {
120 use super::*;
121 use rand::SeedableRng;
122 use rand::rngs::StdRng;
123
124 fn create_test_rng() -> StdRng {
126 StdRng::seed_from_u64(42)
127 }
128
129 #[test]
130 fn test_basic_operations() {
131 let mut index = VectorLite::<3>::new(2, 2);
132 let mut rng = create_test_rng();
133
134 index.insert_with_rng(Vector::from([1.0, 0.0, 0.0]), "101".to_string(), &mut rng);
136 index.insert_with_rng(Vector::from([0.0, 1.0, 0.0]), "102".to_string(), &mut rng);
137 index.insert_with_rng(Vector::from([0.0, 0.0, 1.0]), "103".to_string(), &mut rng);
138 index.insert_with_rng(Vector::from([1.0, 1.0, 0.0]), "104".to_string(), &mut rng);
139
140 let results = index.search(&Vector::from([0.9, 0.1, 0.0]), 2);
142
143 assert_eq!(results.len(), 2);
145 assert_eq!(results[0].0, "101"); let results = index.search(&Vector::from([0.1, 0.9, 0.0]), 1);
149 assert_eq!(results.len(), 1);
150 assert_eq!(results[0].0, "102"); }
152
153 #[test]
154 fn test_multi_tree_performance() {
155 let mut single_tree = VectorLite::<2>::new(1, 2);
157 let mut multi_tree = VectorLite::<2>::new(5, 2);
158
159 let mut rng = create_test_rng();
160
161 for x in 0..5 {
163 for y in 0..5 {
164 let id = x * 10 + y;
165 let vector = Vector::from([x as f32, y as f32]);
166
167 single_tree.insert_with_rng(vector.clone(), id.to_string(), &mut rng);
169 multi_tree.insert_with_rng(vector, id.to_string(), &mut rng);
170 }
171 }
172
173 let query = Vector::from([2.3, 2.3]);
175
176 let single_results = single_tree.search(&query, 5);
177 let multi_results = multi_tree.search(&query, 5);
178
179 assert!(multi_results.len() >= single_results.len());
181
182 assert_eq!(multi_results[0].0, "22");
183
184 for i in 1..multi_results.len() {
186 assert!(multi_results[i].1 >= multi_results[i - 1].1);
187 }
188 }
189
190 #[test]
191 fn test_deletion() {
192 let mut index = VectorLite::<2>::new(3, 2);
193 let mut rng = create_test_rng();
194
195 for i in 0..10 {
197 let x = i as f32;
198 index.insert_with_rng(Vector::from([x, x]), i.to_string(), &mut rng);
199 }
200
201 let results = index.search(&Vector::from([5.0, 5.0]), 1);
203 assert_eq!(results.len(), 1);
204 assert_eq!(results[0].0, "5");
205
206 index.delete_by_id("5".to_string());
208
209 let results = index.search(&Vector::from([5.0, 5.0]), 1);
211 assert_eq!(results.len(), 1);
212 assert_ne!(results[0].0, "5"); assert!(results[0].0 == "4" || results[0].0 == "6");
216
217 index.delete_by_id("4".to_string());
219 index.delete_by_id("6".to_string());
220
221 let results = index.search(&Vector::from([5.0, 5.0]), 3);
222 for result in results {
223 assert!(result.0 != "4" && result.0 != "5" && result.0 != "6");
224 }
225 }
226
227 #[test]
228 fn test_file_operations() {
229 let mut index = VectorLite::<3>::new(2, 2);
231 let mut rng = create_test_rng();
232
233 index.insert_with_rng(Vector::from([1.0, 0.0, 0.0]), "101".to_string(), &mut rng);
235 index.insert_with_rng(Vector::from([0.0, 1.0, 0.0]), "102".to_string(), &mut rng);
236 index.insert_with_rng(Vector::from([0.0, 0.0, 1.0]), "103".to_string(), &mut rng);
237
238 let serialized = index.to_bytes();
239 let loaded_index = VectorLite::<3>::from_bytes(&serialized);
240
241 let query = Vector::from([0.9, 0.1, 0.0]);
243 let original_results = index.search(&query, 2);
244 let loaded_results = loaded_index.search(&query, 2);
245
246 assert_eq!(original_results.len(), loaded_results.len());
247 for i in 0..original_results.len() {
248 assert_eq!(original_results[i].0, loaded_results[i].0);
249 assert!((original_results[i].1 - loaded_results[i].1).abs() < 1e-6);
250 }
251 }
252}