vector_lite/
external_index.rs

1use std::collections::HashSet;
2
3use rand::Rng;
4
5use crate::{Node, Vector};
6
7/// A simple LSH-based ANNIndex implementation.
8pub struct LshExternal<'a, const N: usize> {
9    trees: Vec<Node<N>>,
10    values: &'a [Vector<N>],
11}
12
13impl<const N: usize> LshExternal<'_, N> {
14    /// Count the number of inner nodes in the index.
15    pub fn inner_node_count(&self) -> usize {
16        self.trees
17            .iter()
18            .map(|tree| tree.inner_node_count())
19            .sum::<usize>()
20    }
21}
22
23fn check_unique<const N: usize>(vectors: &[Vector<N>]) -> bool {
24    let mut hashes_seen = HashSet::new();
25    for vector in vectors.iter() {
26        if !hashes_seen.insert(vector) {
27            return false;
28        }
29    }
30    true
31}
32
33impl<'a, const N: usize> ANNIndexExternal<'a, N> for LshExternal<'a, N> {
34    type Index = Self;
35
36    fn build<R: Rng>(
37        num_trees: usize,
38        max_leaf_size: usize,
39        vectors: &'a [Vector<N>],
40        rng: &mut R,
41    ) -> Result<Self::Index, &'static str> {
42        if !check_unique(vectors) {
43            return Err("Vectors are not unique");
44        }
45        if vectors.len() > u32::MAX as usize {
46            return Err("Number of vectors exceeds u32::MAX");
47        }
48
49        let all_indexes: Vec<u32> = (0..vectors.len() as u32).collect();
50        let vector_fn = |idx: u32| &vectors[idx as usize];
51        let trees: Vec<_> = (0..num_trees)
52            .map(|_| Node::build_tree(max_leaf_size, &all_indexes, &vector_fn, rng))
53            .collect();
54
55        Ok(Self {
56            trees,
57            values: vectors,
58        })
59    }
60
61    fn search(&self, query: &Vector<N>, top_k: usize) -> Vec<(usize, f32)> {
62        let mut candidates = HashSet::new();
63        for tree in self.trees.iter() {
64            tree.search(query, top_k, &mut candidates);
65        }
66
67        let mut results = candidates
68            .into_iter()
69            .map(|idx| (idx as usize, self.values[idx as usize].sq_euc_dist(query)))
70            .collect::<Vec<_>>();
71        results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
72        results.into_iter().take(top_k).collect()
73    }
74
75    /// Get the memory usage of the index.
76    fn memory_usage(&self) -> usize {
77        let tree_size = self
78            .trees
79            .iter()
80            .map(|tree| tree.memory_usage())
81            .sum::<usize>();
82        std::mem::size_of::<Self>() + tree_size
83    }
84}
85
86/// A trait for ANNIndex that references external data.
87/// The index is read-only once built.
88pub trait ANNIndexExternal<'a, const N: usize> {
89    type Index;
90
91    /// Build the index for the given vectors and ids.
92    ///
93    /// # Arguments
94    ///
95    /// * `num_trees` - The number of trees to build, higher means more accurate but slower and larger memory usage.
96    /// * `max_leaf_size` - The maximum number of vectors in a leaf node, lower means higher accuracy but slower search.
97    /// * `vectors` - The vectors to build the index from.
98    fn build<R: Rng>(
99        num_trees: usize,
100        max_leaf_size: usize,
101        vectors: &'a [Vector<N>],
102        rng: &mut R,
103    ) -> Result<Self::Index, &'static str>;
104
105    /// Search for the top_k nearest neighbors of the query vector.
106    ///
107    /// # Arguments
108    ///
109    /// * `query` - The query vector to search for.
110    /// * `top_k` - The number of nearest neighbors to return.
111    ///
112    /// # Returns
113    ///
114    /// * `Vec<(usize, f32)>` - The top_k nearest (index, distance) of the query vector.
115    fn search(&self, query: &Vector<N>, top_k: usize) -> Vec<(usize, f32)>;
116
117    /// Get the memory usage of the index.
118    fn memory_usage(&self) -> usize;
119}
120
121/// A linear search implementation of the ANNIndex trait.
122/// This performs no indexing and simply scans the entire dataset for each query.
123/// It's useful as a baseline comparison and for small datasets.
124pub struct LinearSearchExternal<'a, const N: usize> {
125    values: &'a [Vector<N>],
126}
127
128impl<'a, const N: usize> ANNIndexExternal<'a, N> for LinearSearchExternal<'a, N> {
129    type Index = Self;
130
131    fn build<R: Rng>(
132        _num_trees: usize,
133        _max_leaf_size: usize,
134        vectors: &'a [Vector<N>],
135        _rng: &mut R,
136    ) -> Result<Self::Index, &'static str> {
137        if vectors.is_empty() {
138            return Err("Cannot build index with empty vector set");
139        }
140
141        Ok(Self { values: vectors })
142    }
143
144    fn search(&self, query: &Vector<N>, top_k: usize) -> Vec<(usize, f32)> {
145        if top_k == 0 || self.values.is_empty() {
146            return Vec::new();
147        }
148
149        // Use a custom struct to hold distance and index in a way that can be compared
150        #[derive(PartialEq)]
151        struct Entry(f32, usize);
152
153        #[allow(clippy::non_canonical_partial_ord_impl)]
154        impl PartialOrd for Entry {
155            fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
156                self.0.partial_cmp(&other.0)
157            }
158        }
159
160        impl Eq for Entry {}
161
162        impl Ord for Entry {
163            fn cmp(&self, other: &Self) -> std::cmp::Ordering {
164                self.partial_cmp(other).unwrap_or(std::cmp::Ordering::Equal)
165            }
166        }
167
168        // Use a max-heap to track only top_k elements
169        let mut heap = std::collections::BinaryHeap::with_capacity(top_k);
170
171        // Process each vector
172        for (idx, vec) in self.values.iter().enumerate() {
173            let dist = vec.sq_euc_dist(query);
174
175            if heap.len() < top_k {
176                // If we haven't reached capacity, add to the heap
177                heap.push(Entry(dist, idx));
178            } else if let Some(max_entry) = heap.peek() {
179                // If current distance is smaller than the largest in our heap
180                if dist < max_entry.0 {
181                    // Remove the largest distance and add the new one
182                    heap.pop();
183                    heap.push(Entry(dist, idx));
184                }
185            }
186        }
187
188        // Convert heap back to vector of (idx, dist) pairs
189        let mut result = Vec::with_capacity(heap.len());
190        while let Some(Entry(dist, idx)) = heap.pop() {
191            result.push((idx, dist));
192        }
193
194        // Reverse to get results in ascending order by distance
195        result.reverse();
196        result
197    }
198
199    fn memory_usage(&self) -> usize {
200        // Only count the struct itself, not the referenced vectors
201        std::mem::size_of::<Self>()
202    }
203}
204
205#[cfg(test)]
206mod tests {
207    use rand::{SeedableRng, rngs::StdRng};
208
209    use super::*;
210    // A small helper for approximate float equality.
211    fn approx_eq(a: f32, b: f32, tol: f32) -> bool {
212        (a - b).abs() < tol
213    }
214
215    #[test]
216    fn test_basic_nearest_neighbor() {
217        let mut seed_rng = StdRng::seed_from_u64(42);
218        let vectors = vec![
219            Vector::from([10.0, 20.0, 30.0]),
220            Vector::from([10.0, 30.0, 20.0]),
221            Vector::from([20.0, 10.0, 30.0]),
222            Vector::from([20.0, 30.0, 10.0]),
223            Vector::from([30.0, 20.0, 10.0]),
224            Vector::from([30.0, 10.0, 20.0]),
225        ];
226
227        // Build the index with 1 trees and leaf max_size of 1, this will result in exact matches
228        let index = LshExternal::build(1, 1, &vectors, &mut seed_rng).unwrap();
229
230        // Query vectors itself should return exact matches
231        for (i, vector) in vectors.iter().enumerate() {
232            let results = index.search(vector, 1);
233            assert_eq!(results.len(), 1);
234            assert_eq!(results[0].0, i);
235        }
236
237        // Query vectors with a small distance should return the closest vector
238        for (i, vector) in vectors.iter().enumerate() {
239            let query = vector.add(&Vector::from([0.1, 0.1, 0.1]));
240            let results = index.search(&query, 2);
241            assert_eq!(results.len(), 2);
242            assert_eq!(results[0].0, i);
243        }
244
245        let index = LinearSearchExternal::build(1, 1, &vectors, &mut seed_rng).unwrap();
246        for (i, vector) in vectors.iter().enumerate() {
247            let results = index.search(vector, 1);
248            assert_eq!(results.len(), 1);
249            assert_eq!(results[0].0, i);
250        }
251        for (i, vector) in vectors.iter().enumerate() {
252            let query = vector.add(&Vector::from([0.1, 0.1, 0.1]));
253            let results = index.search(&query, 2);
254
255            assert_eq!(results.len(), 2);
256            assert_eq!(results[0].0, i);
257        }
258    }
259
260    #[test]
261    fn test_top_2_nearest_neighbor() {
262        let mut seed_rng = StdRng::seed_from_u64(42);
263        let vectors = vec![
264            Vector::from([10.0, 20.0, 30.0]),
265            Vector::from([10.0, 20.0, 30.1]),
266            Vector::from([20.0, 30.0, 10.0]),
267            Vector::from([20.0, 30.1, 10.0]),
268            Vector::from([30.0, 20.0, 10.0]),
269            Vector::from([30.1, 20.0, 10.0]),
270        ];
271
272        // Build the index with 1 trees and leaf max_size of 1, this will result in exact matches
273        let index = LshExternal::build(1, 2, &vectors, &mut seed_rng).unwrap();
274
275        // Query vectors itself should return exact matches
276        for (i, vector) in vectors.iter().enumerate() {
277            let results = index.search(vector, 2);
278            assert_eq!(results.len(), 2);
279
280            let id_bucket = i / 2 * 2;
281            assert_eq!(results[0].0 + results[1].0, id_bucket + id_bucket + 1);
282        }
283
284        let index = LinearSearchExternal::build(1, 2, &vectors, &mut seed_rng).unwrap();
285        for (i, vector) in vectors.iter().enumerate() {
286            let results = index.search(vector, 2);
287            let id_bucket = i / 2 * 2;
288            assert_eq!(results[0].0 + results[1].0, id_bucket + id_bucket + 1);
289        }
290    }
291
292    /// Test 3: High-dimensional vectors with a top_k value exceeding the number of unique vectors.
293    /// This test uses 4D vectors and checks that:
294    /// - All available unique vectors are returned (when top_k is too high).
295    /// - The computed squared distances are as expected.
296    #[test]
297    fn test_top_k_exceeds_total_and_high_dim() {
298        // Create five 4D vectors.
299        let vectors = vec![
300            Vector::from([0.0, 0.0, 0.0, 0.0]),
301            Vector::from([1.0, 0.0, 0.0, 0.0]),
302            Vector::from([0.0, 1.0, 0.0, 0.0]),
303            Vector::from([0.0, 0.0, 1.0, 0.0]),
304            Vector::from([0.0, 0.0, 0.0, 1.0]),
305        ];
306
307        // Build the index with 3 trees and a max_size of 2.
308        let mut seed_rng = StdRng::seed_from_u64(42);
309        let index = LshExternal::build(3, 2, &vectors, &mut seed_rng).unwrap();
310
311        // Query with a vector that lies equidistant from all the given vectors.
312        let query = Vector::from([0.5, 0.5, 0.5, 0.5]);
313        let top_k = 10; // Request more neighbors than there are unique vectors.
314        let results = index.search(&query, top_k);
315
316        // Since there are only 5 unique vectors, we expect 5 results.
317        assert_eq!(results.len(), 5);
318
319        // The squared distance for each vector should be exactly 1.0:
320        // For example, for v0: (0.5-0.0)² * 4 = 0.25 * 4 = 1.0.
321        // For the others, the differences still sum to 1.0.
322        for &(_, dist) in results.iter() {
323            assert!(approx_eq(dist, 1.0, 0.0001));
324        }
325    }
326}