1use crate::{Node, Vector};
2use bincode::{Decode, Encode};
3use rand::Rng;
4use std::{
5 collections::{HashMap, HashSet},
6 rc::Rc,
7};
8
9pub trait ANNIndexOwned<const N: usize> {
10 fn insert(&mut self, vector: Vector<N>, id: String) {
12 self.insert_with_rng(vector, id, &mut rand::rng());
13 }
14
15 fn insert_with_rng(&mut self, vector: Vector<N>, id: String, rng: &mut impl Rng);
17
18 fn delete_by_id(&mut self, id: String);
20
21 fn get_by_id(&self, id: String) -> Option<&Vector<N>>;
23
24 fn search(&self, query: &Vector<N>, top_k: usize) -> Vec<(String, f32)>;
26}
27
28#[derive(Encode, Decode)]
29pub struct VectorLiteIndex<const N: usize> {
30 trees: Vec<Node<N, Rc<String>>>,
31 max_leaf_size: usize,
32}
33
34impl<const N: usize> VectorLiteIndex<N> {
35 fn new(num_trees: usize, max_leaf_size: usize) -> Self {
36 Self {
37 trees: (0..num_trees).map(|_| Node::new_empty()).collect(),
38 max_leaf_size,
39 }
40 }
41
42 pub fn search(&self, query: &Vector<N>, top_k: usize) -> Vec<String> {
46 let mut candidates = HashSet::new();
47 for tree in &self.trees {
48 tree.search(query, top_k, &mut candidates);
49 }
50 candidates
51 .into_iter()
52 .map(|id| id.as_ref().clone())
53 .collect()
54 }
55
56 pub fn to_bytes(&self) -> Vec<u8> {
58 let config = bincode::config::standard();
59 bincode::encode_to_vec(self, config).unwrap()
60 }
61
62 pub fn from_bytes(bytes: &[u8]) -> Self {
64 let config = bincode::config::standard();
65 let (index, _) = bincode::decode_from_slice(bytes, config).unwrap();
66 index
67 }
68}
69
70#[derive(Encode, Decode)]
72pub struct VectorLite<const N: usize> {
73 vectors: HashMap<Rc<String>, Vector<N>>,
74 index: VectorLiteIndex<N>,
75}
76
77impl<const N: usize> VectorLite<N> {
78 pub fn new(num_trees: usize, max_leaf_size: usize) -> Self {
82 Self {
83 vectors: HashMap::new(),
84 index: VectorLiteIndex::new(num_trees, max_leaf_size),
85 }
86 }
87
88 pub fn len(&self) -> usize {
90 self.vectors.len()
91 }
92
93 pub fn is_empty(&self) -> bool {
95 self.vectors.is_empty()
96 }
97
98 pub fn index(&self) -> &VectorLiteIndex<N> {
100 &self.index
101 }
102
103 pub fn to_bytes(&self) -> Vec<u8> {
105 let config = bincode::config::standard();
106 bincode::encode_to_vec(self, config).unwrap()
107 }
108
109 pub fn from_bytes(bytes: &[u8]) -> Self {
111 let config = bincode::config::standard();
112 let (index, _) = bincode::decode_from_slice(bytes, config).unwrap();
113 index
114 }
115}
116
117impl<const N: usize> ANNIndexOwned<N> for VectorLite<N> {
118 fn insert_with_rng(&mut self, vector: Vector<N>, id: String, rng: &mut impl Rng) {
119 let id = Rc::new(id);
120 self.vectors.insert(id.clone(), vector);
121 let vector_fn = |id: &Rc<String>| &self.vectors[id];
122 for tree in &mut self.index.trees {
123 tree.insert(&vector_fn, id.clone(), rng, self.index.max_leaf_size);
124 }
125 }
126
127 fn delete_by_id(&mut self, id: String) {
128 let id = Rc::new(id);
129 for tree in &mut self.index.trees {
130 tree.delete(&self.vectors[&id], &id);
131 }
132 self.vectors.remove(&id);
133 }
134
135 fn get_by_id(&self, id: String) -> Option<&Vector<N>> {
136 self.vectors.get(&id)
137 }
138
139 fn search(&self, query: &Vector<N>, top_k: usize) -> Vec<(String, f32)> {
140 let candidates = self.index.search(query, top_k);
141
142 let mut results = candidates
143 .into_iter()
144 .map(|id| {
145 let dist = self.vectors[&id].sq_euc_dist(query);
146 (id, dist)
147 })
148 .collect::<Vec<_>>();
149 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
150 results.into_iter().take(top_k).collect()
151 }
152}
153
154#[cfg(test)]
155mod tests {
156 use super::*;
157 use rand::SeedableRng;
158 use rand::rngs::StdRng;
159
160 fn create_test_rng() -> StdRng {
162 StdRng::seed_from_u64(42)
163 }
164
165 #[test]
166 fn test_basic_operations() {
167 let mut index = VectorLite::<3>::new(2, 2);
168 let mut rng = create_test_rng();
169
170 index.insert_with_rng(Vector::from([1.0, 0.0, 0.0]), "101".to_string(), &mut rng);
172 index.insert_with_rng(Vector::from([0.0, 1.0, 0.0]), "102".to_string(), &mut rng);
173 index.insert_with_rng(Vector::from([0.0, 0.0, 1.0]), "103".to_string(), &mut rng);
174 index.insert_with_rng(Vector::from([1.0, 1.0, 0.0]), "104".to_string(), &mut rng);
175
176 let results = index.search(&Vector::from([0.9, 0.1, 0.0]), 2);
178
179 assert_eq!(results.len(), 2);
181 assert_eq!(results[0].0, "101"); let results = index.search(&Vector::from([0.1, 0.9, 0.0]), 1);
185 assert_eq!(results.len(), 1);
186 assert_eq!(results[0].0, "102"); }
188
189 #[test]
190 fn test_multi_tree_performance() {
191 let mut single_tree = VectorLite::<2>::new(1, 2);
193 let mut multi_tree = VectorLite::<2>::new(5, 2);
194
195 let mut rng = create_test_rng();
196
197 for x in 0..5 {
199 for y in 0..5 {
200 let id = x * 10 + y;
201 let vector = Vector::from([x as f32, y as f32]);
202
203 single_tree.insert_with_rng(vector.clone(), id.to_string(), &mut rng);
205 multi_tree.insert_with_rng(vector, id.to_string(), &mut rng);
206 }
207 }
208
209 let query = Vector::from([2.3, 2.3]);
211
212 let single_results = single_tree.search(&query, 5);
213 let multi_results = multi_tree.search(&query, 5);
214
215 assert!(multi_results.len() >= single_results.len());
217
218 assert_eq!(multi_results[0].0, "22");
219
220 for i in 1..multi_results.len() {
222 assert!(multi_results[i].1 >= multi_results[i - 1].1);
223 }
224 }
225
226 #[test]
227 fn test_deletion() {
228 let mut index = VectorLite::<2>::new(3, 2);
229 let mut rng = create_test_rng();
230
231 for i in 0..10 {
233 let x = i as f32;
234 index.insert_with_rng(Vector::from([x, x]), i.to_string(), &mut rng);
235 }
236
237 let results = index.search(&Vector::from([5.0, 5.0]), 1);
239 assert_eq!(results.len(), 1);
240 assert_eq!(results[0].0, "5");
241
242 index.delete_by_id("5".to_string());
244
245 let results = index.search(&Vector::from([5.0, 5.0]), 1);
247 assert_eq!(results.len(), 1);
248 assert_ne!(results[0].0, "5"); assert!(results[0].0 == "4" || results[0].0 == "6");
252
253 index.delete_by_id("4".to_string());
255 index.delete_by_id("6".to_string());
256
257 let results = index.search(&Vector::from([5.0, 5.0]), 3);
258 for result in results {
259 assert!(result.0 != "4" && result.0 != "5" && result.0 != "6");
260 }
261 }
262
263 #[test]
264 fn test_file_operations() {
265 let mut index = VectorLite::<3>::new(2, 2);
267 let mut rng = create_test_rng();
268
269 index.insert_with_rng(Vector::from([1.0, 0.0, 0.0]), "101".to_string(), &mut rng);
271 index.insert_with_rng(Vector::from([0.0, 1.0, 0.0]), "102".to_string(), &mut rng);
272 index.insert_with_rng(Vector::from([0.0, 0.0, 1.0]), "103".to_string(), &mut rng);
273
274 let serialized = index.to_bytes();
275 let loaded_index = VectorLite::<3>::from_bytes(&serialized);
276
277 let query = Vector::from([0.9, 0.1, 0.0]);
279 let original_results = index.search(&query, 2);
280 let loaded_results = loaded_index.search(&query, 2);
281
282 assert_eq!(original_results.len(), loaded_results.len());
283 for i in 0..original_results.len() {
284 assert_eq!(original_results[i].0, loaded_results[i].0);
285 assert!((original_results[i].1 - loaded_results[i].1).abs() < 1e-6);
286 }
287 }
288}