vector/
lib.rs

1//! Vector database allowing for efficient search of nearest neighbors.
2//!
3//! The approach is described in “[FANN: Vector Search in 200 Lines of Rust][1]” by Nikhil Garg and
4//! Navya Mehta.
5//!
6//! # Example
7//!
8//! ```
9//! use vector::Index;
10//!
11//! let vectors = vec![
12//!     [4.0, 2.0],
13//!     [5.0, 7.0],
14//!     [2.0, 9.0],
15//!     [7.0, 8.0],
16//! ];
17//! let index = Index::build(&vectors, 1, 1, 42);
18//!
19//! let query = [5.0, 5.0];
20//! let (indices, distances): (Vec<_>, Vec<_>) = index
21//!     .search(&vectors, &query, 2)
22//!     .into_iter()
23//!     .unzip();
24//! assert_eq!(indices, &[1, 0]);
25//! ```
26//!
27//! [1]: https://fennel.ai/blog/vector-search-in-200-lines-of-rust/
28
29use std::cmp::Ordering;
30use std::collections::{BTreeMap, BTreeSet};
31
32/// An index.
33pub struct Index<const N: usize> {
34    roots: Vec<Node<N>>,
35    duplicates: BTreeMap<usize, Vec<usize>>,
36}
37
38/// A vector.
39pub type Vector<const N: usize> = [f32; N];
40
41enum Node<const N: usize> {
42    Branch(Box<Branch<N>>),
43    Leaf(Box<Leaf>),
44}
45
46struct Branch<const N: usize> {
47    plane: Plane<N>,
48    above: Node<N>,
49    below: Node<N>,
50}
51
52type Leaf = Vec<usize>;
53
54struct Plane<const N: usize> {
55    normal: Vector<N>,
56    offset: f32,
57}
58
59impl<const N: usize> Index<N> {
60    /// Build an index.
61    ///
62    /// The forest size is the number of trees built internally, and the leaf size is the maximum
63    /// number of vectors a leaf can have without being split. Both arguments should be greater or
64    /// equal to one.
65    pub fn build(vectors: &[Vector<N>], forest_size: usize, leaf_size: usize, seed: u64) -> Self {
66        debug_assert!(forest_size >= 1);
67        debug_assert!(leaf_size >= 1);
68        let mut source = random::default(seed);
69        let duplicates = unique(vectors);
70        let indices = duplicates.keys().cloned().collect::<Vec<_>>();
71        let roots = (0..forest_size)
72            .map(|_| Node::build(vectors, &indices, leaf_size, &mut source))
73            .collect();
74        Self { roots, duplicates }
75    }
76
77    /// Find `count` vectors close to `query`.
78    ///
79    /// The vectors should be the same as the ones passed to `build`.
80    pub fn search(
81        &self,
82        vectors: &[Vector<N>],
83        query: &Vector<N>,
84        count: usize,
85    ) -> Vec<(usize, f32)> {
86        let mut indices = BTreeSet::new();
87        for root in self.roots.iter() {
88            search(root, &self.duplicates, query, count, &mut indices);
89        }
90        let mut pairs = indices
91            .into_iter()
92            .map(|index| (index, distance(&vectors[index], query)))
93            .collect::<Vec<_>>();
94        pairs.sort_by(|one, other| compare(one.1, other.1));
95        pairs.truncate(count);
96        pairs
97    }
98}
99
100impl<const N: usize> Node<N> {
101    fn build<T: random::Source>(
102        vectors: &[Vector<N>],
103        indices: &[usize],
104        leaf_size: usize,
105        source: &mut T,
106    ) -> Self {
107        if indices.len() <= leaf_size {
108            return Self::Leaf(Box::new(indices.to_vec()));
109        }
110        let (plane, above, below) = Plane::build(vectors, indices, source);
111        let above = Self::build(vectors, &above, leaf_size, source);
112        let below = Self::build(vectors, &below, leaf_size, source);
113        Self::Branch(Box::new(Branch::<N> {
114            plane,
115            above,
116            below,
117        }))
118    }
119}
120
121impl<const N: usize> Plane<N> {
122    fn build<T: random::Source>(
123        vectors: &[Vector<N>],
124        indices: &[usize],
125        source: &mut T,
126    ) -> (Self, Vec<usize>, Vec<usize>) {
127        debug_assert!(vectors.len() > 1);
128        let i = source.read::<usize>() % indices.len();
129        let mut j = i;
130        while i == j {
131            j = source.read::<usize>() % indices.len();
132        }
133        let one = &vectors[indices[i]];
134        let other = &vectors[indices[j]];
135        let normal = subtract(other, one);
136        let offset = -product(&normal, &average(one, other));
137        let plane = Plane::<N> { normal, offset };
138        let (above, below) = indices
139            .iter()
140            .partition(|index| plane.is_above(&vectors[**index]));
141        (plane, above, below)
142    }
143
144    fn is_above(&self, vector: &Vector<N>) -> bool {
145        product(&self.normal, vector) + self.offset > 0.0
146    }
147}
148
149fn average<const N: usize>(one: &Vector<N>, other: &Vector<N>) -> Vector<N> {
150    one.iter()
151        .zip(other)
152        .map(|(one, other)| (one + other) / 2.0)
153        .collect::<Vec<_>>()
154        .try_into()
155        .unwrap()
156}
157
158fn compare(one: f32, other: f32) -> Ordering {
159    if let Some(value) = one.partial_cmp(&other) {
160        return value;
161    }
162    match (
163        one.is_infinite() || one.is_nan(),
164        other.is_infinite() || other.is_nan(),
165    ) {
166        (true, false) => Ordering::Less,
167        (false, true) => Ordering::Greater,
168        _ => Ordering::Equal,
169    }
170}
171
172fn distance<const N: usize>(one: &Vector<N>, other: &Vector<N>) -> f32 {
173    one.iter()
174        .zip(other)
175        .map(|(one, other)| (one - other).powi(2))
176        .sum()
177}
178
179fn product<const N: usize>(one: &Vector<N>, other: &Vector<N>) -> f32 {
180    one.iter().zip(other).map(|(one, other)| one * other).sum()
181}
182
183fn search<const N: usize>(
184    root: &Node<N>,
185    duplicates: &BTreeMap<usize, Vec<usize>>,
186    vector: &Vector<N>,
187    count: usize,
188    indices: &mut BTreeSet<usize>,
189) {
190    match root {
191        Node::Branch(node) => {
192            let (primary, secondary) = if node.plane.is_above(vector) {
193                (&node.above, &node.below)
194            } else {
195                (&node.below, &node.above)
196            };
197            search(primary, duplicates, vector, count, indices);
198            if indices.len() < count {
199                search(secondary, duplicates, vector, count, indices);
200            }
201        }
202        Node::Leaf(node) => {
203            for index in node.iter() {
204                if indices.len() < count {
205                    indices.insert(*index);
206                    for other in duplicates.get(index).unwrap() {
207                        indices.insert(*other);
208                    }
209                } else {
210                    break;
211                }
212            }
213        }
214    }
215}
216
217fn subtract<const N: usize>(one: &Vector<N>, other: &Vector<N>) -> Vector<N> {
218    one.iter()
219        .zip(other)
220        .map(|(one, other)| one - other)
221        .collect::<Vec<_>>()
222        .try_into()
223        .unwrap()
224}
225
226fn unique<const N: usize>(vectors: &[Vector<N>]) -> BTreeMap<usize, Vec<usize>> {
227    let mut duplicates = BTreeMap::<usize, Vec<usize>>::default();
228    let mut seen = BTreeMap::default();
229    for (index, vector) in vectors.iter().enumerate() {
230        let key: [u32; N] = vector
231            .iter()
232            .map(|value| value.to_bits())
233            .collect::<Vec<_>>()
234            .try_into()
235            .unwrap();
236        if let Some(first) = seen.get(&key) {
237            duplicates.get_mut(first).unwrap().push(index);
238        } else {
239            duplicates.insert(index, Default::default());
240            seen.insert(key, index);
241        }
242    }
243    duplicates
244}
245
246#[cfg(test)]
247mod tests {
248    use super::{Index, Plane};
249
250    #[test]
251    fn index() {
252        let vectors = vec![
253            [4.0, 2.0], // A1
254            [5.0, 7.0], // B1
255            [2.0, 9.0], // A2
256            [7.0, 8.0], // B2
257            [1.0, 3.0],
258            [4.0, 10.0],
259            [10.0, 10.0],
260            [10.0, 10.0],
261            [10.0, 10.0],
262        ];
263        let cases = vec![
264            ([5.0, 0.0], 1, vec![0]),
265            ([0.0, 0.0], 6, vec![4, 0, 1, 2, 3, 5]),
266            ([5.0, 10.0], 1, vec![5]),
267            ([7.0, 8.0], 2, vec![3, 1]),
268            ([10.0, 10.0], 4, vec![6, 7, 8, 3]),
269        ];
270
271        let index = Index::build(&vectors, 1, 1, 42);
272        for (query, count, indices) in cases {
273            assert_eq!(
274                index
275                    .search(&vectors, &query, count)
276                    .into_iter()
277                    .map(|(index, _)| index)
278                    .collect::<Vec<_>>(),
279                indices,
280            );
281        }
282    }
283
284    #[test]
285    fn plane() {
286        let mut source = random::default(25);
287        let vectors = vec![
288            [4.0, 2.0], // A1
289            [5.0, 7.0], // B1
290            [2.0, 9.0], // A2
291            [7.0, 8.0], // B2
292        ];
293        let indices = (0..vectors.len()).collect::<Vec<_>>();
294        let (plane, above, below) = Plane::build(&vectors, &indices, &mut source);
295        assert::close(&plane.normal, &[1.0, 5.0], 1e-6);
296        assert::close(plane.offset, -27.0, 1e-6);
297        assert_eq!(above, &[1, 2, 3]);
298        assert_eq!(below, &[0]);
299    }
300}