sif_kdtree/lib.rs
1#![forbid(unsafe_code)]
2#![deny(missing_docs, missing_debug_implementations)]
3
4//! A simple library implementing an immutable, flat representation of a [k-d tree](https://en.wikipedia.org/wiki/K-d_tree)
5//!
6//! The library supports arbitrary spatial queries via the [`Query`] trait and nearest neighbour search.
7//! Its implementation is simple as the objects in the index are fixed after construction.
8//! This also enables a flat and thereby cache-friendly memory layout which can be backed by memory maps.
9//!
10//! The library provides optional integration with [rayon] for parallel construction and queries and [serde] for (de-)serialization of the trees.
11//!
12//! # Example
13//!
14//! ```
15//! use std::ops::ControlFlow;
16//!
17//! use sif_kdtree::{KdTree, Object, WithinDistance};
18//!
19//! struct Something(usize, [f64; 2]);
20//!
21//! impl Object for Something {
22//! type Point = [f64; 2];
23//!
24//! fn position(&self) -> &Self::Point {
25//! &self.1
26//! }
27//! }
28//!
29//! let index = KdTree::new(
30//! vec![
31//! Something(0, [-0.4, -3.3]),
32//! Something(1, [-4.5, -1.8]),
33//! Something(2, [0.7, 2.0]),
34//! Something(3, [1.7, 1.5]),
35//! Something(4, [-1.3, 2.3]),
36//! Something(5, [2.2, 1.0]),
37//! Something(6, [-3.7, 3.8]),
38//! Something(7, [-3.2, -0.1]),
39//! Something(8, [1.4, 2.7]),
40//! Something(9, [3.1, -0.0]),
41//! Something(10, [4.3, 0.8]),
42//! Something(11, [3.9, -3.3]),
43//! Something(12, [0.4, -3.2]),
44//! ],
45//! );
46//!
47//! let mut close_by = Vec::new();
48//!
49//! index.look_up(&WithinDistance::new([0., 0.], 3.), |thing| {
50//! close_by.push(thing.0);
51//!
52//! ControlFlow::Continue(())
53//! });
54//!
55//! assert_eq!(close_by, [2, 4, 5, 3]);
56//!
57//! let closest = index.nearest(&[0., 0.]).unwrap().0;
58//!
59//! assert_eq!(closest, 2);
60//! ```
61//!
62//! The [`KdTree`] data structure is generic over its backing storage as long as it can be converted into a slice via the [`AsRef`] trait.
63//! This can for instance be used to memory map k-d trees from persistent storage.
64//!
65//! ```no_run
66//! # fn main() -> std::io::Result<()> {
67//! use std::fs::File;
68//! use std::mem::size_of;
69//! use std::slice::from_raw_parts;
70//!
71//! use memmap2::Mmap;
72//!
73//! use sif_kdtree::{KdTree, Object};
74//!
75//! #[derive(Clone, Copy)]
76//! struct Point([f64; 3]);
77//!
78//! impl Object for Point {
79//! type Point = [f64; 3];
80//!
81//! fn position(&self) -> &Self::Point {
82//! &self.0
83//! }
84//! }
85//!
86//! let file = File::open("index.bin")?;
87//! let map = unsafe { Mmap::map(&file)? };
88//!
89//! struct PointCloud(Mmap);
90//!
91//! impl AsRef<[Point]> for PointCloud {
92//! fn as_ref(&self) -> &[Point] {
93//! let ptr = self.0.as_ptr().cast();
94//! let len = self.0.len() / size_of::<Point>();
95//!
96//! unsafe { from_raw_parts(ptr, len) }
97//! }
98//! }
99//!
100//! let index = KdTree::new_unchecked(PointCloud(map));
101//! # Ok(()) }
102//! ```
103
104mod look_up;
105mod nearest;
106mod sort;
107
108pub use look_up::{Query, WithinBoundingBox, WithinDistance};
109
110use std::marker::PhantomData;
111use std::ops::Deref;
112
113use num_traits::Num;
114#[cfg(feature = "serde")]
115use serde::{Deserialize, Serialize};
116
117/// Defines a [finite-dimensional][Self::DIM] space in terms of [coordinate values][Self::coord] along a chosen set of axes
118pub trait Point {
119 /// The dimension of the underlying space
120 const DIM: usize;
121
122 /// The type of the coordinate values
123 type Coord: Num + Copy + PartialOrd;
124
125 /// Access the coordinate value of the point along the given `axis`
126 fn coord(&self, axis: usize) -> Self::Coord;
127}
128
129/// Extends the [`Point`] trait by a distance metric required for nearest neighbour search
130pub trait Distance: Point {
131 /// Return the squared distance between `self` and `other`
132 ///
133 /// This is called during nearest neighbour search and hence only the relation between two distance values is required so that computing square roots can be avoided.
134 fn distance_2(&self, other: &Self) -> Self::Coord;
135}
136
137/// `N`-dimensional space using [Euclidean distance](https://en.wikipedia.org/wiki/Euclidean_distance)
138impl<T, const N: usize> Point for [T; N]
139where
140 T: Num + Copy + PartialOrd,
141{
142 const DIM: usize = N;
143
144 type Coord = T;
145
146 fn coord(&self, axis: usize) -> Self::Coord {
147 self[axis]
148 }
149}
150
151impl<T, const N: usize> Distance for [T; N]
152where
153 T: Num + Copy + PartialOrd,
154{
155 fn distance_2(&self, other: &Self) -> Self::Coord {
156 (0..N).fold(T::zero(), |res, axis| {
157 let diff = self[axis] - other[axis];
158
159 res + diff * diff
160 })
161 }
162}
163
164/// Defines the objects which can be organized in a [`KdTree`] by positioning them in the vector space defined via the [`Point`] trait
165pub trait Object {
166 /// The [`Point`] implementation used to represent the [position][`Self::position`] of these objects
167 type Point: Point;
168
169 /// Return the position associated with this object
170 ///
171 /// Note that calling this method is assumed to be cheap, returning a reference to a point stored in the interior of the object.
172 fn position(&self) -> &Self::Point;
173}
174
175/// An immutable, flat representation of a [k-d tree](https://en.wikipedia.org/wiki/K-d_tree)
176///
177/// Accelerates spatial queries and nearest neighbour search by sorting the objects according to the coordinate values of their positions.
178///
179/// Note that this tree dereferences to and deserializes as a slice of objects.
180/// Modifying object positions through interior mutability or deserializing a modified sequence is safe but will lead to incorrect results.
181#[derive(Debug, Default, Clone)]
182#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
183#[cfg_attr(feature = "serde", serde(transparent))]
184pub struct KdTree<O, S = Box<[O]>>
185where
186 S: AsRef<[O]>,
187{
188 objects: S,
189 _marker: PhantomData<O>,
190}
191
192impl<O, S> KdTree<O, S>
193where
194 O: Object,
195 S: AsRef<[O]>,
196{
197 /// Interprets the given `objects` as a tree
198 ///
199 /// Supplying `objects` which are not actually sorted as a k-d tree is safe but will lead to incorrect results.
200 pub fn new_unchecked(objects: S) -> Self {
201 Self {
202 objects,
203 _marker: PhantomData,
204 }
205 }
206}
207
208impl<O, S> Deref for KdTree<O, S>
209where
210 S: AsRef<[O]>,
211{
212 type Target = [O];
213
214 fn deref(&self) -> &Self::Target {
215 self.objects.as_ref()
216 }
217}
218
219impl<O, S> AsRef<[O]> for KdTree<O, S>
220where
221 S: AsRef<[O]>,
222{
223 fn as_ref(&self) -> &[O] {
224 self.objects.as_ref()
225 }
226}
227
228fn split<O>(objects: &[O]) -> (&[O], &O, &[O]) {
229 let (left, objects) = objects.split_at(objects.len() / 2);
230 let (mid, right) = objects.split_first().unwrap();
231
232 (left, mid, right)
233}
234
235fn contains<P>(aabb: &(P, P), position: &P) -> bool
236where
237 P: Point,
238{
239 (0..P::DIM).all(|axis| {
240 aabb.0.coord(axis) <= position.coord(axis) && position.coord(axis) <= aabb.1.coord(axis)
241 })
242}
243
244#[cfg(test)]
245mod tests {
246 use super::*;
247
248 use std::cmp::Ordering;
249
250 use proptest::{collection::vec, strategy::Strategy};
251
252 pub fn random_points(len: usize) -> impl Strategy<Value = Vec<[f32; 2]>> {
253 (vec(0.0_f32..=1.0, len), vec(0.0_f32..=1.0, len))
254 .prop_map(|(x, y)| x.into_iter().zip(y).map(|(x, y)| [x, y]).collect())
255 }
256
257 #[derive(Debug, PartialEq)]
258 pub struct RandomObject(pub [f32; 2]);
259
260 impl Eq for RandomObject {}
261
262 impl PartialOrd for RandomObject {
263 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
264 Some(self.cmp(other))
265 }
266 }
267
268 impl Ord for RandomObject {
269 fn cmp(&self, other: &Self) -> Ordering {
270 self.0.partial_cmp(&other.0).unwrap()
271 }
272 }
273
274 impl Object for RandomObject {
275 type Point = [f32; 2];
276
277 fn position(&self) -> &Self::Point {
278 &self.0
279 }
280 }
281
282 pub fn random_objects(len: usize) -> impl Strategy<Value = Box<[RandomObject]>> {
283 random_points(len).prop_map(|points| points.into_iter().map(RandomObject).collect())
284 }
285}