sif_kdtree/
lib.rs

1#![forbid(unsafe_code)]
2#![deny(missing_docs, missing_debug_implementations)]
3
4//! A simple library implementing an immutable, flat representation of a [k-d tree](https://en.wikipedia.org/wiki/K-d_tree)
5//!
6//! The library supports arbitrary spatial queries via the [`Query`] trait and nearest neighbour search.
7//! Its implementation is simple as the objects in the index are fixed after construction.
8//! This also enables a flat and thereby cache-friendly memory layout which can be backed by memory maps.
9//!
10//! The library provides optional integration with [rayon] for parallel construction and queries and [serde] for (de-)serialization of the trees.
11//!
12//! # Example
13//!
14//! ```
15//! use std::ops::ControlFlow;
16//!
17//! use sif_kdtree::{KdTree, Object, WithinDistance};
18//!
19//! struct Something(usize, [f64; 2]);
20//!
21//! impl Object for Something {
22//!     type Point = [f64; 2];
23//!
24//!     fn position(&self) -> &Self::Point {
25//!         &self.1
26//!     }
27//! }
28//!
29//! let index = KdTree::new(
30//!     vec![
31//!         Something(0, [-0.4, -3.3]),
32//!         Something(1, [-4.5, -1.8]),
33//!         Something(2, [0.7, 2.0]),
34//!         Something(3, [1.7, 1.5]),
35//!         Something(4, [-1.3, 2.3]),
36//!         Something(5, [2.2, 1.0]),
37//!         Something(6, [-3.7, 3.8]),
38//!         Something(7, [-3.2, -0.1]),
39//!         Something(8, [1.4, 2.7]),
40//!         Something(9, [3.1, -0.0]),
41//!         Something(10, [4.3, 0.8]),
42//!         Something(11, [3.9, -3.3]),
43//!         Something(12, [0.4, -3.2]),
44//!     ],
45//! );
46//!
47//! let mut close_by = Vec::new();
48//!
49//! index
50//!     .look_up(&WithinDistance::new([0., 0.], 3.), |thing| {
51//!         close_by.push(thing.0);
52//!
53//!         ControlFlow::<()>::Continue(())
54//!     })
55//!     .continue_value()
56//!     .unwrap();
57//!
58//! assert_eq!(close_by, [2, 4, 5, 3]);
59//!
60//! let closest = index.nearest(&[0., 0.]).unwrap().0;
61//!
62//! assert_eq!(closest, 2);
63//! ```
64//!
65//! The [`KdTree`] data structure is generic over its backing storage as long as it can be converted into a slice via the [`AsRef`] trait.
66//! This can for instance be used to memory map k-d trees from persistent storage.
67//!
68//! ```no_run
69//! # fn main() -> std::io::Result<()> {
70//! use std::fs::File;
71//! use std::mem::size_of;
72//! use std::slice::from_raw_parts;
73//!
74//! use memmap2::Mmap;
75//!
76//! use sif_kdtree::{KdTree, Object};
77//!
78//! #[derive(Clone, Copy)]
79//! struct Point([f64; 3]);
80//!
81//! impl Object for Point {
82//!     type Point = [f64; 3];
83//!
84//!     fn position(&self) -> &Self::Point {
85//!         &self.0
86//!     }
87//! }
88//!
89//! let file = File::open("index.bin")?;
90//! let map = unsafe { Mmap::map(&file)? };
91//!
92//! struct PointCloud(Mmap);
93//!
94//! impl AsRef<[Point]> for PointCloud {
95//!     fn as_ref(&self) -> &[Point] {
96//!         let ptr = self.0.as_ptr().cast();
97//!         let len = self.0.len() / size_of::<Point>();
98//!
99//!         unsafe { from_raw_parts(ptr, len) }
100//!     }
101//! }
102//!
103//! let index = KdTree::new_unchecked(PointCloud(map));
104//! # Ok(()) }
105//! ```
106
107mod look_up;
108mod nearest;
109mod sort;
110
111pub use look_up::{Query, WithinBoundingBox, WithinDistance};
112
113use std::marker::PhantomData;
114use std::ops::Deref;
115
116use num_traits::Num;
117#[cfg(feature = "serde")]
118use serde::{Deserialize, Serialize};
119
120/// Defines a [finite-dimensional][Self::DIM] space in terms of [coordinate values][Self::coord] along a chosen set of axes
121pub trait Point {
122    /// The dimension of the underlying space
123    const DIM: usize;
124
125    /// The type of the coordinate values
126    type Coord: Num + Copy + PartialOrd;
127
128    /// Access the coordinate value of the point along the given `axis`
129    fn coord(&self, axis: usize) -> Self::Coord;
130}
131
132/// Extends the [`Point`] trait by a distance metric required for nearest neighbour search
133pub trait Distance: Point {
134    /// Return the squared distance between `self` and `other`
135    ///
136    /// This is called during nearest neighbour search and hence only the relation between two distance values is required so that computing square roots can be avoided.
137    fn distance_2(&self, other: &Self) -> Self::Coord;
138}
139
140/// `N`-dimensional space using [Euclidean distance](https://en.wikipedia.org/wiki/Euclidean_distance)
141impl<T, const N: usize> Point for [T; N]
142where
143    T: Num + Copy + PartialOrd,
144{
145    const DIM: usize = N;
146
147    type Coord = T;
148
149    fn coord(&self, axis: usize) -> Self::Coord {
150        self[axis]
151    }
152}
153
154impl<T, const N: usize> Distance for [T; N]
155where
156    T: Num + Copy + PartialOrd,
157{
158    fn distance_2(&self, other: &Self) -> Self::Coord {
159        (0..N).fold(T::zero(), |res, axis| {
160            let diff = self[axis] - other[axis];
161
162            res + diff * diff
163        })
164    }
165}
166
167/// Defines the objects which can be organized in a [`KdTree`] by positioning them in the vector space defined via the [`Point`] trait
168pub trait Object {
169    /// The [`Point`] implementation used to represent the [position][`Self::position`] of these objects
170    type Point: Point;
171
172    /// Return the position associated with this object
173    ///
174    /// Note that calling this method is assumed to be cheap, returning a reference to a point stored in the interior of the object.
175    fn position(&self) -> &Self::Point;
176}
177
178/// An immutable, flat representation of a [k-d tree](https://en.wikipedia.org/wiki/K-d_tree)
179///
180/// Accelerates spatial queries and nearest neighbour search by sorting the objects according to the coordinate values of their positions.
181///
182/// Note that this tree dereferences to and deserializes as a slice of objects.
183/// Modifying object positions through interior mutability or deserializing a modified sequence is safe but will lead to incorrect results.
184#[derive(Debug, Default, Clone)]
185#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
186#[cfg_attr(feature = "serde", serde(transparent))]
187pub struct KdTree<O, S = Box<[O]>>
188where
189    S: AsRef<[O]>,
190{
191    objects: S,
192    _marker: PhantomData<O>,
193}
194
195impl<O, S> KdTree<O, S>
196where
197    O: Object,
198    S: AsRef<[O]>,
199{
200    /// Interprets the given `objects` as a tree
201    ///
202    /// Supplying `objects` which are not actually sorted as a k-d tree is safe but will lead to incorrect results.
203    pub fn new_unchecked(objects: S) -> Self {
204        Self {
205            objects,
206            _marker: PhantomData,
207        }
208    }
209}
210
211impl<O, S> Deref for KdTree<O, S>
212where
213    S: AsRef<[O]>,
214{
215    type Target = [O];
216
217    fn deref(&self) -> &Self::Target {
218        self.objects.as_ref()
219    }
220}
221
222impl<O, S> AsRef<[O]> for KdTree<O, S>
223where
224    S: AsRef<[O]>,
225{
226    fn as_ref(&self) -> &[O] {
227        self.objects.as_ref()
228    }
229}
230
231fn split<O>(objects: &[O]) -> (&[O], &O, &[O]) {
232    let (left, objects) = objects.split_at(objects.len() / 2);
233    let (mid, right) = objects.split_first().unwrap();
234
235    (left, mid, right)
236}
237
238fn contains<P>(aabb: &(P, P), position: &P) -> bool
239where
240    P: Point,
241{
242    (0..P::DIM).all(|axis| {
243        aabb.0.coord(axis) <= position.coord(axis) && position.coord(axis) <= aabb.1.coord(axis)
244    })
245}
246
247#[cfg(test)]
248mod tests {
249    use super::*;
250
251    use std::cmp::Ordering;
252
253    use proptest::{collection::vec, strategy::Strategy};
254
255    pub fn random_points(len: usize) -> impl Strategy<Value = Vec<[f32; 2]>> {
256        (vec(0.0_f32..=1.0, len), vec(0.0_f32..=1.0, len))
257            .prop_map(|(x, y)| x.into_iter().zip(y).map(|(x, y)| [x, y]).collect())
258    }
259
260    #[derive(Debug, PartialEq)]
261    pub struct RandomObject(pub [f32; 2]);
262
263    impl Eq for RandomObject {}
264
265    impl PartialOrd for RandomObject {
266        fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
267            Some(self.cmp(other))
268        }
269    }
270
271    impl Ord for RandomObject {
272        fn cmp(&self, other: &Self) -> Ordering {
273            self.0.partial_cmp(&other.0).unwrap()
274        }
275    }
276
277    impl Object for RandomObject {
278        type Point = [f32; 2];
279
280        fn position(&self) -> &Self::Point {
281            &self.0
282        }
283    }
284
285    pub fn random_objects(len: usize) -> impl Strategy<Value = Box<[RandomObject]>> {
286        random_points(len).prop_map(|points| points.into_iter().map(RandomObject).collect())
287    }
288}