1#![warn(missing_docs)]
16
17extern crate order_stat;
18
19extern crate rand;
20
21#[cfg(feature = "strsim")]
22extern crate strsim;
23
24use rand::Rng;
25
26use std::borrow::Borrow;
27use std::cmp::Ordering;
28use std::collections::BinaryHeap;
29use std::fmt;
30
31use dist::{DistFn, KnownDist};
32
33pub mod dist;
34
35mod print;
36
37#[derive(Clone)]
41pub struct VpTree<T, D> {
42 nodes: Vec<Node>,
43 items: Vec<T>,
44 dist_fn: D,
45}
46
47impl<T> VpTree<T, <T as KnownDist>::DistFn>
48 where T: KnownDist
49{
50 pub fn new<I: IntoIterator<Item = T>>(items: I) -> Self {
57 Self::new_with_dist(items, <T as KnownDist>::dist_fn())
58 }
59
60 pub fn from_vec(items: Vec<T>) -> Self {
64 Self::from_vec_with_dist(items, <T as KnownDist>::dist_fn())
65 }
66}
67
68impl<T, D: DistFn<T>> VpTree<T, D> {
69 pub fn new_with_dist<I: IntoIterator<Item = T>>(items: I, dist_fn: D) -> Self {
74 Self::from_vec_with_dist(items.into_iter().collect(), dist_fn)
75 }
76
77 pub fn from_vec_with_dist(items: Vec<T>, dist_fn: D) -> Self {
79 let mut self_ = VpTree {
80 nodes: Vec::with_capacity(items.len()),
81 items: items,
82 dist_fn: dist_fn,
83 };
84
85 self_.rebuild();
86
87 self_
88 }
89
90 pub fn dist_fn<D_: DistFn<T>>(self, dist_fn: D_) -> VpTree<T, D_> {
92 let mut self_ = VpTree {
93 nodes: self.nodes,
94 items: self.items,
95 dist_fn: dist_fn,
96 };
97
98 self_.rebuild();
99
100 self_
101 }
102
103 pub fn rebuild(&mut self) {
109 self.nodes.clear();
110
111 let len = self.items.len();
112 let nodes_cap = self.nodes.capacity();
113
114 if len > nodes_cap {
115 self.nodes.reserve(len - nodes_cap);
116 }
117
118 self.rebuild_in(NO_NODE, 0, len);
119 }
120
121 fn rebuild_in(&mut self, parent_idx: usize, start: usize, end: usize) -> usize {
123 if start == end {
124 return NO_NODE;
125 }
126
127 if start + 1 == end {
128 return self.push_node(start, parent_idx, 0);
129 }
130
131 let pivot_idx = rand::thread_rng().gen_range(start, end);
132 self.items.swap(start, pivot_idx);
133
134 let median_idx = (end - (start + 1)) / 2;
135
136 let threshold = {
137 let (pivot, items) = self.items.split_first_mut().unwrap();
138
139 let dist_fn = &self.dist_fn;
141
142 let median_thresh_item = order_stat::kth_by(items, median_idx, |left, right| {
144 dist_fn.dist(pivot, left).cmp(&dist_fn.dist(pivot, right))
145 });
146
147 dist_fn.dist(pivot, median_thresh_item)
148 };
149
150 let left_start = start + 1;
151
152 let split_idx = left_start + median_idx + 1;
153
154 let self_idx = self.push_node(start, parent_idx, threshold);
155
156 let left_idx = self.rebuild_in(self_idx, left_start, split_idx);
157
158 let right_idx = self.rebuild_in(self_idx, split_idx, end);
159
160 self.nodes[self_idx].left = left_idx;
161 self.nodes[self_idx].right = right_idx;
162
163 self_idx
164 }
165
166 fn push_node(&mut self, idx: usize, parent_idx: usize, threshold: u64) -> usize {
167 let self_idx = self.nodes.len();
168
169 self.nodes.push(Node {
170 idx: idx,
171 parent: parent_idx,
172 left: NO_NODE,
173 right: NO_NODE,
174 threshold: threshold,
175 });
176
177 self_idx
178 }
179
180 #[inline(always)]
181 fn sanity_check(&self) {
182 assert!(self.nodes.len() == self.items.len(),
183 "Attempting to traverse `VpTree` when it is in an invalid state. This can \
184 happen if a panic was thrown while it was being mutated and then caught \
185 outside.")
186 }
187
188 pub fn extend<I: IntoIterator<Item = T>>(&mut self, new_items: I) {
190 self.nodes.clear();
191 self.items.extend(new_items);
192 self.rebuild();
193 }
194
195 pub fn retain<F>(&mut self, ret_fn: F)
200 where F: FnMut(&T) -> bool
201 {
202 self.nodes.clear();
203 self.items.retain(ret_fn);
204 self.rebuild();
205 }
206
207 pub fn items(&self) -> &[T] {
220 &self.items
221 }
222
223 pub fn with_mut_items<F>(&mut self, mut_fn: F)
233 where F: FnOnce(&mut [T])
234 {
235 self.nodes.clear();
236 mut_fn(&mut self.items);
237 self.rebuild();
238 }
239
240 pub fn k_nearest<'t, O: Borrow<T>>(&'t self, origin: O, k: usize) -> Vec<Neighbor<'t, T>> {
254 self.sanity_check();
255
256 let origin = origin.borrow();
257
258 KnnVisitor::new(self, origin, k)
259 .visit_all()
260 .into_vec()
261 }
262
263 pub fn into_vec(self) -> Vec<T> {
267 self.items
268 }
269}
270
271impl<T: fmt::Debug, D: DistFn<T>> fmt::Debug for VpTree<T, D> {
273 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
274 try!(writeln!(f, "VpTree {{ len: {} }}", self.items.len()));
275
276 if self.nodes.len() == 0 {
277 return f.write_str("[Empty]\n");
278 }
279
280 try!(writeln!(f, "Items: {:?}", self.items));
281
282
283 if self.nodes.len() == self.items.len() {
284 try!(f.write_str("Structure:\n"));
285 print::TreePrinter::new(self).print(f)
286 } else {
287 f.write_str("[Tree is in invalid state]")
288 }
289 }
290}
291
292const NO_NODE: usize = ::std::usize::MAX;
294
295#[derive(Clone, Debug)]
296struct Node {
297 idx: usize,
298 parent: usize,
299 left: usize,
300 right: usize,
301 threshold: u64,
302}
303
304struct KnnVisitor<'t, 'o, T: 't + 'o, D: 't> {
305 tree: &'t VpTree<T, D>,
306 origin: &'o T,
307 heap: BinaryHeap<Neighbor<'t, T>>,
308 k: usize,
309 radius: u64,
310}
311
312impl<'t, 'o, T: 't + 'o, D: 't> KnnVisitor<'t, 'o, T, D>
313 where D: DistFn<T>
314{
315 fn new(tree: &'t VpTree<T, D>, origin: &'o T, k: usize) -> Self {
316 KnnVisitor {
317 tree: tree,
318 origin: origin,
319 heap: if k > 0 {
321 BinaryHeap::with_capacity(k + 2)
322 } else {
323 BinaryHeap::new()
324 },
325 k: k,
326 radius: ::std::u64::MAX,
327 }
328 }
329 fn visit_all(mut self) -> Self {
330 if self.k > 0 && self.tree.nodes.len() > 0 {
331 self.visit(0);
332 }
333
334 self
335 }
336
337 fn visit(&mut self, node_idx: usize) {
338 if node_idx == NO_NODE {
339 return;
340 }
341
342 let cur_node = &self.tree.nodes[node_idx];
343
344 let item = &self.tree.items[cur_node.idx];
345
346 let dist_to_cur = self.tree.dist_fn.dist(&self.origin, item);
347
348 if dist_to_cur < self.radius {
349 let neighbor = Neighbor {
350 item: item,
351 dist: dist_to_cur,
352 };
353
354 if self.heap.len() == self.k {
355 *self.heap.peek_mut().unwrap() = neighbor;
357 } else {
358 self.heap.push(neighbor);
359 }
360
361 if self.heap.len() == self.k {
362 self.radius = self.heap.peek().unwrap().dist;
363 }
364 }
365
366 let go_left = dist_to_cur.saturating_sub(self.radius) <= cur_node.threshold;
371 let go_right = dist_to_cur.saturating_add(self.radius) >= cur_node.threshold;
372
373 if dist_to_cur <= cur_node.threshold {
374 if go_left {
375 self.visit(cur_node.left);
376 }
377
378 if go_right {
379 self.visit(cur_node.right);
380 }
381 } else {
382 if go_right {
383 self.visit(cur_node.right);
384 }
385
386 if go_left {
387 self.visit(cur_node.left);
388 }
389 };
390 }
391
392 fn into_vec(self) -> Vec<Neighbor<'t, T>> {
393 self.heap.into_sorted_vec()
394 }
395}
396
397#[derive(Debug, Clone)]
399pub struct Neighbor<'t, T: 't> {
400 pub item: &'t T,
402 pub dist: u64,
404}
405
406impl<'t, T: 't> PartialOrd for Neighbor<'t, T> {
408 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
409 Some(self.cmp(other))
410 }
411}
412
413impl<'t, T: 't> Ord for Neighbor<'t, T> {
415 fn cmp(&self, other: &Self) -> Ordering {
416 self.dist.cmp(&other.dist)
417 }
418}
419
420impl<'t, T: 't> PartialEq for Neighbor<'t, T> {
422 fn eq(&self, other: &Self) -> bool {
423 self.dist == other.dist
424 }
425}
426
427impl<'t, T: 't> Eq for Neighbor<'t, T> {}
429
430#[cfg(test)]
431mod test {
432 use super::VpTree;
433
434 const MAX_TREE_VAL: i32 = 8;
435 const ORIGIN: i32 = 4;
436 const NEIGHBORS: &'static [i32] = &[2, 3, 4, 5, 6];
437
438 #[test]
439 fn test_k_nearest() {
440 let tree = VpTree::new(0i32..MAX_TREE_VAL);
441
442 println!("Tree: {:?}", tree);
443
444 let nearest: Vec<_> = tree.k_nearest(&ORIGIN, NEIGHBORS.len())
445 .into_iter()
446 .collect();
447
448 println!("Nearest: {:?}", nearest);
449
450 for neighbor in nearest {
451 assert!(NEIGHBORS.contains(&neighbor.item),
452 "Was not expecting {:?}",
453 neighbor);
454 }
455 }
456}