sif_rtree/
nearest.rs

1use std::cmp::Ordering;
2use std::collections::BinaryHeap;
3use std::mem::swap;
4use std::ops::ControlFlow;
5
6use num_traits::{Float, Zero};
7
8use crate::{iter::branch_for_each, Distance, Node, Object, Point, RTree, ROOT_IDX};
9
10impl<O, S> RTree<O, S>
11where
12    O: Object,
13    O: Distance<O::Point>,
14    O::Point: Distance<O::Point>,
15    <O::Point as Point>::Coord: Float,
16    S: AsRef<[Node<O>]>,
17{
18    /// Find the object nearest to the given `target`
19    ///
20    /// Returns a reference to the object and its squared distance to the `target`.
21    ///
22    /// Returns `None` if no object has a finite distance to the `target`.
23    pub fn nearest(&self, target: &O::Point) -> Option<(&O, <O::Point as Point>::Coord)> {
24        let mut min_minmax_distance_2 = <O::Point as Point>::Coord::infinity();
25
26        let nearest = from_near_to_far(
27            self.nodes.as_ref(),
28            target,
29            |aabb, distance_2| {
30                if min_minmax_distance_2 >= distance_2 {
31                    let minmax_distance_2 = minmax_distance_2(aabb, target);
32
33                    if min_minmax_distance_2 > minmax_distance_2 {
34                        min_minmax_distance_2 = minmax_distance_2;
35                    }
36
37                    true
38                } else {
39                    false
40                }
41            },
42            |object, distance_2| ControlFlow::Break((object, distance_2)),
43        );
44
45        match nearest {
46            ControlFlow::Break(nearest) => Some(nearest),
47            ControlFlow::Continue(()) => None,
48        }
49    }
50
51    /// Visit all objects in ascending order of their distance to the given `target`
52    ///
53    /// Yields references to the objects and their squared distances to the `target`.
54    pub fn from_near_to_far<'a, V, R>(&'a self, target: &O::Point, visitor: V) -> ControlFlow<R>
55    where
56        V: FnMut(&'a O, <O::Point as Point>::Coord) -> ControlFlow<R>,
57    {
58        from_near_to_far(
59            self.nodes.as_ref(),
60            target,
61            |_aabb, _distance_2| true,
62            visitor,
63        )
64    }
65}
66
67fn from_near_to_far<'a, O, F, V, R>(
68    nodes: &'a [Node<O>],
69    target: &O::Point,
70    mut filter: F,
71    mut visitor: V,
72) -> ControlFlow<R>
73where
74    O: Object,
75    O: Distance<O::Point>,
76    O::Point: Distance<O::Point>,
77    <O::Point as Point>::Coord: Float,
78    F: FnMut(&(O::Point, O::Point), <O::Point as Point>::Coord) -> bool,
79    V: FnMut(&'a O, <O::Point as Point>::Coord) -> ControlFlow<R>,
80{
81    let mut items = BinaryHeap::new();
82
83    items.push(NearestItem {
84        idx: ROOT_IDX,
85        distance_2: <O::Point as Point>::Coord::nan(),
86    });
87
88    while let Some(item) = items.pop() {
89        let [node, rest @ ..] = &nodes[item.idx..] else {
90            unreachable!()
91        };
92
93        match node {
94            Node::Branch { len, .. } => branch_for_each(len, rest, |idx| {
95                let obj_aabb;
96
97                let (aabb, distance_2) = match &nodes[idx] {
98                    Node::Branch { aabb, .. } => (aabb, aabb.distance_2(target)),
99                    Node::Twig(_) => unreachable!(),
100                    Node::Leaf(obj) => {
101                        obj_aabb = obj.aabb();
102
103                        (&obj_aabb, obj.distance_2(target))
104                    }
105                };
106
107                if filter(aabb, distance_2) {
108                    items.push(NearestItem { idx, distance_2 });
109                }
110
111                ControlFlow::Continue(())
112            })?,
113            Node::Twig(_) => unreachable!(),
114            Node::Leaf(obj) => visitor(obj, item.distance_2)?,
115        }
116    }
117
118    ControlFlow::Continue(())
119}
120
121fn minmax_distance_2<P>(aabb: &(P, P), target: &P) -> P::Coord
122where
123    P: Point,
124    P::Coord: Float,
125{
126    let mut max_diff = P::Coord::zero();
127    let mut max_diff_axis = 0;
128    let mut max_diff_min_2 = P::Coord::zero();
129
130    let max_2 = P::build(|axis| {
131        let lower = aabb.0.coord(axis);
132        let upper = aabb.1.coord(axis);
133        let target = target.coord(axis);
134
135        let mut min_2 = (lower - target).powi(2);
136        let mut max_2 = (upper - target).powi(2);
137
138        if min_2 > max_2 {
139            swap(&mut min_2, &mut max_2);
140        }
141
142        let diff = max_2 - min_2;
143
144        if max_diff <= diff {
145            max_diff = diff;
146            max_diff_axis = axis;
147            max_diff_min_2 = min_2;
148        }
149
150        max_2
151    });
152
153    (0..P::DIM).fold(P::Coord::zero(), |res, axis| {
154        let minmax_2 = if axis == max_diff_axis {
155            max_diff_min_2
156        } else {
157            max_2.coord(axis)
158        };
159
160        res + minmax_2
161    })
162}
163
164struct NearestItem<F> {
165    idx: usize,
166    distance_2: F,
167}
168
169impl<F> PartialEq for NearestItem<F>
170where
171    F: Float,
172{
173    fn eq(&self, other: &Self) -> bool {
174        other.distance_2 == self.distance_2
175    }
176}
177
178impl<F> Eq for NearestItem<F> where F: Float {}
179
180impl<F> PartialOrd for NearestItem<F>
181where
182    F: Float,
183{
184    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
185        Some(self.cmp(other))
186    }
187}
188
189impl<F> Ord for NearestItem<F>
190where
191    F: Float,
192{
193    fn cmp(&self, other: &Self) -> Ordering {
194        other.distance_2.partial_cmp(&self.distance_2).unwrap()
195    }
196}
197
198#[cfg(test)]
199mod tests {
200    use super::*;
201
202    use proptest::test_runner::TestRunner;
203
204    use crate::{
205        tests::{random_objects, random_points},
206        DEF_NODE_LEN,
207    };
208
209    #[test]
210    fn random_nearest() {
211        TestRunner::default()
212            .run(
213                &(random_objects(100), random_points(10)),
214                |(objects, targets)| {
215                    let index = RTree::new(DEF_NODE_LEN, objects);
216
217                    for target in targets {
218                        let result1 = index
219                            .objects()
220                            .map(|obj| obj.distance_2(&target))
221                            .min_by(|lhs, rhs| lhs.partial_cmp(rhs).unwrap())
222                            .unwrap();
223
224                        let (obj, result2) = index.nearest(&target).unwrap();
225                        assert_eq!(obj.distance_2(&target), result2);
226
227                        assert_eq!(result1, result2);
228                    }
229
230                    Ok(())
231                },
232            )
233            .unwrap();
234    }
235
236    #[test]
237    fn random_from_near_to_far() {
238        TestRunner::default()
239            .run(
240                &(random_objects(100), random_points(10)),
241                |(objects, targets)| {
242                    let index = RTree::new(DEF_NODE_LEN, objects);
243
244                    for target in targets {
245                        let mut result1 = index
246                            .objects()
247                            .map(|obj| obj.distance_2(&target))
248                            .collect::<Vec<_>>();
249
250                        result1.sort_unstable_by(|lhs, rhs| lhs.partial_cmp(rhs).unwrap());
251
252                        let mut result2 = Vec::new();
253
254                        index
255                            .from_near_to_far(&target, |obj, distance_2| {
256                                assert_eq!(obj.distance_2(&target), distance_2);
257
258                                result2.push(distance_2);
259                                ControlFlow::<()>::Continue(())
260                            })
261                            .continue_value()
262                            .unwrap();
263
264                        assert_eq!(result1, result2);
265                    }
266
267                    Ok(())
268                },
269            )
270            .unwrap();
271    }
272}