1use std::marker::PhantomData;
2
3#[cfg(feature = "rayon")]
4use rayon::join;
5
6use crate::{KdTree, Object, Point};
7
8impl<O, S> KdTree<O, S>
9where
10 O: Object,
11 S: AsRef<[O]> + AsMut<[O]>,
12{
13 pub fn new(mut objects: S) -> Self {
15 sort(objects.as_mut(), 0);
16
17 Self {
18 objects,
19 _marker: PhantomData,
20 }
21 }
22
23 #[cfg(feature = "rayon")]
24 pub fn par_new(mut objects: S) -> Self
28 where
29 O: Send,
30 {
31 par_sort(objects.as_mut(), 0);
32
33 Self {
34 objects,
35 _marker: PhantomData,
36 }
37 }
38}
39
40fn sort<O>(objects: &mut [O], axis: usize)
41where
42 O: Object,
43{
44 if objects.len() <= 1 {
45 return;
46 }
47
48 let (left, right, next_axis) = sort_axis(objects, axis);
49
50 sort(left, next_axis);
51 sort(right, next_axis);
52}
53
54#[cfg(feature = "rayon")]
55fn par_sort<O>(objects: &mut [O], axis: usize)
56where
57 O: Object + Send,
58{
59 if objects.len() <= 1 {
60 return;
61 }
62
63 let (left, right, next_axis) = sort_axis(objects, axis);
64
65 join(|| par_sort(left, next_axis), || par_sort(right, next_axis));
66}
67
68fn sort_axis<O>(objects: &mut [O], axis: usize) -> (&mut [O], &mut [O], usize)
69where
70 O: Object,
71{
72 let mid = objects.len() / 2;
73
74 let (left, _, right) = objects.select_nth_unstable_by(mid, |lhs, rhs| {
75 let lhs = lhs.position().coord(axis);
76 let rhs = rhs.position().coord(axis);
77
78 lhs.partial_cmp(&rhs).unwrap()
79 });
80
81 let next_axis = (axis + 1) % O::Point::DIM;
82
83 (left, right, next_axis)
84}