voracious_radix_sort/sorts/
ska_sort.rs

1use super::super::{RadixKey, Radixable};
2use super::american_flag_sort::serial_radixsort_rec;
3use super::comparative_sort::insertion_sort;
4use super::utils::{get_histogram, prefix_sums, Params};
5
6const UNROLL_SIZE: usize = 4;
7
8pub fn ska_swap<T: Radixable<K>, K: RadixKey>(
9    arr: &mut [T],
10    heads: &mut Vec<usize>,
11    tails: &[usize],
12    mask: <<T as Radixable<K>>::Key as RadixKey>::Key,
13    shift: usize,
14) {
15    let mut buckets_size = Vec::new();
16    for i in 0..heads.len() {
17        buckets_size.push((i, tails[i] - heads[i]))
18    }
19    buckets_size.sort_unstable_by_key(|elt| elt.1);
20    buckets_size.pop();
21
22    while !buckets_size.is_empty() {
23        let mut to_remove = Vec::new();
24        for (i, (computed_index, _)) in buckets_size.iter().enumerate().rev() {
25            let span = tails[*computed_index] - heads[*computed_index];
26
27            if span > 0 {
28                let offset = heads[*computed_index];
29                let quotient = span / UNROLL_SIZE;
30                let remainder = span % UNROLL_SIZE;
31
32                for q in 0..quotient {
33                    let o = offset + q * UNROLL_SIZE;
34
35                    unsafe {
36                        let tb0 = arr.get_unchecked(o).extract(mask, shift);
37                        let tb1 = arr.get_unchecked(o + 1).extract(mask, shift);
38                        let tb2 = arr.get_unchecked(o + 2).extract(mask, shift);
39                        let tb3 = arr.get_unchecked(o + 3).extract(mask, shift);
40
41                        let dest_index_0 = *heads.get_unchecked(tb0);
42                        heads[tb0] += 1;
43                        let dest_index_1 = *heads.get_unchecked(tb1);
44                        heads[tb1] += 1;
45                        let dest_index_2 = *heads.get_unchecked(tb2);
46                        heads[tb2] += 1;
47                        let dest_index_3 = *heads.get_unchecked(tb3);
48                        heads[tb3] += 1;
49
50                        arr.swap(o, dest_index_0);
51                        arr.swap(o + 1, dest_index_1);
52                        arr.swap(o + 2, dest_index_2);
53                        arr.swap(o + 3, dest_index_3);
54                    }
55                }
56
57                let n_o = offset + UNROLL_SIZE * quotient;
58
59                for i in 0..remainder {
60                    unsafe {
61                        let b = arr.get_unchecked(n_o + i).extract(mask, shift);
62                        arr.swap(n_o + i, heads[b]);
63                        heads[b] += 1;
64                    }
65                }
66            } else {
67                to_remove.push(i);
68            }
69        }
70
71        for i in to_remove.iter() {
72            buckets_size.remove(*i);
73        }
74    }
75}
76
77pub fn ska_sort_rec<T: Radixable<K>, K: RadixKey>(arr: &mut [T], p: Params) {
78    if arr.len() <= 64 {
79        insertion_sort(arr);
80        return;
81    }
82    if arr.len() <= 1024 {
83        serial_radixsort_rec(arr, p);
84        return;
85    }
86
87    let dummy = arr[0];
88    let (mask, shift) = dummy.get_mask_and_shift_from_left(&p);
89    let histogram = get_histogram(arr, &p, mask, shift);
90    let (p_sums, mut heads, tails) = prefix_sums(&histogram);
91
92    ska_swap(arr, &mut heads, &tails, mask, shift);
93
94    let mut rest = arr;
95    if p.level < p.max_level - 1 {
96        for i in 0..(p.radix_range) {
97            let bucket_end = p_sums[i + 1] - p_sums[i];
98            let (first_part, second_part) = rest.split_at_mut(bucket_end);
99            rest = second_part;
100            if histogram[i] > 1 {
101                let new_params = p.new_level(p.level + 1);
102                ska_sort_rec(first_part, new_params);
103            }
104        }
105    }
106}
107
108/// # Ska sort
109///
110/// An implementation of the
111/// [Ska sort](https://probablydance.com/2016/12/27/i-wrote-a-faster-sorting-algorithm/)
112/// algorithm.
113///
114/// The Ska sort is an in place unstable radix sort.
115pub fn ska_sort<T: Radixable<K>, K: RadixKey>(arr: &mut [T], radix: usize) {
116    if arr.len() <= 64 {
117        insertion_sort(arr);
118        return;
119    }
120
121    let dummy = arr[0];
122    let (_, raw_offset) = dummy.compute_offset(arr, radix);
123    let max_level = dummy.compute_max_level(raw_offset, radix);
124
125    if max_level == 0 {
126        return;
127    }
128
129    let params = Params::new(0, radix, raw_offset, max_level);
130
131    ska_sort_rec(arr, params);
132}