voracious_radix_sort/sorts/
ska_sort.rs1use super::super::{RadixKey, Radixable};
2use super::american_flag_sort::serial_radixsort_rec;
3use super::comparative_sort::insertion_sort;
4use super::utils::{get_histogram, prefix_sums, Params};
5
6const UNROLL_SIZE: usize = 4;
7
8pub fn ska_swap<T: Radixable<K>, K: RadixKey>(
9 arr: &mut [T],
10 heads: &mut Vec<usize>,
11 tails: &[usize],
12 mask: <<T as Radixable<K>>::Key as RadixKey>::Key,
13 shift: usize,
14) {
15 let mut buckets_size = Vec::new();
16 for i in 0..heads.len() {
17 buckets_size.push((i, tails[i] - heads[i]))
18 }
19 buckets_size.sort_unstable_by_key(|elt| elt.1);
20 buckets_size.pop();
21
22 while !buckets_size.is_empty() {
23 let mut to_remove = Vec::new();
24 for (i, (computed_index, _)) in buckets_size.iter().enumerate().rev() {
25 let span = tails[*computed_index] - heads[*computed_index];
26
27 if span > 0 {
28 let offset = heads[*computed_index];
29 let quotient = span / UNROLL_SIZE;
30 let remainder = span % UNROLL_SIZE;
31
32 for q in 0..quotient {
33 let o = offset + q * UNROLL_SIZE;
34
35 unsafe {
36 let tb0 = arr.get_unchecked(o).extract(mask, shift);
37 let tb1 = arr.get_unchecked(o + 1).extract(mask, shift);
38 let tb2 = arr.get_unchecked(o + 2).extract(mask, shift);
39 let tb3 = arr.get_unchecked(o + 3).extract(mask, shift);
40
41 let dest_index_0 = *heads.get_unchecked(tb0);
42 heads[tb0] += 1;
43 let dest_index_1 = *heads.get_unchecked(tb1);
44 heads[tb1] += 1;
45 let dest_index_2 = *heads.get_unchecked(tb2);
46 heads[tb2] += 1;
47 let dest_index_3 = *heads.get_unchecked(tb3);
48 heads[tb3] += 1;
49
50 arr.swap(o, dest_index_0);
51 arr.swap(o + 1, dest_index_1);
52 arr.swap(o + 2, dest_index_2);
53 arr.swap(o + 3, dest_index_3);
54 }
55 }
56
57 let n_o = offset + UNROLL_SIZE * quotient;
58
59 for i in 0..remainder {
60 unsafe {
61 let b = arr.get_unchecked(n_o + i).extract(mask, shift);
62 arr.swap(n_o + i, heads[b]);
63 heads[b] += 1;
64 }
65 }
66 } else {
67 to_remove.push(i);
68 }
69 }
70
71 for i in to_remove.iter() {
72 buckets_size.remove(*i);
73 }
74 }
75}
76
77pub fn ska_sort_rec<T: Radixable<K>, K: RadixKey>(arr: &mut [T], p: Params) {
78 if arr.len() <= 64 {
79 insertion_sort(arr);
80 return;
81 }
82 if arr.len() <= 1024 {
83 serial_radixsort_rec(arr, p);
84 return;
85 }
86
87 let dummy = arr[0];
88 let (mask, shift) = dummy.get_mask_and_shift_from_left(&p);
89 let histogram = get_histogram(arr, &p, mask, shift);
90 let (p_sums, mut heads, tails) = prefix_sums(&histogram);
91
92 ska_swap(arr, &mut heads, &tails, mask, shift);
93
94 let mut rest = arr;
95 if p.level < p.max_level - 1 {
96 for i in 0..(p.radix_range) {
97 let bucket_end = p_sums[i + 1] - p_sums[i];
98 let (first_part, second_part) = rest.split_at_mut(bucket_end);
99 rest = second_part;
100 if histogram[i] > 1 {
101 let new_params = p.new_level(p.level + 1);
102 ska_sort_rec(first_part, new_params);
103 }
104 }
105 }
106}
107
108pub fn ska_sort<T: Radixable<K>, K: RadixKey>(arr: &mut [T], radix: usize) {
116 if arr.len() <= 64 {
117 insertion_sort(arr);
118 return;
119 }
120
121 let dummy = arr[0];
122 let (_, raw_offset) = dummy.compute_offset(arr, radix);
123 let max_level = dummy.compute_max_level(raw_offset, radix);
124
125 if max_level == 0 {
126 return;
127 }
128
129 let params = Params::new(0, radix, raw_offset, max_level);
130
131 ska_sort_rec(arr, params);
132}