1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
use std::mem::swap;

use num_traits::Float;

use crate::{split, KdTree, Object, Point};

impl<O, S> KdTree<O, S>
where
    O: Object,
    <O::Point as Point>::Coord: Float,
    S: AsRef<[O]>,
{
    /// Find the object nearest to the given `target`
    ///
    /// The distance is determined according to [`Point::distance_2`].
    ///
    /// Returns `None` if the tree is empty or if no object has a finite distance to the `target`.
    pub fn nearest(&self, target: &O::Point) -> Option<&O> {
        let mut args = NearestArgs {
            target,
            distance_2: <O::Point as Point>::Coord::infinity(),
            best_match: None,
        };

        let objects = self.objects.as_ref();

        if !objects.is_empty() {
            nearest(&mut args, objects, 0);
        }

        args.best_match
    }
}

struct NearestArgs<'a, 'b, O>
where
    O: Object,
{
    target: &'b O::Point,
    distance_2: <O::Point as Point>::Coord,
    best_match: Option<&'a O>,
}

fn nearest<'a, O>(args: &mut NearestArgs<'a, '_, O>, mut objects: &'a [O], mut axis: usize)
where
    O: Object,
    <O::Point as Point>::Coord: Float,
{
    loop {
        let (mut left, object, mut right) = split(objects);

        let position = object.position();

        let distance_2 = args.target.distance_2(position);

        if args.distance_2 > distance_2 {
            args.distance_2 = distance_2;
            args.best_match = Some(object);
        }

        let offset = args.target.coord(axis) - position.coord(axis);

        if offset.is_sign_positive() {
            swap(&mut left, &mut right);
        }

        let search_left = !left.is_empty();
        let search_right = !right.is_empty();

        axis = (axis + 1) % O::Point::DIM;

        if search_right {
            if search_left {
                nearest(args, left, axis);
            }

            if args.distance_2 > offset.powi(2) {
                objects = right;
            } else {
                return;
            }
        } else if search_left {
            objects = left;
        } else {
            return;
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    use proptest::test_runner::TestRunner;

    use crate::tests::{random_objects, random_points};

    #[test]
    fn random_nearest() {
        TestRunner::default()
            .run(
                &(random_objects(100), random_points(10)),
                |(objects, targets)| {
                    let index = KdTree::new(objects);

                    for target in targets {
                        let result1 = index
                            .iter()
                            .min_by(|lhs, rhs| {
                                let lhs = lhs.0.distance_2(&target);
                                let rhs = rhs.0.distance_2(&target);

                                lhs.partial_cmp(&rhs).unwrap()
                            })
                            .unwrap();

                        let result2 = index.nearest(&target).unwrap();

                        assert_eq!(result1, result2);
                    }

                    Ok(())
                },
            )
            .unwrap();
    }
}