tf_idf_vectorizer/utils/
sort.rs

1use core::mem;
2
3/// Fast u32-key radix sort for SoA (inds/vals).
4/// - Sorts by inds ascending
5/// - Reorders vals accordingly
6/// - Requires `N: Copy` for speed (good for f16/f32 etc.)
7///
8/// Complexity: 4 passes, each O(n + 256)
9pub fn radix_sort_u32_soa<N: Copy>(inds: &mut [u32], vals: &mut [N]) {
10    assert_eq!(inds.len(), vals.len());
11    let n = inds.len();
12    if n <= 1 {
13        return;
14    }
15
16    // Small sizes: insertion sort is often faster than allocating scratch.
17    if n <= 32 {
18        insertion_sort_u32_soa(inds, vals);
19        return;
20    }
21
22    // Scratch buffers (allocate once per call)
23    let mut inds_tmp = vec![0u32; n];
24    let mut vals_tmp: Vec<N> = vec![unsafe { mem::zeroed() }; n];
25
26    // Alternate between (src -> dst)
27    let mut src_inds: &mut [u32] = inds;
28    let mut src_vals: &mut [N] = vals;
29    let mut dst_inds: &mut [u32] = &mut inds_tmp;
30    let mut dst_vals: &mut [N] = &mut vals_tmp;
31
32    // 4 passes: byte 0..3 (LSD)
33    for shift in [0u32, 8, 16, 24] {
34        let mut count = [0usize; 256];
35
36        // Count
37        for &k in src_inds.iter() {
38            count[((k >> shift) & 0xFF) as usize] += 1;
39        }
40
41        // Prefix sum -> starting positions
42        let mut sum = 0usize;
43        for c in count.iter_mut() {
44            let tmp = *c;
45            *c = sum;
46            sum += tmp;
47        }
48
49        // Distribute (stable)
50        // NOTE: writing with raw pointers helps the optimizer a bit.
51        unsafe {
52            let s_i = src_inds.as_ptr();
53            let s_v = src_vals.as_ptr();
54            let d_i = dst_inds.as_mut_ptr();
55            let d_v = dst_vals.as_mut_ptr();
56
57            for idx in 0..n {
58                let k = *s_i.add(idx);
59                let b = ((k >> shift) & 0xFF) as usize;
60                let pos = count[b];
61                count[b] = pos + 1;
62
63                *d_i.add(pos) = k;
64                *d_v.add(pos) = *s_v.add(idx);
65            }
66        }
67
68        // swap src/dst for next pass
69        mem::swap(&mut src_inds, &mut dst_inds);
70        mem::swap(&mut src_vals, &mut dst_vals);
71    }
72
73    // After 4 passes, result is in `src_*`.
74    // Since we did 4 passes (even), `src_*` points back to original (inds/vals).
75    // If you change pass count, you may need a copy-back.
76    // (Here no-op.)
77}
78
79/// Tiny insertion sort for small n (SoA).
80#[inline]
81fn insertion_sort_u32_soa<N: Copy>(inds: &mut [u32], vals: &mut [N]) {
82    let n = inds.len();
83    for i in 1..n {
84        let mut j = i;
85        while j > 0 && inds[j] < inds[j - 1] {
86            inds.swap(j, j - 1);
87            vals.swap(j, j - 1);
88            j -= 1;
89        }
90    }
91}
92
93#[cfg(test)]
94mod tests {
95    use super::*;
96
97    /// Compare with stable baseline:
98    /// sort by key, and if key is equal, preserve original order (stable).
99    fn baseline_stable_sort<N: Copy>(inds: &[u32], vals: &[N]) -> (Vec<u32>, Vec<N>) {
100        let mut pairs: Vec<(u32, usize, N)> = inds
101            .iter()
102            .copied()
103            .enumerate()
104            .map(|(i, k)| (k, i, vals[i]))
105            .collect();
106
107        // stable baseline: sort by (key, original_index)
108        pairs.sort_unstable_by(|a, b| a.0.cmp(&b.0).then_with(|| a.1.cmp(&b.1)));
109
110        let mut out_k = Vec::with_capacity(pairs.len());
111        let mut out_v = Vec::with_capacity(pairs.len());
112        for (k, _i, v) in pairs {
113            out_k.push(k);
114            out_v.push(v);
115        }
116        (out_k, out_v)
117    }
118
119    fn assert_sorted(keys: &[u32]) {
120        for i in 1..keys.len() {
121            assert!(keys[i - 1] <= keys[i], "not sorted at {i}: {} > {}", keys[i - 1], keys[i]);
122        }
123    }
124
125    /// tiny deterministic PRNG (xorshift32)
126    struct Rng(u32);
127    impl Rng {
128        fn new(seed: u32) -> Self { Self(seed) }
129        fn next_u32(&mut self) -> u32 {
130            let mut x = self.0;
131            x ^= x << 13;
132            x ^= x >> 17;
133            x ^= x << 5;
134            self.0 = x;
135            x
136        }
137    }
138
139    #[test]
140    fn radix_sort_handles_empty_and_single() {
141        // empty
142        let mut inds: Vec<u32> = vec![];
143        let mut vals: Vec<u16> = vec![];
144        radix_sort_u32_soa(&mut inds, &mut vals);
145        assert!(inds.is_empty());
146        assert!(vals.is_empty());
147
148        // single
149        let mut inds = vec![42u32];
150        let mut vals = vec![7u16];
151        radix_sort_u32_soa(&mut inds, &mut vals);
152        assert_eq!(inds, vec![42u32]);
153        assert_eq!(vals, vec![7u16]);
154    }
155
156    #[test]
157    fn radix_sort_works_on_duplicates_and_preserves_pairing() {
158        // Keys have duplicates; vals encode original position
159        let mut inds = vec![3u32, 1, 3, 2, 1, 3, 0];
160        let mut vals: Vec<u32> = (0..inds.len() as u32).collect();
161
162        let (base_k, base_v) = baseline_stable_sort(&inds, &vals);
163
164        radix_sort_u32_soa(&mut inds, &mut vals);
165
166        assert_sorted(&inds);
167        assert_eq!(inds, base_k);
168        assert_eq!(vals, base_v, "radix sort should be stable in this implementation");
169    }
170
171    #[test]
172    fn radix_sort_matches_baseline_many_sizes() {
173        let mut rng = Rng::new(0x1234_5678);
174
175        // Test a range of sizes, including small threshold area and larger sizes.
176        for &n in &[0usize, 1, 2, 3, 4, 7, 8, 15, 16, 31, 32, 33, 63, 64, 65, 127, 128, 129, 1024] {
177            let mut inds = Vec::with_capacity(n);
178            let mut vals = Vec::with_capacity(n);
179
180            for i in 0..n {
181                // Make duplicates somewhat likely by masking.
182                let k = rng.next_u32() & 0x00FF_FFFF;
183                inds.push(k);
184                // Value carries identity to verify pairing
185                vals.push((i as u32) ^ 0xA5A5_5A5A);
186            }
187
188            let (base_k, base_v) = baseline_stable_sort(&inds, &vals);
189
190            radix_sort_u32_soa(&mut inds, &mut vals);
191
192            assert_sorted(&inds);
193            assert_eq!(inds, base_k, "keys mismatch at n={n}");
194            assert_eq!(vals, base_v, "vals mismatch at n={n}");
195        }
196    }
197
198    #[test]
199    fn radix_sort_extremes() {
200        let mut inds = vec![
201            0u32,
202            u32::MAX,
203            1,
204            u32::MAX - 1,
205            0,
206            2,
207            u32::MAX,
208        ];
209        let mut vals: Vec<u32> = (0..inds.len() as u32).collect();
210
211        let (base_k, base_v) = baseline_stable_sort(&inds, &vals);
212
213        radix_sort_u32_soa(&mut inds, &mut vals);
214
215        assert_sorted(&inds);
216        assert_eq!(inds, base_k);
217        assert_eq!(vals, base_v);
218    }
219}
220