1use crate::{Node, Vector};
2use bincode::{Decode, Encode};
3use rand::Rng;
4use std::collections::{HashMap, HashSet};
5
6pub trait ANNIndexOwned<const N: usize> {
7 fn insert(&mut self, vector: Vector<N>, id: String) {
9 self.insert_with_rng(vector, id, &mut rand::rng());
10 }
11
12 fn insert_with_rng(&mut self, vector: Vector<N>, id: String, rng: &mut impl Rng);
14
15 fn delete_by_id(&mut self, id: String);
17
18 fn search(&self, query: &Vector<N>, top_k: usize) -> Vec<(String, f32)>;
20}
21
22#[derive(Encode, Decode)]
24pub struct VectorLite<const N: usize> {
25 vectors: Vec<Vector<N>>,
26 id_to_offset: HashMap<String, u32>,
27 offset_to_id: HashMap<u32, String>,
28 trees: Vec<Node<N>>,
29 max_leaf_size: usize,
30}
31
32impl<const N: usize> VectorLite<N> {
33 pub fn new(num_trees: usize, max_leaf_size: usize) -> Self {
37 Self {
38 vectors: Vec::new(),
39 id_to_offset: HashMap::new(),
40 offset_to_id: HashMap::new(),
41 trees: (0..num_trees).map(|_| Node::new_empty()).collect(),
42 max_leaf_size,
43 }
44 }
45
46 pub fn to_bytes(&self) -> Vec<u8> {
47 let config = bincode::config::standard();
48 bincode::encode_to_vec(self, config).unwrap()
49 }
50
51 pub fn from_bytes(bytes: &[u8]) -> Self {
52 let config = bincode::config::standard();
53 let (index, _) = bincode::decode_from_slice(bytes, config).unwrap();
54 index
55 }
56}
57
58impl<const N: usize> ANNIndexOwned<N> for VectorLite<N> {
59 fn insert_with_rng(&mut self, vector: Vector<N>, id: String, rng: &mut impl Rng) {
60 self.vectors.push(vector);
61 let offset = self.vectors.len() as u32 - 1;
62 self.offset_to_id.insert(offset, id.clone());
63 self.id_to_offset.insert(id, offset);
64 let vector_fn = |idx: u32| &self.vectors[idx as usize];
65 for tree in &mut self.trees {
66 tree.insert(
67 &vector_fn,
68 self.vectors.len() as u32 - 1,
69 rng,
70 self.max_leaf_size,
71 );
72 }
73 }
74
75 fn delete_by_id(&mut self, id: String) {
76 let offset = self.id_to_offset[&id];
77 for tree in &mut self.trees {
78 tree.delete(&self.vectors[offset as usize], offset);
79 }
80 self.id_to_offset.remove(&id);
81 self.offset_to_id.remove(&offset);
82 }
84
85 fn search(&self, query: &Vector<N>, top_k: usize) -> Vec<(String, f32)> {
86 let mut candidates = HashSet::new();
87 for tree in self.trees.iter() {
88 tree.search(query, top_k, &mut candidates);
89 }
90
91 let mut results = candidates
92 .into_iter()
93 .map(|offset| (offset, self.vectors[offset as usize].cosine_dist(query)))
94 .collect::<Vec<_>>();
95 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
96 results
97 .into_iter()
98 .take(top_k)
99 .map(|(offset, dist)| (self.offset_to_id[&offset].clone(), dist))
100 .collect()
101 }
102}
103
104#[cfg(test)]
105mod tests {
106 use super::*;
107 use rand::SeedableRng;
108 use rand::rngs::StdRng;
109
110 fn create_test_rng() -> StdRng {
112 StdRng::seed_from_u64(42)
113 }
114
115 #[test]
116 fn test_basic_operations() {
117 let mut index = VectorLite::<3>::new(2, 2);
118 let mut rng = create_test_rng();
119
120 index.insert_with_rng(Vector::from([1.0, 0.0, 0.0]), "101".to_string(), &mut rng);
122 index.insert_with_rng(Vector::from([0.0, 1.0, 0.0]), "102".to_string(), &mut rng);
123 index.insert_with_rng(Vector::from([0.0, 0.0, 1.0]), "103".to_string(), &mut rng);
124 index.insert_with_rng(Vector::from([1.0, 1.0, 0.0]), "104".to_string(), &mut rng);
125
126 let results = index.search(&Vector::from([0.9, 0.1, 0.0]), 2);
128
129 assert_eq!(results.len(), 2);
131 assert_eq!(results[0].0, "101"); let results = index.search(&Vector::from([0.1, 0.9, 0.0]), 1);
135 assert_eq!(results.len(), 1);
136 assert_eq!(results[0].0, "102"); }
138
139 #[test]
140 fn test_multi_tree_performance() {
141 let mut single_tree = VectorLite::<2>::new(1, 2);
143 let mut multi_tree = VectorLite::<2>::new(5, 2);
144
145 let mut rng = create_test_rng();
146
147 for x in 0..5 {
149 for y in 0..5 {
150 let id = x * 10 + y;
151 let vector = Vector::from([x as f32, y as f32]);
152
153 single_tree.insert_with_rng(vector.clone(), id.to_string(), &mut rng);
155 multi_tree.insert_with_rng(vector, id.to_string(), &mut rng);
156 }
157 }
158
159 let query = Vector::from([2.3, 2.3]);
161
162 let single_results = single_tree.search(&query, 5);
163 let multi_results = multi_tree.search(&query, 5);
164
165 assert!(multi_results.len() >= single_results.len());
167
168 assert_eq!(multi_results[0].0, "22");
169
170 for i in 1..multi_results.len() {
172 assert!(multi_results[i].1 >= multi_results[i - 1].1);
173 }
174 }
175
176 #[test]
177 fn test_deletion() {
178 let mut index = VectorLite::<2>::new(3, 2);
179 let mut rng = create_test_rng();
180
181 for i in 0..10 {
183 let x = i as f32;
184 index.insert_with_rng(Vector::from([x, x]), i.to_string(), &mut rng);
185 }
186
187 let results = index.search(&Vector::from([5.0, 5.0]), 1);
189 assert_eq!(results.len(), 1);
190 assert_eq!(results[0].0, "5");
191
192 index.delete_by_id("5".to_string());
194
195 let results = index.search(&Vector::from([5.0, 5.0]), 1);
197 assert_eq!(results.len(), 1);
198 assert_ne!(results[0].0, "5"); assert!(results[0].0 == "4" || results[0].0 == "6");
202
203 index.delete_by_id("4".to_string());
205 index.delete_by_id("6".to_string());
206
207 let results = index.search(&Vector::from([5.0, 5.0]), 3);
208 for result in results {
209 assert!(result.0 != "4" && result.0 != "5" && result.0 != "6");
210 }
211 }
212
213 #[test]
214 fn test_file_operations() {
215 let mut index = VectorLite::<3>::new(2, 2);
217 let mut rng = create_test_rng();
218
219 index.insert_with_rng(Vector::from([1.0, 0.0, 0.0]), "101".to_string(), &mut rng);
221 index.insert_with_rng(Vector::from([0.0, 1.0, 0.0]), "102".to_string(), &mut rng);
222 index.insert_with_rng(Vector::from([0.0, 0.0, 1.0]), "103".to_string(), &mut rng);
223
224 let serialized = index.to_bytes();
225 let loaded_index = VectorLite::<3>::from_bytes(&serialized);
226
227 let query = Vector::from([0.9, 0.1, 0.0]);
229 let original_results = index.search(&query, 2);
230 let loaded_results = loaded_index.search(&query, 2);
231
232 assert_eq!(original_results.len(), loaded_results.len());
233 for i in 0..original_results.len() {
234 assert_eq!(original_results[i].0, loaded_results[i].0);
235 assert!((original_results[i].1 - loaded_results[i].1).abs() < 1e-6);
236 }
237 }
238}