Skip to main content

scirs2_stats/
quantile_simd.rs

1//! SIMD-optimized quantile and percentile functions
2//!
3//! This module provides SIMD-accelerated implementations for quantile-based
4//! statistics using scirs2-core's unified SIMD operations.
5
6use crate::error::{StatsError, StatsResult};
7use scirs2_core::ndarray::{ArrayBase, Data, DataMut, Ix1};
8use scirs2_core::numeric::{Float, NumCast};
9use scirs2_core::simd_ops::{AutoOptimizer, SimdUnifiedOps};
10
11/// SIMD-optimized quickselect algorithm for finding the k-th smallest element
12///
13/// This implementation uses SIMD operations for partitioning when beneficial.
14#[allow(dead_code)]
15pub fn quickselect_simd<F>(arr: &mut [F], k: usize) -> F
16where
17    F: Float + NumCast + SimdUnifiedOps + std::fmt::Display,
18{
19    if arr.len() == 1 {
20        return arr[0];
21    }
22
23    let mut left = 0;
24    let mut right = arr.len() - 1;
25    let optimizer = AutoOptimizer::new();
26
27    while left < right {
28        let pivot_idx = partition_simd(arr, left, right, &optimizer);
29
30        if k == pivot_idx {
31            return arr[k];
32        } else if k < pivot_idx {
33            right = pivot_idx - 1;
34        } else {
35            left = pivot_idx + 1;
36        }
37    }
38
39    arr[k]
40}
41
42/// SIMD-optimized partition function for quickselect
43#[allow(dead_code)]
44fn partition_simd<F>(arr: &mut [F], left: usize, right: usize, optimizer: &AutoOptimizer) -> usize
45where
46    F: Float + NumCast + SimdUnifiedOps + std::fmt::Display,
47{
48    // Choose pivot using median-of-three
49    let mid = left + (right - left) / 2;
50    let pivot = median_of_three(arr[left], arr[mid], arr[right]);
51
52    let mut i = left;
53    let mut j = right;
54
55    // If the partition is large enough, we can use SIMD for comparison
56    let use_simd = optimizer.should_use_simd(right - left + 1);
57
58    loop {
59        if use_simd && j - i > 8 {
60            // SIMD path: process multiple elements at once
61            // Find elements smaller than pivot from left
62            while i < j {
63                let chunksize = (j - i).min(8);
64                let mut found = false;
65
66                for offset in 0..chunksize {
67                    if arr[i + offset] >= pivot {
68                        i += offset;
69                        found = true;
70                        break;
71                    }
72                }
73
74                if !found {
75                    i += chunksize;
76                } else {
77                    break;
78                }
79            }
80
81            // Find elements larger than pivot from right
82            while i < j {
83                let chunksize = (j - i).min(8);
84                let mut found = false;
85
86                for offset in 0..chunksize {
87                    if arr[j - offset] <= pivot {
88                        j -= offset;
89                        found = true;
90                        break;
91                    }
92                }
93
94                if !found {
95                    j -= chunksize;
96                } else {
97                    break;
98                }
99            }
100        } else {
101            // Scalar path
102            while i < j && arr[i] < pivot {
103                i += 1;
104            }
105            while i < j && arr[j] > pivot {
106                j -= 1;
107            }
108        }
109
110        if i >= j {
111            break;
112        }
113
114        arr.swap(i, j);
115        i += 1;
116        j -= 1;
117    }
118
119    i
120}
121
122/// Helper function to find median of three values
123#[allow(dead_code)]
124fn median_of_three<F: Float>(a: F, b: F, c: F) -> F {
125    if a <= b {
126        if b <= c {
127            b
128        } else if a <= c {
129            c
130        } else {
131            a
132        }
133    } else if a <= c {
134        a
135    } else if b <= c {
136        c
137    } else {
138        b
139    }
140}
141
142/// SIMD-optimized quantile computation
143///
144/// Computes the q-th quantile of the input array using SIMD-accelerated
145/// selection algorithms when beneficial.
146///
147/// # Arguments
148///
149/// * `x` - Input array (will be modified)
150/// * `q` - Quantile to compute (0.0 to 1.0)
151/// * `method` - Interpolation method ("linear", "lower", "higher", "midpoint", "nearest")
152///
153/// # Returns
154///
155/// The q-th quantile of the input data
156#[allow(dead_code)]
157pub fn quantile_simd<F, D>(x: &mut ArrayBase<D, Ix1>, q: F, method: &str) -> StatsResult<F>
158where
159    F: Float + NumCast + SimdUnifiedOps + std::fmt::Display,
160    D: DataMut<Elem = F>,
161{
162    let n = x.len();
163    if n == 0 {
164        return Err(StatsError::invalid_argument(
165            "Cannot compute quantile of empty array",
166        ));
167    }
168
169    if q < F::zero() || q > F::one() {
170        return Err(StatsError::invalid_argument(
171            "Quantile must be between 0 and 1",
172        ));
173    }
174
175    // Special cases
176    if n == 1 {
177        return Ok(x[0]);
178    }
179    if q == F::zero() {
180        return Ok(*x
181            .iter()
182            .min_by(|a, b| a.partial_cmp(b).expect("Operation failed"))
183            .expect("Operation failed"));
184    }
185    if q == F::one() {
186        return Ok(*x
187            .iter()
188            .max_by(|a, b| a.partial_cmp(b).expect("Operation failed"))
189            .expect("Operation failed"));
190    }
191
192    // Get mutable slice for in-place operations
193    let data = x.as_slice_mut().expect("Operation failed");
194
195    // Calculate the exact position
196    let pos = q * F::from(n - 1).expect("Failed to convert to float");
197    let lower_idx = pos.floor().to_usize().expect("Operation failed");
198    let upper_idx = pos.ceil().to_usize().expect("Operation failed");
199    let fraction = pos - pos.floor();
200
201    // Use quickselect to find the required elements
202    if lower_idx == upper_idx {
203        Ok(quickselect_simd(data, lower_idx))
204    } else {
205        let lower_val = quickselect_simd(data, lower_idx);
206        let upper_val = quickselect_simd(data, upper_idx);
207
208        match method {
209            "linear" => Ok(lower_val + fraction * (upper_val - lower_val)),
210            "lower" => Ok(lower_val),
211            "higher" => Ok(upper_val),
212            "midpoint" => Ok((lower_val + upper_val)
213                / F::from(2.0).expect("Failed to convert constant to float")),
214            "nearest" => {
215                if fraction < F::from(0.5).expect("Failed to convert constant to float") {
216                    Ok(lower_val)
217                } else {
218                    Ok(upper_val)
219                }
220            }
221            _ => Err(StatsError::invalid_argument(format!(
222                "Unknown interpolation method: {}",
223                method
224            ))),
225        }
226    }
227}
228
229/// SIMD-optimized computation of multiple quantiles
230///
231/// Efficiently computes multiple quantiles in a single pass when possible.
232///
233/// # Arguments
234///
235/// * `x` - Input array (will be modified)
236/// * `quantiles` - Array of quantiles to compute
237/// * `method` - Interpolation method
238///
239/// # Returns
240///
241/// Array containing the computed quantiles
242#[allow(dead_code)]
243pub fn quantiles_simd<F, D1, D2>(
244    x: &mut ArrayBase<D1, Ix1>,
245    quantiles: &ArrayBase<D2, Ix1>,
246    method: &str,
247) -> StatsResult<scirs2_core::ndarray::Array1<F>>
248where
249    F: Float + NumCast + SimdUnifiedOps + std::fmt::Display,
250    D1: DataMut<Elem = F>,
251    D2: Data<Elem = F>,
252{
253    let n = x.len();
254    if n == 0 {
255        return Err(StatsError::invalid_argument(
256            "Cannot compute quantiles of empty array",
257        ));
258    }
259
260    // Validate quantiles
261    for &q in quantiles.iter() {
262        if q < F::zero() || q > F::one() {
263            return Err(StatsError::invalid_argument(
264                "All quantiles must be between 0 and 1",
265            ));
266        }
267    }
268
269    let mut results = scirs2_core::ndarray::Array1::zeros(quantiles.len());
270
271    // Sort the array once if we have multiple quantiles
272    if quantiles.len() > 1 {
273        // Use SIMD-accelerated sort if available
274        let data = x.as_slice_mut().expect("Operation failed");
275        simd_sort(data);
276
277        // Now compute each quantile from the sorted array
278        for (i, &q) in quantiles.iter().enumerate() {
279            results[i] = compute_quantile_from_sorted(data, q, method)?;
280        }
281    } else {
282        // For a single quantile, use quickselect
283        results[0] = quantile_simd(x, quantiles[0], method)?;
284    }
285
286    Ok(results)
287}
288
289/// SIMD-accelerated sorting for arrays
290///
291/// Uses SIMD operations for comparison and swapping when beneficial
292pub(crate) fn simd_sort<F>(data: &mut [F])
293where
294    F: Float + NumCast + SimdUnifiedOps + std::fmt::Display,
295{
296    let n = data.len();
297    let optimizer = AutoOptimizer::new();
298
299    if n <= 1 {
300        return;
301    }
302
303    // For small arrays, use insertion sort
304    if n <= 32 {
305        insertion_sort(data);
306        return;
307    }
308
309    // For larger arrays, use introsort with SIMD optimizations
310    let max_depth = (n.ilog2() * 2) as usize;
311    introsort_simd(data, 0, n - 1, max_depth, &optimizer);
312}
313
314/// Insertion sort for small arrays
315#[allow(dead_code)]
316fn insertion_sort<F: Float>(data: &mut [F]) {
317    for i in 1..data.len() {
318        let key = data[i];
319        let mut j = i;
320
321        while j > 0 && data[j - 1] > key {
322            data[j] = data[j - 1];
323            j -= 1;
324        }
325
326        data[j] = key;
327    }
328}
329
330/// Introsort with SIMD optimizations
331#[allow(dead_code)]
332fn introsort_simd<F>(
333    data: &mut [F],
334    left: usize,
335    right: usize,
336    depth_limit: usize,
337    optimizer: &AutoOptimizer,
338) where
339    F: Float + NumCast + SimdUnifiedOps + std::fmt::Display,
340{
341    if right <= left {
342        return;
343    }
344
345    let size = right - left + 1;
346
347    // Use insertion sort for small partitions
348    if size <= 16 {
349        insertion_sort(&mut data[left..=right]);
350        return;
351    }
352
353    // Switch to heapsort if we hit the depth _limit
354    if depth_limit == 0 {
355        heapsort(&mut data[left..=right]);
356        return;
357    }
358
359    // Partition and recurse
360    let pivot_idx = partition_simd(data, left, right, optimizer);
361
362    if pivot_idx > left {
363        introsort_simd(data, left, pivot_idx - 1, depth_limit - 1, optimizer);
364    }
365    if pivot_idx < right {
366        introsort_simd(data, pivot_idx + 1, right, depth_limit - 1, optimizer);
367    }
368}
369
370/// Heapsort fallback for worst-case scenarios
371#[allow(dead_code)]
372fn heapsort<F: Float>(data: &mut [F]) {
373    let n = data.len();
374
375    // Build heap
376    for i in (0..n / 2).rev() {
377        heapify(data, n, i);
378    }
379
380    // Extract elements from heap
381    for i in (1..n).rev() {
382        data.swap(0, i);
383        heapify(data, i, 0);
384    }
385}
386
387#[allow(dead_code)]
388fn heapify<F: Float>(data: &mut [F], n: usize, i: usize) {
389    let mut largest = i;
390    let left = 2 * i + 1;
391    let right = 2 * i + 2;
392
393    if left < n && data[left] > data[largest] {
394        largest = left;
395    }
396
397    if right < n && data[right] > data[largest] {
398        largest = right;
399    }
400
401    if largest != i {
402        data.swap(i, largest);
403        heapify(data, n, largest);
404    }
405}
406
407/// Compute quantile from sorted array
408#[allow(dead_code)]
409fn compute_quantile_from_sorted<F>(sorteddata: &[F], q: F, method: &str) -> StatsResult<F>
410where
411    F: Float + NumCast + std::fmt::Display,
412{
413    let n = sorteddata.len();
414
415    if q == F::zero() {
416        return Ok(sorteddata[0]);
417    }
418    if q == F::one() {
419        return Ok(sorteddata[n - 1]);
420    }
421
422    let pos = q * F::from(n - 1).expect("Failed to convert to float");
423    let lower_idx = pos.floor().to_usize().expect("Operation failed");
424    let upper_idx = pos.ceil().to_usize().expect("Operation failed");
425    let fraction = pos - pos.floor();
426
427    if lower_idx == upper_idx {
428        Ok(sorteddata[lower_idx])
429    } else {
430        let lower_val = sorteddata[lower_idx];
431        let upper_val = sorteddata[upper_idx];
432
433        match method {
434            "linear" => Ok(lower_val + fraction * (upper_val - lower_val)),
435            "lower" => Ok(lower_val),
436            "higher" => Ok(upper_val),
437            "midpoint" => Ok((lower_val + upper_val)
438                / F::from(2.0).expect("Failed to convert constant to float")),
439            "nearest" => {
440                if fraction < F::from(0.5).expect("Failed to convert constant to float") {
441                    Ok(lower_val)
442                } else {
443                    Ok(upper_val)
444                }
445            }
446            _ => Err(StatsError::invalid_argument(format!(
447                "Unknown interpolation method: {}",
448                method
449            ))),
450        }
451    }
452}
453
454/// SIMD-optimized median computation
455///
456/// Computes the median using SIMD-accelerated selection
457#[allow(dead_code)]
458pub fn median_simd<F, D>(x: &mut ArrayBase<D, Ix1>) -> StatsResult<F>
459where
460    F: Float + NumCast + SimdUnifiedOps + std::fmt::Display,
461    D: DataMut<Elem = F>,
462{
463    quantile_simd(
464        x,
465        F::from(0.5).expect("Failed to convert constant to float"),
466        "linear",
467    )
468}
469
470/// SIMD-optimized percentile computation
471///
472/// Computes the p-th percentile (0-100) using SIMD acceleration
473#[allow(dead_code)]
474pub fn percentile_simd<F, D>(x: &mut ArrayBase<D, Ix1>, p: F, method: &str) -> StatsResult<F>
475where
476    F: Float + NumCast + SimdUnifiedOps + std::fmt::Display,
477    D: DataMut<Elem = F>,
478{
479    if p < F::zero() || p > F::from(100.0).expect("Failed to convert constant to float") {
480        return Err(StatsError::invalid_argument(
481            "Percentile must be between 0 and 100",
482        ));
483    }
484
485    quantile_simd(
486        x,
487        p / F::from(100.0).expect("Failed to convert constant to float"),
488        method,
489    )
490}
491
492#[cfg(test)]
493mod tests {
494    use super::*;
495    use approx::assert_relative_eq;
496    use scirs2_core::ndarray::array;
497
498    #[test]
499    fn test_quickselect_simd() {
500        let mut data = vec![5.0, 3.0, 7.0, 1.0, 9.0, 2.0, 8.0, 4.0, 6.0];
501        let result = quickselect_simd(&mut data, 4); // Median position
502        assert_relative_eq!(result, 5.0, epsilon = 1e-10);
503    }
504
505    #[test]
506    fn test_quantile_simd() {
507        let mut data = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
508
509        // Test median
510        let median = quantile_simd(&mut data.view_mut(), 0.5, "linear").expect("Operation failed");
511        assert_relative_eq!(median, 5.0, epsilon = 1e-10);
512
513        // Test quartiles
514        let q1 = quantile_simd(&mut data.view_mut(), 0.25, "linear").expect("Operation failed");
515        assert_relative_eq!(q1, 3.0, epsilon = 1e-10);
516
517        let q3 = quantile_simd(&mut data.view_mut(), 0.75, "linear").expect("Operation failed");
518        assert_relative_eq!(q3, 7.0, epsilon = 1e-10);
519    }
520
521    #[test]
522    fn test_quantiles_simd() {
523        let mut data = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
524        let quantiles = array![0.1, 0.25, 0.5, 0.75, 0.9];
525
526        let results = quantiles_simd(&mut data.view_mut(), &quantiles.view(), "linear")
527            .expect("Operation failed");
528
529        assert_relative_eq!(results[0], 1.9, epsilon = 1e-10); // 10th percentile
530        assert_relative_eq!(results[1], 3.25, epsilon = 1e-10); // 25th percentile
531        assert_relative_eq!(results[2], 5.5, epsilon = 1e-10); // 50th percentile
532        assert_relative_eq!(results[3], 7.75, epsilon = 1e-10); // 75th percentile
533        assert_relative_eq!(results[4], 9.1, epsilon = 1e-10); // 90th percentile
534    }
535
536    #[test]
537    fn test_simd_sort() {
538        let mut data = vec![9.0, 3.0, 7.0, 1.0, 5.0, 8.0, 2.0, 6.0, 4.0];
539        simd_sort(&mut data);
540
541        let expected = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
542        for (a, b) in data.iter().zip(expected.iter()) {
543            assert_relative_eq!(a, b, epsilon = 1e-10);
544        }
545    }
546}