1use crate::{Node, Vector};
2use bincode::{Decode, Encode};
3use rand::Rng;
4use std::collections::{HashMap, HashSet};
5
6pub trait ANNIndexOwned<const N: usize> {
7 fn insert(&mut self, vector: Vector<N>, id: String) {
9 self.insert_with_rng(vector, id, &mut rand::rng());
10 }
11
12 fn insert_with_rng(&mut self, vector: Vector<N>, id: String, rng: &mut impl Rng);
14
15 fn delete_by_id(&mut self, id: String);
17
18 fn get_by_id(&self, id: String) -> Option<&Vector<N>>;
20
21 fn search(&self, query: &Vector<N>, top_k: usize) -> Vec<(String, f32)>;
23}
24
25#[derive(Encode, Decode)]
27pub struct VectorLite<const N: usize> {
28 vectors: Vec<Vector<N>>,
29 id_to_offset: HashMap<String, u32>,
30 offset_to_id: HashMap<u32, String>,
31 trees: Vec<Node<N>>,
32 max_leaf_size: usize,
33}
34
35impl<const N: usize> VectorLite<N> {
36 pub fn new(num_trees: usize, max_leaf_size: usize) -> Self {
40 Self {
41 vectors: Vec::new(),
42 id_to_offset: HashMap::new(),
43 offset_to_id: HashMap::new(),
44 trees: (0..num_trees).map(|_| Node::new_empty()).collect(),
45 max_leaf_size,
46 }
47 }
48
49 pub fn len(&self) -> usize {
51 self.vectors.len()
52 }
53
54 pub fn is_empty(&self) -> bool {
56 self.vectors.is_empty()
57 }
58
59 pub fn to_bytes(&self) -> Vec<u8> {
61 let config = bincode::config::standard();
62 bincode::encode_to_vec(self, config).unwrap()
63 }
64
65 pub fn from_bytes(bytes: &[u8]) -> Self {
67 let config = bincode::config::standard();
68 let (index, _) = bincode::decode_from_slice(bytes, config).unwrap();
69 index
70 }
71}
72
73impl<const N: usize> ANNIndexOwned<N> for VectorLite<N> {
74 fn insert_with_rng(&mut self, vector: Vector<N>, id: String, rng: &mut impl Rng) {
75 self.vectors.push(vector);
76 let offset = self.vectors.len() as u32 - 1;
77 self.offset_to_id.insert(offset, id.clone());
78 self.id_to_offset.insert(id, offset);
79 let vector_fn = |idx: u32| &self.vectors[idx as usize];
80 for tree in &mut self.trees {
81 tree.insert(
82 &vector_fn,
83 self.vectors.len() as u32 - 1,
84 rng,
85 self.max_leaf_size,
86 );
87 }
88 }
89
90 fn delete_by_id(&mut self, id: String) {
91 let offset = self.id_to_offset[&id];
92 for tree in &mut self.trees {
93 tree.delete(&self.vectors[offset as usize], offset);
94 }
95 self.id_to_offset.remove(&id);
96 self.offset_to_id.remove(&offset);
97 }
99
100 fn get_by_id(&self, id: String) -> Option<&Vector<N>> {
101 self.id_to_offset
102 .get(&id)
103 .map(|offset| &self.vectors[*offset as usize])
104 }
105
106 fn search(&self, query: &Vector<N>, top_k: usize) -> Vec<(String, f32)> {
107 let mut candidates = HashSet::new();
108 for tree in self.trees.iter() {
109 tree.search(query, top_k, &mut candidates);
110 }
111
112 let mut results = candidates
113 .into_iter()
114 .map(|offset| (offset, self.vectors[offset as usize].sq_euc_dist(query)))
115 .collect::<Vec<_>>();
116 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
117 results
118 .into_iter()
119 .take(top_k)
120 .map(|(offset, dist)| (self.offset_to_id[&offset].clone(), dist))
121 .collect()
122 }
123}
124
125#[cfg(test)]
126mod tests {
127 use super::*;
128 use rand::SeedableRng;
129 use rand::rngs::StdRng;
130
131 fn create_test_rng() -> StdRng {
133 StdRng::seed_from_u64(42)
134 }
135
136 #[test]
137 fn test_basic_operations() {
138 let mut index = VectorLite::<3>::new(2, 2);
139 let mut rng = create_test_rng();
140
141 index.insert_with_rng(Vector::from([1.0, 0.0, 0.0]), "101".to_string(), &mut rng);
143 index.insert_with_rng(Vector::from([0.0, 1.0, 0.0]), "102".to_string(), &mut rng);
144 index.insert_with_rng(Vector::from([0.0, 0.0, 1.0]), "103".to_string(), &mut rng);
145 index.insert_with_rng(Vector::from([1.0, 1.0, 0.0]), "104".to_string(), &mut rng);
146
147 let results = index.search(&Vector::from([0.9, 0.1, 0.0]), 2);
149
150 assert_eq!(results.len(), 2);
152 assert_eq!(results[0].0, "101"); let results = index.search(&Vector::from([0.1, 0.9, 0.0]), 1);
156 assert_eq!(results.len(), 1);
157 assert_eq!(results[0].0, "102"); }
159
160 #[test]
161 fn test_multi_tree_performance() {
162 let mut single_tree = VectorLite::<2>::new(1, 2);
164 let mut multi_tree = VectorLite::<2>::new(5, 2);
165
166 let mut rng = create_test_rng();
167
168 for x in 0..5 {
170 for y in 0..5 {
171 let id = x * 10 + y;
172 let vector = Vector::from([x as f32, y as f32]);
173
174 single_tree.insert_with_rng(vector.clone(), id.to_string(), &mut rng);
176 multi_tree.insert_with_rng(vector, id.to_string(), &mut rng);
177 }
178 }
179
180 let query = Vector::from([2.3, 2.3]);
182
183 let single_results = single_tree.search(&query, 5);
184 let multi_results = multi_tree.search(&query, 5);
185
186 assert!(multi_results.len() >= single_results.len());
188
189 assert_eq!(multi_results[0].0, "22");
190
191 for i in 1..multi_results.len() {
193 assert!(multi_results[i].1 >= multi_results[i - 1].1);
194 }
195 }
196
197 #[test]
198 fn test_deletion() {
199 let mut index = VectorLite::<2>::new(3, 2);
200 let mut rng = create_test_rng();
201
202 for i in 0..10 {
204 let x = i as f32;
205 index.insert_with_rng(Vector::from([x, x]), i.to_string(), &mut rng);
206 }
207
208 let results = index.search(&Vector::from([5.0, 5.0]), 1);
210 assert_eq!(results.len(), 1);
211 assert_eq!(results[0].0, "5");
212
213 index.delete_by_id("5".to_string());
215
216 let results = index.search(&Vector::from([5.0, 5.0]), 1);
218 assert_eq!(results.len(), 1);
219 assert_ne!(results[0].0, "5"); assert!(results[0].0 == "4" || results[0].0 == "6");
223
224 index.delete_by_id("4".to_string());
226 index.delete_by_id("6".to_string());
227
228 let results = index.search(&Vector::from([5.0, 5.0]), 3);
229 for result in results {
230 assert!(result.0 != "4" && result.0 != "5" && result.0 != "6");
231 }
232 }
233
234 #[test]
235 fn test_file_operations() {
236 let mut index = VectorLite::<3>::new(2, 2);
238 let mut rng = create_test_rng();
239
240 index.insert_with_rng(Vector::from([1.0, 0.0, 0.0]), "101".to_string(), &mut rng);
242 index.insert_with_rng(Vector::from([0.0, 1.0, 0.0]), "102".to_string(), &mut rng);
243 index.insert_with_rng(Vector::from([0.0, 0.0, 1.0]), "103".to_string(), &mut rng);
244
245 let serialized = index.to_bytes();
246 let loaded_index = VectorLite::<3>::from_bytes(&serialized);
247
248 let query = Vector::from([0.9, 0.1, 0.0]);
250 let original_results = index.search(&query, 2);
251 let loaded_results = loaded_index.search(&query, 2);
252
253 assert_eq!(original_results.len(), loaded_results.len());
254 for i in 0..original_results.len() {
255 assert_eq!(original_results[i].0, loaded_results[i].0);
256 assert!((original_results[i].1 - loaded_results[i].1).abs() < 1e-6);
257 }
258 }
259}