1use crate::{Node, Vector};
2use bincode::{Decode, Encode};
3use rand::Rng;
4use std::{
5 collections::{HashMap, HashSet},
6 rc::Rc,
7};
8
9pub enum ScoreMetric {
10 Cosine,
11 L2,
12}
13
14pub trait ANNIndexOwned<const N: usize> {
15 fn insert(&mut self, vector: Vector<N>, id: String) {
17 self.insert_with_rng(vector, id, &mut rand::rng());
18 }
19
20 fn insert_with_rng(&mut self, vector: Vector<N>, id: String, rng: &mut impl Rng);
22
23 fn delete_by_id(&mut self, id: &str) -> bool;
26
27 fn get_by_id(&self, id: &str) -> Option<&Vector<N>>;
29
30 fn search(&self, query: &Vector<N>, top_k: usize) -> Vec<(String, f32)> {
33 self.search_with_metric(query, top_k, ScoreMetric::L2)
34 }
35
36 fn search_with_metric(
39 &self,
40 query: &Vector<N>,
41 top_k: usize,
42 metric: ScoreMetric,
43 ) -> Vec<(String, f32)>;
44}
45
46#[derive(Encode, Decode)]
47pub struct VectorLiteIndex<const N: usize> {
48 trees: Vec<Node<N, Rc<String>>>,
49 max_leaf_size: usize,
50}
51
52impl<const N: usize> VectorLiteIndex<N> {
53 fn new(num_trees: usize, max_leaf_size: usize) -> Self {
54 Self {
55 trees: (0..num_trees).map(|_| Node::new_empty()).collect(),
56 max_leaf_size,
57 }
58 }
59
60 pub fn search(&self, query: &Vector<N>, top_k: usize) -> Vec<&String> {
64 let mut candidates = HashSet::new();
65 for tree in &self.trees {
66 tree.search(query, top_k, &mut candidates);
67 }
68 candidates.into_iter().map(|id| id.as_ref()).collect()
69 }
70
71 pub fn to_bytes(&self) -> Vec<u8> {
73 let config = bincode::config::standard();
74 bincode::encode_to_vec(self, config).unwrap()
75 }
76
77 pub fn from_bytes(bytes: &[u8]) -> Self {
79 let config = bincode::config::standard();
80 let (index, _) = bincode::decode_from_slice(bytes, config).unwrap();
81 index
82 }
83}
84
85#[derive(Encode, Decode)]
87pub struct VectorLite<const N: usize> {
88 vectors: HashMap<Rc<String>, Vector<N>>,
89 index: VectorLiteIndex<N>,
90}
91
92impl<const N: usize> VectorLite<N> {
93 pub fn new(num_trees: usize, max_leaf_size: usize) -> Self {
97 Self {
98 vectors: HashMap::new(),
99 index: VectorLiteIndex::new(num_trees, max_leaf_size),
100 }
101 }
102
103 pub fn len(&self) -> usize {
105 self.vectors.len()
106 }
107
108 pub fn is_empty(&self) -> bool {
110 self.vectors.is_empty()
111 }
112
113 pub fn index(&self) -> &VectorLiteIndex<N> {
115 &self.index
116 }
117
118 pub fn to_bytes(&self) -> Vec<u8> {
120 let config = bincode::config::standard();
121 bincode::encode_to_vec(self, config).unwrap()
122 }
123
124 pub fn from_bytes(bytes: &[u8]) -> Self {
126 let config = bincode::config::standard();
127 let (index, _) = bincode::decode_from_slice(bytes, config).unwrap();
128 index
129 }
130}
131
132impl<const N: usize> ANNIndexOwned<N> for VectorLite<N> {
133 fn insert_with_rng(&mut self, vector: Vector<N>, id: String, rng: &mut impl Rng) {
134 let id = Rc::new(id);
135 self.vectors.insert(id.clone(), vector);
136 let vector_fn = |id: &Rc<String>| &self.vectors[id];
137 for tree in &mut self.index.trees {
138 tree.insert(&vector_fn, id.clone(), rng, self.index.max_leaf_size);
139 }
140 }
141
142 fn delete_by_id(&mut self, id: &str) -> bool {
143 let id = Rc::new(id.to_string());
144 let Some(vector) = self.vectors.remove(&id) else {
145 return false;
146 };
147 for tree in &mut self.index.trees {
148 tree.delete(&vector, &id);
149 }
150 true
151 }
152
153 fn get_by_id(&self, id: &str) -> Option<&Vector<N>> {
154 self.vectors.get(&Rc::new(id.to_string()))
155 }
156
157 fn search_with_metric(
158 &self,
159 query: &Vector<N>,
160 top_k: usize,
161 metric: ScoreMetric,
162 ) -> Vec<(String, f32)> {
163 let candidates = self.index.search(query, top_k);
164
165 let results = match metric {
166 ScoreMetric::L2 => {
167 let mut results = candidates
168 .into_iter()
169 .map(|id| {
170 let dist = self.vectors[id].sq_euc_dist(query);
171 (id, dist)
172 })
173 .collect::<Vec<_>>();
174 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
175 results
176 }
177 ScoreMetric::Cosine => {
178 let mut results = candidates
179 .into_iter()
180 .map(|id| {
181 let dist = self.vectors[id].cosine_similarity(query);
182 (id, dist)
183 })
184 .collect::<Vec<_>>();
185 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
186 results
187 }
188 };
189
190 results
191 .into_iter()
192 .take(top_k)
193 .map(|(id, dist)| (id.clone(), dist))
194 .collect()
195 }
196}
197
198#[cfg(test)]
199mod tests {
200 use super::*;
201 use rand::SeedableRng;
202 use rand::rngs::StdRng;
203
204 fn create_test_rng() -> StdRng {
206 StdRng::seed_from_u64(42)
207 }
208
209 #[test]
210 fn test_basic_operations() {
211 let mut index = VectorLite::<3>::new(2, 2);
212 let mut rng = create_test_rng();
213
214 index.insert_with_rng(Vector::from([1.0, 0.0, 0.0]), "101".to_string(), &mut rng);
216 index.insert_with_rng(Vector::from([0.0, 1.0, 0.0]), "102".to_string(), &mut rng);
217 index.insert_with_rng(Vector::from([0.0, 0.0, 1.0]), "103".to_string(), &mut rng);
218 index.insert_with_rng(Vector::from([1.0, 1.0, 0.0]), "104".to_string(), &mut rng);
219
220 let results = index.search(&Vector::from([0.9, 0.1, 0.0]), 2);
222
223 assert_eq!(results.len(), 2);
225 assert_eq!(results[0].0, "101"); let results = index.search(&Vector::from([0.1, 0.9, 0.0]), 1);
229 assert_eq!(results.len(), 1);
230 assert_eq!(results[0].0, "102"); }
232
233 #[test]
234 fn test_multi_tree_performance() {
235 let mut single_tree = VectorLite::<2>::new(1, 2);
237 let mut multi_tree = VectorLite::<2>::new(5, 2);
238
239 let mut rng = create_test_rng();
240
241 for x in 0..5 {
243 for y in 0..5 {
244 let id = x * 10 + y;
245 let vector = Vector::from([x as f32, y as f32]);
246
247 single_tree.insert_with_rng(vector.clone(), id.to_string(), &mut rng);
249 multi_tree.insert_with_rng(vector, id.to_string(), &mut rng);
250 }
251 }
252
253 let query = Vector::from([2.3, 2.3]);
255
256 let single_results = single_tree.search(&query, 5);
257 let multi_results = multi_tree.search(&query, 5);
258
259 assert!(multi_results.len() >= single_results.len());
261
262 assert_eq!(multi_results[0].0, "22");
263
264 for i in 1..multi_results.len() {
266 assert!(multi_results[i].1 >= multi_results[i - 1].1);
267 }
268 }
269
270 #[test]
271 fn test_deletion() {
272 let mut index = VectorLite::<2>::new(3, 2);
273 let mut rng = create_test_rng();
274
275 for i in 0..10 {
277 let x = i as f32;
278 index.insert_with_rng(Vector::from([x, x]), i.to_string(), &mut rng);
279 }
280
281 let results = index.search(&Vector::from([5.0, 5.0]), 1);
283 assert_eq!(results.len(), 1);
284 assert_eq!(results[0].0, "5");
285
286 index.delete_by_id("5");
288
289 let results = index.search(&Vector::from([5.0, 5.0]), 1);
291 assert_eq!(results.len(), 1);
292 assert_ne!(results[0].0, "5"); assert!(results[0].0 == "4" || results[0].0 == "6");
296
297 index.delete_by_id("4");
299 index.delete_by_id("6");
300
301 let results = index.search(&Vector::from([5.0, 5.0]), 3);
302 for result in results {
303 assert!(result.0 != "4" && result.0 != "5" && result.0 != "6");
304 }
305 }
306
307 #[test]
308 fn test_file_operations() {
309 let mut index = VectorLite::<3>::new(2, 2);
311 let mut rng = create_test_rng();
312
313 index.insert_with_rng(Vector::from([1.0, 0.0, 0.0]), "101".to_string(), &mut rng);
315 index.insert_with_rng(Vector::from([0.0, 1.0, 0.0]), "102".to_string(), &mut rng);
316 index.insert_with_rng(Vector::from([0.0, 0.0, 1.0]), "103".to_string(), &mut rng);
317
318 let serialized = index.to_bytes();
319 let loaded_index = VectorLite::<3>::from_bytes(&serialized);
320
321 let query = Vector::from([0.9, 0.1, 0.0]);
323 let original_results = index.search(&query, 2);
324 let loaded_results = loaded_index.search(&query, 2);
325
326 assert_eq!(original_results.len(), loaded_results.len());
327 for i in 0..original_results.len() {
328 assert_eq!(original_results[i].0, loaded_results[i].0);
329 assert!((original_results[i].1 - loaded_results[i].1).abs() < 1e-6);
330 }
331 }
332
333 #[test]
334 fn test_deleting_nonexistent_id() {
335 let mut index = VectorLite::<3>::new(2, 2);
336 let mut rng = create_test_rng();
337
338 index.insert_with_rng(Vector::from([1.0, 0.0, 0.0]), "101".to_string(), &mut rng);
339 index.insert_with_rng(Vector::from([0.0, 1.0, 0.0]), "102".to_string(), &mut rng);
340
341 let result = index.delete_by_id("non_existent_id");
342
343 assert_eq!(result, false);
344
345 assert_eq!(index.len(), 2);
346 assert!(index.get_by_id("101").is_some());
347 assert!(index.get_by_id("102").is_some());
348
349 let result = index.delete_by_id("101");
350 assert_eq!(result, true);
351 assert_eq!(index.len(), 1);
352 }
353}