vec_vp_tree/
lib.rs

1// Copyright 2016 Austin Bonander
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7//
8// Implementation adapted from C++ code at http://stevehanov.ca/blog/index.php?id=130
9// (accessed October 14, 2016). No copyright or license information is provided,
10// so the original code is assumed to be public domain.
11
12//! An implementation of a [vantage-point tree][vp-tree] backed by a vector.
13//!
14//! [vp-tree]: https://en.wikipedia.org/wiki/Vantage-point_tree
15#![warn(missing_docs)]
16
17extern crate order_stat;
18
19extern crate rand;
20
21#[cfg(feature = "strsim")]
22extern crate strsim;
23
24use rand::Rng;
25
26use std::borrow::Borrow;
27use std::cmp::Ordering;
28use std::collections::BinaryHeap;
29use std::fmt;
30
31use dist::{DistFn, KnownDist};
32
33pub mod dist;
34
35mod print;
36
37/// An implementation of a vantage-point tree backed by a vector.
38///
39/// Only bulk insert/removals are provided in order to keep the tree balanced.
40#[derive(Clone)]
41pub struct VpTree<T, D> {
42    nodes: Vec<Node>,
43    items: Vec<T>,
44    dist_fn: D,
45}
46
47impl<T> VpTree<T, <T as KnownDist>::DistFn>
48    where T: KnownDist
49{
50    /// Collect the results of `items` into the tree, and build the tree using the known distance
51    /// function for `T`.
52    ///
53    /// If `items` is `Vec<T>`, use `from_vec` to avoid a copy.
54    ///
55    /// If you want to provide a custom distance function, use `new_with_dist()` instead.
56    pub fn new<I: IntoIterator<Item = T>>(items: I) -> Self {
57        Self::new_with_dist(items, <T as KnownDist>::dist_fn())
58    }
59
60    /// Build the tree directly from `items` using the known distance function for `T`.
61    ///
62    /// If you want to provide a custom distance function, use `from_vec_with_dist()` instead.
63    pub fn from_vec(items: Vec<T>) -> Self {
64        Self::from_vec_with_dist(items, <T as KnownDist>::dist_fn())
65    }
66}
67
68impl<T, D: DistFn<T>> VpTree<T, D> {
69    /// Collect the results of `items` into the tree, and build the tree using the given distance
70    /// function `dist_fn`.
71    ///
72    /// If `items` is `Vec<T>`, use `from_vec_with_dist` to avoid a copy.
73    pub fn new_with_dist<I: IntoIterator<Item = T>>(items: I, dist_fn: D) -> Self {
74        Self::from_vec_with_dist(items.into_iter().collect(), dist_fn)
75    }
76
77    /// Build the tree directly from `items` using the given distance function `dist_fn`.
78    pub fn from_vec_with_dist(items: Vec<T>, dist_fn: D) -> Self {
79        let mut self_ = VpTree {
80            nodes: Vec::with_capacity(items.len()),
81            items: items,
82            dist_fn: dist_fn,
83        };
84
85        self_.rebuild();
86
87        self_
88    }
89
90    /// Apply a new distance function and rebuild the tree, returning the transformed type.
91    pub fn dist_fn<D_: DistFn<T>>(self, dist_fn: D_) -> VpTree<T, D_> {
92        let mut self_ = VpTree {
93            nodes: self.nodes,
94            items: self.items,
95            dist_fn: dist_fn,
96        };
97
98        self_.rebuild();
99
100        self_
101    }
102
103    /// Rebuild the full tree.
104    ///
105    /// This is only necessary if the one or more properties of a contained
106    /// item which determine their distance via `D: DistFn<T>` was somehow changed without
107    /// the tree being rebuilt, or a panic occurred during a mutation and was caught.
108    pub fn rebuild(&mut self) {
109        self.nodes.clear();
110
111        let len = self.items.len();
112        let nodes_cap = self.nodes.capacity();
113
114        if len > nodes_cap {
115            self.nodes.reserve(len - nodes_cap);
116        }
117
118        self.rebuild_in(NO_NODE, 0, len);
119    }
120
121    /// Rebuild the tree in [start, end)
122    fn rebuild_in(&mut self, parent_idx: usize, start: usize, end: usize) -> usize {
123        if start == end {
124            return NO_NODE;
125        }
126
127        if start + 1 == end {
128            return self.push_node(start, parent_idx, 0);
129        }
130
131        let pivot_idx = rand::thread_rng().gen_range(start, end);
132        self.items.swap(start, pivot_idx);
133
134        let median_idx = (end - (start + 1)) / 2;
135
136        let threshold = {
137            let (pivot, items) = self.items.split_first_mut().unwrap();
138
139            // Without this reborrow, the closure will try to borrow all of `self`.
140            let dist_fn = &self.dist_fn;
141
142            // This function will partition around the median element
143            let median_thresh_item = order_stat::kth_by(items, median_idx, |left, right| {
144                dist_fn.dist(pivot, left).cmp(&dist_fn.dist(pivot, right))
145            });
146
147            dist_fn.dist(pivot, median_thresh_item)
148        };
149
150        let left_start = start + 1;
151
152        let split_idx = left_start + median_idx + 1;
153
154        let self_idx = self.push_node(start, parent_idx, threshold);
155
156        let left_idx = self.rebuild_in(self_idx, left_start, split_idx);
157
158        let right_idx = self.rebuild_in(self_idx, split_idx, end);
159
160        self.nodes[self_idx].left = left_idx;
161        self.nodes[self_idx].right = right_idx;
162
163        self_idx
164    }
165
166    fn push_node(&mut self, idx: usize, parent_idx: usize, threshold: u64) -> usize {
167        let self_idx = self.nodes.len();
168
169        self.nodes.push(Node {
170            idx: idx,
171            parent: parent_idx,
172            left: NO_NODE,
173            right: NO_NODE,
174            threshold: threshold,
175        });
176
177        self_idx
178    }
179
180    #[inline(always)]
181    fn sanity_check(&self) {
182        assert!(self.nodes.len() == self.items.len(),
183                "Attempting to traverse `VpTree` when it is in an invalid state. This can \
184                 happen if a panic was thrown while it was being mutated and then caught \
185                 outside.")
186    }
187
188    /// Add `new_items` to the tree and rebuild it.
189    pub fn extend<I: IntoIterator<Item = T>>(&mut self, new_items: I) {
190        self.nodes.clear();
191        self.items.extend(new_items);
192        self.rebuild();
193    }
194
195    /// Iterate over the contained items, dropping them if `ret_fn` returns `false`,
196    /// keeping them otherwise.
197    ///
198    /// The tree will be rebuilt afterwards.
199    pub fn retain<F>(&mut self, ret_fn: F)
200        where F: FnMut(&T) -> bool
201    {
202        self.nodes.clear();
203        self.items.retain(ret_fn);
204        self.rebuild();
205    }
206
207    /// Get a slice of the items in the tree.
208    ///
209    /// These items may have been rearranged from the order which they were inserted.
210    ///
211    /// ## Note
212    /// It is a logic error for an item to be modified in such a way that the item's distance
213    /// to any other item, as determined by `D: DistFn<T>`, changes while it is in the tree
214    /// without the tree being rebuilt.
215    /// This is normally only possible through `Cell`, `RefCell`, global state, I/O, or unsafe code.
216    ///
217    /// If you wish to mutate one or more of the contained items, use `.with_mut_items()` instead,
218    /// to ensure the tree is rebuilt after the mutation.
219    pub fn items(&self) -> &[T] {
220        &self.items
221    }
222
223    /// Get a scoped mutable slice to the contained items.
224    ///
225    /// The tree will be rebuilt after `mut_fn` returns, in assumption that it will modify one or
226    /// more of the contained items such that their distance to others,
227    /// as determined by `D: DistFn<T>`, changes.
228    ///
229    /// ## Note
230    /// If a panic is initiated in `mut_fn` and then caught outside this method,
231    /// the tree will need to be manually rebuilt with `.rebuild()`.
232    pub fn with_mut_items<F>(&mut self, mut_fn: F)
233        where F: FnOnce(&mut [T])
234    {
235        self.nodes.clear();
236        mut_fn(&mut self.items);
237        self.rebuild();
238    }
239
240    /// Get a vector of the `k` nearest neighbors to `origin`, sorted in ascending order
241    /// by the distance.
242    ///
243    /// ## Note
244    /// If `origin` is contained within the tree, which is allowed by the API contract,
245    /// it will be returned in the results. In this case, it may be preferable to start with a
246    /// higher `k` and filter out duplicate entries.
247    ///
248    /// If `k > self.items.len()`, then obviously only `self.items.len()` items will be returned.
249    ///
250    /// ## Panics
251    /// If the tree was in an invalid state. This can happen if a panic occurred during
252    /// a mutation and was then caught without calling `.rebuild()`.
253    pub fn k_nearest<'t, O: Borrow<T>>(&'t self, origin: O, k: usize) -> Vec<Neighbor<'t, T>> {
254        self.sanity_check();
255
256        let origin = origin.borrow();
257
258        KnnVisitor::new(self, origin, k)
259            .visit_all()
260            .into_vec()
261    }
262
263    /// Consume `self` and return the vector of items.
264    ///
265    /// The items may have been rearranged from the order in which they were inserted.
266    pub fn into_vec(self) -> Vec<T> {
267        self.items
268    }
269}
270
271/// Prints the contained items as well as the tree structure, if it is in a valid state.
272impl<T: fmt::Debug, D: DistFn<T>> fmt::Debug for VpTree<T, D> {
273    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
274        try!(writeln!(f, "VpTree {{ len: {} }}", self.items.len()));
275
276        if self.nodes.len() == 0 {
277            return f.write_str("[Empty]\n");
278        }
279
280        try!(writeln!(f, "Items: {:?}", self.items));
281
282
283        if self.nodes.len() == self.items.len() {
284            try!(f.write_str("Structure:\n"));
285            print::TreePrinter::new(self).print(f)
286        } else {
287            f.write_str("[Tree is in invalid state]")
288        }
289    }
290}
291
292/// Signifier for `Node` that there is no parent or child node for a given field.
293const NO_NODE: usize = ::std::usize::MAX;
294
295#[derive(Clone, Debug)]
296struct Node {
297    idx: usize,
298    parent: usize,
299    left: usize,
300    right: usize,
301    threshold: u64,
302}
303
304struct KnnVisitor<'t, 'o, T: 't + 'o, D: 't> {
305    tree: &'t VpTree<T, D>,
306    origin: &'o T,
307    heap: BinaryHeap<Neighbor<'t, T>>,
308    k: usize,
309    radius: u64,
310}
311
312impl<'t, 'o, T: 't + 'o, D: 't> KnnVisitor<'t, 'o, T, D>
313    where D: DistFn<T>
314{
315    fn new(tree: &'t VpTree<T, D>, origin: &'o T, k: usize) -> Self {
316        KnnVisitor {
317            tree: tree,
318            origin: origin,
319            // Preallocate enough scratch space but don't allocate if `k = 0`
320            heap: if k > 0 {
321                BinaryHeap::with_capacity(k + 2)
322            } else {
323                BinaryHeap::new()
324            },
325            k: k,
326            radius: ::std::u64::MAX,
327        }
328    }
329    fn visit_all(mut self) -> Self {
330        if self.k > 0 && self.tree.nodes.len() > 0 {
331            self.visit(0);
332        }
333
334        self
335    }
336
337    fn visit(&mut self, node_idx: usize) {
338        if node_idx == NO_NODE {
339            return;
340        }
341
342        let cur_node = &self.tree.nodes[node_idx];
343
344        let item = &self.tree.items[cur_node.idx];
345
346        let dist_to_cur = self.tree.dist_fn.dist(&self.origin, item);
347
348        if dist_to_cur < self.radius {
349            let neighbor = Neighbor {
350                item: item,
351                dist: dist_to_cur,
352            };
353
354            if self.heap.len() == self.k {
355                // Equivalent to .push_pop(), k is assured to be > 0
356                *self.heap.peek_mut().unwrap() = neighbor;
357            } else {
358                self.heap.push(neighbor);
359            }
360
361            if self.heap.len() == self.k {
362                self.radius = self.heap.peek().unwrap().dist;
363            }
364        }
365
366        // Original implementation used `double`, which could go negative and so didn't
367        // worry about wrapping.
368        // Unsigned integer distances make more sense for most use cases and are faster
369        // to work with, but require care, especially when `self.radius` is near or at max value.
370        let go_left = dist_to_cur.saturating_sub(self.radius) <= cur_node.threshold;
371        let go_right = dist_to_cur.saturating_add(self.radius) >= cur_node.threshold;
372
373        if dist_to_cur <= cur_node.threshold {
374            if go_left {
375                self.visit(cur_node.left);
376            }
377
378            if go_right {
379                self.visit(cur_node.right);
380            }
381        } else {
382            if go_right {
383                self.visit(cur_node.right);
384            }
385
386            if go_left {
387                self.visit(cur_node.left);
388            }
389        };
390    }
391
392    fn into_vec(self) -> Vec<Neighbor<'t, T>> {
393        self.heap.into_sorted_vec()
394    }
395}
396
397/// Wrapper of an item and a distance, returned by `Neighbors`.
398#[derive(Debug, Clone)]
399pub struct Neighbor<'t, T: 't> {
400    /// The item that this entry concerns.
401    pub item: &'t T,
402    /// The distance between `item` and the origin passed to `VpTree::k_nearest()`.
403    pub dist: u64,
404}
405
406/// Returns the comparison of the distances only.
407impl<'t, T: 't> PartialOrd for Neighbor<'t, T> {
408    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
409        Some(self.cmp(other))
410    }
411}
412
413/// Returns the comparison of the distances only.
414impl<'t, T: 't> Ord for Neighbor<'t, T> {
415    fn cmp(&self, other: &Self) -> Ordering {
416        self.dist.cmp(&other.dist)
417    }
418}
419
420/// Returns the equality of the distances only.
421impl<'t, T: 't> PartialEq for Neighbor<'t, T> {
422    fn eq(&self, other: &Self) -> bool {
423        self.dist == other.dist
424    }
425}
426
427/// Returns the equality of the distances only.
428impl<'t, T: 't> Eq for Neighbor<'t, T> {}
429
430#[cfg(test)]
431mod test {
432    use super::VpTree;
433
434    const MAX_TREE_VAL: i32 = 8;
435    const ORIGIN: i32 = 4;
436    const NEIGHBORS: &'static [i32] = &[2, 3, 4, 5, 6];
437
438    #[test]
439    fn test_k_nearest() {
440        let tree = VpTree::new(0i32..MAX_TREE_VAL);
441
442        println!("Tree: {:?}", tree);
443
444        let nearest: Vec<_> = tree.k_nearest(&ORIGIN, NEIGHBORS.len())
445            .into_iter()
446            .collect();
447
448        println!("Nearest: {:?}", nearest);
449
450        for neighbor in nearest {
451            assert!(NEIGHBORS.contains(&neighbor.item),
452                    "Was not expecting {:?}",
453                    neighbor);
454        }
455    }
456}