1use std::cmp::Ordering;
30use std::collections::{BTreeMap, BTreeSet};
31
32pub struct Index<const N: usize> {
34 roots: Vec<Node<N>>,
35 duplicates: BTreeMap<usize, Vec<usize>>,
36}
37
38pub type Vector<const N: usize> = [f32; N];
40
41enum Node<const N: usize> {
42 Branch(Box<Branch<N>>),
43 Leaf(Box<Leaf>),
44}
45
46struct Branch<const N: usize> {
47 plane: Plane<N>,
48 above: Node<N>,
49 below: Node<N>,
50}
51
52type Leaf = Vec<usize>;
53
54struct Plane<const N: usize> {
55 normal: Vector<N>,
56 offset: f32,
57}
58
59impl<const N: usize> Index<N> {
60 pub fn build(vectors: &[Vector<N>], forest_size: usize, leaf_size: usize, seed: u64) -> Self {
66 debug_assert!(forest_size >= 1);
67 debug_assert!(leaf_size >= 1);
68 let mut source = random::default(seed);
69 let duplicates = unique(vectors);
70 let indices = duplicates.keys().cloned().collect::<Vec<_>>();
71 let roots = (0..forest_size)
72 .map(|_| Node::build(vectors, &indices, leaf_size, &mut source))
73 .collect();
74 Self { roots, duplicates }
75 }
76
77 pub fn search(
81 &self,
82 vectors: &[Vector<N>],
83 query: &Vector<N>,
84 count: usize,
85 ) -> Vec<(usize, f32)> {
86 let mut indices = BTreeSet::new();
87 for root in self.roots.iter() {
88 search(root, &self.duplicates, query, count, &mut indices);
89 }
90 let mut pairs = indices
91 .into_iter()
92 .map(|index| (index, distance(&vectors[index], query)))
93 .collect::<Vec<_>>();
94 pairs.sort_by(|one, other| compare(one.1, other.1));
95 pairs.truncate(count);
96 pairs
97 }
98}
99
100impl<const N: usize> Node<N> {
101 fn build<T: random::Source>(
102 vectors: &[Vector<N>],
103 indices: &[usize],
104 leaf_size: usize,
105 source: &mut T,
106 ) -> Self {
107 if indices.len() <= leaf_size {
108 return Self::Leaf(Box::new(indices.to_vec()));
109 }
110 let (plane, above, below) = Plane::build(vectors, indices, source);
111 let above = Self::build(vectors, &above, leaf_size, source);
112 let below = Self::build(vectors, &below, leaf_size, source);
113 Self::Branch(Box::new(Branch::<N> {
114 plane,
115 above,
116 below,
117 }))
118 }
119}
120
121impl<const N: usize> Plane<N> {
122 fn build<T: random::Source>(
123 vectors: &[Vector<N>],
124 indices: &[usize],
125 source: &mut T,
126 ) -> (Self, Vec<usize>, Vec<usize>) {
127 debug_assert!(vectors.len() > 1);
128 let i = source.read::<usize>() % indices.len();
129 let mut j = i;
130 while i == j {
131 j = source.read::<usize>() % indices.len();
132 }
133 let one = &vectors[indices[i]];
134 let other = &vectors[indices[j]];
135 let normal = subtract(other, one);
136 let offset = -product(&normal, &average(one, other));
137 let plane = Plane::<N> { normal, offset };
138 let (above, below) = indices
139 .iter()
140 .partition(|index| plane.is_above(&vectors[**index]));
141 (plane, above, below)
142 }
143
144 fn is_above(&self, vector: &Vector<N>) -> bool {
145 product(&self.normal, vector) + self.offset > 0.0
146 }
147}
148
149fn average<const N: usize>(one: &Vector<N>, other: &Vector<N>) -> Vector<N> {
150 one.iter()
151 .zip(other)
152 .map(|(one, other)| (one + other) / 2.0)
153 .collect::<Vec<_>>()
154 .try_into()
155 .unwrap()
156}
157
158fn compare(one: f32, other: f32) -> Ordering {
159 if let Some(value) = one.partial_cmp(&other) {
160 return value;
161 }
162 match (
163 one.is_infinite() || one.is_nan(),
164 other.is_infinite() || other.is_nan(),
165 ) {
166 (true, false) => Ordering::Less,
167 (false, true) => Ordering::Greater,
168 _ => Ordering::Equal,
169 }
170}
171
172fn distance<const N: usize>(one: &Vector<N>, other: &Vector<N>) -> f32 {
173 one.iter()
174 .zip(other)
175 .map(|(one, other)| (one - other).powi(2))
176 .sum()
177}
178
179fn product<const N: usize>(one: &Vector<N>, other: &Vector<N>) -> f32 {
180 one.iter().zip(other).map(|(one, other)| one * other).sum()
181}
182
183fn search<const N: usize>(
184 root: &Node<N>,
185 duplicates: &BTreeMap<usize, Vec<usize>>,
186 vector: &Vector<N>,
187 count: usize,
188 indices: &mut BTreeSet<usize>,
189) {
190 match root {
191 Node::Branch(node) => {
192 let (primary, secondary) = if node.plane.is_above(vector) {
193 (&node.above, &node.below)
194 } else {
195 (&node.below, &node.above)
196 };
197 search(primary, duplicates, vector, count, indices);
198 if indices.len() < count {
199 search(secondary, duplicates, vector, count, indices);
200 }
201 }
202 Node::Leaf(node) => {
203 for index in node.iter() {
204 if indices.len() < count {
205 indices.insert(*index);
206 for other in duplicates.get(index).unwrap() {
207 indices.insert(*other);
208 }
209 } else {
210 break;
211 }
212 }
213 }
214 }
215}
216
217fn subtract<const N: usize>(one: &Vector<N>, other: &Vector<N>) -> Vector<N> {
218 one.iter()
219 .zip(other)
220 .map(|(one, other)| one - other)
221 .collect::<Vec<_>>()
222 .try_into()
223 .unwrap()
224}
225
226fn unique<const N: usize>(vectors: &[Vector<N>]) -> BTreeMap<usize, Vec<usize>> {
227 let mut duplicates = BTreeMap::<usize, Vec<usize>>::default();
228 let mut seen = BTreeMap::default();
229 for (index, vector) in vectors.iter().enumerate() {
230 let key: [u32; N] = vector
231 .iter()
232 .map(|value| value.to_bits())
233 .collect::<Vec<_>>()
234 .try_into()
235 .unwrap();
236 if let Some(first) = seen.get(&key) {
237 duplicates.get_mut(first).unwrap().push(index);
238 } else {
239 duplicates.insert(index, Default::default());
240 seen.insert(key, index);
241 }
242 }
243 duplicates
244}
245
246#[cfg(test)]
247mod tests {
248 use super::{Index, Plane};
249
250 #[test]
251 fn index() {
252 let vectors = vec![
253 [4.0, 2.0], [5.0, 7.0], [2.0, 9.0], [7.0, 8.0], [1.0, 3.0],
258 [4.0, 10.0],
259 [10.0, 10.0],
260 [10.0, 10.0],
261 [10.0, 10.0],
262 ];
263 let cases = vec![
264 ([5.0, 0.0], 1, vec![0]),
265 ([0.0, 0.0], 6, vec![4, 0, 1, 2, 3, 5]),
266 ([5.0, 10.0], 1, vec![5]),
267 ([7.0, 8.0], 2, vec![3, 1]),
268 ([10.0, 10.0], 4, vec![6, 7, 8, 3]),
269 ];
270
271 let index = Index::build(&vectors, 1, 1, 42);
272 for (query, count, indices) in cases {
273 assert_eq!(
274 index
275 .search(&vectors, &query, count)
276 .into_iter()
277 .map(|(index, _)| index)
278 .collect::<Vec<_>>(),
279 indices,
280 );
281 }
282 }
283
284 #[test]
285 fn plane() {
286 let mut source = random::default(25);
287 let vectors = vec![
288 [4.0, 2.0], [5.0, 7.0], [2.0, 9.0], [7.0, 8.0], ];
293 let indices = (0..vectors.len()).collect::<Vec<_>>();
294 let (plane, above, below) = Plane::build(&vectors, &indices, &mut source);
295 assert::close(&plane.normal, &[1.0, 5.0], 1e-6);
296 assert::close(plane.offset, -27.0, 1e-6);
297 assert_eq!(above, &[1, 2, 3]);
298 assert_eq!(below, &[0]);
299 }
300}