sif_kdtree/
sort.rs

1use std::marker::PhantomData;
2
3#[cfg(feature = "rayon")]
4use rayon::join;
5
6use crate::{KdTree, Object, Point};
7
8impl<O, S> KdTree<O, S>
9where
10    O: Object,
11    S: AsRef<[O]> + AsMut<[O]>,
12{
13    /// Construct a new tree by sorting the given `objects`
14    pub fn new(mut objects: S) -> Self {
15        sort(objects.as_mut(), 0);
16
17        Self {
18            objects,
19            _marker: PhantomData,
20        }
21    }
22
23    #[cfg(feature = "rayon")]
24    /// Construct a new tree by sorting the given `objects`, in parallel
25    ///
26    /// Requires the `rayon` feature and dispatches tasks into the current [thread pool][rayon::ThreadPool].
27    pub fn par_new(mut objects: S) -> Self
28    where
29        O: Send,
30    {
31        par_sort(objects.as_mut(), 0);
32
33        Self {
34            objects,
35            _marker: PhantomData,
36        }
37    }
38}
39
40fn sort<O>(objects: &mut [O], axis: usize)
41where
42    O: Object,
43{
44    if objects.len() <= 1 {
45        return;
46    }
47
48    let (left, right, next_axis) = sort_axis(objects, axis);
49
50    sort(left, next_axis);
51    sort(right, next_axis);
52}
53
54#[cfg(feature = "rayon")]
55fn par_sort<O>(objects: &mut [O], axis: usize)
56where
57    O: Object + Send,
58{
59    if objects.len() <= 1 {
60        return;
61    }
62
63    let (left, right, next_axis) = sort_axis(objects, axis);
64
65    join(|| par_sort(left, next_axis), || par_sort(right, next_axis));
66}
67
68fn sort_axis<O>(objects: &mut [O], axis: usize) -> (&mut [O], &mut [O], usize)
69where
70    O: Object,
71{
72    let mid = objects.len() / 2;
73
74    let (left, _, right) = objects.select_nth_unstable_by(mid, |lhs, rhs| {
75        let lhs = lhs.position().coord(axis);
76        let rhs = rhs.position().coord(axis);
77
78        lhs.partial_cmp(&rhs).unwrap()
79    });
80
81    let next_axis = (axis + 1) % O::Point::DIM;
82
83    (left, right, next_axis)
84}