1use std::cmp;
2use std::mem::MaybeUninit;
3use std::thread;
4
5pub trait SortKey {
9 fn sort_key(&self) -> u64;
10}
11
12impl SortKey for u32 {
15 #[inline(always)]
16 fn sort_key(&self) -> u64 {
17 *self as u64
18 }
19}
20
21impl SortKey for u64 {
22 #[inline(always)]
23 fn sort_key(&self) -> u64 {
24 *self
25 }
26}
27
28impl SortKey for i32 {
32 #[inline(always)]
33 fn sort_key(&self) -> u64 {
34 (*self as u32 ^ 0x8000_0000) as u64
35 }
36}
37
38impl SortKey for i64 {
39 #[inline(always)]
40 fn sort_key(&self) -> u64 {
41 *self as u64 ^ 0x8000_0000_0000_0000
42 }
43}
44
45impl SortKey for f32 {
49 #[inline(always)]
50 fn sort_key(&self) -> u64 {
51 let bits = self.to_bits();
52 let sign_mask = ((bits as i32) >> 31) as u32;
53 (bits ^ (sign_mask | 0x8000_0000)) as u64
54 }
55}
56
57impl SortKey for f64 {
58 #[inline(always)]
59 fn sort_key(&self) -> u64 {
60 let bits = self.to_bits();
61 let sign_mask = ((bits as i64) >> 63) as u64;
62 bits ^ (sign_mask | 0x8000_0000_0000_0000)
63 }
64}
65
66#[inline(always)]
70pub fn custom_insertion_sort<T: SortKey>(arr: &mut [T]) {
71 let len = arr.len();
72 if len <= 1 {
73 return;
74 }
75
76 let base_ptr = arr.as_mut_ptr();
77 for i in 1..len {
78 unsafe {
79 let val_ptr = base_ptr.add(i);
80 let val = std::ptr::read(val_ptr);
81 let val_key = val.sort_key();
82 let mut j = i;
83 while j > 0 {
84 let prev_ptr = base_ptr.add(j - 1);
85 if (*prev_ptr).sort_key() > val_key {
86 std::ptr::write(base_ptr.add(j), std::ptr::read(prev_ptr));
87 j -= 1;
88 } else {
89 break;
90 }
91 }
92 std::ptr::write(base_ptr.add(j), val);
93 }
94 }
95}
96
97#[inline(always)]
100fn sort_overflow<T: SortKey>(arr: &mut [(usize, MaybeUninit<T>)]) {
101 let len = arr.len();
102 if len <= 1 {
103 return;
104 }
105
106 let base_ptr = arr.as_mut_ptr();
107 for i in 1..len {
108 unsafe {
109 let val_ptr = base_ptr.add(i);
110 let val = std::ptr::read(val_ptr);
111 let val_chunk = val.0;
112 let val_key = val.1.assume_init_ref().sort_key();
113 let mut j = i;
114 while j > 0 {
115 let prev_ptr = base_ptr.add(j - 1);
116 let prev_chunk = (*prev_ptr).0;
117 let prev_key = (*prev_ptr).1.assume_init_ref().sort_key();
118 if prev_chunk > val_chunk || (prev_chunk == val_chunk && prev_key > val_key) {
119 std::ptr::write(base_ptr.add(j), std::ptr::read(prev_ptr));
120 j -= 1;
121 } else {
122 break;
123 }
124 }
125 std::ptr::write(base_ptr.add(j), val);
126 }
127 }
128}
129
130#[repr(C, align(64))]
133struct ChunkData<T> {
134 data: [MaybeUninit<T>; 16],
135}
136
137impl<T> Default for ChunkData<T> {
138 fn default() -> Self {
139 unsafe { MaybeUninit::uninit().assume_init() }
140 }
141}
142
143#[derive(Clone, Copy, Default)]
144struct ChunkMeta {
145 bitmap: u16,
146 occupancy: u8,
147 is_dirty: bool,
148}
149
150struct Workspace<T> {
153 datas: Vec<ChunkData<T>>,
154 metas: Vec<ChunkMeta>,
155 overflow: Vec<(usize, MaybeUninit<T>)>,
156}
157
158impl<T> Workspace<T> {
159 fn new() -> Self {
160 Self {
161 datas: Vec::new(),
162 metas: Vec::new(),
163 overflow: Vec::new(),
164 }
165 }
166
167 #[inline(always)]
168 fn prepare(&mut self, c: usize) {
169 self.metas.clear();
170 self.metas.resize(c, ChunkMeta::default());
171 self.datas.clear();
172 self.datas.reserve(c);
173 unsafe {
174 self.datas.set_len(c);
175 }
176 self.overflow.clear();
177 }
178}
179
180fn zan_sort_local<T: SortKey>(data: &mut [T], min_key: u64, max_key: u64, ws: &mut Workspace<T>) {
183 let n = data.len();
184 if n <= 1 {
185 return;
186 }
187 let range = max_key.saturating_sub(min_key);
188 if range == 0 {
189 return;
190 }
191
192 let c = cmp::max(1, n / 4);
193 let m = (c * 16 - 1) as u64;
194 let multiplier = (((m as u128) << 32) / (range as u128)) as u64;
196
197 ws.prepare(c);
198 let metas = &mut ws.metas;
199 let datas = &mut ws.datas;
200 let overflow = &mut ws.overflow;
201
202 for i in 0..n {
204 unsafe {
205 let v = std::ptr::read(data.as_ptr().add(i));
206 let v_key = v.sort_key();
207 let v_diff = v_key - min_key;
208 let i_v = ((v_diff as u128 * multiplier as u128) >> 32) as usize;
209
210 let chunk_id = cmp::min(i_v >> 4, c - 1);
211 let offset = i_v & 15;
212
213 let meta = &mut metas[chunk_id];
214 let data_chunk = &mut datas[chunk_id];
215
216 if meta.occupancy < 16 {
217 let bit = 1 << offset;
218 if (meta.bitmap & bit) == 0 {
220 data_chunk.data[offset].write(v);
221 meta.bitmap |= bit;
222 meta.occupancy += 1;
223 } else {
224 meta.is_dirty = true;
225 let empty_offset = (!meta.bitmap).trailing_zeros() as usize;
226 data_chunk.data[empty_offset].write(v);
227 meta.bitmap |= 1 << empty_offset;
228 meta.occupancy += 1;
229 }
230 } else {
231 overflow.push((chunk_id, MaybeUninit::new(v)));
233 }
234 }
235 }
236
237 if overflow.len() > 1 {
238 sort_overflow(overflow);
239 }
240
241 let mut overflow_idx = 0;
243 let mut write_ptr = 0;
244
245 for id in 0..c {
246 let meta = &metas[id];
247 let data_chunk = &mut datas[id];
248 let has_overflow = overflow_idx < overflow.len() && overflow[overflow_idx].0 == id;
249
250 if meta.occupancy == 0 && !has_overflow {
251 continue;
252 }
253
254 let mut local: [MaybeUninit<T>; 16] = unsafe { MaybeUninit::uninit().assume_init() };
255 let mut local_len = 0;
256 let mut bmp = meta.bitmap;
257
258 while bmp != 0 {
260 let offset = bmp.trailing_zeros() as usize;
261 unsafe {
262 local[local_len].write(data_chunk.data[offset].assume_init_read());
263 }
264 local_len += 1;
265 bmp &= bmp - 1;
266 }
267
268 if meta.is_dirty && local_len > 1 {
270 unsafe {
271 let slice = std::slice::from_raw_parts_mut(local.as_mut_ptr() as *mut T, local_len);
272 custom_insertion_sort(slice);
273 }
274 }
275
276 if !has_overflow {
278 unsafe {
279 let dst = data.as_mut_ptr().add(write_ptr);
280 let src = local.as_ptr() as *const T;
281 std::ptr::copy_nonoverlapping(src, dst, local_len);
282 }
283 write_ptr += local_len;
284 } else {
285 let mut l_idx = 0;
286 loop {
287 let has_local = l_idx < local_len;
288 let has_over = overflow_idx < overflow.len() && overflow[overflow_idx].0 == id;
289
290 if has_local && has_over {
291 unsafe {
292 let l_key = (*(local.as_ptr().add(l_idx) as *const T)).sort_key();
293 let o_key = overflow[overflow_idx].1.assume_init_ref().sort_key();
294 if l_key <= o_key {
295 let l_val = local.as_ptr().add(l_idx).cast::<T>().read();
296 data.as_mut_ptr().add(write_ptr).write(l_val);
297 l_idx += 1;
298 } else {
299 let o_val = overflow[overflow_idx].1.assume_init_read();
300 data.as_mut_ptr().add(write_ptr).write(o_val);
301 overflow_idx += 1;
302 }
303 }
304 write_ptr += 1;
305 } else if has_local {
306 unsafe {
307 let l_val = local.as_ptr().add(l_idx).cast::<T>().read();
308 data.as_mut_ptr().add(write_ptr).write(l_val);
309 }
310 l_idx += 1;
311 write_ptr += 1;
312 } else if has_over {
313 unsafe {
314 let o_val = overflow[overflow_idx].1.assume_init_read();
315 data.as_mut_ptr().add(write_ptr).write(o_val);
316 }
317 overflow_idx += 1;
318 write_ptr += 1;
319 } else {
320 break;
321 }
322 }
323 }
324 }
325}
326
327pub fn zan_sort<T: SortKey + Send>(data: &mut [T]) {
333 let n = data.len();
334 if n <= 1 {
335 return;
336 }
337
338 #[cfg(not(feature = "pure"))]
340 {
341 if n <= 16 {
342 custom_insertion_sort(data);
343 return;
344 } else if n <= 5000 {
345 data.sort_unstable_by_key(|item| item.sort_key());
346 return;
347 }
348 }
349
350 #[cfg(feature = "pure")]
351 {
352 if n <= 16 {
353 custom_insertion_sort(data);
354 return;
355 }
356 }
357
358 let mut min_key = u64::MAX;
360 let mut max_key = u64::MIN;
361 for item in data.iter() {
362 let key = item.sort_key();
363 if key < min_key {
364 min_key = key;
365 }
366 if key > max_key {
367 max_key = key;
368 }
369 }
370
371 if min_key == max_key {
372 return;
373 }
374
375 if n <= 16384 {
377 let mut ws = Workspace::new();
378 zan_sort_local(data, min_key, max_key, &mut ws);
379 return;
380 }
381
382 let target_num_buckets = (n / 32768).next_power_of_two().clamp(16, 16384);
386 let num_buckets = target_num_buckets;
387
388 let range = max_key.saturating_sub(min_key);
389 let shift_bits = if range > (u32::MAX as u64) {
390 64 - range.leading_zeros() - 32
391 } else {
392 0
393 };
394 let scaled_range = range >> shift_bits;
395 let multiplier = ((num_buckets as u64) << 32) / (scaled_range + 1);
396
397 let num_threads = thread::available_parallelism()
398 .map(|n| n.get())
399 .unwrap_or(4);
400 let chunk_size = n.div_ceil(num_threads);
401
402 let mut local_counts = vec![vec![0usize; num_buckets]; num_threads];
404 thread::scope(|s| {
405 for (chunk, counts) in data.chunks_mut(chunk_size).zip(local_counts.iter_mut()) {
406 s.spawn(move || {
407 for item in chunk {
408 let v_diff = item.sort_key() - min_key;
409 let scaled_diff = v_diff >> shift_bits;
410 let bucket = ((scaled_diff * multiplier) >> 32) as usize;
411 counts[bucket] += 1;
412 }
413 });
414 }
415 });
416
417 let mut bucket_offsets = vec![0usize; num_buckets];
419 let mut local_offsets = vec![vec![0usize; num_buckets]; num_threads];
420 let mut global_counts = vec![0usize; num_buckets];
421 let mut sum = 0;
422
423 for b in 0..num_buckets {
424 bucket_offsets[b] = sum;
425 for t in 0..num_threads {
426 local_offsets[t][b] = sum;
427 sum += local_counts[t][b];
428 global_counts[b] += local_counts[t][b];
429 }
430 }
431
432 let mut buffer: Vec<MaybeUninit<T>> = Vec::with_capacity(n);
433 unsafe {
434 buffer.set_len(n);
435 }
436
437 let data_ptr = data.as_mut_ptr() as usize;
438 let buffer_ptr = buffer.as_mut_ptr() as usize;
439
440 thread::scope(|s| {
442 for (t_id, mut offsets) in local_offsets.into_iter().enumerate() {
443 let chunk_start = t_id * chunk_size;
444 let chunk_end = cmp::min(chunk_start + chunk_size, n);
445
446 s.spawn(move || unsafe {
447 let d_ptr = data_ptr as *mut T;
448 let b_ptr = buffer_ptr as *mut MaybeUninit<T>;
449
450 const BUF_SIZE: usize = 16;
451 let mut local_buf: Vec<[MaybeUninit<T>; BUF_SIZE]> =
453 Vec::with_capacity(num_buckets);
454 local_buf.set_len(num_buckets);
455
456 let mut local_idx = vec![0usize; num_buckets];
457
458 for i in chunk_start..chunk_end {
459 let v_ptr = d_ptr.add(i);
460 let v_key = (*v_ptr).sort_key();
461 let v_diff = v_key - min_key;
462 let scaled_diff = v_diff >> shift_bits;
463 let bucket = ((scaled_diff * multiplier) >> 32) as usize;
464
465 let idx = local_idx[bucket];
466 local_buf[bucket][idx] = std::ptr::read(v_ptr as *const MaybeUninit<T>);
467 local_idx[bucket] = idx + 1;
468
469 if idx + 1 == BUF_SIZE {
470 let dst = b_ptr.add(offsets[bucket]);
471 std::ptr::copy_nonoverlapping(local_buf[bucket].as_ptr(), dst, BUF_SIZE);
472 offsets[bucket] += BUF_SIZE;
473 local_idx[bucket] = 0;
474 }
475 }
476
477 for b in 0..num_buckets {
478 let remain = local_idx[b];
479 if remain > 0 {
480 let dst = b_ptr.add(offsets[b]);
481 std::ptr::copy_nonoverlapping(local_buf[b].as_ptr(), dst, remain);
482 offsets[b] += remain;
483 }
484 }
485 });
486 }
487 });
488
489 let buckets_per_thread = num_buckets.div_ceil(num_threads);
491
492 let workspaces: Vec<Workspace<T>> = (0..num_threads)
493 .map(|t_id| {
494 let start_b = t_id * buckets_per_thread;
495 let end_b = cmp::min(start_b + buckets_per_thread, num_buckets);
496 let max_bucket_count = (start_b..end_b)
497 .map(|b| global_counts[b])
498 .max()
499 .unwrap_or(0);
500
501 let mut ws = Workspace::new();
502 if max_bucket_count > 0 {
503 ws.prepare(cmp::max(1, max_bucket_count / 4));
504 }
505 ws
506 })
507 .collect();
508
509 let mut ws_iter = workspaces.into_iter();
510
511 thread::scope(|s| {
512 for t_id in 0..num_threads {
513 let start_b = t_id * buckets_per_thread;
514 let end_b = cmp::min(start_b + buckets_per_thread, num_buckets);
515 #[allow(unused_mut, unused_variables)]
516 let mut ws = ws_iter.next().unwrap();
517 let g_counts = &global_counts;
518 let b_offsets = &bucket_offsets;
519
520 s.spawn(move || unsafe {
521 let d_ptr = data_ptr as *mut T;
522 let b_ptr = buffer_ptr as *mut MaybeUninit<T>;
523
524 for b in start_b..end_b {
525 let count = g_counts[b];
526 if count == 0 {
527 continue;
528 }
529
530 let offset = b_offsets[b];
531 let block_ptr = b_ptr.add(offset) as *mut T;
532 let block = std::slice::from_raw_parts_mut(block_ptr, count);
533
534 if count <= 16 {
535 custom_insertion_sort(block);
536 } else {
537 #[cfg(not(feature = "pure"))]
538 {
539 if count <= 5000 {
540 block.sort_unstable_by_key(|item| item.sort_key());
541 std::ptr::copy_nonoverlapping(block_ptr, d_ptr.add(offset), count);
542 continue;
543 }
544 }
545
546 let (mut l_min, mut l_max) = (u64::MAX, u64::MIN);
548 for item in block.iter() {
549 let key = item.sort_key();
550 if key < l_min {
551 l_min = key;
552 }
553 if key > l_max {
554 l_max = key;
555 }
556 }
557 if l_min != l_max {
558 zan_sort_local(block, l_min, l_max, &mut ws);
559 }
560 }
561
562 std::ptr::copy_nonoverlapping(block_ptr, d_ptr.add(offset), count);
563 }
564 });
565 }
566 });
567}