Skip to main content

zan_sort/
core.rs

1use std::cmp;
2use std::mem::MaybeUninit;
3use std::thread;
4
5/// The core trait of `zan-sort`.
6/// It maps any arbitrary data type into a strictly ordered, 1-dimensional `u64` space.
7/// In this algorithm, the `u64` representation is the absolute truth for ordering.
8pub trait SortKey {
9    fn sort_key(&self) -> u64;
10}
11
12// --- Primitive Type Implementations ---
13
14impl SortKey for u32 {
15    #[inline(always)]
16    fn sort_key(&self) -> u64 {
17        *self as u64
18    }
19}
20
21impl SortKey for u64 {
22    #[inline(always)]
23    fn sort_key(&self) -> u64 {
24        *self
25    }
26}
27
28// --- Signed Integers ---
29// Branchless two's complement mapping:
30// Inverts the sign bit (XOR) to strictly align negative and positive values in the u64 space.
31impl SortKey for i32 {
32    #[inline(always)]
33    fn sort_key(&self) -> u64 {
34        (*self as u32 ^ 0x8000_0000) as u64
35    }
36}
37
38impl SortKey for i64 {
39    #[inline(always)]
40    fn sort_key(&self) -> u64 {
41        *self as u64 ^ 0x8000_0000_0000_0000
42    }
43}
44
45// --- Floating-Point Numbers ---
46// Branchless IEEE 754 bit-hack mapping:
47// Uses arithmetic right shift (>>) to generate a sign mask without CPU branch-prediction stalls.
48impl SortKey for f32 {
49    #[inline(always)]
50    fn sort_key(&self) -> u64 {
51        let bits = self.to_bits();
52        let sign_mask = ((bits as i32) >> 31) as u32;
53        (bits ^ (sign_mask | 0x8000_0000)) as u64
54    }
55}
56
57impl SortKey for f64 {
58    #[inline(always)]
59    fn sort_key(&self) -> u64 {
60        let bits = self.to_bits();
61        let sign_mask = ((bits as i64) >> 63) as u64;
62        bits ^ (sign_mask | 0x8000_0000_0000_0000)
63    }
64}
65
66/// Micro-optimized in-place insertion sort for extremely small arrays (N <= 16).
67/// By utilizing raw `ptr::read` and `ptr::write` instead of `slice::swap`,
68/// it coercers LLVM to perform register-level element shifting rather than memory-to-memory copies.
69#[inline(always)]
70pub fn custom_insertion_sort<T: SortKey>(arr: &mut [T]) {
71    let len = arr.len();
72    if len <= 1 {
73        return;
74    }
75
76    let base_ptr = arr.as_mut_ptr();
77    for i in 1..len {
78        unsafe {
79            let val_ptr = base_ptr.add(i);
80            let val = std::ptr::read(val_ptr);
81            let val_key = val.sort_key();
82            let mut j = i;
83            while j > 0 {
84                let prev_ptr = base_ptr.add(j - 1);
85                if (*prev_ptr).sort_key() > val_key {
86                    std::ptr::write(base_ptr.add(j), std::ptr::read(prev_ptr));
87                    j -= 1;
88                } else {
89                    break;
90                }
91            }
92            std::ptr::write(base_ptr.add(j), val);
93        }
94    }
95}
96
97/// Helper function to sort overflow elements.
98/// Similar to `custom_insertion_sort`, but specifically handles tuples containing chunk IDs.
99#[inline(always)]
100fn sort_overflow<T: SortKey>(arr: &mut [(usize, MaybeUninit<T>)]) {
101    let len = arr.len();
102    if len <= 1 {
103        return;
104    }
105
106    let base_ptr = arr.as_mut_ptr();
107    for i in 1..len {
108        unsafe {
109            let val_ptr = base_ptr.add(i);
110            let val = std::ptr::read(val_ptr);
111            let val_chunk = val.0;
112            let val_key = val.1.assume_init_ref().sort_key();
113            let mut j = i;
114            while j > 0 {
115                let prev_ptr = base_ptr.add(j - 1);
116                let prev_chunk = (*prev_ptr).0;
117                let prev_key = (*prev_ptr).1.assume_init_ref().sort_key();
118                if prev_chunk > val_chunk || (prev_chunk == val_chunk && prev_key > val_key) {
119                    std::ptr::write(base_ptr.add(j), std::ptr::read(prev_ptr));
120                    j -= 1;
121                } else {
122                    break;
123                }
124            }
125            std::ptr::write(base_ptr.add(j), val);
126        }
127    }
128}
129
130// --- Structure of Arrays (SoA) Definitions for the Micro/Mid Phase ---
131
132#[repr(C, align(64))]
133struct ChunkData<T> {
134    data: [MaybeUninit<T>; 16],
135}
136
137impl<T> Default for ChunkData<T> {
138    fn default() -> Self {
139        unsafe { MaybeUninit::uninit().assume_init() }
140    }
141}
142
143#[derive(Clone, Copy, Default)]
144struct ChunkMeta {
145    bitmap: u16,
146    occupancy: u8,
147    is_dirty: bool,
148}
149
150/// A thread-local, zero-allocation memory arena.
151/// Avoids OS-level `malloc`/`free` calls within hot execution loops by reusing vectors.
152struct Workspace<T> {
153    datas: Vec<ChunkData<T>>,
154    metas: Vec<ChunkMeta>,
155    overflow: Vec<(usize, MaybeUninit<T>)>,
156}
157
158impl<T> Workspace<T> {
159    fn new() -> Self {
160        Self {
161            datas: Vec::new(),
162            metas: Vec::new(),
163            overflow: Vec::new(),
164        }
165    }
166
167    #[inline(always)]
168    fn prepare(&mut self, c: usize) {
169        self.metas.clear();
170        self.metas.resize(c, ChunkMeta::default());
171        self.datas.clear();
172        self.datas.reserve(c);
173        unsafe {
174            self.datas.set_len(c);
175        }
176        self.overflow.clear();
177    }
178}
179
180/// The core O(N) arithmetic routing algorithm for mid-scale localized processing.
181/// Maps elements into an SoA structure (ChunkData / ChunkMeta) using 128-bit linear interpolation.
182fn zan_sort_local<T: SortKey>(data: &mut [T], min_key: u64, max_key: u64, ws: &mut Workspace<T>) {
183    let n = data.len();
184    if n <= 1 {
185        return;
186    }
187    let range = max_key.saturating_sub(min_key);
188    if range == 0 {
189        return;
190    }
191
192    let c = cmp::max(1, n / 4);
193    let m = (c * 16 - 1) as u64;
194    // Pre-calculate the routing multiplier to avoid division in the loop
195    let multiplier = (((m as u128) << 32) / (range as u128)) as u64;
196
197    ws.prepare(c);
198    let metas = &mut ws.metas;
199    let datas = &mut ws.datas;
200    let overflow = &mut ws.overflow;
201
202    // Phase 1: Arithmetic O(1) Routing
203    for i in 0..n {
204        unsafe {
205            let v = std::ptr::read(data.as_ptr().add(i));
206            let v_key = v.sort_key();
207            let v_diff = v_key - min_key;
208            let i_v = ((v_diff as u128 * multiplier as u128) >> 32) as usize;
209
210            let chunk_id = cmp::min(i_v >> 4, c - 1);
211            let offset = i_v & 15;
212
213            let meta = &mut metas[chunk_id];
214            let data_chunk = &mut datas[chunk_id];
215
216            if meta.occupancy < 16 {
217                let bit = 1 << offset;
218                // Collision resolution using bit manipulation
219                if (meta.bitmap & bit) == 0 {
220                    data_chunk.data[offset].write(v);
221                    meta.bitmap |= bit;
222                    meta.occupancy += 1;
223                } else {
224                    meta.is_dirty = true;
225                    let empty_offset = (!meta.bitmap).trailing_zeros() as usize;
226                    data_chunk.data[empty_offset].write(v);
227                    meta.bitmap |= 1 << empty_offset;
228                    meta.occupancy += 1;
229                }
230            } else {
231                // If a 16-element chunk is full, push to the fallback overflow buffer
232                overflow.push((chunk_id, MaybeUninit::new(v)));
233            }
234        }
235    }
236
237    if overflow.len() > 1 {
238        sort_overflow(overflow);
239    }
240
241    // Phase 2: Sequential Write-back & Micro-sorting
242    let mut overflow_idx = 0;
243    let mut write_ptr = 0;
244
245    for id in 0..c {
246        let meta = &metas[id];
247        let data_chunk = &mut datas[id];
248        let has_overflow = overflow_idx < overflow.len() && overflow[overflow_idx].0 == id;
249
250        if meta.occupancy == 0 && !has_overflow {
251            continue;
252        }
253
254        let mut local: [MaybeUninit<T>; 16] = unsafe { MaybeUninit::uninit().assume_init() };
255        let mut local_len = 0;
256        let mut bmp = meta.bitmap;
257
258        // Extract occupied elements densely via trailing zeros
259        while bmp != 0 {
260            let offset = bmp.trailing_zeros() as usize;
261            unsafe {
262                local[local_len].write(data_chunk.data[offset].assume_init_read());
263            }
264            local_len += 1;
265            bmp &= bmp - 1;
266        }
267
268        // Sort dirty chunks (where collisions forced elements out of exact alignment)
269        if meta.is_dirty && local_len > 1 {
270            unsafe {
271                let slice = std::slice::from_raw_parts_mut(local.as_mut_ptr() as *mut T, local_len);
272                custom_insertion_sort(slice);
273            }
274        }
275
276        // Merge local chunk data with potential overflow elements, writing back to the original slice
277        if !has_overflow {
278            unsafe {
279                let dst = data.as_mut_ptr().add(write_ptr);
280                let src = local.as_ptr() as *const T;
281                std::ptr::copy_nonoverlapping(src, dst, local_len);
282            }
283            write_ptr += local_len;
284        } else {
285            let mut l_idx = 0;
286            loop {
287                let has_local = l_idx < local_len;
288                let has_over = overflow_idx < overflow.len() && overflow[overflow_idx].0 == id;
289
290                if has_local && has_over {
291                    unsafe {
292                        let l_key = (*(local.as_ptr().add(l_idx) as *const T)).sort_key();
293                        let o_key = overflow[overflow_idx].1.assume_init_ref().sort_key();
294                        if l_key <= o_key {
295                            let l_val = local.as_ptr().add(l_idx).cast::<T>().read();
296                            data.as_mut_ptr().add(write_ptr).write(l_val);
297                            l_idx += 1;
298                        } else {
299                            let o_val = overflow[overflow_idx].1.assume_init_read();
300                            data.as_mut_ptr().add(write_ptr).write(o_val);
301                            overflow_idx += 1;
302                        }
303                    }
304                    write_ptr += 1;
305                } else if has_local {
306                    unsafe {
307                        let l_val = local.as_ptr().add(l_idx).cast::<T>().read();
308                        data.as_mut_ptr().add(write_ptr).write(l_val);
309                    }
310                    l_idx += 1;
311                    write_ptr += 1;
312                } else if has_over {
313                    unsafe {
314                        let o_val = overflow[overflow_idx].1.assume_init_read();
315                        data.as_mut_ptr().add(write_ptr).write(o_val);
316                    }
317                    overflow_idx += 1;
318                    write_ptr += 1;
319                } else {
320                    break;
321                }
322            }
323        }
324    }
325}
326
327// --- Public API ---
328
329/// High-performance, $O(N)$ generic hybrid sort.
330/// It dynamically scales from standard `pdqsort` (for mid-sized arrays) up to
331/// `Ordex`-inspired lock-free parallel arithmetic routing for massive arrays.
332pub fn zan_sort<T: SortKey + Send>(data: &mut [T]) {
333    let n = data.len();
334    if n <= 1 {
335        return;
336    }
337
338    // Dynamic Entry Point / Threshold Detection
339    #[cfg(not(feature = "pure"))]
340    {
341        if n <= 16 {
342            custom_insertion_sort(data);
343            return;
344        } else if n <= 5000 {
345            data.sort_unstable_by_key(|item| item.sort_key());
346            return;
347        }
348    }
349
350    #[cfg(feature = "pure")]
351    {
352        if n <= 16 {
353            custom_insertion_sort(data);
354            return;
355        }
356    }
357
358    // Determine the global bounds
359    let mut min_key = u64::MAX;
360    let mut max_key = u64::MIN;
361    for item in data.iter() {
362        let key = item.sort_key();
363        if key < min_key {
364            min_key = key;
365        }
366        if key > max_key {
367            max_key = key;
368        }
369    }
370
371    if min_key == max_key {
372        return;
373    }
374
375    // Route mid-scale data to single-threaded SoA bucketing; larger datasets fall through to the parallel Macro Phase.
376    if n <= 16384 {
377        let mut ws = Workspace::new();
378        zan_sort_local(data, min_key, max_key, &mut ws);
379        return;
380    }
381
382    // --- Macro Phase: Dynamic Multi-Bucket Routing ---
383
384    // Clamp minimum buckets to 16 to prevent over-partitioning
385    let target_num_buckets = (n / 32768).next_power_of_two().clamp(16, 16384);
386    let num_buckets = target_num_buckets;
387
388    let range = max_key.saturating_sub(min_key);
389    let shift_bits = if range > (u32::MAX as u64) {
390        64 - range.leading_zeros() - 32
391    } else {
392        0
393    };
394    let scaled_range = range >> shift_bits;
395    let multiplier = ((num_buckets as u64) << 32) / (scaled_range + 1);
396
397    let num_threads = thread::available_parallelism()
398        .map(|n| n.get())
399        .unwrap_or(4);
400    let chunk_size = n.div_ceil(num_threads);
401
402    // Step 1: Parallel Local Histograms
403    let mut local_counts = vec![vec![0usize; num_buckets]; num_threads];
404    thread::scope(|s| {
405        for (chunk, counts) in data.chunks_mut(chunk_size).zip(local_counts.iter_mut()) {
406            s.spawn(move || {
407                for item in chunk {
408                    let v_diff = item.sort_key() - min_key;
409                    let scaled_diff = v_diff >> shift_bits;
410                    let bucket = ((scaled_diff * multiplier) >> 32) as usize;
411                    counts[bucket] += 1;
412                }
413            });
414        }
415    });
416
417    // Step 2: Global Prefix Sums
418    let mut bucket_offsets = vec![0usize; num_buckets];
419    let mut local_offsets = vec![vec![0usize; num_buckets]; num_threads];
420    let mut global_counts = vec![0usize; num_buckets];
421    let mut sum = 0;
422
423    for b in 0..num_buckets {
424        bucket_offsets[b] = sum;
425        for t in 0..num_threads {
426            local_offsets[t][b] = sum;
427            sum += local_counts[t][b];
428            global_counts[b] += local_counts[t][b];
429        }
430    }
431
432    let mut buffer: Vec<MaybeUninit<T>> = Vec::with_capacity(n);
433    unsafe {
434        buffer.set_len(n);
435    }
436
437    let data_ptr = data.as_mut_ptr() as usize;
438    let buffer_ptr = buffer.as_mut_ptr() as usize;
439
440    // Step 3: Lock-free Parallel Scatter with Heap-allocated Local Buffers
441    thread::scope(|s| {
442        for (t_id, mut offsets) in local_offsets.into_iter().enumerate() {
443            let chunk_start = t_id * chunk_size;
444            let chunk_end = cmp::min(chunk_start + chunk_size, n);
445
446            s.spawn(move || unsafe {
447                let d_ptr = data_ptr as *mut T;
448                let b_ptr = buffer_ptr as *mut MaybeUninit<T>;
449
450                const BUF_SIZE: usize = 16;
451                // Allocate physical memory directly to bypass initialization overhead and Clone bounds.
452                let mut local_buf: Vec<[MaybeUninit<T>; BUF_SIZE]> =
453                    Vec::with_capacity(num_buckets);
454                local_buf.set_len(num_buckets);
455
456                let mut local_idx = vec![0usize; num_buckets];
457
458                for i in chunk_start..chunk_end {
459                    let v_ptr = d_ptr.add(i);
460                    let v_key = (*v_ptr).sort_key();
461                    let v_diff = v_key - min_key;
462                    let scaled_diff = v_diff >> shift_bits;
463                    let bucket = ((scaled_diff * multiplier) >> 32) as usize;
464
465                    let idx = local_idx[bucket];
466                    local_buf[bucket][idx] = std::ptr::read(v_ptr as *const MaybeUninit<T>);
467                    local_idx[bucket] = idx + 1;
468
469                    if idx + 1 == BUF_SIZE {
470                        let dst = b_ptr.add(offsets[bucket]);
471                        std::ptr::copy_nonoverlapping(local_buf[bucket].as_ptr(), dst, BUF_SIZE);
472                        offsets[bucket] += BUF_SIZE;
473                        local_idx[bucket] = 0;
474                    }
475                }
476
477                for b in 0..num_buckets {
478                    let remain = local_idx[b];
479                    if remain > 0 {
480                        let dst = b_ptr.add(offsets[b]);
481                        std::ptr::copy_nonoverlapping(local_buf[b].as_ptr(), dst, remain);
482                        offsets[b] += remain;
483                    }
484                }
485            });
486        }
487    });
488
489    // Step 4: Ahead-of-Time Allocation & Parallel Recursive Sort
490    let buckets_per_thread = num_buckets.div_ceil(num_threads);
491
492    let workspaces: Vec<Workspace<T>> = (0..num_threads)
493        .map(|t_id| {
494            let start_b = t_id * buckets_per_thread;
495            let end_b = cmp::min(start_b + buckets_per_thread, num_buckets);
496            let max_bucket_count = (start_b..end_b)
497                .map(|b| global_counts[b])
498                .max()
499                .unwrap_or(0);
500
501            let mut ws = Workspace::new();
502            if max_bucket_count > 0 {
503                ws.prepare(cmp::max(1, max_bucket_count / 4));
504            }
505            ws
506        })
507        .collect();
508
509    let mut ws_iter = workspaces.into_iter();
510
511    thread::scope(|s| {
512        for t_id in 0..num_threads {
513            let start_b = t_id * buckets_per_thread;
514            let end_b = cmp::min(start_b + buckets_per_thread, num_buckets);
515            #[allow(unused_mut, unused_variables)]
516            let mut ws = ws_iter.next().unwrap();
517            let g_counts = &global_counts;
518            let b_offsets = &bucket_offsets;
519
520            s.spawn(move || unsafe {
521                let d_ptr = data_ptr as *mut T;
522                let b_ptr = buffer_ptr as *mut MaybeUninit<T>;
523
524                for b in start_b..end_b {
525                    let count = g_counts[b];
526                    if count == 0 {
527                        continue;
528                    }
529
530                    let offset = b_offsets[b];
531                    let block_ptr = b_ptr.add(offset) as *mut T;
532                    let block = std::slice::from_raw_parts_mut(block_ptr, count);
533
534                    if count <= 16 {
535                        custom_insertion_sort(block);
536                    } else {
537                        #[cfg(not(feature = "pure"))]
538                        {
539                            if count <= 5000 {
540                                block.sort_unstable_by_key(|item| item.sort_key());
541                                std::ptr::copy_nonoverlapping(block_ptr, d_ptr.add(offset), count);
542                                continue;
543                            }
544                        }
545
546                        // Pure arithmetic routing fallback
547                        let (mut l_min, mut l_max) = (u64::MAX, u64::MIN);
548                        for item in block.iter() {
549                            let key = item.sort_key();
550                            if key < l_min {
551                                l_min = key;
552                            }
553                            if key > l_max {
554                                l_max = key;
555                            }
556                        }
557                        if l_min != l_max {
558                            zan_sort_local(block, l_min, l_max, &mut ws);
559                        }
560                    }
561
562                    std::ptr::copy_nonoverlapping(block_ptr, d_ptr.add(offset), count);
563                }
564            });
565        }
566    });
567}