1use crate::{Node, Vector};
2use rand::Rng;
3use serde::{Deserialize, Serialize};
4use std::collections::HashSet;
5
6pub trait ANNIndexOwned<const N: usize> {
7 fn insert(&mut self, vector: Vector<N>, id: usize) {
9 self.insert_with_rng(vector, id, &mut rand::rng());
10 }
11
12 fn insert_with_rng(&mut self, vector: Vector<N>, id: usize, rng: &mut impl Rng);
14
15 fn delete_by_id(&mut self, id: usize);
17
18 fn search(&self, query: &Vector<N>, top_k: usize) -> Vec<(usize, f32)>;
20}
21
22#[derive(Serialize, Deserialize)]
24pub struct VectorLite<const N: usize> {
25 vectors: Vec<Vector<N>>,
26 ids: Vec<usize>,
27 trees: Vec<Node<N>>,
28 max_leaf_size: usize,
29}
30
31impl<const N: usize> VectorLite<N> {
32 pub fn new(num_trees: usize, max_leaf_size: usize) -> Self {
36 Self {
37 vectors: Vec::new(),
38 ids: Vec::new(),
39 trees: (0..num_trees).map(|_| Node::new_empty()).collect(),
40 max_leaf_size,
41 }
42 }
43
44 pub fn to_bytes(&self) -> Vec<u8> {
45 let mut s = flexbuffers::FlexbufferSerializer::new();
46 self.serialize(&mut s).unwrap();
47 s.take_buffer()
48 }
49
50 pub fn from_bytes(bytes: &[u8]) -> Self {
51 let s = flexbuffers::Reader::get_root(bytes).unwrap();
52 let index: Self = VectorLite::deserialize(s).unwrap();
53 index
54 }
55}
56
57impl<const N: usize> ANNIndexOwned<N> for VectorLite<N> {
58 fn insert_with_rng(&mut self, vector: Vector<N>, id: usize, rng: &mut impl Rng) {
59 self.vectors.push(vector);
60 self.ids.push(id);
61 let vector_fn = |idx: u32| &self.vectors[idx as usize];
62 for tree in &mut self.trees {
63 tree.insert(
64 &vector_fn,
65 self.vectors.len() as u32 - 1,
66 rng,
67 self.max_leaf_size,
68 );
69 }
70 }
71
72 fn delete_by_id(&mut self, id: usize) {
73 let vector_fn = |idx: u32| &self.vectors[idx as usize];
74 for tree in &mut self.trees {
75 tree.delete(&vector_fn, id as u32);
76 }
77 }
79
80 fn search(&self, query: &Vector<N>, top_k: usize) -> Vec<(usize, f32)> {
81 let mut candidates = HashSet::new();
82 for tree in self.trees.iter() {
83 tree.search(query, top_k, &mut candidates);
84 }
85
86 let mut results = candidates
87 .into_iter()
88 .map(|idx| (idx as usize, self.vectors[idx as usize].sq_euc_dist(query)))
89 .collect::<Vec<_>>();
90 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
91 results
92 .into_iter()
93 .take(top_k)
94 .map(|(id, dist)| (self.ids[id], dist))
95 .collect()
96 }
97}
98
99#[cfg(test)]
100mod tests {
101 use super::*;
102 use rand::SeedableRng;
103 use rand::rngs::StdRng;
104
105 fn create_test_rng() -> StdRng {
107 StdRng::seed_from_u64(42)
108 }
109
110 #[test]
111 fn test_basic_operations() {
112 let mut index = VectorLite::<3>::new(2, 2);
113 let mut rng = create_test_rng();
114
115 index.insert_with_rng(Vector::from([1.0, 0.0, 0.0]), 101, &mut rng);
117 index.insert_with_rng(Vector::from([0.0, 1.0, 0.0]), 102, &mut rng);
118 index.insert_with_rng(Vector::from([0.0, 0.0, 1.0]), 103, &mut rng);
119 index.insert_with_rng(Vector::from([1.0, 1.0, 0.0]), 104, &mut rng);
120
121 let results = index.search(&Vector::from([0.9, 0.1, 0.0]), 2);
123
124 assert_eq!(results.len(), 2);
126 assert_eq!(results[0].0, 101); let results = index.search(&Vector::from([0.1, 0.9, 0.0]), 1);
130 assert_eq!(results.len(), 1);
131 assert_eq!(results[0].0, 102); }
133
134 #[test]
135 fn test_multi_tree_performance() {
136 let mut single_tree = VectorLite::<2>::new(1, 2);
138 let mut multi_tree = VectorLite::<2>::new(5, 2);
139
140 let mut rng = create_test_rng();
141
142 for x in 0..5 {
144 for y in 0..5 {
145 let id = x * 10 + y;
146 let vector = Vector::from([x as f32, y as f32]);
147
148 single_tree.insert_with_rng(vector.clone(), id, &mut rng);
150 multi_tree.insert_with_rng(vector, id, &mut rng);
151 }
152 }
153
154 let query = Vector::from([2.3, 2.3]);
156
157 let single_results = single_tree.search(&query, 5);
158 let multi_results = multi_tree.search(&query, 5);
159
160 assert!(multi_results.len() >= single_results.len());
162
163 assert_eq!(multi_results[0].0, 22);
164
165 for i in 1..multi_results.len() {
167 assert!(multi_results[i].1 >= multi_results[i - 1].1);
168 }
169 }
170
171 #[test]
172 fn test_deletion() {
173 let mut index = VectorLite::<2>::new(3, 2);
174 let mut rng = create_test_rng();
175
176 for i in 0..10 {
178 let x = i as f32;
179 index.insert_with_rng(Vector::from([x, x]), i, &mut rng);
180 }
181
182 let results = index.search(&Vector::from([5.0, 5.0]), 1);
184 assert_eq!(results.len(), 1);
185 assert_eq!(results[0].0, 5);
186
187 index.delete_by_id(5);
189
190 let results = index.search(&Vector::from([5.0, 5.0]), 1);
192 assert_eq!(results.len(), 1);
193 assert_ne!(results[0].0, 5); assert!(results[0].0 == 4 || results[0].0 == 6);
197
198 index.delete_by_id(4);
200 index.delete_by_id(6);
201
202 let results = index.search(&Vector::from([5.0, 5.0]), 3);
203 for result in results {
204 assert!(result.0 != 4 && result.0 != 5 && result.0 != 6);
205 }
206 }
207
208 #[test]
209 fn test_file_operations() {
210 let mut index = VectorLite::<3>::new(2, 2);
212 let mut rng = create_test_rng();
213
214 index.insert_with_rng(Vector::from([1.0, 0.0, 0.0]), 101, &mut rng);
216 index.insert_with_rng(Vector::from([0.0, 1.0, 0.0]), 102, &mut rng);
217 index.insert_with_rng(Vector::from([0.0, 0.0, 1.0]), 103, &mut rng);
218
219 let serialized = index.to_bytes();
220 let loaded_index = VectorLite::<3>::from_bytes(&serialized);
221
222 let query = Vector::from([0.9, 0.1, 0.0]);
224 let original_results = index.search(&query, 2);
225 let loaded_results = loaded_index.search(&query, 2);
226
227 assert_eq!(original_results.len(), loaded_results.len());
228 for i in 0..original_results.len() {
229 assert_eq!(original_results[i].0, loaded_results[i].0);
230 assert!((original_results[i].1 - loaded_results[i].1).abs() < 1e-6);
231 }
232 }
233}