tf_idf_vectorizer/utils/
sort.rs1use core::mem;
2
3pub fn radix_sort_u32_soa<N: Copy>(inds: &mut [u32], vals: &mut [N]) {
10 assert_eq!(inds.len(), vals.len());
11 let n = inds.len();
12 if n <= 1 {
13 return;
14 }
15
16 if n <= 32 {
18 insertion_sort_u32_soa(inds, vals);
19 return;
20 }
21
22 let mut inds_tmp = vec![0u32; n];
24 let mut vals_tmp: Vec<N> = vec![unsafe { mem::zeroed() }; n];
25
26 let mut src_inds: &mut [u32] = inds;
28 let mut src_vals: &mut [N] = vals;
29 let mut dst_inds: &mut [u32] = &mut inds_tmp;
30 let mut dst_vals: &mut [N] = &mut vals_tmp;
31
32 for shift in [0u32, 8, 16, 24] {
34 let mut count = [0usize; 256];
35
36 for &k in src_inds.iter() {
38 count[((k >> shift) & 0xFF) as usize] += 1;
39 }
40
41 let mut sum = 0usize;
43 for c in count.iter_mut() {
44 let tmp = *c;
45 *c = sum;
46 sum += tmp;
47 }
48
49 unsafe {
52 let s_i = src_inds.as_ptr();
53 let s_v = src_vals.as_ptr();
54 let d_i = dst_inds.as_mut_ptr();
55 let d_v = dst_vals.as_mut_ptr();
56
57 for idx in 0..n {
58 let k = *s_i.add(idx);
59 let b = ((k >> shift) & 0xFF) as usize;
60 let pos = count[b];
61 count[b] = pos + 1;
62
63 *d_i.add(pos) = k;
64 *d_v.add(pos) = *s_v.add(idx);
65 }
66 }
67
68 mem::swap(&mut src_inds, &mut dst_inds);
70 mem::swap(&mut src_vals, &mut dst_vals);
71 }
72
73 }
78
79#[inline]
81fn insertion_sort_u32_soa<N: Copy>(inds: &mut [u32], vals: &mut [N]) {
82 let n = inds.len();
83 for i in 1..n {
84 let mut j = i;
85 while j > 0 && inds[j] < inds[j - 1] {
86 inds.swap(j, j - 1);
87 vals.swap(j, j - 1);
88 j -= 1;
89 }
90 }
91}
92
93#[cfg(test)]
94mod tests {
95 use super::*;
96
97 fn baseline_stable_sort<N: Copy>(inds: &[u32], vals: &[N]) -> (Vec<u32>, Vec<N>) {
100 let mut pairs: Vec<(u32, usize, N)> = inds
101 .iter()
102 .copied()
103 .enumerate()
104 .map(|(i, k)| (k, i, vals[i]))
105 .collect();
106
107 pairs.sort_unstable_by(|a, b| a.0.cmp(&b.0).then_with(|| a.1.cmp(&b.1)));
109
110 let mut out_k = Vec::with_capacity(pairs.len());
111 let mut out_v = Vec::with_capacity(pairs.len());
112 for (k, _i, v) in pairs {
113 out_k.push(k);
114 out_v.push(v);
115 }
116 (out_k, out_v)
117 }
118
119 fn assert_sorted(keys: &[u32]) {
120 for i in 1..keys.len() {
121 assert!(keys[i - 1] <= keys[i], "not sorted at {i}: {} > {}", keys[i - 1], keys[i]);
122 }
123 }
124
125 struct Rng(u32);
127 impl Rng {
128 fn new(seed: u32) -> Self { Self(seed) }
129 fn next_u32(&mut self) -> u32 {
130 let mut x = self.0;
131 x ^= x << 13;
132 x ^= x >> 17;
133 x ^= x << 5;
134 self.0 = x;
135 x
136 }
137 }
138
139 #[test]
140 fn radix_sort_handles_empty_and_single() {
141 let mut inds: Vec<u32> = vec![];
143 let mut vals: Vec<u16> = vec![];
144 radix_sort_u32_soa(&mut inds, &mut vals);
145 assert!(inds.is_empty());
146 assert!(vals.is_empty());
147
148 let mut inds = vec![42u32];
150 let mut vals = vec![7u16];
151 radix_sort_u32_soa(&mut inds, &mut vals);
152 assert_eq!(inds, vec![42u32]);
153 assert_eq!(vals, vec![7u16]);
154 }
155
156 #[test]
157 fn radix_sort_works_on_duplicates_and_preserves_pairing() {
158 let mut inds = vec![3u32, 1, 3, 2, 1, 3, 0];
160 let mut vals: Vec<u32> = (0..inds.len() as u32).collect();
161
162 let (base_k, base_v) = baseline_stable_sort(&inds, &vals);
163
164 radix_sort_u32_soa(&mut inds, &mut vals);
165
166 assert_sorted(&inds);
167 assert_eq!(inds, base_k);
168 assert_eq!(vals, base_v, "radix sort should be stable in this implementation");
169 }
170
171 #[test]
172 fn radix_sort_matches_baseline_many_sizes() {
173 let mut rng = Rng::new(0x1234_5678);
174
175 for &n in &[0usize, 1, 2, 3, 4, 7, 8, 15, 16, 31, 32, 33, 63, 64, 65, 127, 128, 129, 1024] {
177 let mut inds = Vec::with_capacity(n);
178 let mut vals = Vec::with_capacity(n);
179
180 for i in 0..n {
181 let k = rng.next_u32() & 0x00FF_FFFF;
183 inds.push(k);
184 vals.push((i as u32) ^ 0xA5A5_5A5A);
186 }
187
188 let (base_k, base_v) = baseline_stable_sort(&inds, &vals);
189
190 radix_sort_u32_soa(&mut inds, &mut vals);
191
192 assert_sorted(&inds);
193 assert_eq!(inds, base_k, "keys mismatch at n={n}");
194 assert_eq!(vals, base_v, "vals mismatch at n={n}");
195 }
196 }
197
198 #[test]
199 fn radix_sort_extremes() {
200 let mut inds = vec![
201 0u32,
202 u32::MAX,
203 1,
204 u32::MAX - 1,
205 0,
206 2,
207 u32::MAX,
208 ];
209 let mut vals: Vec<u32> = (0..inds.len() as u32).collect();
210
211 let (base_k, base_v) = baseline_stable_sort(&inds, &vals);
212
213 radix_sort_u32_soa(&mut inds, &mut vals);
214
215 assert_sorted(&inds);
216 assert_eq!(inds, base_k);
217 assert_eq!(vals, base_v);
218 }
219}
220