1use std::mem::swap;
2
3use num_traits::Float;
4
5use crate::{split, Distance, KdTree, Object, Point};
6
7impl<O, S> KdTree<O, S>
8where
9 O: Object,
10 O::Point: Distance,
11 <O::Point as Point>::Coord: Float,
12 S: AsRef<[O]>,
13{
14 pub fn nearest(&self, target: &O::Point) -> Option<&O> {
20 let mut args = NearestArgs {
21 target,
22 distance_2: <O::Point as Point>::Coord::infinity(),
23 best_match: None,
24 };
25
26 let objects = self.objects.as_ref();
27
28 if !objects.is_empty() {
29 nearest(&mut args, objects, 0);
30 }
31
32 args.best_match
33 }
34}
35
36struct NearestArgs<'a, 'b, O>
37where
38 O: Object,
39{
40 target: &'b O::Point,
41 distance_2: <O::Point as Point>::Coord,
42 best_match: Option<&'a O>,
43}
44
45fn nearest<'a, O>(args: &mut NearestArgs<'a, '_, O>, mut objects: &'a [O], mut axis: usize)
46where
47 O: Object,
48 O::Point: Distance,
49 <O::Point as Point>::Coord: Float,
50{
51 loop {
52 let (mut left, object, mut right) = split(objects);
53
54 let position = object.position();
55
56 let distance_2 = args.target.distance_2(position);
57
58 if args.distance_2 > distance_2 {
59 args.distance_2 = distance_2;
60 args.best_match = Some(object);
61 }
62
63 let offset = args.target.coord(axis) - position.coord(axis);
64
65 if offset.is_sign_positive() {
66 swap(&mut left, &mut right);
67 }
68
69 let search_left = !left.is_empty();
70 let search_right = !right.is_empty();
71
72 axis = (axis + 1) % O::Point::DIM;
73
74 if search_right {
75 if search_left {
76 nearest(args, left, axis);
77 }
78
79 if args.distance_2 > offset.powi(2) {
80 objects = right;
81 } else {
82 return;
83 }
84 } else if search_left {
85 objects = left;
86 } else {
87 return;
88 }
89 }
90}
91
92#[cfg(test)]
93mod tests {
94 use super::*;
95
96 use proptest::test_runner::TestRunner;
97
98 use crate::tests::{random_objects, random_points};
99
100 #[test]
101 fn random_nearest() {
102 TestRunner::default()
103 .run(
104 &(random_objects(100), random_points(10)),
105 |(objects, targets)| {
106 let index = KdTree::new(objects);
107
108 for target in targets {
109 let result1 = index
110 .iter()
111 .min_by(|lhs, rhs| {
112 let lhs = lhs.0.distance_2(&target);
113 let rhs = rhs.0.distance_2(&target);
114
115 lhs.partial_cmp(&rhs).unwrap()
116 })
117 .unwrap();
118
119 let result2 = index.nearest(&target).unwrap();
120
121 assert_eq!(result1, result2);
122 }
123
124 Ok(())
125 },
126 )
127 .unwrap();
128 }
129}