toolbox_rs/
rdx_sort.rs

1use core::mem;
2
3use crate::invoke_macro_for_types;
4
5pub trait RadixType: Clone + Copy + Default {
6    // RadixTypes are sortable by rdx_sort(.)
7    const IS_SIGNED: bool;
8    // signed data requires special handling in the last round
9    fn key(&self, round: usize) -> u8;
10    // the key is the radix of size u8, i.e. one byte
11}
12
13macro_rules! is_signed {
14    // convenience macro to derive which built-in number types are signed since
15    // they require special case handling in the final sorting round.
16    (i8) => {
17        true
18    };
19    (i16) => {
20        true
21    };
22    (i32) => {
23        true
24    };
25    (i64) => {
26        true
27    };
28    (i128) => {
29        true
30    };
31    (f32) => {
32        true
33    };
34    (f64) => {
35        true
36    };
37    ($_t:ty) => {
38        false
39    };
40}
41
42macro_rules! radix_type {
43    // short-hand to add a default RadixType implementation for the
44    // given input type. Works with built-in types like integers.
45    ($a:ident) => {
46        // forward to the general implementation below
47        radix_type!($a, $a);
48    };
49    ($a:ident, $b:ident) => {
50        impl RadixType for $a {
51            const IS_SIGNED: bool = is_signed!($a);
52            fn key(&self, round: usize) -> u8 {
53                (*self as $b >> (round << 3)) as u8
54            }
55        }
56    };
57}
58
59// define built-in number types (integers, floats, bool) as RadixType
60invoke_macro_for_types!(
61    radix_type, u8, u16, u32, u64, u128, usize, i8, i16, i32, i64, i128, isize
62);
63
64radix_type!(bool, u8);
65
66impl RadixType for f32 {
67    const IS_SIGNED: bool = is_signed!(f32);
68    fn key(&self, round: usize) -> u8 {
69        // Interpret the bits of a float as if they were an integer.
70        // This relies on the floats being in IEEE-754 format to work.
71        (self.to_bits() >> (round << 3)) as u8
72    }
73}
74
75impl RadixType for f64 {
76    const IS_SIGNED: bool = is_signed!(f64);
77    fn key(&self, round: usize) -> u8 {
78        // Interpret the bits of a float as if they were an integer.
79        // This relies on the floats being in IEEE-754 format to work.
80        (self.to_bits() >> (round << 3)) as u8
81    }
82}
83pub trait Sort {
84    fn rdx_sort(&mut self);
85}
86
87impl<T: 'static + RadixType> Sort for Vec<T> {
88    fn rdx_sort(&mut self) {
89        // TODO(dl): Add an explanation of how radix sort works
90        let mut output = vec![T::default(); self.len()];
91        let rounds = mem::size_of::<T>();
92
93        // implementation of Friend's optimization: Compute all frequencies at once
94        let mut histogram_table = Vec::<Vec<usize>>::new();
95        histogram_table.resize(rounds, Vec::with_capacity(256));
96        for histogram in &mut histogram_table {
97            histogram.resize(256, 0);
98        }
99        self.iter().for_each(|num| {
100            for k in 0..rounds {
101                let radix = num.key(k);
102                unsafe {
103                    *histogram_table
104                        .get_unchecked_mut(k)
105                        .get_unchecked_mut(radix as usize) += 1;
106                }
107            }
108        });
109
110        let mut skip_table = Vec::with_capacity(rounds);
111        skip_table.resize(rounds, false);
112
113        for k in 0..rounds {
114            // TODO: there must be a more elegant way to do this!
115
116            let mut prev = match T::IS_SIGNED && k == rounds - 1 {
117                // add offset to non-negative entries making room for negative ones
118                true => histogram_table[k].iter().skip(128).sum(),
119                false => 0,
120            };
121
122            if T::IS_SIGNED && k == rounds - 1 {
123                // last round for signed numbers needs to handle negatives
124                // note that the sign-bit is in the MSB
125
126                (0..128).for_each(|i| unsafe {
127                    skip_table[k] = skip_table[k]
128                        || *histogram_table.get_unchecked_mut(k).get_unchecked_mut(i) == self.len();
129                    let temp = *histogram_table.get_unchecked_mut(k).get_unchecked_mut(i);
130                    *histogram_table.get_unchecked_mut(k).get_unchecked_mut(i) = prev;
131                    prev += temp;
132                });
133                prev = 0;
134                for i in (128..256).rev() {
135                    // build prefix sums for negative numbers from the right
136                    unsafe {
137                        skip_table[k] = skip_table[k]
138                            || *histogram_table.get_unchecked_mut(k).get_unchecked_mut(i)
139                                == self.len();
140                        let temp = *histogram_table.get_unchecked_mut(k).get_unchecked_mut(i);
141                        *histogram_table.get_unchecked_mut(k).get_unchecked_mut(i) = prev;
142                        prev += temp;
143                    }
144                }
145            } else {
146                // a round can be skipped if all entries fall into the same bucket
147                skip_table[k] = histogram_table[k][0] == self.len();
148
149                // let mut prev = 0;
150                (0..256).for_each(|i| unsafe {
151                    skip_table[k] = skip_table[k]
152                        || *histogram_table.get_unchecked_mut(k).get_unchecked_mut(i) == self.len();
153                    let temp = *histogram_table.get_unchecked_mut(k).get_unchecked_mut(i);
154                    *histogram_table.get_unchecked_mut(k).get_unchecked_mut(i) = prev;
155                    prev += temp;
156                });
157            }
158        }
159
160        // permutation rounds
161        for (k, skip_round) in skip_table.iter().enumerate().take(rounds) {
162            if *skip_round {
163                // skipping round {k} since all of input falls into exactly one bucket
164                continue;
165            }
166
167            // place values into their slot
168            self.iter().for_each(|num| {
169                let radix = num.key(k);
170                unsafe {
171                    // performance optimization
172                    let target = *histogram_table
173                        .get_unchecked(k)
174                        .get_unchecked(radix as usize);
175                    *output.get_unchecked_mut(target) = *num;
176                    *histogram_table
177                        .get_unchecked_mut(k)
178                        .get_unchecked_mut(radix as usize) += 1;
179                }
180            });
181
182            core::mem::swap(self, &mut output);
183        }
184    }
185}
186
187#[cfg(test)]
188mod tests {
189    use rand::Rng;
190
191    use super::Sort;
192
193    #[test]
194    fn tenknumbers() {
195        let mut rng = rand::rng();
196        let mut list = Vec::new();
197        (0..10_000).for_each(|_| {
198            list.push(rng.random::<u64>());
199        });
200        list.rdx_sort();
201        list.windows(2).for_each(|i| {
202            // now verify numbers are in ascending order
203            assert!(i[0] <= i[1]);
204        });
205    }
206
207    #[test]
208    fn invertedrun() {
209        let mut list: Vec<i32> = (0..10_000).rev().collect();
210        list.windows(2).for_each(|i| {
211            // verify numbers are in descending order
212            assert!(i[0] > i[1]);
213        });
214
215        list.rdx_sort();
216        // Note: is_sorted has not been stabilized at the time of writing.
217        // assert!(list.is_sorted());
218
219        list.windows(2).for_each(|i| {
220            // now verify numbers are in ascending order
221            assert!(i[0] < i[1]);
222        });
223    }
224
225    #[test]
226    fn sort_bools() {
227        // assumes false < true
228        let mut bits: Vec<bool> = vec![false, false, false, false, false, true, false, true];
229        bits.rdx_sort();
230        assert_eq!(
231            bits,
232            vec![false, false, false, false, false, false, true, true]
233        )
234    }
235
236    #[test]
237    fn sort_f32() {
238        let mut bits: Vec<f32> = vec![1.0, 4.0, 3.2415, 0.0, 26.6, 14.32, 1.23, 0.12];
239        bits.rdx_sort();
240        assert_eq!(bits, vec![0.0, 0.12, 1.0, 1.23, 3.2415, 4.0, 14.32, 26.6])
241    }
242
243    #[test]
244    fn sort_f64() {
245        let mut bits = vec![1.0, 4.0, 3.2415, 0.0, 26.6, 14.32, 1.23, 0.12];
246        bits.rdx_sort();
247        assert_eq!(bits, vec![0.0, 0.12, 1.0, 1.23, 3.2415, 4.0, 14.32, 26.6])
248    }
249
250    #[test]
251    fn sort_i32() {
252        let mut input = vec![0, 128, -1, 170, 45, 75, 90, -127, 280, -4, 24, 1, 2, 66];
253        input.rdx_sort();
254        assert_eq!(
255            input,
256            vec![-127, -4, -1, 0, 1, 2, 24, 45, 66, 75, 90, 128, 170, 280]
257        );
258    }
259
260    #[test]
261    fn sort_i64() {
262        let mut input: Vec<i64> = vec![0, 128, -1, 170, 45, 75, 90, -127, 280, -4, 24, 1, 2, 66];
263        input.rdx_sort();
264        assert_eq!(
265            input,
266            vec![-127, -4, -1, 0, 1, 2, 24, 45, 66, 75, 90, 128, 170, 280]
267        );
268    }
269}