1use std::cmp::Ordering;
2use std::collections::BinaryHeap;
3use std::mem::swap;
4use std::ops::ControlFlow;
5
6use num_traits::{Float, Zero};
7
8use crate::{iter::branch_for_each, Distance, Node, Object, Point, RTree, ROOT_IDX};
9
10impl<O, S> RTree<O, S>
11where
12 O: Object,
13 O: Distance<O::Point>,
14 O::Point: Distance<O::Point>,
15 <O::Point as Point>::Coord: Float,
16 S: AsRef<[Node<O>]>,
17{
18 pub fn nearest(&self, target: &O::Point) -> Option<(&O, <O::Point as Point>::Coord)> {
24 let mut min_minmax_distance_2 = <O::Point as Point>::Coord::infinity();
25
26 let nearest = from_near_to_far(
27 self.nodes.as_ref(),
28 target,
29 |aabb, distance_2| {
30 if min_minmax_distance_2 >= distance_2 {
31 let minmax_distance_2 = minmax_distance_2(aabb, target);
32
33 if min_minmax_distance_2 > minmax_distance_2 {
34 min_minmax_distance_2 = minmax_distance_2;
35 }
36
37 true
38 } else {
39 false
40 }
41 },
42 |object, distance_2| ControlFlow::Break((object, distance_2)),
43 );
44
45 match nearest {
46 ControlFlow::Break(nearest) => Some(nearest),
47 ControlFlow::Continue(()) => None,
48 }
49 }
50
51 pub fn from_near_to_far<'a, V, R>(&'a self, target: &O::Point, visitor: V) -> ControlFlow<R>
55 where
56 V: FnMut(&'a O, <O::Point as Point>::Coord) -> ControlFlow<R>,
57 {
58 from_near_to_far(
59 self.nodes.as_ref(),
60 target,
61 |_aabb, _distance_2| true,
62 visitor,
63 )
64 }
65}
66
67fn from_near_to_far<'a, O, F, V, R>(
68 nodes: &'a [Node<O>],
69 target: &O::Point,
70 mut filter: F,
71 mut visitor: V,
72) -> ControlFlow<R>
73where
74 O: Object,
75 O: Distance<O::Point>,
76 O::Point: Distance<O::Point>,
77 <O::Point as Point>::Coord: Float,
78 F: FnMut(&(O::Point, O::Point), <O::Point as Point>::Coord) -> bool,
79 V: FnMut(&'a O, <O::Point as Point>::Coord) -> ControlFlow<R>,
80{
81 let mut items = BinaryHeap::new();
82
83 items.push(NearestItem {
84 idx: ROOT_IDX,
85 distance_2: <O::Point as Point>::Coord::nan(),
86 });
87
88 while let Some(item) = items.pop() {
89 let [node, rest @ ..] = &nodes[item.idx..] else {
90 unreachable!()
91 };
92
93 match node {
94 Node::Branch { len, .. } => branch_for_each(len, rest, |idx| {
95 let obj_aabb;
96
97 let (aabb, distance_2) = match &nodes[idx] {
98 Node::Branch { aabb, .. } => (aabb, aabb.distance_2(target)),
99 Node::Twig(_) => unreachable!(),
100 Node::Leaf(obj) => {
101 obj_aabb = obj.aabb();
102
103 (&obj_aabb, obj.distance_2(target))
104 }
105 };
106
107 if filter(aabb, distance_2) {
108 items.push(NearestItem { idx, distance_2 });
109 }
110
111 ControlFlow::Continue(())
112 })?,
113 Node::Twig(_) => unreachable!(),
114 Node::Leaf(obj) => visitor(obj, item.distance_2)?,
115 }
116 }
117
118 ControlFlow::Continue(())
119}
120
121fn minmax_distance_2<P>(aabb: &(P, P), target: &P) -> P::Coord
122where
123 P: Point,
124 P::Coord: Float,
125{
126 let mut max_diff = P::Coord::zero();
127 let mut max_diff_axis = 0;
128 let mut max_diff_min_2 = P::Coord::zero();
129
130 let max_2 = P::build(|axis| {
131 let lower = aabb.0.coord(axis);
132 let upper = aabb.1.coord(axis);
133 let target = target.coord(axis);
134
135 let mut min_2 = (lower - target).powi(2);
136 let mut max_2 = (upper - target).powi(2);
137
138 if min_2 > max_2 {
139 swap(&mut min_2, &mut max_2);
140 }
141
142 let diff = max_2 - min_2;
143
144 if max_diff <= diff {
145 max_diff = diff;
146 max_diff_axis = axis;
147 max_diff_min_2 = min_2;
148 }
149
150 max_2
151 });
152
153 (0..P::DIM).fold(P::Coord::zero(), |res, axis| {
154 let minmax_2 = if axis == max_diff_axis {
155 max_diff_min_2
156 } else {
157 max_2.coord(axis)
158 };
159
160 res + minmax_2
161 })
162}
163
164struct NearestItem<F> {
165 idx: usize,
166 distance_2: F,
167}
168
169impl<F> PartialEq for NearestItem<F>
170where
171 F: Float,
172{
173 fn eq(&self, other: &Self) -> bool {
174 other.distance_2 == self.distance_2
175 }
176}
177
178impl<F> Eq for NearestItem<F> where F: Float {}
179
180impl<F> PartialOrd for NearestItem<F>
181where
182 F: Float,
183{
184 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
185 Some(self.cmp(other))
186 }
187}
188
189impl<F> Ord for NearestItem<F>
190where
191 F: Float,
192{
193 fn cmp(&self, other: &Self) -> Ordering {
194 other.distance_2.partial_cmp(&self.distance_2).unwrap()
195 }
196}
197
198#[cfg(test)]
199mod tests {
200 use super::*;
201
202 use proptest::test_runner::TestRunner;
203
204 use crate::{
205 tests::{random_objects, random_points},
206 DEF_NODE_LEN,
207 };
208
209 #[test]
210 fn random_nearest() {
211 TestRunner::default()
212 .run(
213 &(random_objects(100), random_points(10)),
214 |(objects, targets)| {
215 let index = RTree::new(DEF_NODE_LEN, objects);
216
217 for target in targets {
218 let result1 = index
219 .objects()
220 .map(|obj| obj.distance_2(&target))
221 .min_by(|lhs, rhs| lhs.partial_cmp(rhs).unwrap())
222 .unwrap();
223
224 let (obj, result2) = index.nearest(&target).unwrap();
225 assert_eq!(obj.distance_2(&target), result2);
226
227 assert_eq!(result1, result2);
228 }
229
230 Ok(())
231 },
232 )
233 .unwrap();
234 }
235
236 #[test]
237 fn random_from_near_to_far() {
238 TestRunner::default()
239 .run(
240 &(random_objects(100), random_points(10)),
241 |(objects, targets)| {
242 let index = RTree::new(DEF_NODE_LEN, objects);
243
244 for target in targets {
245 let mut result1 = index
246 .objects()
247 .map(|obj| obj.distance_2(&target))
248 .collect::<Vec<_>>();
249
250 result1.sort_unstable_by(|lhs, rhs| lhs.partial_cmp(rhs).unwrap());
251
252 let mut result2 = Vec::new();
253
254 index
255 .from_near_to_far(&target, |obj, distance_2| {
256 assert_eq!(obj.distance_2(&target), distance_2);
257
258 result2.push(distance_2);
259 ControlFlow::<()>::Continue(())
260 })
261 .continue_value()
262 .unwrap();
263
264 assert_eq!(result1, result2);
265 }
266
267 Ok(())
268 },
269 )
270 .unwrap();
271 }
272}