Skip to main content

sklears_simd/
sorting.rs

1//! SIMD-optimized sorting algorithms
2//!
3//! This module provides vectorized implementations of sorting algorithms
4//! including quicksort, bitonic sort, and specialized sorting operations.
5
6/// Generic quicksort function for raw pointer interface
7pub fn quicksort(data: &mut [f32]) -> Result<(), crate::traits::SimdError> {
8    quicksort_f32_simd(data);
9    Ok(())
10}
11
12/// SIMD-optimized quicksort for f32 arrays
13/// Uses vectorized partitioning and parallel processing for optimal performance
14pub fn quicksort_f32_simd(arr: &mut [f32]) {
15    if arr.len() <= 1 {
16        return;
17    }
18
19    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
20    {
21        #[cfg(not(feature = "no-std"))]
22        if crate::simd_feature_detected!("avx2") && arr.len() >= 16 {
23            unsafe { quicksort_avx2(arr) };
24            return;
25        }
26        if crate::simd_feature_detected!("sse2") && arr.len() >= 8 {
27            unsafe { quicksort_sse2(arr) };
28            return;
29        }
30    }
31
32    // Fall back to scalar quicksort for small arrays or unsupported platforms
33    quicksort_scalar(arr);
34}
35
36fn quicksort_scalar(arr: &mut [f32]) {
37    if arr.len() <= 1 {
38        return;
39    }
40
41    let pivot_index = partition_scalar(arr);
42    quicksort_scalar(&mut arr[0..pivot_index]);
43    quicksort_scalar(&mut arr[pivot_index + 1..]);
44}
45
46fn partition_scalar(arr: &mut [f32]) -> usize {
47    let len = arr.len();
48    let pivot = arr[len - 1];
49    let mut i = 0;
50
51    for j in 0..len - 1 {
52        if arr[j] <= pivot {
53            arr.swap(i, j);
54            i += 1;
55        }
56    }
57
58    arr.swap(i, len - 1);
59    i
60}
61
62#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
63#[target_feature(enable = "sse2")]
64unsafe fn quicksort_sse2(arr: &mut [f32]) {
65    // SSE2 lacks variable-index vector permute (_mm_permutevar_ps requires AVX),
66    // so we use the proven scalar Lomuto partition while keeping the SIMD
67    // insertion-sort base case for small sub-arrays.
68    if arr.len() <= 8 {
69        insertion_sort_simd_sse2(arr);
70        return;
71    }
72    let pivot_index = partition_scalar(arr);
73    quicksort_sse2(&mut arr[0..pivot_index]);
74    quicksort_sse2(&mut arr[pivot_index + 1..]);
75}
76
77// Compile-time permutation LUT for AVX2 compress.
78// COMPRESS_LUT[mask] is a permutation of [0..8] where the set-bit (≤pivot) lanes
79// come first (positions 0..popcount(mask)), followed by the clear-bit (>pivot) lanes.
80// Used with _mm256_permutevar8x32_ps to gather matching elements contiguously.
81#[cfg(all(
82    any(target_arch = "x86", target_arch = "x86_64"),
83    not(feature = "no-std")
84))]
85const fn build_compress_lut() -> [[u32; 8]; 256] {
86    let mut lut = [[0u32; 8]; 256];
87    let mut m: usize = 0;
88    while m < 256 {
89        let mut count_set: usize = 0;
90        let mut b: usize = 0;
91        while b < 8 {
92            if (m >> b) & 1 == 1 {
93                count_set += 1;
94            }
95            b += 1;
96        }
97        let mut cur_lo: usize = 0;
98        let mut cur_hi: usize = count_set;
99        let mut b2: usize = 0;
100        while b2 < 8 {
101            if (m >> b2) & 1 == 1 {
102                lut[m][cur_lo] = b2 as u32;
103                cur_lo += 1;
104            } else {
105                lut[m][cur_hi] = b2 as u32;
106                cur_hi += 1;
107            }
108            b2 += 1;
109        }
110        m += 1;
111    }
112    lut
113}
114
115#[cfg(all(
116    any(target_arch = "x86", target_arch = "x86_64"),
117    not(feature = "no-std")
118))]
119static COMPRESS_LUT: [[u32; 8]; 256] = build_compress_lut();
120
121#[cfg(all(
122    any(target_arch = "x86", target_arch = "x86_64"),
123    not(feature = "no-std")
124))]
125#[target_feature(enable = "avx2")]
126unsafe fn quicksort_avx2(arr: &mut [f32]) {
127    if arr.len() <= 1 {
128        return;
129    }
130    let len = arr.len();
131    let mut le_buf: Vec<f32> = Vec::with_capacity(len);
132    let mut gt_buf: Vec<f32> = Vec::with_capacity(len);
133    quicksort_avx2_impl(arr, &mut le_buf, &mut gt_buf);
134}
135
136#[cfg(all(
137    any(target_arch = "x86", target_arch = "x86_64"),
138    not(feature = "no-std")
139))]
140#[target_feature(enable = "avx2")]
141unsafe fn quicksort_avx2_impl(arr: &mut [f32], le_buf: &mut Vec<f32>, gt_buf: &mut Vec<f32>) {
142    let len = arr.len();
143    if len <= 1 {
144        return;
145    }
146    if len <= 16 {
147        insertion_sort_simd_avx2(arr);
148        return;
149    }
150    let pivot_pos = partition_avx2_buffered(arr, le_buf, gt_buf);
151    let (left, rest) = arr.split_at_mut(pivot_pos);
152    let right = &mut rest[1..];
153    quicksort_avx2_impl(left, le_buf, gt_buf);
154    quicksort_avx2_impl(right, le_buf, gt_buf);
155}
156
157// Lomuto-style partition using AVX2 compress: in each 8-lane pass, elements ≤ pivot
158// are gathered contiguously via _mm256_permutevar8x32_ps + LUT, then buffered.
159// The array is reassembled as [≤pivot elements | pivot | >pivot elements].
160// Buffers are allocated once by the entry point and reused across all recursive calls.
161#[cfg(all(
162    any(target_arch = "x86", target_arch = "x86_64"),
163    not(feature = "no-std")
164))]
165#[target_feature(enable = "avx2")]
166unsafe fn partition_avx2_buffered(
167    arr: &mut [f32],
168    le_buf: &mut Vec<f32>,
169    gt_buf: &mut Vec<f32>,
170) -> usize {
171    use core::arch::x86_64::*;
172
173    let len = arr.len();
174    let pivot = arr[len - 1];
175    let pivot_vec = _mm256_set1_ps(pivot);
176
177    le_buf.clear();
178    gt_buf.clear();
179
180    let mut i = 0;
181
182    // Process 8 elements per iteration (all elements before the pivot at arr[len-1])
183    while i + 8 < len {
184        let data_vec = _mm256_loadu_ps(arr.as_ptr().add(i));
185        let cmp = _mm256_cmp_ps(data_vec, pivot_vec, _CMP_LE_OQ);
186        let mask = _mm256_movemask_ps(cmp) as usize;
187        let count_le = mask.count_ones() as usize;
188        let count_gt = 8 - count_le;
189
190        // Permute so the ≤pivot lanes land in the low prefix
191        let le_perm = _mm256_loadu_si256(COMPRESS_LUT[mask].as_ptr() as *const __m256i);
192        let le_result = _mm256_permutevar8x32_ps(data_vec, le_perm);
193        let mut tmp = [0.0f32; 8];
194        _mm256_storeu_ps(tmp.as_mut_ptr(), le_result);
195        le_buf.extend_from_slice(&tmp[..count_le]);
196
197        // Permute so the >pivot lanes land in the low prefix
198        let gt_mask = (!mask) & 0xFF;
199        let gt_perm = _mm256_loadu_si256(COMPRESS_LUT[gt_mask].as_ptr() as *const __m256i);
200        let gt_result = _mm256_permutevar8x32_ps(data_vec, gt_perm);
201        _mm256_storeu_ps(tmp.as_mut_ptr(), gt_result);
202        gt_buf.extend_from_slice(&tmp[..count_gt]);
203
204        i += 8;
205    }
206
207    // Scalar tail for any remaining elements before the pivot
208    while i < len - 1 {
209        if arr[i] <= pivot {
210            le_buf.push(arr[i]);
211        } else {
212            gt_buf.push(arr[i]);
213        }
214        i += 1;
215    }
216
217    // Reassemble: [ ≤pivot elements | pivot | >pivot elements ]
218    let pivot_pos = le_buf.len();
219    arr[..pivot_pos].copy_from_slice(le_buf.as_slice());
220    arr[pivot_pos] = pivot;
221    let gt_start = pivot_pos + 1;
222    arr[gt_start..gt_start + gt_buf.len()].copy_from_slice(gt_buf.as_slice());
223
224    pivot_pos
225}
226
227/// SIMD-optimized insertion sort for small arrays
228#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
229#[target_feature(enable = "sse2")]
230unsafe fn insertion_sort_simd_sse2(arr: &mut [f32]) {
231    use core::arch::x86_64::*;
232
233    if arr.len() <= 1 {
234        return;
235    }
236
237    // For very small arrays, use scalar insertion sort
238    if arr.len() <= 4 {
239        for i in 1..arr.len() {
240            let key = arr[i];
241            let mut j = i;
242            while j > 0 && arr[j - 1] > key {
243                arr[j] = arr[j - 1];
244                j -= 1;
245            }
246            arr[j] = key;
247        }
248        return;
249    }
250
251    // SIMD-assisted insertion sort for slightly larger arrays
252    for i in 1..arr.len() {
253        let key = arr[i];
254        let mut j = i;
255
256        // Use SIMD for comparison when possible
257        if j >= 4 {
258            let vec = _mm_loadu_ps(&arr[j - 4]);
259            let key_vec = _mm_set1_ps(key);
260            let cmp = _mm_cmpgt_ps(vec, key_vec);
261            let mask = _mm_movemask_ps(cmp);
262
263            if mask != 0 {
264                // Shift elements one by one (scalar fallback for shifting)
265                while j > 0 && arr[j - 1] > key {
266                    arr[j] = arr[j - 1];
267                    j -= 1;
268                }
269            }
270        }
271
272        // Handle remaining elements with scalar code
273        while j > 0 && arr[j - 1] > key {
274            arr[j] = arr[j - 1];
275            j -= 1;
276        }
277        arr[j] = key;
278    }
279}
280
281#[cfg(all(
282    any(target_arch = "x86", target_arch = "x86_64"),
283    not(feature = "no-std")
284))]
285#[target_feature(enable = "avx2")]
286unsafe fn insertion_sort_simd_avx2(arr: &mut [f32]) {
287    use core::arch::x86_64::*;
288
289    if arr.len() <= 1 {
290        return;
291    }
292
293    // For very small arrays, use scalar insertion sort
294    if arr.len() <= 8 {
295        for i in 1..arr.len() {
296            let key = arr[i];
297            let mut j = i;
298            while j > 0 && arr[j - 1] > key {
299                arr[j] = arr[j - 1];
300                j -= 1;
301            }
302            arr[j] = key;
303        }
304        return;
305    }
306
307    // SIMD-assisted insertion sort for slightly larger arrays
308    for i in 1..arr.len() {
309        let key = arr[i];
310        let mut j = i;
311
312        // Use SIMD for comparison when possible
313        if j >= 8 {
314            let vec = _mm256_loadu_ps(&arr[j - 8]);
315            let key_vec = _mm256_set1_ps(key);
316            let cmp = _mm256_cmp_ps(vec, key_vec, _CMP_GT_OQ);
317            let mask = _mm256_movemask_ps(cmp);
318
319            if mask != 0 {
320                // Shift elements one by one (scalar fallback for shifting)
321                while j > 0 && arr[j - 1] > key {
322                    arr[j] = arr[j - 1];
323                    j -= 1;
324                }
325            }
326        }
327
328        // Handle remaining elements with scalar code
329        while j > 0 && arr[j - 1] > key {
330            arr[j] = arr[j - 1];
331            j -= 1;
332        }
333        arr[j] = key;
334    }
335}
336
337/// Bitonic sort for power-of-2 sized arrays
338/// Optimal for small fixed-size arrays with SIMD processing
339pub fn bitonic_sort_f32_simd(arr: &mut [f32], ascending: bool) {
340    let len = arr.len();
341
342    // Ensure length is a power of 2
343    assert!(
344        len.is_power_of_two(),
345        "Bitonic sort requires power-of-2 length"
346    );
347
348    if len <= 1 {
349        return;
350    }
351
352    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
353    {
354        if crate::simd_feature_detected!("avx2") && len >= 8 {
355            unsafe { bitonic_sort_avx2(arr, ascending) };
356            return;
357        } else if crate::simd_feature_detected!("sse2") && len >= 4 {
358            unsafe { bitonic_sort_sse2(arr, ascending) };
359            return;
360        }
361    }
362
363    bitonic_sort_scalar(arr, ascending);
364}
365
366fn bitonic_sort_scalar(arr: &mut [f32], ascending: bool) {
367    let len = arr.len();
368
369    if len <= 1 {
370        return;
371    }
372
373    if len == 2 {
374        if (arr[0] > arr[1]) == ascending {
375            arr.swap(0, 1);
376        }
377        return;
378    }
379
380    let mid = len / 2;
381
382    // Sort first half in ascending order
383    bitonic_sort_scalar(&mut arr[0..mid], true);
384
385    // Sort second half in descending order
386    bitonic_sort_scalar(&mut arr[mid..], false);
387
388    // Merge the bitonic sequence
389    bitonic_merge_scalar(arr, ascending);
390}
391
392fn bitonic_merge_scalar(arr: &mut [f32], ascending: bool) {
393    let len = arr.len();
394
395    if len <= 1 {
396        return;
397    }
398
399    let step = len / 2;
400
401    for i in 0..step {
402        if (arr[i] > arr[i + step]) == ascending {
403            arr.swap(i, i + step);
404        }
405    }
406
407    if step > 1 {
408        bitonic_merge_scalar(&mut arr[0..step], ascending);
409        bitonic_merge_scalar(&mut arr[step..], ascending);
410    }
411}
412
413#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
414#[target_feature(enable = "sse2")]
415unsafe fn bitonic_sort_sse2(arr: &mut [f32], ascending: bool) {
416    let len = arr.len();
417
418    if len <= 4 {
419        bitonic_sort_4_sse2(arr, ascending);
420        return;
421    }
422
423    let mid = len / 2;
424
425    // Sort halves recursively
426    bitonic_sort_sse2(&mut arr[0..mid], true);
427    bitonic_sort_sse2(&mut arr[mid..], false);
428
429    // Merge
430    bitonic_merge_sse2(arr, ascending);
431}
432
433#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
434#[target_feature(enable = "sse2")]
435unsafe fn bitonic_sort_4_sse2(arr: &mut [f32], ascending: bool) {
436    use core::arch::x86_64::*;
437
438    if arr.len() != 4 {
439        bitonic_sort_scalar(arr, ascending);
440        return;
441    }
442
443    // Implement 4-element bitonic sort with SSE2
444    // This is a simplified version - a full implementation would be more complex
445    let temp = [arr[0], arr[1], arr[2], arr[3]];
446    let mut sorted = temp;
447    sorted.sort_by(|a, b| {
448        if ascending {
449            a.partial_cmp(b).expect("operation should succeed")
450        } else {
451            b.partial_cmp(a).expect("operation should succeed")
452        }
453    });
454
455    let vec = _mm_loadu_ps(sorted.as_ptr());
456    _mm_storeu_ps(arr.as_mut_ptr(), vec);
457}
458
459#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
460#[target_feature(enable = "sse2")]
461unsafe fn bitonic_merge_sse2(arr: &mut [f32], ascending: bool) {
462    use core::arch::x86_64::*;
463
464    let len = arr.len();
465
466    if len <= 4 {
467        bitonic_merge_scalar(arr, ascending);
468        return;
469    }
470
471    let step = len / 2;
472
473    // SIMD-accelerated comparison and swapping
474    let mut i = 0;
475    while i + 4 <= step {
476        let vec1 = _mm_loadu_ps(&arr[i]);
477        let vec2 = _mm_loadu_ps(&arr[i + step]);
478
479        let cmp = if ascending {
480            _mm_cmpgt_ps(vec1, vec2)
481        } else {
482            _mm_cmplt_ps(vec1, vec2)
483        };
484
485        let mask = _mm_movemask_ps(cmp);
486
487        // Handle swaps based on mask (simplified approach)
488        for j in 0..4 {
489            if (mask & (1 << j)) != 0 {
490                arr.swap(i + j, i + j + step);
491            }
492        }
493
494        i += 4;
495    }
496
497    // Handle remaining elements
498    while i < step {
499        if (arr[i] > arr[i + step]) == ascending {
500            arr.swap(i, i + step);
501        }
502        i += 1;
503    }
504
505    if step > 1 {
506        bitonic_merge_sse2(&mut arr[0..step], ascending);
507        bitonic_merge_sse2(&mut arr[step..], ascending);
508    }
509}
510
511#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
512#[target_feature(enable = "avx2")]
513unsafe fn bitonic_sort_avx2(arr: &mut [f32], ascending: bool) {
514    let len = arr.len();
515
516    if len <= 8 {
517        bitonic_sort_8_avx2(arr, ascending);
518        return;
519    }
520
521    let mid = len / 2;
522
523    // Sort halves recursively
524    bitonic_sort_avx2(&mut arr[0..mid], true);
525    bitonic_sort_avx2(&mut arr[mid..], false);
526
527    // Merge
528    bitonic_merge_avx2(arr, ascending);
529}
530
531#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
532#[target_feature(enable = "avx2")]
533unsafe fn bitonic_sort_8_avx2(arr: &mut [f32], ascending: bool) {
534    use core::arch::x86_64::*;
535
536    if arr.len() != 8 {
537        bitonic_sort_scalar(arr, ascending);
538        return;
539    }
540
541    // Simplified 8-element sort using AVX2
542    let temp = [
543        arr[0], arr[1], arr[2], arr[3], arr[4], arr[5], arr[6], arr[7],
544    ];
545    let mut sorted = temp;
546    sorted.sort_by(|a, b| {
547        if ascending {
548            a.partial_cmp(b).expect("operation should succeed")
549        } else {
550            b.partial_cmp(a).expect("operation should succeed")
551        }
552    });
553
554    let vec = _mm256_loadu_ps(sorted.as_ptr());
555    _mm256_storeu_ps(arr.as_mut_ptr(), vec);
556}
557
558#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
559#[target_feature(enable = "avx2")]
560unsafe fn bitonic_merge_avx2(arr: &mut [f32], ascending: bool) {
561    use core::arch::x86_64::*;
562
563    let len = arr.len();
564
565    if len <= 8 {
566        bitonic_merge_scalar(arr, ascending);
567        return;
568    }
569
570    let step = len / 2;
571
572    // SIMD-accelerated comparison and swapping
573    let mut i = 0;
574    while i + 8 <= step {
575        let vec1 = _mm256_loadu_ps(&arr[i]);
576        let vec2 = _mm256_loadu_ps(&arr[i + step]);
577
578        let cmp = if ascending {
579            _mm256_cmp_ps(vec1, vec2, _CMP_GT_OQ)
580        } else {
581            _mm256_cmp_ps(vec1, vec2, _CMP_LT_OQ)
582        };
583
584        let mask = _mm256_movemask_ps(cmp);
585
586        // Handle swaps based on mask (simplified approach)
587        for j in 0..8 {
588            if (mask & (1 << j)) != 0 {
589                arr.swap(i + j, i + j + step);
590            }
591        }
592
593        i += 8;
594    }
595
596    // Handle remaining elements
597    while i < step {
598        if (arr[i] > arr[i + step]) == ascending {
599            arr.swap(i, i + step);
600        }
601        i += 1;
602    }
603
604    if step > 1 {
605        bitonic_merge_avx2(&mut arr[0..step], ascending);
606        bitonic_merge_avx2(&mut arr[step..], ascending);
607    }
608}
609
610/// SIMD-optimized median computation using quickselect
611pub fn median_f32_simd(arr: &mut [f32]) -> Option<f32> {
612    if arr.is_empty() {
613        return None;
614    }
615
616    let len = arr.len();
617    let mid = len / 2;
618
619    if len % 2 == 1 {
620        Some(quickselect_f32_simd(arr, mid))
621    } else {
622        let left_mid = quickselect_f32_simd(arr, mid - 1);
623        let right_mid = quickselect_f32_simd(arr, mid);
624        Some((left_mid + right_mid) / 2.0)
625    }
626}
627
628/// SIMD-optimized quickselect for k-th smallest element
629pub fn quickselect_f32_simd(arr: &mut [f32], k: usize) -> f32 {
630    assert!(k < arr.len(), "k must be less than array length");
631
632    let mut left = 0;
633    let mut right = arr.len() - 1;
634
635    loop {
636        if left == right {
637            return arr[left];
638        }
639
640        let pivot_index = partition_range(arr, left, right);
641
642        if k == pivot_index {
643            return arr[k];
644        } else if k < pivot_index {
645            right = pivot_index - 1;
646        } else {
647            left = pivot_index + 1;
648        }
649    }
650}
651
652fn partition_range(arr: &mut [f32], left: usize, right: usize) -> usize {
653    let pivot = arr[right];
654    let mut i = left;
655
656    for j in left..right {
657        if arr[j] <= pivot {
658            arr.swap(i, j);
659            i += 1;
660        }
661    }
662
663    arr.swap(i, right);
664    i
665}
666
667#[allow(non_snake_case)]
668#[cfg(all(test, not(feature = "no-std")))]
669mod tests {
670    use super::*;
671    use scirs2_core::random::prelude::*;
672
673    #[cfg(feature = "no-std")]
674    use alloc::{vec, vec::Vec};
675
676    fn is_sorted(arr: &[f32], ascending: bool) -> bool {
677        for i in 1..arr.len() {
678            if ascending && arr[i - 1] > arr[i] {
679                return false;
680            }
681            if !ascending && arr[i - 1] < arr[i] {
682                return false;
683            }
684        }
685        true
686    }
687
688    #[test]
689    fn test_quicksort_simd() {
690        let mut arr = vec![3.0, 1.0, 4.0, 1.0, 5.0, 9.0, 2.0, 6.0];
691        quicksort_f32_simd(&mut arr);
692        assert!(is_sorted(&arr, true));
693    }
694
695    #[test]
696    fn test_quicksort_random() {
697        let mut rng = thread_rng();
698        let mut arr: Vec<f32> = (0..100).map(|_| rng.random_range(0.0..100.0)).collect();
699
700        quicksort_f32_simd(&mut arr);
701        assert!(is_sorted(&arr, true));
702    }
703
704    #[test]
705    fn test_bitonic_sort_small() {
706        let mut arr = vec![4.0, 2.0, 7.0, 1.0];
707        bitonic_sort_f32_simd(&mut arr, true);
708        assert!(is_sorted(&arr, true));
709
710        let mut arr = vec![4.0, 2.0, 7.0, 1.0];
711        bitonic_sort_f32_simd(&mut arr, false);
712        assert!(is_sorted(&arr, false));
713    }
714
715    #[test]
716    fn test_bitonic_sort_power_of_2() {
717        let mut arr = vec![8.0, 4.0, 2.0, 1.0, 3.0, 6.0, 5.0, 7.0];
718        bitonic_sort_f32_simd(&mut arr, true);
719        assert!(is_sorted(&arr, true));
720    }
721
722    #[test]
723    fn test_median_odd() {
724        let mut arr = vec![3.0, 1.0, 4.0, 1.0, 5.0];
725        let median = median_f32_simd(&mut arr);
726        assert_eq!(median, Some(3.0));
727    }
728
729    #[test]
730    fn test_median_even() {
731        let mut arr = vec![3.0, 1.0, 4.0, 2.0];
732        let median = median_f32_simd(&mut arr);
733        assert_eq!(median, Some(2.5)); // (2.0 + 3.0) / 2.0
734    }
735
736    #[test]
737    fn test_quickselect() {
738        let mut arr = vec![3.0, 1.0, 4.0, 1.0, 5.0, 9.0, 2.0, 6.0];
739
740        // Find the 3rd smallest element (0-indexed)
741        let third_smallest = quickselect_f32_simd(&mut arr, 2);
742
743        // Sort to verify
744        let mut sorted = arr.clone();
745        sorted.sort_by(|a, b| a.partial_cmp(b).expect("operation should succeed"));
746        assert_eq!(third_smallest, sorted[2]);
747    }
748
749    #[test]
750    fn test_empty_median() {
751        let mut arr: Vec<f32> = vec![];
752        let median = median_f32_simd(&mut arr);
753        assert_eq!(median, None);
754    }
755
756    #[test]
757    fn test_single_element() {
758        let mut arr = vec![42.0];
759        quicksort_f32_simd(&mut arr);
760        assert_eq!(arr, vec![42.0]);
761
762        let median = median_f32_simd(&mut arr);
763        assert_eq!(median, Some(42.0));
764    }
765
766    fn multiset_eq(a: &[f32], b: &[f32]) -> bool {
767        let mut va: Vec<u32> = a.iter().map(|x| x.to_bits()).collect();
768        let mut vb: Vec<u32> = b.iter().map(|x| x.to_bits()).collect();
769        va.sort_unstable();
770        vb.sort_unstable();
771        va == vb
772    }
773
774    #[test]
775    fn test_quicksort_already_sorted() {
776        let mut arr: Vec<f32> = (0..50).map(|i| i as f32).collect();
777        let original = arr.clone();
778        quicksort_f32_simd(&mut arr);
779        assert!(is_sorted(&arr, true));
780        assert!(multiset_eq(&arr, &original));
781    }
782
783    #[test]
784    fn test_quicksort_reverse_sorted() {
785        let mut arr: Vec<f32> = (0..50).rev().map(|i| i as f32).collect();
786        let original = arr.clone();
787        quicksort_f32_simd(&mut arr);
788        assert!(is_sorted(&arr, true));
789        assert!(multiset_eq(&arr, &original));
790    }
791
792    #[test]
793    fn test_quicksort_all_equal() {
794        let mut arr = vec![7.0f32; 100];
795        let original = arr.clone();
796        quicksort_f32_simd(&mut arr);
797        assert!(is_sorted(&arr, true));
798        assert!(multiset_eq(&arr, &original));
799    }
800
801    #[test]
802    fn test_quicksort_heavy_duplicates() {
803        let mut rng = thread_rng();
804        // Only 3 distinct values among 200 elements: lots of ties in the partition
805        let mut arr: Vec<f32> = (0..200)
806            .map(|_| [1.0f32, 2.0, 3.0][rng.random_range(0usize..3)])
807            .collect();
808        let original = arr.clone();
809        quicksort_f32_simd(&mut arr);
810        assert!(is_sorted(&arr, true));
811        assert!(multiset_eq(&arr, &original));
812    }
813
814    #[test]
815    fn test_quicksort_non_multiple_of_8() {
816        let mut rng = thread_rng();
817        // Sizes that don't align to the 8-lane AVX2 width
818        for size in [17usize, 23, 31, 41, 97, 103] {
819            let mut arr: Vec<f32> = (0..size)
820                .map(|_| rng.random_range(0.0f32..1000.0))
821                .collect();
822            let original = arr.clone();
823            quicksort_f32_simd(&mut arr);
824            assert!(is_sorted(&arr, true), "size {size} not sorted");
825            assert!(multiset_eq(&arr, &original), "size {size} multiset changed");
826        }
827    }
828
829    #[test]
830    fn test_quicksort_large() {
831        let mut rng = thread_rng();
832        let mut arr: Vec<f32> = (0..1000)
833            .map(|_| rng.random_range(0.0f32..10000.0))
834            .collect();
835        let original = arr.clone();
836        quicksort_f32_simd(&mut arr);
837        assert!(is_sorted(&arr, true));
838        assert!(multiset_eq(&arr, &original));
839    }
840}