tf_idf_vectorizer/utils/
sort.rs1use core::mem;
2
3#[inline(always)]
10pub fn radix_sort_u32_soa<N: Copy>(inds: &mut [u32], vals: &mut [N]) {
11 assert_eq!(inds.len(), vals.len());
12 let n = inds.len();
13 if n <= 1 {
14 return;
15 }
16
17 if n <= 32 {
19 insertion_sort_u32_soa(inds, vals);
20 return;
21 }
22
23 let mut inds_tmp = vec![0u32; n];
25 let mut vals_tmp: Vec<N> = vec![unsafe { mem::zeroed() }; n];
26
27 let mut src_inds: &mut [u32] = inds;
29 let mut src_vals: &mut [N] = vals;
30 let mut dst_inds: &mut [u32] = &mut inds_tmp;
31 let mut dst_vals: &mut [N] = &mut vals_tmp;
32
33 for shift in [0u32, 8, 16, 24] {
35 let mut count = [0usize; 256];
36
37 for &k in src_inds.iter() {
39 count[((k >> shift) & 0xFF) as usize] += 1;
40 }
41
42 let mut sum = 0usize;
44 for c in count.iter_mut() {
45 let tmp = *c;
46 *c = sum;
47 sum += tmp;
48 }
49
50 unsafe {
53 let s_i = src_inds.as_ptr();
54 let s_v = src_vals.as_ptr();
55 let d_i = dst_inds.as_mut_ptr();
56 let d_v = dst_vals.as_mut_ptr();
57
58 for idx in 0..n {
59 let k = *s_i.add(idx);
60 let b = ((k >> shift) & 0xFF) as usize;
61 let pos = count[b];
62 count[b] = pos + 1;
63
64 *d_i.add(pos) = k;
65 *d_v.add(pos) = *s_v.add(idx);
66 }
67 }
68
69 mem::swap(&mut src_inds, &mut dst_inds);
71 mem::swap(&mut src_vals, &mut dst_vals);
72 }
73
74 }
79
80#[inline(always)]
82fn insertion_sort_u32_soa<N: Copy>(inds: &mut [u32], vals: &mut [N]) {
83 let n = inds.len();
84 for i in 1..n {
85 let mut j = i;
86 while j > 0 && inds[j] < inds[j - 1] {
87 inds.swap(j, j - 1);
88 vals.swap(j, j - 1);
89 j -= 1;
90 }
91 }
92}
93
94#[cfg(test)]
95mod tests {
96 use super::*;
97
98 fn baseline_stable_sort<N: Copy>(inds: &[u32], vals: &[N]) -> (Vec<u32>, Vec<N>) {
101 let mut pairs: Vec<(u32, usize, N)> = inds
102 .iter()
103 .copied()
104 .enumerate()
105 .map(|(i, k)| (k, i, vals[i]))
106 .collect();
107
108 pairs.sort_unstable_by(|a, b| a.0.cmp(&b.0).then_with(|| a.1.cmp(&b.1)));
110
111 let mut out_k = Vec::with_capacity(pairs.len());
112 let mut out_v = Vec::with_capacity(pairs.len());
113 for (k, _i, v) in pairs {
114 out_k.push(k);
115 out_v.push(v);
116 }
117 (out_k, out_v)
118 }
119
120 fn assert_sorted(keys: &[u32]) {
121 for i in 1..keys.len() {
122 assert!(keys[i - 1] <= keys[i], "not sorted at {i}: {} > {}", keys[i - 1], keys[i]);
123 }
124 }
125
126 struct Rng(u32);
128 impl Rng {
129 fn new(seed: u32) -> Self { Self(seed) }
130 fn next_u32(&mut self) -> u32 {
131 let mut x = self.0;
132 x ^= x << 13;
133 x ^= x >> 17;
134 x ^= x << 5;
135 self.0 = x;
136 x
137 }
138 }
139
140 #[test]
141 fn radix_sort_handles_empty_and_single() {
142 let mut inds: Vec<u32> = vec![];
144 let mut vals: Vec<u16> = vec![];
145 radix_sort_u32_soa(&mut inds, &mut vals);
146 assert!(inds.is_empty());
147 assert!(vals.is_empty());
148
149 let mut inds = vec![42u32];
151 let mut vals = vec![7u16];
152 radix_sort_u32_soa(&mut inds, &mut vals);
153 assert_eq!(inds, vec![42u32]);
154 assert_eq!(vals, vec![7u16]);
155 }
156
157 #[test]
158 fn radix_sort_works_on_duplicates_and_preserves_pairing() {
159 let mut inds = vec![3u32, 1, 3, 2, 1, 3, 0];
161 let mut vals: Vec<u32> = (0..inds.len() as u32).collect();
162
163 let (base_k, base_v) = baseline_stable_sort(&inds, &vals);
164
165 radix_sort_u32_soa(&mut inds, &mut vals);
166
167 assert_sorted(&inds);
168 assert_eq!(inds, base_k);
169 assert_eq!(vals, base_v, "radix sort should be stable in this implementation");
170 }
171
172 #[test]
173 fn radix_sort_matches_baseline_many_sizes() {
174 let mut rng = Rng::new(0x1234_5678);
175
176 for &n in &[0usize, 1, 2, 3, 4, 7, 8, 15, 16, 31, 32, 33, 63, 64, 65, 127, 128, 129, 1024] {
178 let mut inds = Vec::with_capacity(n);
179 let mut vals = Vec::with_capacity(n);
180
181 for i in 0..n {
182 let k = rng.next_u32() & 0x00FF_FFFF;
184 inds.push(k);
185 vals.push((i as u32) ^ 0xA5A5_5A5A);
187 }
188
189 let (base_k, base_v) = baseline_stable_sort(&inds, &vals);
190
191 radix_sort_u32_soa(&mut inds, &mut vals);
192
193 assert_sorted(&inds);
194 assert_eq!(inds, base_k, "keys mismatch at n={n}");
195 assert_eq!(vals, base_v, "vals mismatch at n={n}");
196 }
197 }
198
199 #[test]
200 fn radix_sort_extremes() {
201 let mut inds = vec![
202 0u32,
203 u32::MAX,
204 1,
205 u32::MAX - 1,
206 0,
207 2,
208 u32::MAX,
209 ];
210 let mut vals: Vec<u32> = (0..inds.len() as u32).collect();
211
212 let (base_k, base_v) = baseline_stable_sort(&inds, &vals);
213
214 radix_sort_u32_soa(&mut inds, &mut vals);
215
216 assert_sorted(&inds);
217 assert_eq!(inds, base_k);
218 assert_eq!(vals, base_v);
219 }
220}
221