use std::mem::swap;
use num_traits::Float;
use crate::{split, KdTree, Object, Point};
impl<O, S> KdTree<O, S>
where
O: Object,
<O::Point as Point>::Coord: Float,
S: AsRef<[O]>,
{
pub fn nearest(&self, target: &O::Point) -> Option<&O> {
let mut args = NearestArgs {
target,
distance_2: <O::Point as Point>::Coord::infinity(),
best_match: None,
};
let objects = self.objects.as_ref();
if !objects.is_empty() {
nearest(&mut args, objects, 0);
}
args.best_match
}
}
struct NearestArgs<'a, 'b, O>
where
O: Object,
{
target: &'b O::Point,
distance_2: <O::Point as Point>::Coord,
best_match: Option<&'a O>,
}
fn nearest<'a, O>(args: &mut NearestArgs<'a, '_, O>, mut objects: &'a [O], mut axis: usize)
where
O: Object,
<O::Point as Point>::Coord: Float,
{
loop {
let (mut left, object, mut right) = split(objects);
let position = object.position();
let distance_2 = args.target.distance_2(position);
if args.distance_2 > distance_2 {
args.distance_2 = distance_2;
args.best_match = Some(object);
}
let offset = args.target.coord(axis) - position.coord(axis);
if offset.is_sign_positive() {
swap(&mut left, &mut right);
}
let search_left = !left.is_empty();
let search_right = !right.is_empty();
axis = (axis + 1) % O::Point::DIM;
if search_right {
if search_left {
nearest(args, left, axis);
}
if args.distance_2 > offset.powi(2) {
objects = right;
} else {
return;
}
} else if search_left {
objects = left;
} else {
return;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use proptest::test_runner::TestRunner;
use crate::tests::{random_objects, random_points};
#[test]
fn random_nearest() {
TestRunner::default()
.run(
&(random_objects(100), random_points(10)),
|(objects, targets)| {
let index = KdTree::new(objects);
for target in targets {
let result1 = index
.iter()
.min_by(|lhs, rhs| {
let lhs = lhs.0.distance_2(&target);
let rhs = rhs.0.distance_2(&target);
lhs.partial_cmp(&rhs).unwrap()
})
.unwrap();
let result2 = index.nearest(&target).unwrap();
assert_eq!(result1, result2);
}
Ok(())
},
)
.unwrap();
}
}