1use std::collections::HashSet;
2
3use rand::Rng;
4
5use crate::{Node, Vector};
6
7pub struct LshExternal<'a, const N: usize> {
9 trees: Vec<Node<N>>,
10 values: &'a [Vector<N>],
11}
12
13impl<const N: usize> LshExternal<'_, N> {
14 pub fn inner_node_count(&self) -> usize {
16 self.trees
17 .iter()
18 .map(|tree| tree.inner_node_count())
19 .sum::<usize>()
20 }
21}
22
23fn check_unique<const N: usize>(vectors: &[Vector<N>]) -> bool {
24 let mut hashes_seen = HashSet::new();
25 for vector in vectors.iter() {
26 if !hashes_seen.insert(vector) {
27 return false;
28 }
29 }
30 true
31}
32
33impl<'a, const N: usize> ANNIndexExternal<'a, N> for LshExternal<'a, N> {
34 type Index = Self;
35
36 fn build<R: Rng>(
37 num_trees: usize,
38 max_leaf_size: usize,
39 vectors: &'a [Vector<N>],
40 rng: &mut R,
41 ) -> Result<Self::Index, &'static str> {
42 if !check_unique(vectors) {
43 return Err("Vectors are not unique");
44 }
45 if vectors.len() > u32::MAX as usize {
46 return Err("Number of vectors exceeds u32::MAX");
47 }
48
49 let all_indexes: Vec<u32> = (0..vectors.len() as u32).collect();
50 let vector_fn = |idx: u32| &vectors[idx as usize];
51 let trees: Vec<_> = (0..num_trees)
52 .map(|_| Node::build_tree(max_leaf_size, &all_indexes, &vector_fn, rng))
53 .collect();
54
55 Ok(Self {
56 trees,
57 values: vectors,
58 })
59 }
60
61 fn search(&self, query: &Vector<N>, top_k: usize) -> Vec<(usize, f32)> {
62 let mut candidates = HashSet::new();
63 for tree in self.trees.iter() {
64 tree.search(query, top_k, &mut candidates);
65 }
66
67 let mut results = candidates
68 .into_iter()
69 .map(|idx| (idx as usize, self.values[idx as usize].sq_euc_dist(query)))
70 .collect::<Vec<_>>();
71 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
72 results.into_iter().take(top_k).collect()
73 }
74
75 fn memory_usage(&self) -> usize {
77 let tree_size = self
78 .trees
79 .iter()
80 .map(|tree| tree.memory_usage())
81 .sum::<usize>();
82 std::mem::size_of::<Self>() + tree_size
83 }
84}
85
86pub trait ANNIndexExternal<'a, const N: usize> {
89 type Index;
90
91 fn build<R: Rng>(
99 num_trees: usize,
100 max_leaf_size: usize,
101 vectors: &'a [Vector<N>],
102 rng: &mut R,
103 ) -> Result<Self::Index, &'static str>;
104
105 fn search(&self, query: &Vector<N>, top_k: usize) -> Vec<(usize, f32)>;
116
117 fn memory_usage(&self) -> usize;
119}
120
121pub struct LinearSearchExternal<'a, const N: usize> {
125 values: &'a [Vector<N>],
126}
127
128impl<'a, const N: usize> ANNIndexExternal<'a, N> for LinearSearchExternal<'a, N> {
129 type Index = Self;
130
131 fn build<R: Rng>(
132 _num_trees: usize,
133 _max_leaf_size: usize,
134 vectors: &'a [Vector<N>],
135 _rng: &mut R,
136 ) -> Result<Self::Index, &'static str> {
137 if vectors.is_empty() {
138 return Err("Cannot build index with empty vector set");
139 }
140
141 Ok(Self { values: vectors })
142 }
143
144 fn search(&self, query: &Vector<N>, top_k: usize) -> Vec<(usize, f32)> {
145 if top_k == 0 || self.values.is_empty() {
146 return Vec::new();
147 }
148
149 #[derive(PartialEq)]
151 struct Entry(f32, usize);
152
153 #[allow(clippy::non_canonical_partial_ord_impl)]
154 impl PartialOrd for Entry {
155 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
156 self.0.partial_cmp(&other.0)
157 }
158 }
159
160 impl Eq for Entry {}
161
162 impl Ord for Entry {
163 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
164 self.partial_cmp(other).unwrap_or(std::cmp::Ordering::Equal)
165 }
166 }
167
168 let mut heap = std::collections::BinaryHeap::with_capacity(top_k);
170
171 for (idx, vec) in self.values.iter().enumerate() {
173 let dist = vec.sq_euc_dist(query);
174
175 if heap.len() < top_k {
176 heap.push(Entry(dist, idx));
178 } else if let Some(max_entry) = heap.peek() {
179 if dist < max_entry.0 {
181 heap.pop();
183 heap.push(Entry(dist, idx));
184 }
185 }
186 }
187
188 let mut result = Vec::with_capacity(heap.len());
190 while let Some(Entry(dist, idx)) = heap.pop() {
191 result.push((idx, dist));
192 }
193
194 result.reverse();
196 result
197 }
198
199 fn memory_usage(&self) -> usize {
200 std::mem::size_of::<Self>()
202 }
203}
204
205#[cfg(test)]
206mod tests {
207 use rand::{SeedableRng, rngs::StdRng};
208
209 use super::*;
210 fn approx_eq(a: f32, b: f32, tol: f32) -> bool {
212 (a - b).abs() < tol
213 }
214
215 #[test]
216 fn test_basic_nearest_neighbor() {
217 let mut seed_rng = StdRng::seed_from_u64(42);
218 let vectors = vec![
219 Vector::from([10.0, 20.0, 30.0]),
220 Vector::from([10.0, 30.0, 20.0]),
221 Vector::from([20.0, 10.0, 30.0]),
222 Vector::from([20.0, 30.0, 10.0]),
223 Vector::from([30.0, 20.0, 10.0]),
224 Vector::from([30.0, 10.0, 20.0]),
225 ];
226
227 let index = LshExternal::build(1, 1, &vectors, &mut seed_rng).unwrap();
229
230 for (i, vector) in vectors.iter().enumerate() {
232 let results = index.search(vector, 1);
233 assert_eq!(results.len(), 1);
234 assert_eq!(results[0].0, i);
235 }
236
237 for (i, vector) in vectors.iter().enumerate() {
239 let query = vector.add(&Vector::from([0.1, 0.1, 0.1]));
240 let results = index.search(&query, 2);
241 assert_eq!(results.len(), 2);
242 assert_eq!(results[0].0, i);
243 }
244
245 let index = LinearSearchExternal::build(1, 1, &vectors, &mut seed_rng).unwrap();
246 for (i, vector) in vectors.iter().enumerate() {
247 let results = index.search(vector, 1);
248 assert_eq!(results.len(), 1);
249 assert_eq!(results[0].0, i);
250 }
251 for (i, vector) in vectors.iter().enumerate() {
252 let query = vector.add(&Vector::from([0.1, 0.1, 0.1]));
253 let results = index.search(&query, 2);
254
255 assert_eq!(results.len(), 2);
256 assert_eq!(results[0].0, i);
257 }
258 }
259
260 #[test]
261 fn test_top_2_nearest_neighbor() {
262 let mut seed_rng = StdRng::seed_from_u64(42);
263 let vectors = vec![
264 Vector::from([10.0, 20.0, 30.0]),
265 Vector::from([10.0, 20.0, 30.1]),
266 Vector::from([20.0, 30.0, 10.0]),
267 Vector::from([20.0, 30.1, 10.0]),
268 Vector::from([30.0, 20.0, 10.0]),
269 Vector::from([30.1, 20.0, 10.0]),
270 ];
271
272 let index = LshExternal::build(1, 2, &vectors, &mut seed_rng).unwrap();
274
275 for (i, vector) in vectors.iter().enumerate() {
277 let results = index.search(vector, 2);
278 assert_eq!(results.len(), 2);
279
280 let id_bucket = i / 2 * 2;
281 assert_eq!(results[0].0 + results[1].0, id_bucket + id_bucket + 1);
282 }
283
284 let index = LinearSearchExternal::build(1, 2, &vectors, &mut seed_rng).unwrap();
285 for (i, vector) in vectors.iter().enumerate() {
286 let results = index.search(vector, 2);
287 let id_bucket = i / 2 * 2;
288 assert_eq!(results[0].0 + results[1].0, id_bucket + id_bucket + 1);
289 }
290 }
291
292 #[test]
297 fn test_top_k_exceeds_total_and_high_dim() {
298 let vectors = vec![
300 Vector::from([0.0, 0.0, 0.0, 0.0]),
301 Vector::from([1.0, 0.0, 0.0, 0.0]),
302 Vector::from([0.0, 1.0, 0.0, 0.0]),
303 Vector::from([0.0, 0.0, 1.0, 0.0]),
304 Vector::from([0.0, 0.0, 0.0, 1.0]),
305 ];
306
307 let mut seed_rng = StdRng::seed_from_u64(42);
309 let index = LshExternal::build(3, 2, &vectors, &mut seed_rng).unwrap();
310
311 let query = Vector::from([0.5, 0.5, 0.5, 0.5]);
313 let top_k = 10; let results = index.search(&query, top_k);
315
316 assert_eq!(results.len(), 5);
318
319 for &(_, dist) in results.iter() {
323 assert!(approx_eq(dist, 1.0, 0.0001));
324 }
325 }
326}