sif_kdtree/lib.rs
1#![forbid(unsafe_code)]
2#![deny(missing_docs, missing_debug_implementations)]
3
4//! A simple library implementing an immutable, flat representation of a [k-d tree](https://en.wikipedia.org/wiki/K-d_tree)
5//!
6//! The library supports arbitrary spatial queries via the [`Query`] trait and nearest neighbour search.
7//! Its implementation is simple as the objects in the index are fixed after construction.
8//! This also enables a flat and thereby cache-friendly memory layout which can be backed by memory maps.
9//!
10//! The library provides optional integration with [rayon] for parallel construction and queries and [serde] for (de-)serialization of the trees.
11//!
12//! # Example
13//!
14//! ```
15//! use std::ops::ControlFlow;
16//!
17//! use sif_kdtree::{KdTree, Object, WithinDistance};
18//!
19//! struct Something(usize, [f64; 2]);
20//!
21//! impl Object for Something {
22//! type Point = [f64; 2];
23//!
24//! fn position(&self) -> &Self::Point {
25//! &self.1
26//! }
27//! }
28//!
29//! let index = KdTree::new(
30//! vec![
31//! Something(0, [-0.4, -3.3]),
32//! Something(1, [-4.5, -1.8]),
33//! Something(2, [0.7, 2.0]),
34//! Something(3, [1.7, 1.5]),
35//! Something(4, [-1.3, 2.3]),
36//! Something(5, [2.2, 1.0]),
37//! Something(6, [-3.7, 3.8]),
38//! Something(7, [-3.2, -0.1]),
39//! Something(8, [1.4, 2.7]),
40//! Something(9, [3.1, -0.0]),
41//! Something(10, [4.3, 0.8]),
42//! Something(11, [3.9, -3.3]),
43//! Something(12, [0.4, -3.2]),
44//! ],
45//! );
46//!
47//! let mut close_by = Vec::new();
48//!
49//! index
50//! .look_up(&WithinDistance::new([0., 0.], 3.), |thing| {
51//! close_by.push(thing.0);
52//!
53//! ControlFlow::<()>::Continue(())
54//! })
55//! .continue_value()
56//! .unwrap();
57//!
58//! assert_eq!(close_by, [2, 4, 5, 3]);
59//!
60//! let closest = index.nearest(&[0., 0.]).unwrap().0;
61//!
62//! assert_eq!(closest, 2);
63//! ```
64//!
65//! The [`KdTree`] data structure is generic over its backing storage as long as it can be converted into a slice via the [`AsRef`] trait.
66//! This can for instance be used to memory map k-d trees from persistent storage.
67//!
68//! ```no_run
69//! # fn main() -> std::io::Result<()> {
70//! use std::fs::File;
71//! use std::mem::size_of;
72//! use std::slice::from_raw_parts;
73//!
74//! use memmap2::Mmap;
75//!
76//! use sif_kdtree::{KdTree, Object};
77//!
78//! #[derive(Clone, Copy)]
79//! struct Point([f64; 3]);
80//!
81//! impl Object for Point {
82//! type Point = [f64; 3];
83//!
84//! fn position(&self) -> &Self::Point {
85//! &self.0
86//! }
87//! }
88//!
89//! let file = File::open("index.bin")?;
90//! let map = unsafe { Mmap::map(&file)? };
91//!
92//! struct PointCloud(Mmap);
93//!
94//! impl AsRef<[Point]> for PointCloud {
95//! fn as_ref(&self) -> &[Point] {
96//! let ptr = self.0.as_ptr().cast();
97//! let len = self.0.len() / size_of::<Point>();
98//!
99//! unsafe { from_raw_parts(ptr, len) }
100//! }
101//! }
102//!
103//! let index = KdTree::new_unchecked(PointCloud(map));
104//! # Ok(()) }
105//! ```
106
107mod look_up;
108mod nearest;
109mod sort;
110
111pub use look_up::{Query, WithinBoundingBox, WithinDistance};
112
113use std::marker::PhantomData;
114use std::ops::Deref;
115
116use num_traits::Num;
117#[cfg(feature = "serde")]
118use serde::{Deserialize, Serialize};
119
120/// Defines a [finite-dimensional][Self::DIM] space in terms of [coordinate values][Self::coord] along a chosen set of axes
121pub trait Point {
122 /// The dimension of the underlying space
123 const DIM: usize;
124
125 /// The type of the coordinate values
126 type Coord: Num + Copy + PartialOrd;
127
128 /// Access the coordinate value of the point along the given `axis`
129 fn coord(&self, axis: usize) -> Self::Coord;
130}
131
132/// Extends the [`Point`] trait by a distance metric required for nearest neighbour search
133pub trait Distance: Point {
134 /// Return the squared distance between `self` and `other`
135 ///
136 /// This is called during nearest neighbour search and hence only the relation between two distance values is required so that computing square roots can be avoided.
137 fn distance_2(&self, other: &Self) -> Self::Coord;
138}
139
140/// `N`-dimensional space using [Euclidean distance](https://en.wikipedia.org/wiki/Euclidean_distance)
141impl<T, const N: usize> Point for [T; N]
142where
143 T: Num + Copy + PartialOrd,
144{
145 const DIM: usize = N;
146
147 type Coord = T;
148
149 fn coord(&self, axis: usize) -> Self::Coord {
150 self[axis]
151 }
152}
153
154impl<T, const N: usize> Distance for [T; N]
155where
156 T: Num + Copy + PartialOrd,
157{
158 fn distance_2(&self, other: &Self) -> Self::Coord {
159 (0..N).fold(T::zero(), |res, axis| {
160 let diff = self[axis] - other[axis];
161
162 res + diff * diff
163 })
164 }
165}
166
167/// Defines the objects which can be organized in a [`KdTree`] by positioning them in the vector space defined via the [`Point`] trait
168pub trait Object {
169 /// The [`Point`] implementation used to represent the [position][`Self::position`] of these objects
170 type Point: Point;
171
172 /// Return the position associated with this object
173 ///
174 /// Note that calling this method is assumed to be cheap, returning a reference to a point stored in the interior of the object.
175 fn position(&self) -> &Self::Point;
176}
177
178/// An immutable, flat representation of a [k-d tree](https://en.wikipedia.org/wiki/K-d_tree)
179///
180/// Accelerates spatial queries and nearest neighbour search by sorting the objects according to the coordinate values of their positions.
181///
182/// Note that this tree dereferences to and deserializes as a slice of objects.
183/// Modifying object positions through interior mutability or deserializing a modified sequence is safe but will lead to incorrect results.
184#[derive(Debug, Default, Clone)]
185#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
186#[cfg_attr(feature = "serde", serde(transparent))]
187pub struct KdTree<O, S = Box<[O]>>
188where
189 S: AsRef<[O]>,
190{
191 objects: S,
192 _marker: PhantomData<O>,
193}
194
195impl<O, S> KdTree<O, S>
196where
197 O: Object,
198 S: AsRef<[O]>,
199{
200 /// Interprets the given `objects` as a tree
201 ///
202 /// Supplying `objects` which are not actually sorted as a k-d tree is safe but will lead to incorrect results.
203 pub fn new_unchecked(objects: S) -> Self {
204 Self {
205 objects,
206 _marker: PhantomData,
207 }
208 }
209}
210
211impl<O, S> Deref for KdTree<O, S>
212where
213 S: AsRef<[O]>,
214{
215 type Target = [O];
216
217 fn deref(&self) -> &Self::Target {
218 self.objects.as_ref()
219 }
220}
221
222impl<O, S> AsRef<[O]> for KdTree<O, S>
223where
224 S: AsRef<[O]>,
225{
226 fn as_ref(&self) -> &[O] {
227 self.objects.as_ref()
228 }
229}
230
231fn split<O>(objects: &[O]) -> (&[O], &O, &[O]) {
232 let (left, objects) = objects.split_at(objects.len() / 2);
233 let (mid, right) = objects.split_first().unwrap();
234
235 (left, mid, right)
236}
237
238fn contains<P>(aabb: &(P, P), position: &P) -> bool
239where
240 P: Point,
241{
242 (0..P::DIM).all(|axis| {
243 aabb.0.coord(axis) <= position.coord(axis) && position.coord(axis) <= aabb.1.coord(axis)
244 })
245}
246
247#[cfg(test)]
248mod tests {
249 use super::*;
250
251 use std::cmp::Ordering;
252
253 use proptest::{collection::vec, strategy::Strategy};
254
255 pub fn random_points(len: usize) -> impl Strategy<Value = Vec<[f32; 2]>> {
256 (vec(0.0_f32..=1.0, len), vec(0.0_f32..=1.0, len))
257 .prop_map(|(x, y)| x.into_iter().zip(y).map(|(x, y)| [x, y]).collect())
258 }
259
260 #[derive(Debug, PartialEq)]
261 pub struct RandomObject(pub [f32; 2]);
262
263 impl Eq for RandomObject {}
264
265 impl PartialOrd for RandomObject {
266 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
267 Some(self.cmp(other))
268 }
269 }
270
271 impl Ord for RandomObject {
272 fn cmp(&self, other: &Self) -> Ordering {
273 self.0.partial_cmp(&other.0).unwrap()
274 }
275 }
276
277 impl Object for RandomObject {
278 type Point = [f32; 2];
279
280 fn position(&self) -> &Self::Point {
281 &self.0
282 }
283 }
284
285 pub fn random_objects(len: usize) -> impl Strategy<Value = Box<[RandomObject]>> {
286 random_points(len).prop_map(|points| points.into_iter().map(RandomObject).collect())
287 }
288}