sif_rtree/
look_up.rs

1use std::num::NonZeroUsize;
2use std::ops::ControlFlow;
3
4use num_traits::Zero;
5
6use crate::{iter::branch_for_each, Distance, Node, Object, Point, RTree, ROOT_IDX};
7
8impl<O, S> RTree<O, S>
9where
10    O: Object,
11    S: AsRef<[Node<O>]>,
12{
13    /// Locates all objects whose axis-aligned bounding box (AABB) is contained in the queried AABB
14    pub fn look_up_aabb_contains<'a, V, R>(
15        &'a self,
16        query: &(O::Point, O::Point),
17        visitor: V,
18    ) -> ControlFlow<R>
19    where
20        V: FnMut(&'a O) -> ControlFlow<R>,
21    {
22        let query = |node: &Node<O>| match node {
23            Node::Branch { aabb, .. } => intersects(query, aabb),
24            Node::Twig(_) => unreachable!(),
25            Node::Leaf(obj) => contains(query, &obj.aabb()),
26        };
27
28        self.look_up(query, visitor)
29    }
30
31    /// Locates all objects whose axis-aligned bounding box (AABB) intersects the queried AABB
32    pub fn look_up_aabb_intersects<'a, V, R>(
33        &'a self,
34        query: &(O::Point, O::Point),
35        visitor: V,
36    ) -> ControlFlow<R>
37    where
38        V: FnMut(&'a O) -> ControlFlow<R>,
39    {
40        let query = |node: &Node<O>| match node {
41            Node::Branch { aabb, .. } => intersects(query, aabb),
42            Node::Twig(_) => unreachable!(),
43            Node::Leaf(obj) => intersects(query, &obj.aabb()),
44        };
45
46        self.look_up(query, visitor)
47    }
48
49    /// Locates all objects which contain the queried point
50    pub fn look_up_at_point<'a, V, R>(&'a self, query: &O::Point, visitor: V) -> ControlFlow<R>
51    where
52        O: Distance<O::Point>,
53        O::Point: Distance<O::Point>,
54        V: FnMut(&'a O) -> ControlFlow<R>,
55    {
56        let query = |node: &Node<O>| match node {
57            Node::Branch { aabb, .. } => aabb.contains(query),
58            Node::Twig(_) => unreachable!(),
59            Node::Leaf(obj) => obj.contains(query),
60        };
61
62        self.look_up(query, visitor)
63    }
64
65    /// Locates all objects which are within the given `distance` of the given `center`
66    pub fn look_up_within_distance_of_point<'a, V, R>(
67        &'a self,
68        center: &O::Point,
69        distance: <O::Point as Point>::Coord,
70        visitor: V,
71    ) -> ControlFlow<R>
72    where
73        O: Distance<O::Point>,
74        O::Point: Distance<O::Point>,
75        V: FnMut(&'a O) -> ControlFlow<R>,
76    {
77        let distance_2 = distance * distance;
78
79        let query = |node: &Node<O>| match node {
80            Node::Branch { aabb, .. } => aabb.distance_2(center) <= distance_2,
81            Node::Twig(_) => unreachable!(),
82            Node::Leaf(obj) => obj.distance_2(center) <= distance_2,
83        };
84
85        self.look_up(query, visitor)
86    }
87
88    fn look_up<'a, Q, V, R>(&'a self, query: Q, visitor: V) -> ControlFlow<R>
89    where
90        Q: FnMut(&'a Node<O>) -> bool,
91        V: FnMut(&'a O) -> ControlFlow<R>,
92    {
93        let mut args = LookUpArgs {
94            nodes: self.nodes.as_ref(),
95            query,
96            visitor,
97        };
98
99        let [node, rest @ ..] = &args.nodes[ROOT_IDX..] else {
100            unreachable!()
101        };
102
103        if (args.query)(node) {
104            match node {
105                Node::Branch { len, .. } => look_up(&mut args, len, rest)?,
106                Node::Twig(_) | Node::Leaf(_) => unreachable!(),
107            }
108        }
109
110        ControlFlow::Continue(())
111    }
112}
113
114struct LookUpArgs<'a, O, Q, V>
115where
116    O: Object,
117{
118    nodes: &'a [Node<O>],
119    query: Q,
120    visitor: V,
121}
122
123fn look_up<'a, O, Q, V, R>(
124    args: &mut LookUpArgs<'a, O, Q, V>,
125    mut len: &'a NonZeroUsize,
126    mut twigs: &'a [Node<O>],
127) -> ControlFlow<R>
128where
129    O: Object,
130    Q: FnMut(&'a Node<O>) -> bool,
131    V: FnMut(&'a O) -> ControlFlow<R>,
132{
133    loop {
134        let mut branch = None;
135
136        branch_for_each(len, twigs, |idx| {
137            let [node, rest @ ..] = &args.nodes[idx..] else {
138                unreachable!()
139            };
140
141            if (args.query)(node) {
142                match node {
143                    Node::Branch { len, .. } => {
144                        if let Some((len1, twigs1)) = branch.replace((len, rest)) {
145                            look_up(args, len1, twigs1)?;
146                        }
147                    }
148                    Node::Twig(_) => unreachable!(),
149                    Node::Leaf(obj) => (args.visitor)(obj)?,
150                }
151            }
152
153            ControlFlow::Continue(())
154        })?;
155
156        if let Some((len1, twigs1)) = branch {
157            len = len1;
158            twigs = twigs1;
159        } else {
160            return ControlFlow::Continue(());
161        }
162    }
163}
164
165fn intersects<P>(lhs: &(P, P), rhs: &(P, P)) -> bool
166where
167    P: Point,
168{
169    (0..P::DIM).all(|axis| {
170        lhs.0.coord(axis) <= rhs.1.coord(axis) && lhs.1.coord(axis) >= rhs.0.coord(axis)
171    })
172}
173
174fn contains<P>(lhs: &(P, P), rhs: &(P, P)) -> bool
175where
176    P: Point,
177{
178    (0..P::DIM).all(|axis| {
179        lhs.0.coord(axis) <= rhs.0.coord(axis) && lhs.1.coord(axis) >= rhs.1.coord(axis)
180    })
181}
182
183impl<P> Distance<P> for (P, P)
184where
185    P: Point + Distance<P>,
186{
187    fn distance_2(&self, point: &P) -> P::Coord {
188        if !self.contains(point) {
189            let min = self.1.min(&self.0.max(point));
190
191            min.distance_2(point)
192        } else {
193            P::Coord::zero()
194        }
195    }
196
197    fn contains(&self, point: &P) -> bool {
198        (0..P::DIM).all(|axis| {
199            self.0.coord(axis) <= point.coord(axis) && point.coord(axis) <= self.1.coord(axis)
200        })
201    }
202}
203
204#[cfg(test)]
205mod tests {
206    use super::*;
207
208    use proptest::{collection::vec, test_runner::TestRunner};
209
210    use crate::{
211        tests::{random_objects, random_points},
212        DEF_NODE_LEN,
213    };
214
215    #[test]
216    fn random_look_up_aabb_contains() {
217        TestRunner::default()
218            .run(
219                &(random_objects(100), random_objects(10)),
220                |(objects, queries)| {
221                    let index = RTree::new(DEF_NODE_LEN, objects);
222
223                    for query in queries {
224                        let mut results1 = index
225                            .objects()
226                            .filter(|obj| contains(&query.aabb(), &obj.aabb()))
227                            .collect::<Vec<_>>();
228
229                        let mut results2 = Vec::new();
230                        index
231                            .look_up_aabb_contains(&query.aabb(), |obj| {
232                                results2.push(obj);
233                                ControlFlow::<()>::Continue(())
234                            })
235                            .continue_value()
236                            .unwrap();
237
238                        results1.sort_unstable();
239                        results2.sort_unstable();
240                        assert_eq!(results1, results2);
241                    }
242
243                    Ok(())
244                },
245            )
246            .unwrap();
247    }
248
249    #[test]
250    fn random_look_up_aabb_intersects() {
251        TestRunner::default()
252            .run(
253                &(random_objects(100), random_objects(10)),
254                |(objects, queries)| {
255                    let index = RTree::new(DEF_NODE_LEN, objects);
256
257                    for query in queries {
258                        let mut results1 = index
259                            .objects()
260                            .filter(|obj| intersects(&query.aabb(), &obj.aabb()))
261                            .collect::<Vec<_>>();
262
263                        let mut results2 = Vec::new();
264                        index
265                            .look_up_aabb_intersects(&query.aabb(), |obj| {
266                                results2.push(obj);
267                                ControlFlow::<()>::Continue(())
268                            })
269                            .continue_value()
270                            .unwrap();
271
272                        results1.sort_unstable();
273                        results2.sort_unstable();
274                        assert_eq!(results1, results2);
275                    }
276
277                    Ok(())
278                },
279            )
280            .unwrap();
281    }
282
283    #[test]
284    fn random_look_up_at_point() {
285        TestRunner::default()
286            .run(
287                &(random_objects(100), random_points(10)),
288                |(objects, queries)| {
289                    let index = RTree::new(DEF_NODE_LEN, objects);
290
291                    for query in queries {
292                        let mut results1 = index
293                            .objects()
294                            .filter(|obj| obj.contains(&query))
295                            .collect::<Vec<_>>();
296
297                        let mut results2 = Vec::new();
298                        index
299                            .look_up_at_point(&query, |obj| {
300                                results2.push(obj);
301                                ControlFlow::<()>::Continue(())
302                            })
303                            .continue_value()
304                            .unwrap();
305
306                        results1.sort_unstable();
307                        results2.sort_unstable();
308                        assert_eq!(results1, results2);
309                    }
310
311                    Ok(())
312                },
313            )
314            .unwrap();
315    }
316
317    #[test]
318    fn random_look_up_within_distance_of_point() {
319        TestRunner::default()
320            .run(
321                &(
322                    random_objects(100),
323                    random_points(10),
324                    vec(0.0_f32..=1.0, 10),
325                ),
326                |(objects, centers, distances)| {
327                    let index = RTree::new(DEF_NODE_LEN, objects);
328
329                    for (center, distance) in centers.iter().zip(distances) {
330                        let mut results1 = index
331                            .objects()
332                            .filter(|obj| obj.distance_2(center) <= distance * distance)
333                            .collect::<Vec<_>>();
334
335                        let mut results2 = Vec::new();
336                        index
337                            .look_up_within_distance_of_point(center, distance, |obj| {
338                                results2.push(obj);
339                                ControlFlow::<()>::Continue(())
340                            })
341                            .continue_value()
342                            .unwrap();
343
344                        results1.sort_unstable();
345                        results2.sort_unstable();
346                        assert_eq!(results1, results2);
347                    }
348
349                    Ok(())
350                },
351            )
352            .unwrap();
353    }
354}