1use core::mem;
2
3use crate::invoke_macro_for_types;
4
5pub trait RadixType: Clone + Copy + Default {
6 const IS_SIGNED: bool;
8 fn key(&self, round: usize) -> u8;
10 }
12
13macro_rules! is_signed {
14 (i8) => {
17 true
18 };
19 (i16) => {
20 true
21 };
22 (i32) => {
23 true
24 };
25 (i64) => {
26 true
27 };
28 (i128) => {
29 true
30 };
31 (f32) => {
32 true
33 };
34 (f64) => {
35 true
36 };
37 ($_t:ty) => {
38 false
39 };
40}
41
42macro_rules! radix_type {
43 ($a:ident) => {
46 radix_type!($a, $a);
48 };
49 ($a:ident, $b:ident) => {
50 impl RadixType for $a {
51 const IS_SIGNED: bool = is_signed!($a);
52 fn key(&self, round: usize) -> u8 {
53 (*self as $b >> (round << 3)) as u8
54 }
55 }
56 };
57}
58
59invoke_macro_for_types!(
61 radix_type, u8, u16, u32, u64, u128, usize, i8, i16, i32, i64, i128, isize
62);
63
64radix_type!(bool, u8);
65
66impl RadixType for f32 {
67 const IS_SIGNED: bool = is_signed!(f32);
68 fn key(&self, round: usize) -> u8 {
69 (self.to_bits() >> (round << 3)) as u8
72 }
73}
74
75impl RadixType for f64 {
76 const IS_SIGNED: bool = is_signed!(f64);
77 fn key(&self, round: usize) -> u8 {
78 (self.to_bits() >> (round << 3)) as u8
81 }
82}
83pub trait Sort {
84 fn rdx_sort(&mut self);
85}
86
87impl<T: 'static + RadixType> Sort for Vec<T> {
88 fn rdx_sort(&mut self) {
89 let mut output = vec![T::default(); self.len()];
91 let rounds = mem::size_of::<T>();
92
93 let mut histogram_table = Vec::<Vec<usize>>::new();
95 histogram_table.resize(rounds, Vec::with_capacity(256));
96 for histogram in &mut histogram_table {
97 histogram.resize(256, 0);
98 }
99 self.iter().for_each(|num| {
100 for k in 0..rounds {
101 let radix = num.key(k);
102 unsafe {
103 *histogram_table
104 .get_unchecked_mut(k)
105 .get_unchecked_mut(radix as usize) += 1;
106 }
107 }
108 });
109
110 let mut skip_table = Vec::with_capacity(rounds);
111 skip_table.resize(rounds, false);
112
113 for k in 0..rounds {
114 let mut prev = match T::IS_SIGNED && k == rounds - 1 {
117 true => histogram_table[k].iter().skip(128).sum(),
119 false => 0,
120 };
121
122 if T::IS_SIGNED && k == rounds - 1 {
123 (0..128).for_each(|i| unsafe {
127 skip_table[k] = skip_table[k]
128 || *histogram_table.get_unchecked_mut(k).get_unchecked_mut(i) == self.len();
129 let temp = *histogram_table.get_unchecked_mut(k).get_unchecked_mut(i);
130 *histogram_table.get_unchecked_mut(k).get_unchecked_mut(i) = prev;
131 prev += temp;
132 });
133 prev = 0;
134 for i in (128..256).rev() {
135 unsafe {
137 skip_table[k] = skip_table[k]
138 || *histogram_table.get_unchecked_mut(k).get_unchecked_mut(i)
139 == self.len();
140 let temp = *histogram_table.get_unchecked_mut(k).get_unchecked_mut(i);
141 *histogram_table.get_unchecked_mut(k).get_unchecked_mut(i) = prev;
142 prev += temp;
143 }
144 }
145 } else {
146 skip_table[k] = histogram_table[k][0] == self.len();
148
149 (0..256).for_each(|i| unsafe {
151 skip_table[k] = skip_table[k]
152 || *histogram_table.get_unchecked_mut(k).get_unchecked_mut(i) == self.len();
153 let temp = *histogram_table.get_unchecked_mut(k).get_unchecked_mut(i);
154 *histogram_table.get_unchecked_mut(k).get_unchecked_mut(i) = prev;
155 prev += temp;
156 });
157 }
158 }
159
160 for (k, skip_round) in skip_table.iter().enumerate().take(rounds) {
162 if *skip_round {
163 continue;
165 }
166
167 self.iter().for_each(|num| {
169 let radix = num.key(k);
170 unsafe {
171 let target = *histogram_table
173 .get_unchecked(k)
174 .get_unchecked(radix as usize);
175 *output.get_unchecked_mut(target) = *num;
176 *histogram_table
177 .get_unchecked_mut(k)
178 .get_unchecked_mut(radix as usize) += 1;
179 }
180 });
181
182 core::mem::swap(self, &mut output);
183 }
184 }
185}
186
187#[cfg(test)]
188mod tests {
189 use rand::Rng;
190
191 use super::Sort;
192
193 #[test]
194 fn tenknumbers() {
195 let mut rng = rand::rng();
196 let mut list = Vec::new();
197 (0..10_000).for_each(|_| {
198 list.push(rng.random::<u64>());
199 });
200 list.rdx_sort();
201 list.windows(2).for_each(|i| {
202 assert!(i[0] <= i[1]);
204 });
205 }
206
207 #[test]
208 fn invertedrun() {
209 let mut list: Vec<i32> = (0..10_000).rev().collect();
210 list.windows(2).for_each(|i| {
211 assert!(i[0] > i[1]);
213 });
214
215 list.rdx_sort();
216 list.windows(2).for_each(|i| {
220 assert!(i[0] < i[1]);
222 });
223 }
224
225 #[test]
226 fn sort_bools() {
227 let mut bits: Vec<bool> = vec![false, false, false, false, false, true, false, true];
229 bits.rdx_sort();
230 assert_eq!(
231 bits,
232 vec![false, false, false, false, false, false, true, true]
233 )
234 }
235
236 #[test]
237 fn sort_f32() {
238 let mut bits: Vec<f32> = vec![1.0, 4.0, 3.2415, 0.0, 26.6, 14.32, 1.23, 0.12];
239 bits.rdx_sort();
240 assert_eq!(bits, vec![0.0, 0.12, 1.0, 1.23, 3.2415, 4.0, 14.32, 26.6])
241 }
242
243 #[test]
244 fn sort_f64() {
245 let mut bits = vec![1.0, 4.0, 3.2415, 0.0, 26.6, 14.32, 1.23, 0.12];
246 bits.rdx_sort();
247 assert_eq!(bits, vec![0.0, 0.12, 1.0, 1.23, 3.2415, 4.0, 14.32, 26.6])
248 }
249
250 #[test]
251 fn sort_i32() {
252 let mut input = vec![0, 128, -1, 170, 45, 75, 90, -127, 280, -4, 24, 1, 2, 66];
253 input.rdx_sort();
254 assert_eq!(
255 input,
256 vec![-127, -4, -1, 0, 1, 2, 24, 45, 66, 75, 90, 128, 170, 280]
257 );
258 }
259
260 #[test]
261 fn sort_i64() {
262 let mut input: Vec<i64> = vec![0, 128, -1, 170, 45, 75, 90, -127, 280, -4, 24, 1, 2, 66];
263 input.rdx_sort();
264 assert_eq!(
265 input,
266 vec![-127, -4, -1, 0, 1, 2, 24, 45, 66, 75, 90, 128, 170, 280]
267 );
268 }
269}