sif_kdtree/
lib.rs

1#![forbid(unsafe_code)]
2#![deny(missing_docs, missing_debug_implementations)]
3
4//! A simple library implementing an immutable, flat representation of a [k-d tree](https://en.wikipedia.org/wiki/K-d_tree)
5//!
6//! The library supports arbitrary spatial queries via the [`Query`] trait and nearest neighbour search.
7//! Its implementation is simple as the objects in the index are fixed after construction.
8//! This also enables a flat and thereby cache-friendly memory layout which can be backed by memory maps.
9//!
10//! The library provides optional integration with [rayon] for parallel construction and queries and [serde] for (de-)serialization of the trees.
11//!
12//! # Example
13//!
14//! ```
15//! use std::ops::ControlFlow;
16//!
17//! use sif_kdtree::{KdTree, Object, WithinDistance};
18//!
19//! struct Something(usize, [f64; 2]);
20//!
21//! impl Object for Something {
22//!     type Point = [f64; 2];
23//!
24//!     fn position(&self) -> &Self::Point {
25//!         &self.1
26//!     }
27//! }
28//!
29//! let index = KdTree::new(
30//!     vec![
31//!         Something(0, [-0.4, -3.3]),
32//!         Something(1, [-4.5, -1.8]),
33//!         Something(2, [0.7, 2.0]),
34//!         Something(3, [1.7, 1.5]),
35//!         Something(4, [-1.3, 2.3]),
36//!         Something(5, [2.2, 1.0]),
37//!         Something(6, [-3.7, 3.8]),
38//!         Something(7, [-3.2, -0.1]),
39//!         Something(8, [1.4, 2.7]),
40//!         Something(9, [3.1, -0.0]),
41//!         Something(10, [4.3, 0.8]),
42//!         Something(11, [3.9, -3.3]),
43//!         Something(12, [0.4, -3.2]),
44//!     ],
45//! );
46//!
47//! let mut close_by = Vec::new();
48//!
49//! index.look_up(&WithinDistance::new([0., 0.], 3.), |thing| {
50//!     close_by.push(thing.0);
51//!
52//!     ControlFlow::Continue(())
53//! });
54//!
55//! assert_eq!(close_by, [2, 4, 5, 3]);
56//!
57//! let closest = index.nearest(&[0., 0.]).unwrap().0;
58//!
59//! assert_eq!(closest, 2);
60//! ```
61//!
62//! The [`KdTree`] data structure is generic over its backing storage as long as it can be converted into a slice via the [`AsRef`] trait.
63//! This can for instance be used to memory map k-d trees from persistent storage.
64//!
65//! ```no_run
66//! # fn main() -> std::io::Result<()> {
67//! use std::fs::File;
68//! use std::mem::size_of;
69//! use std::slice::from_raw_parts;
70//!
71//! use memmap2::Mmap;
72//!
73//! use sif_kdtree::{KdTree, Object};
74//!
75//! #[derive(Clone, Copy)]
76//! struct Point([f64; 3]);
77//!
78//! impl Object for Point {
79//!     type Point = [f64; 3];
80//!
81//!     fn position(&self) -> &Self::Point {
82//!         &self.0
83//!     }
84//! }
85//!
86//! let file = File::open("index.bin")?;
87//! let map = unsafe { Mmap::map(&file)? };
88//!
89//! struct PointCloud(Mmap);
90//!
91//! impl AsRef<[Point]> for PointCloud {
92//!     fn as_ref(&self) -> &[Point] {
93//!         let ptr = self.0.as_ptr().cast();
94//!         let len = self.0.len() / size_of::<Point>();
95//!
96//!         unsafe { from_raw_parts(ptr, len) }
97//!     }
98//! }
99//!
100//! let index = KdTree::new_unchecked(PointCloud(map));
101//! # Ok(()) }
102//! ```
103
104mod look_up;
105mod nearest;
106mod sort;
107
108pub use look_up::{Query, WithinBoundingBox, WithinDistance};
109
110use std::marker::PhantomData;
111use std::ops::Deref;
112
113use num_traits::Num;
114#[cfg(feature = "serde")]
115use serde::{Deserialize, Serialize};
116
117/// Defines a [finite-dimensional][Self::DIM] space in terms of [coordinate values][Self::coord] along a chosen set of axes
118pub trait Point {
119    /// The dimension of the underlying space
120    const DIM: usize;
121
122    /// The type of the coordinate values
123    type Coord: Num + Copy + PartialOrd;
124
125    /// Access the coordinate value of the point along the given `axis`
126    fn coord(&self, axis: usize) -> Self::Coord;
127}
128
129/// Extends the [`Point`] trait by a distance metric required for nearest neighbour search
130pub trait Distance: Point {
131    /// Return the squared distance between `self` and `other`
132    ///
133    /// This is called during nearest neighbour search and hence only the relation between two distance values is required so that computing square roots can be avoided.
134    fn distance_2(&self, other: &Self) -> Self::Coord;
135}
136
137/// `N`-dimensional space using [Euclidean distance](https://en.wikipedia.org/wiki/Euclidean_distance)
138impl<T, const N: usize> Point for [T; N]
139where
140    T: Num + Copy + PartialOrd,
141{
142    const DIM: usize = N;
143
144    type Coord = T;
145
146    fn coord(&self, axis: usize) -> Self::Coord {
147        self[axis]
148    }
149}
150
151impl<T, const N: usize> Distance for [T; N]
152where
153    T: Num + Copy + PartialOrd,
154{
155    fn distance_2(&self, other: &Self) -> Self::Coord {
156        (0..N).fold(T::zero(), |res, axis| {
157            let diff = self[axis] - other[axis];
158
159            res + diff * diff
160        })
161    }
162}
163
164/// Defines the objects which can be organized in a [`KdTree`] by positioning them in the vector space defined via the [`Point`] trait
165pub trait Object {
166    /// The [`Point`] implementation used to represent the [position][`Self::position`] of these objects
167    type Point: Point;
168
169    /// Return the position associated with this object
170    ///
171    /// Note that calling this method is assumed to be cheap, returning a reference to a point stored in the interior of the object.
172    fn position(&self) -> &Self::Point;
173}
174
175/// An immutable, flat representation of a [k-d tree](https://en.wikipedia.org/wiki/K-d_tree)
176///
177/// Accelerates spatial queries and nearest neighbour search by sorting the objects according to the coordinate values of their positions.
178///
179/// Note that this tree dereferences to and deserializes as a slice of objects.
180/// Modifying object positions through interior mutability or deserializing a modified sequence is safe but will lead to incorrect results.
181#[derive(Debug, Default, Clone)]
182#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
183#[cfg_attr(feature = "serde", serde(transparent))]
184pub struct KdTree<O, S = Box<[O]>>
185where
186    S: AsRef<[O]>,
187{
188    objects: S,
189    _marker: PhantomData<O>,
190}
191
192impl<O, S> KdTree<O, S>
193where
194    O: Object,
195    S: AsRef<[O]>,
196{
197    /// Interprets the given `objects` as a tree
198    ///
199    /// Supplying `objects` which are not actually sorted as a k-d tree is safe but will lead to incorrect results.
200    pub fn new_unchecked(objects: S) -> Self {
201        Self {
202            objects,
203            _marker: PhantomData,
204        }
205    }
206}
207
208impl<O, S> Deref for KdTree<O, S>
209where
210    S: AsRef<[O]>,
211{
212    type Target = [O];
213
214    fn deref(&self) -> &Self::Target {
215        self.objects.as_ref()
216    }
217}
218
219impl<O, S> AsRef<[O]> for KdTree<O, S>
220where
221    S: AsRef<[O]>,
222{
223    fn as_ref(&self) -> &[O] {
224        self.objects.as_ref()
225    }
226}
227
228fn split<O>(objects: &[O]) -> (&[O], &O, &[O]) {
229    let (left, objects) = objects.split_at(objects.len() / 2);
230    let (mid, right) = objects.split_first().unwrap();
231
232    (left, mid, right)
233}
234
235fn contains<P>(aabb: &(P, P), position: &P) -> bool
236where
237    P: Point,
238{
239    (0..P::DIM).all(|axis| {
240        aabb.0.coord(axis) <= position.coord(axis) && position.coord(axis) <= aabb.1.coord(axis)
241    })
242}
243
244#[cfg(test)]
245mod tests {
246    use super::*;
247
248    use std::cmp::Ordering;
249
250    use proptest::{collection::vec, strategy::Strategy};
251
252    pub fn random_points(len: usize) -> impl Strategy<Value = Vec<[f32; 2]>> {
253        (vec(0.0_f32..=1.0, len), vec(0.0_f32..=1.0, len))
254            .prop_map(|(x, y)| x.into_iter().zip(y).map(|(x, y)| [x, y]).collect())
255    }
256
257    #[derive(Debug, PartialEq)]
258    pub struct RandomObject(pub [f32; 2]);
259
260    impl Eq for RandomObject {}
261
262    impl PartialOrd for RandomObject {
263        fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
264            Some(self.cmp(other))
265        }
266    }
267
268    impl Ord for RandomObject {
269        fn cmp(&self, other: &Self) -> Ordering {
270            self.0.partial_cmp(&other.0).unwrap()
271        }
272    }
273
274    impl Object for RandomObject {
275        type Point = [f32; 2];
276
277        fn position(&self) -> &Self::Point {
278            &self.0
279        }
280    }
281
282    pub fn random_objects(len: usize) -> impl Strategy<Value = Box<[RandomObject]>> {
283        random_points(len).prop_map(|points| points.into_iter().map(RandomObject).collect())
284    }
285}