tf_idf_vectorizer/utils/
sort.rs

1use core::mem;
2
3/// Fast u32-key radix sort for SoA (inds/vals).
4/// - Sorts by inds ascending
5/// - Reorders vals accordingly
6/// - Requires `N: Copy` for speed (good for f16/f32 etc.)
7///
8/// Complexity: 4 passes, each O(n + 256)
9#[inline(always)]
10pub fn radix_sort_u32_soa<N: Copy>(inds: &mut [u32], vals: &mut [N]) {
11    assert_eq!(inds.len(), vals.len());
12    let n = inds.len();
13    if n <= 1 {
14        return;
15    }
16
17    // Small sizes: insertion sort is often faster than allocating scratch.
18    if n <= 32 {
19        insertion_sort_u32_soa(inds, vals);
20        return;
21    }
22
23    // Scratch buffers (allocate once per call)
24    let mut inds_tmp = vec![0u32; n];
25    let mut vals_tmp: Vec<N> = vec![unsafe { mem::zeroed() }; n];
26
27    // Alternate between (src -> dst)
28    let mut src_inds: &mut [u32] = inds;
29    let mut src_vals: &mut [N] = vals;
30    let mut dst_inds: &mut [u32] = &mut inds_tmp;
31    let mut dst_vals: &mut [N] = &mut vals_tmp;
32
33    // 4 passes: byte 0..3 (LSD)
34    for shift in [0u32, 8, 16, 24] {
35        let mut count = [0usize; 256];
36
37        // Count
38        for &k in src_inds.iter() {
39            count[((k >> shift) & 0xFF) as usize] += 1;
40        }
41
42        // Prefix sum -> starting positions
43        let mut sum = 0usize;
44        for c in count.iter_mut() {
45            let tmp = *c;
46            *c = sum;
47            sum += tmp;
48        }
49
50        // Distribute (stable)
51        // NOTE: writing with raw pointers helps the optimizer a bit.
52        unsafe {
53            let s_i = src_inds.as_ptr();
54            let s_v = src_vals.as_ptr();
55            let d_i = dst_inds.as_mut_ptr();
56            let d_v = dst_vals.as_mut_ptr();
57
58            for idx in 0..n {
59                let k = *s_i.add(idx);
60                let b = ((k >> shift) & 0xFF) as usize;
61                let pos = count[b];
62                count[b] = pos + 1;
63
64                *d_i.add(pos) = k;
65                *d_v.add(pos) = *s_v.add(idx);
66            }
67        }
68
69        // swap src/dst for next pass
70        mem::swap(&mut src_inds, &mut dst_inds);
71        mem::swap(&mut src_vals, &mut dst_vals);
72    }
73
74    // After 4 passes, result is in `src_*`.
75    // Since we did 4 passes (even), `src_*` points back to original (inds/vals).
76    // If you change pass count, you may need a copy-back.
77    // (Here no-op.)
78}
79
80/// Tiny insertion sort for small n (SoA).
81#[inline(always)]
82fn insertion_sort_u32_soa<N: Copy>(inds: &mut [u32], vals: &mut [N]) {
83    let n = inds.len();
84    for i in 1..n {
85        let mut j = i;
86        while j > 0 && inds[j] < inds[j - 1] {
87            inds.swap(j, j - 1);
88            vals.swap(j, j - 1);
89            j -= 1;
90        }
91    }
92}
93
94#[cfg(test)]
95mod tests {
96    use super::*;
97
98    /// Compare with stable baseline:
99    /// sort by key, and if key is equal, preserve original order (stable).
100    fn baseline_stable_sort<N: Copy>(inds: &[u32], vals: &[N]) -> (Vec<u32>, Vec<N>) {
101        let mut pairs: Vec<(u32, usize, N)> = inds
102            .iter()
103            .copied()
104            .enumerate()
105            .map(|(i, k)| (k, i, vals[i]))
106            .collect();
107
108        // stable baseline: sort by (key, original_index)
109        pairs.sort_unstable_by(|a, b| a.0.cmp(&b.0).then_with(|| a.1.cmp(&b.1)));
110
111        let mut out_k = Vec::with_capacity(pairs.len());
112        let mut out_v = Vec::with_capacity(pairs.len());
113        for (k, _i, v) in pairs {
114            out_k.push(k);
115            out_v.push(v);
116        }
117        (out_k, out_v)
118    }
119
120    fn assert_sorted(keys: &[u32]) {
121        for i in 1..keys.len() {
122            assert!(keys[i - 1] <= keys[i], "not sorted at {i}: {} > {}", keys[i - 1], keys[i]);
123        }
124    }
125
126    /// tiny deterministic PRNG (xorshift32)
127    struct Rng(u32);
128    impl Rng {
129        fn new(seed: u32) -> Self { Self(seed) }
130        fn next_u32(&mut self) -> u32 {
131            let mut x = self.0;
132            x ^= x << 13;
133            x ^= x >> 17;
134            x ^= x << 5;
135            self.0 = x;
136            x
137        }
138    }
139
140    #[test]
141    fn radix_sort_handles_empty_and_single() {
142        // empty
143        let mut inds: Vec<u32> = vec![];
144        let mut vals: Vec<u16> = vec![];
145        radix_sort_u32_soa(&mut inds, &mut vals);
146        assert!(inds.is_empty());
147        assert!(vals.is_empty());
148
149        // single
150        let mut inds = vec![42u32];
151        let mut vals = vec![7u16];
152        radix_sort_u32_soa(&mut inds, &mut vals);
153        assert_eq!(inds, vec![42u32]);
154        assert_eq!(vals, vec![7u16]);
155    }
156
157    #[test]
158    fn radix_sort_works_on_duplicates_and_preserves_pairing() {
159        // Keys have duplicates; vals encode original position
160        let mut inds = vec![3u32, 1, 3, 2, 1, 3, 0];
161        let mut vals: Vec<u32> = (0..inds.len() as u32).collect();
162
163        let (base_k, base_v) = baseline_stable_sort(&inds, &vals);
164
165        radix_sort_u32_soa(&mut inds, &mut vals);
166
167        assert_sorted(&inds);
168        assert_eq!(inds, base_k);
169        assert_eq!(vals, base_v, "radix sort should be stable in this implementation");
170    }
171
172    #[test]
173    fn radix_sort_matches_baseline_many_sizes() {
174        let mut rng = Rng::new(0x1234_5678);
175
176        // Test a range of sizes, including small threshold area and larger sizes.
177        for &n in &[0usize, 1, 2, 3, 4, 7, 8, 15, 16, 31, 32, 33, 63, 64, 65, 127, 128, 129, 1024] {
178            let mut inds = Vec::with_capacity(n);
179            let mut vals = Vec::with_capacity(n);
180
181            for i in 0..n {
182                // Make duplicates somewhat likely by masking.
183                let k = rng.next_u32() & 0x00FF_FFFF;
184                inds.push(k);
185                // Value carries identity to verify pairing
186                vals.push((i as u32) ^ 0xA5A5_5A5A);
187            }
188
189            let (base_k, base_v) = baseline_stable_sort(&inds, &vals);
190
191            radix_sort_u32_soa(&mut inds, &mut vals);
192
193            assert_sorted(&inds);
194            assert_eq!(inds, base_k, "keys mismatch at n={n}");
195            assert_eq!(vals, base_v, "vals mismatch at n={n}");
196        }
197    }
198
199    #[test]
200    fn radix_sort_extremes() {
201        let mut inds = vec![
202            0u32,
203            u32::MAX,
204            1,
205            u32::MAX - 1,
206            0,
207            2,
208            u32::MAX,
209        ];
210        let mut vals: Vec<u32> = (0..inds.len() as u32).collect();
211
212        let (base_k, base_v) = baseline_stable_sort(&inds, &vals);
213
214        radix_sort_u32_soa(&mut inds, &mut vals);
215
216        assert_sorted(&inds);
217        assert_eq!(inds, base_k);
218        assert_eq!(vals, base_v);
219    }
220}
221