1use std::ops::ControlFlow;
2
3use num_traits::Num;
4#[cfg(feature = "rayon")]
5use rayon::join;
6
7use crate::{contains, split, Distance, KdTree, Object, Point};
8
9pub trait Query<P: Point> {
16 fn aabb(&self) -> &(P, P);
22
23 fn test(&self, position: &P) -> bool;
25}
26
27#[derive(Debug)]
29pub struct WithinBoundingBox<T, const N: usize> {
30 aabb: ([T; N], [T; N]),
31}
32
33impl<T, const N: usize> WithinBoundingBox<T, N> {
34 pub fn new(lower: [T; N], upper: [T; N]) -> Self {
36 Self {
37 aabb: (lower, upper),
38 }
39 }
40}
41
42impl<T, const N: usize> Query<[T; N]> for WithinBoundingBox<T, N>
43where
44 T: Num + Copy + PartialOrd,
45{
46 fn aabb(&self) -> &([T; N], [T; N]) {
47 &self.aabb
48 }
49
50 fn test(&self, _position: &[T; N]) -> bool {
51 true
52 }
53}
54
55#[derive(Debug)]
57pub struct WithinDistance<T, const N: usize> {
58 aabb: ([T; N], [T; N]),
59 center: [T; N],
60 distance_2: T,
61}
62
63impl<T, const N: usize> WithinDistance<T, N>
64where
65 T: Num + Copy + PartialOrd,
66{
67 pub fn new(center: [T; N], distance: T) -> Self {
69 Self {
70 aabb: (
71 center.map(|coord| coord - distance),
72 center.map(|coord| coord + distance),
73 ),
74 center,
75 distance_2: distance * distance,
76 }
77 }
78}
79
80impl<T, const N: usize> Query<[T; N]> for WithinDistance<T, N>
81where
82 T: Num + Copy + PartialOrd,
83{
84 fn aabb(&self) -> &([T; N], [T; N]) {
85 &self.aabb
86 }
87
88 fn test(&self, position: &[T; N]) -> bool {
89 self.center.distance_2(position) <= self.distance_2
90 }
91}
92
93impl<O, S> KdTree<O, S>
94where
95 O: Object,
96 S: AsRef<[O]>,
97{
98 pub fn look_up<'a, Q, V, R>(&'a self, query: &Q, visitor: V) -> ControlFlow<R>
105 where
106 Q: Query<O::Point>,
107 V: FnMut(&'a O) -> ControlFlow<R>,
108 {
109 let objects = self.objects.as_ref();
110
111 if !objects.is_empty() {
112 look_up(&mut LookUpArgs { query, visitor }, objects, 0)?;
113 }
114
115 ControlFlow::Continue(())
116 }
117
118 #[cfg(feature = "rayon")]
119 pub fn par_look_up<'a, Q, V, R>(&'a self, query: &Q, visitor: V) -> ControlFlow<R>
129 where
130 O: Send + Sync,
131 O::Point: Sync,
132 Q: Query<O::Point> + Sync,
133 V: Fn(&'a O) -> ControlFlow<R> + Sync,
134 R: Send,
135 {
136 let objects = self.objects.as_ref();
137
138 if !objects.is_empty() {
139 par_look_up(&LookUpArgs { query, visitor }, objects, 0)?;
140 }
141
142 ControlFlow::Continue(())
143 }
144}
145
146struct LookUpArgs<'a, Q, V> {
147 query: &'a Q,
148 visitor: V,
149}
150
151fn look_up<'a, O, Q, V, R>(
152 args: &mut LookUpArgs<Q, V>,
153 mut objects: &'a [O],
154 mut axis: usize,
155) -> ControlFlow<R>
156where
157 O: Object,
158 Q: Query<O::Point>,
159 V: FnMut(&'a O) -> ControlFlow<R>,
160{
161 loop {
162 let (left, object, right) = split(objects);
163
164 let position = object.position();
165
166 if contains(args.query.aabb(), position) && args.query.test(position) {
167 (args.visitor)(object)?;
168 }
169
170 let search_left =
171 !left.is_empty() && args.query.aabb().0.coord(axis) <= position.coord(axis);
172
173 let search_right =
174 !right.is_empty() && position.coord(axis) <= args.query.aabb().1.coord(axis);
175
176 axis = (axis + 1) % O::Point::DIM;
177
178 match (search_left, search_right) {
179 (true, true) => {
180 look_up(args, left, axis)?;
181
182 objects = right;
183 }
184 (true, false) => objects = left,
185 (false, true) => objects = right,
186 (false, false) => return ControlFlow::Continue(()),
187 }
188 }
189}
190
191#[cfg(feature = "rayon")]
192fn par_look_up<'a, O, Q, V, R>(
193 args: &LookUpArgs<Q, V>,
194 mut objects: &'a [O],
195 mut axis: usize,
196) -> ControlFlow<R>
197where
198 O: Object + Send + Sync,
199 O::Point: Sync,
200 Q: Query<O::Point> + Sync,
201 V: Fn(&'a O) -> ControlFlow<R> + Sync,
202 R: Send,
203{
204 loop {
205 let (left, object, right) = split(objects);
206
207 let position = object.position();
208
209 if contains(args.query.aabb(), position) && args.query.test(position) {
210 (args.visitor)(object)?;
211 }
212
213 let search_left =
214 !left.is_empty() && args.query.aabb().0.coord(axis) <= position.coord(axis);
215
216 let search_right =
217 !right.is_empty() && position.coord(axis) <= args.query.aabb().1.coord(axis);
218
219 axis = (axis + 1) % O::Point::DIM;
220
221 match (search_left, search_right) {
222 (true, true) => {
223 let (left, right) = join(
224 || par_look_up(args, left, axis),
225 || par_look_up(args, right, axis),
226 );
227
228 left?;
229 right?;
230
231 return ControlFlow::Continue(());
232 }
233 (true, false) => objects = left,
234 (false, true) => objects = right,
235 (false, false) => return ControlFlow::Continue(()),
236 }
237 }
238}
239
240#[cfg(test)]
241mod tests {
242 use super::*;
243
244 #[cfg(feature = "rayon")]
245 use std::sync::Mutex;
246
247 use proptest::{collection::vec, strategy::Strategy, test_runner::TestRunner};
248
249 use crate::tests::{random_objects, random_points};
250
251 pub fn random_queries(len: usize) -> impl Strategy<Value = Vec<WithinDistance<f32, 2>>> {
252 (random_points(len), vec(0.0_f32..=1.0, len)).prop_map(|(centers, distances)| {
253 centers
254 .into_iter()
255 .zip(distances)
256 .map(|(center, distance)| WithinDistance::new(center, distance))
257 .collect()
258 })
259 }
260
261 #[test]
262 fn random_look_up() {
263 TestRunner::default()
264 .run(
265 &(random_objects(100), random_queries(10)),
266 |(objects, queries)| {
267 let index = KdTree::new(objects);
268
269 for query in queries {
270 let mut results1 = index
271 .iter()
272 .filter(|object| query.test(object.position()))
273 .collect::<Vec<_>>();
274
275 let mut results2 = Vec::new();
276 index
277 .look_up(&query, |object| {
278 results2.push(object);
279 ControlFlow::<()>::Continue(())
280 })
281 .continue_value()
282 .unwrap();
283
284 results1.sort_unstable();
285 results2.sort_unstable();
286 assert_eq!(results1, results2);
287 }
288
289 Ok(())
290 },
291 )
292 .unwrap();
293 }
294
295 #[cfg(feature = "rayon")]
296 #[test]
297 fn random_par_look_up() {
298 TestRunner::default()
299 .run(
300 &(random_objects(100), random_queries(10)),
301 |(objects, queries)| {
302 let index = KdTree::par_new(objects);
303
304 for query in queries {
305 let mut results1 = index
306 .iter()
307 .filter(|object| query.test(object.position()))
308 .collect::<Vec<_>>();
309
310 let results2 = Mutex::new(Vec::new());
311 index
312 .par_look_up(&query, |object| {
313 results2.lock().unwrap().push(object);
314 ControlFlow::<()>::Continue(())
315 })
316 .continue_value()
317 .unwrap();
318 let mut results2 = results2.into_inner().unwrap();
319
320 results1.sort_unstable();
321 results2.sort_unstable();
322 assert_eq!(results1, results2);
323 }
324
325 Ok(())
326 },
327 )
328 .unwrap();
329 }
330}