1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
use std::mem::swap;

use num_traits::Float;

use crate::{split, Distance, KdTree, Object, Point};

impl<O, S> KdTree<O, S>
where
    O: Object,
    O::Point: Distance,
    <O::Point as Point>::Coord: Float,
    S: AsRef<[O]>,
{
    /// Find the object nearest to the given `target`
    ///
    /// The distance is determined according to [`Point::distance_2`].
    ///
    /// Returns `None` if the tree is empty or if no object has a finite distance to the `target`.
    pub fn nearest(&self, target: &O::Point) -> Option<&O> {
        let mut args = NearestArgs {
            target,
            distance_2: <O::Point as Point>::Coord::infinity(),
            best_match: None,
        };

        let objects = self.objects.as_ref();

        if !objects.is_empty() {
            nearest(&mut args, objects, 0);
        }

        args.best_match
    }
}

struct NearestArgs<'a, 'b, O>
where
    O: Object,
{
    target: &'b O::Point,
    distance_2: <O::Point as Point>::Coord,
    best_match: Option<&'a O>,
}

fn nearest<'a, O>(args: &mut NearestArgs<'a, '_, O>, mut objects: &'a [O], mut axis: usize)
where
    O: Object,
    O::Point: Distance,
    <O::Point as Point>::Coord: Float,
{
    loop {
        let (mut left, object, mut right) = split(objects);

        let position = object.position();

        let distance_2 = args.target.distance_2(position);

        if args.distance_2 > distance_2 {
            args.distance_2 = distance_2;
            args.best_match = Some(object);
        }

        let offset = args.target.coord(axis) - position.coord(axis);

        if offset.is_sign_positive() {
            swap(&mut left, &mut right);
        }

        let search_left = !left.is_empty();
        let search_right = !right.is_empty();

        axis = (axis + 1) % O::Point::DIM;

        if search_right {
            if search_left {
                nearest(args, left, axis);
            }

            if args.distance_2 > offset.powi(2) {
                objects = right;
            } else {
                return;
            }
        } else if search_left {
            objects = left;
        } else {
            return;
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    use proptest::test_runner::TestRunner;

    use crate::tests::{random_objects, random_points};

    #[test]
    fn random_nearest() {
        TestRunner::default()
            .run(
                &(random_objects(100), random_points(10)),
                |(objects, targets)| {
                    let index = KdTree::new(objects);

                    for target in targets {
                        let result1 = index
                            .iter()
                            .min_by(|lhs, rhs| {
                                let lhs = lhs.0.distance_2(&target);
                                let rhs = rhs.0.distance_2(&target);

                                lhs.partial_cmp(&rhs).unwrap()
                            })
                            .unwrap();

                        let result2 = index.nearest(&target).unwrap();

                        assert_eq!(result1, result2);
                    }

                    Ok(())
                },
            )
            .unwrap();
    }
}