vpsearch/
lib.rs

1//! A relatively simple and readable Rust implementation of Vantage Point tree search algorithm.
2//!
3//! The VP tree algorithm doesn't need to know coordinates of items, only distances between them. It can efficiently search multi-dimensional spaces and abstract things as long as you can define similarity between them (e.g. points, colors, and even images).
4//!
5//! [Project page](https://github.com/kornelski/vpsearch).
6//!
7//!
8//! **This algorithm does not work with squared distances. When implementing Euclidean distance, you *MUST* use `sqrt()`**. You really really can't use that optimization. There's no way around it. Vantage Point trees require [metric spaces](https://en.wikipedia.org/wiki/Metric_space).
9//!
10//! ```rust
11//! #[derive(Copy, Clone)]
12//! struct Point {
13//!     x: f32, y: f32,
14//! }
15//!
16//! impl vpsearch::MetricSpace for Point {
17//!     type UserData = ();
18//!     type Distance = f32;
19//!
20//!     fn distance(&self, other: &Self, _: &Self::UserData) -> Self::Distance {
21//!         let dx = self.x - other.x;
22//!         let dy = self.y - other.y;
23//!         (dx*dx + dy*dy).sqrt() // sqrt is required
24//!     }
25//! }
26//!
27//! fn main() {
28//!     let points = vec![Point{x:2.0,y:3.0}, Point{x:0.0,y:1.0}, Point{x:4.0,y:5.0}];
29//!     let vp = vpsearch::Tree::new(&points);
30//!     let (index, _) = vp.find_nearest(&Point{x:1.0,y:2.0});
31//!     println!("The nearest point is at ({}, {})", points[index].x, points[index].y);
32//! }
33//! ```
34//!
35//! ```rust
36//! #[derive(Clone)]
37//! struct LotsaDimensions<'a>(&'a [u8; 64]);
38//!
39//! impl<'a> vpsearch::MetricSpace for LotsaDimensions<'a> {
40//!     type UserData = ();
41//!     type Distance = f64;
42//!
43//!     fn distance(&self, other: &Self, _: &Self::UserData) -> Self::Distance {
44//!         let dist_squared = self.0.iter().copied().zip(other.0.iter().copied())
45//!             .map(|(a, b)| {
46//!                 (a as i32 - b as i32).pow(2) as u32
47//!             }).sum::<u32>();
48//!
49//!         (dist_squared as f64).sqrt() // sqrt is required
50//!     }
51//! }
52//!
53//! fn main() {
54//!     let points = vec![LotsaDimensions(&[0; 64]), LotsaDimensions(&[5; 64]), LotsaDimensions(&[10; 64])];
55//!     let vp = vpsearch::Tree::new(&points);
56//!     let (index, _) = vp.find_nearest(&LotsaDimensions(&[6; 64]));
57//!     println!("The {}th element is the nearest", index);
58//! }
59//! ```
60
61
62
63use std::cmp::Ordering;
64use std::ops::Add;
65use std::marker::Sized;
66use num_traits::Bounded;
67
68#[cfg(test)]
69mod test;
70mod debug;
71
72#[doc(hidden)]
73pub struct Owned<T>(T);
74
75/// Elements you're searching for must be comparable using this trait.
76///
77/// You can ignore `UserImplementationType` if you're implementing `MetricSpace` for your custom type.
78/// However, if you're implementing `MetricSpace` for a type from std or another crate, then you need
79/// to uniquely identify your implementation (that's because of Rust's Orphan Rules).
80///
81/// ```rust,ignore
82/// impl MetricSpace for MyInt {/*…*/}
83///
84/// /// That dummy struct disambiguates between yours and everyone else's impl for a tuple:
85/// struct MyXYCoordinates;
86/// impl MetricSpace<MyXYCoordinates> for (f32,f32) {/*…*/}
87pub trait MetricSpace<UserImplementationType=()> {
88    /// This is used as a context for comparisons. Use `()` if the elements already contain all the data you need.
89    type UserData;
90
91    /// This is a fancy way of saying it should be `f32` or `u32`
92    type Distance: Copy + PartialOrd + Bounded + Add<Output=Self::Distance>;
93
94    /**
95     * This function must return distance between two items that meets triangle inequality.
96     * Specifically, it **MUST NOT return a squared distance** (you must use sqrt if you use Euclidean distance)
97     *
98     * * `user_data` —Whatever you want. Passed from `new_with_user_data_*()`
99     */
100    fn distance(&self, other: &Self, user_data: &Self::UserData) -> Self::Distance;
101}
102
103/// You can implement this if you want to peek at all visited elements
104///
105/// ```rust
106/// # use vpsearch::*;
107/// struct Impl;
108/// struct ReturnByIndex<I: MetricSpace<Impl>> {
109///    distance: I::Distance,
110///    idx: usize,
111/// }
112///
113/// impl<Item: MetricSpace<Impl> + Clone> BestCandidate<Item, Impl> for ReturnByIndex<Item> {
114///     type Output = (usize, Item::Distance);
115///
116///     fn consider(&mut self, _: &Item, distance: Item::Distance, candidate_index: usize, _: &Item::UserData) {
117///         if distance < self.distance {
118///             self.distance = distance;
119///             self.idx = candidate_index;
120///         }
121///     }
122///     fn distance(&self) -> Item::Distance {
123///         self.distance
124///     }
125///     fn result(self, _: &Item::UserData) -> Self::Output {
126///         (self.idx, self.distance)
127///     }
128/// }
129/// ```
130pub trait BestCandidate<Item: MetricSpace<Impl> + Clone, Impl> where Self: Sized {
131    /// `find_nearest()` will return this type
132    type Output;
133
134    /// This is a visitor method. If the given distance is smaller than previously seen, keep the item (or its index).
135    /// `UserData` is the same as for `MetricSpace<Impl>`, and it's `()` by default.
136    fn consider(&mut self, item: &Item, distance: Item::Distance, candidate_index: usize, user_data: &Item::UserData);
137
138    /// Minimum distance seen so far
139    fn distance(&self) -> Item::Distance;
140
141    /// Called once after all relevant nodes in the tree were visited
142    fn result(self, user_data: &Item::UserData) -> Self::Output;
143}
144
145impl<Item: MetricSpace<Impl> + Clone, Impl> BestCandidate<Item, Impl> for ReturnByIndex<Item, Impl> {
146    type Output = (usize, Item::Distance);
147
148    #[inline]
149    fn consider(&mut self, _: &Item, distance: Item::Distance, candidate_index: usize, _: &Item::UserData) {
150        if distance < self.distance {
151            self.distance = distance;
152            self.idx = candidate_index;
153        }
154    }
155
156    #[inline]
157    fn distance(&self) -> Item::Distance {
158        self.distance
159    }
160
161    fn result(self, _: &Item::UserData) -> (usize, Item::Distance) {
162        (self.idx, self.distance)
163    }
164}
165
166const NO_NODE: u32 = u32::max_value();
167
168struct Node<Item: MetricSpace<Impl> + Clone, Impl> {
169    near: u32,
170    far: u32,
171    vantage_point: Item, // Pointer to the item (value) represented by the current node
172    radius: Item::Distance,    // How far the `near` node stretches
173    idx: u32,             // Index of the `vantage_point` in the original items array
174}
175
176/// The VP-Tree.
177pub struct Tree<Item: MetricSpace<Impl> + Clone, Impl=(), Ownership=Owned<()>> {
178    nodes: Vec<Node<Item, Impl>>,
179    root: u32,
180    user_data: Ownership,
181}
182
183/* Temporary object used to reorder/track distance between items without modifying the orignial items array
184   (also used during search to hold the two properties).
185*/
186struct Tmp<Item: MetricSpace<Impl>, Impl> {
187    distance: Item::Distance,
188    idx: u32,
189}
190
191struct ReturnByIndex<Item: MetricSpace<Impl>, Impl> {
192    distance: Item::Distance,
193    idx: usize,
194}
195
196impl<Item: MetricSpace<Impl>, Impl> ReturnByIndex<Item, Impl> {
197    fn new() -> Self {
198        ReturnByIndex {
199            distance: <Item::Distance as Bounded>::max_value(),
200            idx: 0,
201        }
202    }
203}
204
205impl<Item: MetricSpace<Impl, UserData = ()> + Clone, Impl> Tree<Item, Impl, Owned<()>> {
206
207    /**
208     * Creates a new tree from items. Maximum number of items is 2^31.
209     *
210     * See `Tree::new_with_user_data_owned`.
211     */
212    pub fn new(items: &[Item]) -> Self {
213        Self::new_with_user_data_owned(items, ())
214    }
215}
216
217impl<U, Impl, Item: MetricSpace<Impl, UserData = U> + Clone> Tree<Item, Impl, Owned<U>> {
218    /**
219     * Finds item closest to the given `needle` (that can be any item) and returns *index* of the item in items array from `new()`.
220     *
221     * Returns the index of the nearest item (index from the items slice passed to `new()`) found and the distance from the nearest item.
222     */
223    #[inline]
224    pub fn find_nearest(&self, needle: &Item) -> (usize, Item::Distance) {
225        self.find_nearest_with_user_data(needle, &self.user_data.0)
226    }
227}
228
229impl<Item: MetricSpace<Impl> + Clone, Ownership, Impl> Tree<Item, Impl, Ownership> {
230    fn sort_indexes_by_distance(vantage_point: Item, indexes: &mut [Tmp<Item, Impl>], items: &[Item], user_data: &Item::UserData) {
231        for i in indexes.iter_mut() {
232            i.distance = vantage_point.distance(&items[i.idx as usize], user_data);
233        }
234        indexes.sort_unstable_by(|a, b| if a.distance < b.distance {Ordering::Less} else {Ordering::Greater});
235    }
236
237    fn create_node(indexes: &mut [Tmp<Item, Impl>], nodes: &mut Vec<Node<Item, Impl>>, items: &[Item], user_data: &Item::UserData) -> u32 {
238        if indexes.len() == 0 {
239            return NO_NODE;
240        }
241
242        if indexes.len() == 1 {
243            let node_idx = nodes.len();
244            nodes.push(Node{
245                near: NO_NODE, far: NO_NODE,
246                vantage_point: items[indexes[0].idx as usize].clone(),
247                idx: indexes[0].idx,
248                radius: <Item::Distance as Bounded>::max_value(),
249            });
250            return node_idx as u32;
251        }
252
253        let last = indexes.len()-1;
254        let ref_idx = indexes[last].idx;
255
256        // Removes the `ref_idx` item from remaining items, because it's included in the current node
257        let rest = &mut indexes[..last];
258
259        Self::sort_indexes_by_distance(items[ref_idx as usize].clone(), rest, items, user_data);
260
261        // Remaining items are split by the median distance
262        let half_idx = rest.len()/2;
263
264        let (near_indexes, far_indexes) = rest.split_at_mut(half_idx);
265        let vantage_point = items[ref_idx as usize].clone();
266        let radius = far_indexes[0].distance;
267
268        // push first to reserve space before its children
269        let node_idx = nodes.len();
270        nodes.push(Node{
271            vantage_point,
272            idx: ref_idx,
273            radius,
274            near: NO_NODE,
275            far: NO_NODE,
276        });
277
278        let near = Self::create_node(near_indexes, nodes, items, user_data);
279        let far = Self::create_node(far_indexes, nodes, items, user_data);
280        nodes[node_idx].near = near;
281        nodes[node_idx].far = far;
282        node_idx as u32
283    }
284}
285
286impl<Item: MetricSpace<Impl> + Clone, Impl> Tree<Item, Impl, Owned<Item::UserData>> {
287    /**
288     * Create a Vantage Point tree for fast nearest neighbor search.
289     *
290     * * `items` —       Array of items that will be searched.
291     * * `user_data` —   Reference to any object that is passed down to item.distance()
292     */
293    pub fn new_with_user_data_owned(items: &[Item], user_data: Item::UserData) -> Self {
294        let mut nodes = Vec::with_capacity(items.len());
295        let root = Self::create_root_node(items, &mut nodes, &user_data);
296        Tree {
297            root,
298            nodes,
299            user_data: Owned(user_data),
300        }
301    }
302}
303
304impl<Item: MetricSpace<Impl> + Clone, Impl> Tree<Item, Impl, ()> {
305    /// The tree doesn't have to own the UserData. You can keep passing it to find_nearest().
306    pub fn new_with_user_data_ref(items: &[Item], user_data: &Item::UserData) -> Self {
307        let mut nodes = Vec::with_capacity(items.len());
308        let root = Self::create_root_node(items, &mut nodes, &user_data);
309        Tree {
310            root,
311            nodes,
312            user_data: (),
313        }
314    }
315
316    #[inline]
317    pub fn find_nearest(&self, needle: &Item, user_data: &Item::UserData) -> (usize, Item::Distance) {
318        self.find_nearest_with_user_data(needle, user_data)
319    }
320}
321
322impl<Item: MetricSpace<Impl> + Clone, Ownership, Impl> Tree<Item, Impl, Ownership> {
323    fn create_root_node(items: &[Item], nodes: &mut Vec<Node<Item, Impl>>, user_data: &Item::UserData) -> u32 {
324        assert!(items.len() < (u32::max_value()/2) as usize);
325
326        let mut indexes: Vec<_> = (0..items.len() as u32).map(|i| Tmp{
327            idx: i, distance: <Item::Distance as Bounded>::max_value(),
328        }).collect();
329
330        Self::create_node(&mut indexes[..], nodes, items, user_data) as u32
331    }
332
333    fn search_node<B: BestCandidate<Item, Impl>>(node: &Node<Item, Impl>, nodes: &[Node<Item, Impl>], needle: &Item, best_candidate: &mut B, user_data: &Item::UserData) {
334        let distance = needle.distance(&node.vantage_point, user_data);
335
336        best_candidate.consider(&node.vantage_point, distance, node.idx as usize, user_data);
337
338        // Recurse towards most likely candidate first to narrow best candidate's distance as soon as possible
339        if distance < node.radius {
340            // No-node case uses out-of-bounds index, so this reuses a safe bounds check as the "null" check
341            if let Some(near) = nodes.get(node.near as usize) {
342                Self::search_node(near, nodes, needle, best_candidate, user_data);
343            }
344            // The best node (final answer) may be just ouside the radius, but not farther than
345            // the best distance we know so far. The search_node above should have narrowed
346            // best_candidate.distance, so this path is rarely taken.
347            if let Some(far) = nodes.get(node.far as usize) {
348                if distance + best_candidate.distance() >= node.radius {
349                    Self::search_node(far, nodes, needle, best_candidate, user_data);
350                }
351            }
352        } else {
353            if let Some(far) = nodes.get(node.far as usize) {
354                Self::search_node(far, nodes, needle, best_candidate, user_data);
355            }
356            if let Some(near) = nodes.get(node.near as usize) {
357                if distance <= node.radius + best_candidate.distance() {
358                    Self::search_node(near, nodes, needle, best_candidate, user_data);
359                }
360            }
361        }
362    }
363
364    #[inline]
365    fn find_nearest_with_user_data(&self, needle: &Item, user_data: &Item::UserData) -> (usize, Item::Distance) {
366        self.find_nearest_custom(needle, user_data, ReturnByIndex::new())
367    }
368
369    #[inline]
370    /// All the bells and whistles version. For best_candidate implement `BestCandidate<Item, Impl>` trait.
371    pub fn find_nearest_custom<ReturnBy: BestCandidate<Item, Impl>>(&self, needle: &Item, user_data: &Item::UserData, mut best_candidate: ReturnBy) -> ReturnBy::Output {
372        if let Some(root) = self.nodes.get(self.root as usize) {
373            Self::search_node(root, &self.nodes, needle, &mut best_candidate, user_data);
374        }
375
376        best_candidate.result(user_data)
377    }
378}