toolbox_rs/
rdx_sort.rs

1pub mod radix {
2    use core::mem;
3
4    use crate::invoke_macro_for_types;
5
6    pub trait RadixType: Clone + Copy + Default {
7        // RadixTypes are sortable by rdx_sort(.)
8        const IS_SIGNED: bool;
9        // signed data requires special handling in the last round
10        fn key(&self, round: usize) -> u8;
11        // the key is the radix of size u8, i.e. one byte
12    }
13
14    macro_rules! is_signed {
15        // convenience macro to derive which built-in number types are signed since
16        // they require special case handling in the final sorting round.
17        (i8) => {
18            true
19        };
20        (i16) => {
21            true
22        };
23        (i32) => {
24            true
25        };
26        (i64) => {
27            true
28        };
29        (i128) => {
30            true
31        };
32        (f32) => {
33            true
34        };
35        (f64) => {
36            true
37        };
38        ($_t:ty) => {
39            false
40        };
41    }
42
43    macro_rules! radix_type {
44        // short-hand to add a default RadixType implementation for the
45        // given input type. Works with built-in types like integers.
46        ($a:ident) => {
47            // forward to the general implementation below
48            radix_type!($a, $a);
49        };
50        ($a:ident, $b:ident) => {
51            impl RadixType for $a {
52                const IS_SIGNED: bool = is_signed!($a);
53                fn key(&self, round: usize) -> u8 {
54                    (*self as $b >> (round << 3)) as u8
55                }
56            }
57        };
58    }
59
60    // define built-in number types (integers, floats, bool) as RadixType
61    invoke_macro_for_types!(
62        radix_type, u8, u16, u32, u64, u128, usize, i8, i16, i32, i64, i128, isize
63    );
64
65    radix_type!(bool, u8);
66
67    impl RadixType for f32 {
68        const IS_SIGNED: bool = is_signed!(f32);
69        fn key(&self, round: usize) -> u8 {
70            // Interpret the bits of a float as if they were an integer.
71            // This relies on the floats being in IEEE-754 format to work.
72            (self.to_bits() >> (round << 3)) as u8
73        }
74    }
75
76    impl RadixType for f64 {
77        const IS_SIGNED: bool = is_signed!(f64);
78        fn key(&self, round: usize) -> u8 {
79            // Interpret the bits of a float as if they were an integer.
80            // This relies on the floats being in IEEE-754 format to work.
81            (self.to_bits() >> (round << 3)) as u8
82        }
83    }
84    pub trait Sort {
85        fn rdx_sort(&mut self);
86    }
87
88    impl<T: 'static + RadixType> Sort for Vec<T> {
89        fn rdx_sort(&mut self) {
90            // TODO(dl): Add an explanation of how radix sort works
91            let mut output = vec![T::default(); self.len()];
92            let rounds = mem::size_of::<T>();
93
94            // implementation of Friend's optimization: Compute all frequencies at once
95            let mut histogram_table = Vec::<Vec<usize>>::new();
96            histogram_table.resize(rounds, Vec::with_capacity(256));
97            for histogram in &mut histogram_table {
98                histogram.resize(256, 0);
99            }
100            self.iter().for_each(|num| {
101                for k in 0..rounds {
102                    let radix = num.key(k);
103                    unsafe {
104                        *histogram_table
105                            .get_unchecked_mut(k)
106                            .get_unchecked_mut(radix as usize) += 1;
107                    }
108                }
109            });
110
111            let mut skip_table = Vec::with_capacity(rounds);
112            skip_table.resize(rounds, false);
113
114            for k in 0..rounds {
115                // TODO: there must be a more elegant way to do this!
116
117                let mut prev = match T::IS_SIGNED && k == rounds - 1 {
118                    // add offset to non-negative entries making room for negative ones
119                    true => histogram_table[k].iter().skip(128).sum(),
120                    false => 0,
121                };
122
123                if T::IS_SIGNED && k == rounds - 1 {
124                    // last round for signed numbers needs to handle negatives
125                    // note that the sign-bit is in the MSB
126
127                    (0..128).for_each(|i| unsafe {
128                        skip_table[k] = skip_table[k]
129                            || *histogram_table.get_unchecked_mut(k).get_unchecked_mut(i)
130                                == self.len();
131                        let temp = *histogram_table.get_unchecked_mut(k).get_unchecked_mut(i);
132                        *histogram_table.get_unchecked_mut(k).get_unchecked_mut(i) = prev;
133                        prev += temp;
134                    });
135                    prev = 0;
136                    for i in (128..256).rev() {
137                        // build prefix sums for negative numbers from the right
138                        unsafe {
139                            skip_table[k] = skip_table[k]
140                                || *histogram_table.get_unchecked_mut(k).get_unchecked_mut(i)
141                                    == self.len();
142                            let temp = *histogram_table.get_unchecked_mut(k).get_unchecked_mut(i);
143                            *histogram_table.get_unchecked_mut(k).get_unchecked_mut(i) = prev;
144                            prev += temp;
145                        }
146                    }
147                } else {
148                    // a round can be skipped if all entries fall into the same bucket
149                    skip_table[k] = histogram_table[k][0] == self.len();
150
151                    // let mut prev = 0;
152                    (0..256).for_each(|i| unsafe {
153                        skip_table[k] = skip_table[k]
154                            || *histogram_table.get_unchecked_mut(k).get_unchecked_mut(i)
155                                == self.len();
156                        let temp = *histogram_table.get_unchecked_mut(k).get_unchecked_mut(i);
157                        *histogram_table.get_unchecked_mut(k).get_unchecked_mut(i) = prev;
158                        prev += temp;
159                    });
160                }
161            }
162
163            // permutation rounds
164            for (k, skip_round) in skip_table.iter().enumerate().take(rounds) {
165                if *skip_round {
166                    // skipping round {k} since all of input falls into exactly one bucket
167                    continue;
168                }
169
170                // place values into their slot
171                self.iter().for_each(|num| {
172                    let radix = num.key(k);
173                    unsafe {
174                        // performance optimization
175                        let target = *histogram_table
176                            .get_unchecked(k)
177                            .get_unchecked(radix as usize);
178                        *output.get_unchecked_mut(target) = *num;
179                        *histogram_table
180                            .get_unchecked_mut(k)
181                            .get_unchecked_mut(radix as usize) += 1;
182                    }
183                });
184
185                core::mem::swap(self, &mut output);
186            }
187        }
188    }
189}
190#[cfg(test)]
191mod tests {
192    use rand::Rng;
193
194    use super::radix::Sort;
195
196    #[test]
197    fn tenknumbers() {
198        let mut rng = rand::rng();
199        let mut list = Vec::new();
200        (0..10_000).for_each(|_| {
201            list.push(rng.random::<u64>());
202        });
203        list.rdx_sort();
204        list.windows(2).for_each(|i| {
205            // now verify numbers are in ascending order
206            assert!(i[0] <= i[1]);
207        });
208    }
209
210    #[test]
211    fn invertedrun() {
212        let mut list: Vec<i32> = (0..10_000).rev().collect();
213        list.windows(2).for_each(|i| {
214            // verify numbers are in descending order
215            assert!(i[0] > i[1]);
216        });
217
218        list.rdx_sort();
219        // Note: is_sorted has not been stabilized at the time of writing.
220        // assert!(list.is_sorted());
221
222        list.windows(2).for_each(|i| {
223            // now verify numbers are in ascending order
224            assert!(i[0] < i[1]);
225        });
226    }
227
228    #[test]
229    fn sort_bools() {
230        // assumes false < true
231        let mut bits: Vec<bool> = vec![false, false, false, false, false, true, false, true];
232        bits.rdx_sort();
233        assert_eq!(
234            bits,
235            vec![false, false, false, false, false, false, true, true]
236        )
237    }
238
239    #[test]
240    fn sort_f32() {
241        let mut bits: Vec<f32> = vec![1.0, 4.0, 3.2415, 0.0, 26.6, 14.32, 1.23, 0.12];
242        bits.rdx_sort();
243        assert_eq!(bits, vec![0.0, 0.12, 1.0, 1.23, 3.2415, 4.0, 14.32, 26.6])
244    }
245
246    #[test]
247    fn sort_f64() {
248        let mut bits = vec![1.0, 4.0, 3.2415, 0.0, 26.6, 14.32, 1.23, 0.12];
249        bits.rdx_sort();
250        assert_eq!(bits, vec![0.0, 0.12, 1.0, 1.23, 3.2415, 4.0, 14.32, 26.6])
251    }
252
253    #[test]
254    fn sort_i32() {
255        let mut input = vec![0, 128, -1, 170, 45, 75, 90, -127, 280, -4, 24, 1, 2, 66];
256        input.rdx_sort();
257        assert_eq!(
258            input,
259            vec![-127, -4, -1, 0, 1, 2, 24, 45, 66, 75, 90, 128, 170, 280]
260        );
261    }
262
263    #[test]
264    fn sort_i64() {
265        let mut input: Vec<i64> = vec![0, 128, -1, 170, 45, 75, 90, -127, 280, -4, 24, 1, 2, 66];
266        input.rdx_sort();
267        assert_eq!(
268            input,
269            vec![-127, -4, -1, 0, 1, 2, 24, 45, 66, 75, 90, 128, 170, 280]
270        );
271    }
272}