sif_kdtree/
look_up.rs

1use std::ops::ControlFlow;
2
3use num_traits::Num;
4#[cfg(feature = "rayon")]
5use rayon::join;
6
7use crate::{contains, split, Distance, KdTree, Object, Point};
8
9/// Defines a spatial query by its axis-aligned bounding box (AABB) and a method to test a single point
10///
11/// The AABB of the query is used to limit the points which are tested and therefore the AABB should be as tight as possible while staying aligned to the coordinate axes.
12/// The test method itself can then be relatively expensive like determining the distance of the given position to an arbitrary polygon.
13///
14/// A very simple example of implementing this trait is [`WithinBoundingBox`] whereas a very common example is [`WithinDistance`].
15pub trait Query<P: Point> {
16    /// Return the axis-aligned bounding box (AABB) of the query
17    ///
18    /// Represented by the corners with first the smallest and then the largest coordinate values.
19    ///
20    /// Note that calling this method is assumed to be cheap, returning a reference to an AABB stored in the interior of the object.
21    fn aabb(&self) -> &(P, P);
22
23    /// Check whether a given `position` inside the [axis-aligned bounding box (AABB)][Self::aabb] machtes the query.
24    fn test(&self, position: &P) -> bool;
25}
26
27/// A query which yields all objects within a given axis-aligned boundary box (AABB) in `N`-dimensional space
28#[derive(Debug)]
29pub struct WithinBoundingBox<T, const N: usize> {
30    aabb: ([T; N], [T; N]),
31}
32
33impl<T, const N: usize> WithinBoundingBox<T, N> {
34    /// Construct a query from first the corner smallest coordinate values `lower` and then the corner with the largest coordinate values `upper`
35    pub fn new(lower: [T; N], upper: [T; N]) -> Self {
36        Self {
37            aabb: (lower, upper),
38        }
39    }
40}
41
42impl<T, const N: usize> Query<[T; N]> for WithinBoundingBox<T, N>
43where
44    T: Num + Copy + PartialOrd,
45{
46    fn aabb(&self) -> &([T; N], [T; N]) {
47        &self.aabb
48    }
49
50    fn test(&self, _position: &[T; N]) -> bool {
51        true
52    }
53}
54
55/// A query which yields all objects within a given distance to a central point in `N`-dimensional real space
56#[derive(Debug)]
57pub struct WithinDistance<T, const N: usize> {
58    aabb: ([T; N], [T; N]),
59    center: [T; N],
60    distance_2: T,
61}
62
63impl<T, const N: usize> WithinDistance<T, N>
64where
65    T: Num + Copy + PartialOrd,
66{
67    /// Construct a query from the `center` and the largest allowed Euclidean `distance` to it
68    pub fn new(center: [T; N], distance: T) -> Self {
69        Self {
70            aabb: (
71                center.map(|coord| coord - distance),
72                center.map(|coord| coord + distance),
73            ),
74            center,
75            distance_2: distance * distance,
76        }
77    }
78}
79
80impl<T, const N: usize> Query<[T; N]> for WithinDistance<T, N>
81where
82    T: Num + Copy + PartialOrd,
83{
84    fn aabb(&self) -> &([T; N], [T; N]) {
85        &self.aabb
86    }
87
88    fn test(&self, position: &[T; N]) -> bool {
89        self.center.distance_2(position) <= self.distance_2
90    }
91}
92
93impl<O, S> KdTree<O, S>
94where
95    O: Object,
96    S: AsRef<[O]>,
97{
98    /// Find objects matching the given `query`
99    ///
100    /// Queries are defined by passing an implementor of the [`Query`] trait.
101    ///
102    /// Objects matching the `query` are passed to the `visitor` as they are found.
103    /// Depending on its [return value][`ControlFlow`], the search is continued or stopped.
104    pub fn look_up<'a, Q, V, R>(&'a self, query: &Q, visitor: V) -> ControlFlow<R>
105    where
106        Q: Query<O::Point>,
107        V: FnMut(&'a O) -> ControlFlow<R>,
108    {
109        let objects = self.objects.as_ref();
110
111        if !objects.is_empty() {
112            look_up(&mut LookUpArgs { query, visitor }, objects, 0)?;
113        }
114
115        ControlFlow::Continue(())
116    }
117
118    #[cfg(feature = "rayon")]
119    /// Find objects matching the given `query`, in parallel
120    ///
121    /// Queries are defined by passing an implementor of the [`Query`] trait.
122    ///
123    /// Objects matching the `query` are passed to the `visitor` as they are found.
124    /// In contrast to the [serial version][Self::look_up], parts of the search can continue
125    /// even after it has been stopped.
126    ///
127    /// Requires the `rayon` feature and dispatches tasks into the current [thread pool][rayon::ThreadPool].
128    pub fn par_look_up<'a, Q, V, R>(&'a self, query: &Q, visitor: V) -> ControlFlow<R>
129    where
130        O: Send + Sync,
131        O::Point: Sync,
132        Q: Query<O::Point> + Sync,
133        V: Fn(&'a O) -> ControlFlow<R> + Sync,
134        R: Send,
135    {
136        let objects = self.objects.as_ref();
137
138        if !objects.is_empty() {
139            par_look_up(&LookUpArgs { query, visitor }, objects, 0)?;
140        }
141
142        ControlFlow::Continue(())
143    }
144}
145
146struct LookUpArgs<'a, Q, V> {
147    query: &'a Q,
148    visitor: V,
149}
150
151fn look_up<'a, O, Q, V, R>(
152    args: &mut LookUpArgs<Q, V>,
153    mut objects: &'a [O],
154    mut axis: usize,
155) -> ControlFlow<R>
156where
157    O: Object,
158    Q: Query<O::Point>,
159    V: FnMut(&'a O) -> ControlFlow<R>,
160{
161    loop {
162        let (left, object, right) = split(objects);
163
164        let position = object.position();
165
166        if contains(args.query.aabb(), position) && args.query.test(position) {
167            (args.visitor)(object)?;
168        }
169
170        let search_left =
171            !left.is_empty() && args.query.aabb().0.coord(axis) <= position.coord(axis);
172
173        let search_right =
174            !right.is_empty() && position.coord(axis) <= args.query.aabb().1.coord(axis);
175
176        axis = (axis + 1) % O::Point::DIM;
177
178        match (search_left, search_right) {
179            (true, true) => {
180                look_up(args, left, axis)?;
181
182                objects = right;
183            }
184            (true, false) => objects = left,
185            (false, true) => objects = right,
186            (false, false) => return ControlFlow::Continue(()),
187        }
188    }
189}
190
191#[cfg(feature = "rayon")]
192fn par_look_up<'a, O, Q, V, R>(
193    args: &LookUpArgs<Q, V>,
194    mut objects: &'a [O],
195    mut axis: usize,
196) -> ControlFlow<R>
197where
198    O: Object + Send + Sync,
199    O::Point: Sync,
200    Q: Query<O::Point> + Sync,
201    V: Fn(&'a O) -> ControlFlow<R> + Sync,
202    R: Send,
203{
204    loop {
205        let (left, object, right) = split(objects);
206
207        let position = object.position();
208
209        if contains(args.query.aabb(), position) && args.query.test(position) {
210            (args.visitor)(object)?;
211        }
212
213        let search_left =
214            !left.is_empty() && args.query.aabb().0.coord(axis) <= position.coord(axis);
215
216        let search_right =
217            !right.is_empty() && position.coord(axis) <= args.query.aabb().1.coord(axis);
218
219        axis = (axis + 1) % O::Point::DIM;
220
221        match (search_left, search_right) {
222            (true, true) => {
223                let (left, right) = join(
224                    || par_look_up(args, left, axis),
225                    || par_look_up(args, right, axis),
226                );
227
228                left?;
229                right?;
230
231                return ControlFlow::Continue(());
232            }
233            (true, false) => objects = left,
234            (false, true) => objects = right,
235            (false, false) => return ControlFlow::Continue(()),
236        }
237    }
238}
239
240#[cfg(test)]
241mod tests {
242    use super::*;
243
244    #[cfg(feature = "rayon")]
245    use std::sync::Mutex;
246
247    use proptest::{collection::vec, strategy::Strategy, test_runner::TestRunner};
248
249    use crate::tests::{random_objects, random_points};
250
251    pub fn random_queries(len: usize) -> impl Strategy<Value = Vec<WithinDistance<f32, 2>>> {
252        (random_points(len), vec(0.0_f32..=1.0, len)).prop_map(|(centers, distances)| {
253            centers
254                .into_iter()
255                .zip(distances)
256                .map(|(center, distance)| WithinDistance::new(center, distance))
257                .collect()
258        })
259    }
260
261    #[test]
262    fn random_look_up() {
263        TestRunner::default()
264            .run(
265                &(random_objects(100), random_queries(10)),
266                |(objects, queries)| {
267                    let index = KdTree::new(objects);
268
269                    for query in queries {
270                        let mut results1 = index
271                            .iter()
272                            .filter(|object| query.test(object.position()))
273                            .collect::<Vec<_>>();
274
275                        let mut results2 = Vec::new();
276                        index
277                            .look_up(&query, |object| {
278                                results2.push(object);
279                                ControlFlow::<()>::Continue(())
280                            })
281                            .continue_value()
282                            .unwrap();
283
284                        results1.sort_unstable();
285                        results2.sort_unstable();
286                        assert_eq!(results1, results2);
287                    }
288
289                    Ok(())
290                },
291            )
292            .unwrap();
293    }
294
295    #[cfg(feature = "rayon")]
296    #[test]
297    fn random_par_look_up() {
298        TestRunner::default()
299            .run(
300                &(random_objects(100), random_queries(10)),
301                |(objects, queries)| {
302                    let index = KdTree::par_new(objects);
303
304                    for query in queries {
305                        let mut results1 = index
306                            .iter()
307                            .filter(|object| query.test(object.position()))
308                            .collect::<Vec<_>>();
309
310                        let results2 = Mutex::new(Vec::new());
311                        index
312                            .par_look_up(&query, |object| {
313                                results2.lock().unwrap().push(object);
314                                ControlFlow::<()>::Continue(())
315                            })
316                            .continue_value()
317                            .unwrap();
318                        let mut results2 = results2.into_inner().unwrap();
319
320                        results1.sort_unstable();
321                        results2.sort_unstable();
322                        assert_eq!(results1, results2);
323                    }
324
325                    Ok(())
326                },
327            )
328            .unwrap();
329    }
330}