1use crate::error::{StatsError, StatsResult};
7use scirs2_core::ndarray::{ArrayBase, Data, DataMut, Ix1};
8use scirs2_core::numeric::{Float, NumCast};
9use scirs2_core::simd_ops::{AutoOptimizer, SimdUnifiedOps};
10
11#[allow(dead_code)]
15pub fn quickselect_simd<F>(arr: &mut [F], k: usize) -> F
16where
17 F: Float + NumCast + SimdUnifiedOps + std::fmt::Display,
18{
19 if arr.len() == 1 {
20 return arr[0];
21 }
22
23 let mut left = 0;
24 let mut right = arr.len() - 1;
25 let optimizer = AutoOptimizer::new();
26
27 while left < right {
28 let pivot_idx = partition_simd(arr, left, right, &optimizer);
29
30 if k == pivot_idx {
31 return arr[k];
32 } else if k < pivot_idx {
33 right = pivot_idx - 1;
34 } else {
35 left = pivot_idx + 1;
36 }
37 }
38
39 arr[k]
40}
41
42#[allow(dead_code)]
44fn partition_simd<F>(arr: &mut [F], left: usize, right: usize, optimizer: &AutoOptimizer) -> usize
45where
46 F: Float + NumCast + SimdUnifiedOps + std::fmt::Display,
47{
48 let mid = left + (right - left) / 2;
50 let pivot = median_of_three(arr[left], arr[mid], arr[right]);
51
52 let mut i = left;
53 let mut j = right;
54
55 let use_simd = optimizer.should_use_simd(right - left + 1);
57
58 loop {
59 if use_simd && j - i > 8 {
60 while i < j {
63 let chunksize = (j - i).min(8);
64 let mut found = false;
65
66 for offset in 0..chunksize {
67 if arr[i + offset] >= pivot {
68 i += offset;
69 found = true;
70 break;
71 }
72 }
73
74 if !found {
75 i += chunksize;
76 } else {
77 break;
78 }
79 }
80
81 while i < j {
83 let chunksize = (j - i).min(8);
84 let mut found = false;
85
86 for offset in 0..chunksize {
87 if arr[j - offset] <= pivot {
88 j -= offset;
89 found = true;
90 break;
91 }
92 }
93
94 if !found {
95 j -= chunksize;
96 } else {
97 break;
98 }
99 }
100 } else {
101 while i < j && arr[i] < pivot {
103 i += 1;
104 }
105 while i < j && arr[j] > pivot {
106 j -= 1;
107 }
108 }
109
110 if i >= j {
111 break;
112 }
113
114 arr.swap(i, j);
115 i += 1;
116 j -= 1;
117 }
118
119 i
120}
121
122#[allow(dead_code)]
124fn median_of_three<F: Float>(a: F, b: F, c: F) -> F {
125 if a <= b {
126 if b <= c {
127 b
128 } else if a <= c {
129 c
130 } else {
131 a
132 }
133 } else if a <= c {
134 a
135 } else if b <= c {
136 c
137 } else {
138 b
139 }
140}
141
142#[allow(dead_code)]
157pub fn quantile_simd<F, D>(x: &mut ArrayBase<D, Ix1>, q: F, method: &str) -> StatsResult<F>
158where
159 F: Float + NumCast + SimdUnifiedOps + std::fmt::Display,
160 D: DataMut<Elem = F>,
161{
162 let n = x.len();
163 if n == 0 {
164 return Err(StatsError::invalid_argument(
165 "Cannot compute quantile of empty array",
166 ));
167 }
168
169 if q < F::zero() || q > F::one() {
170 return Err(StatsError::invalid_argument(
171 "Quantile must be between 0 and 1",
172 ));
173 }
174
175 if n == 1 {
177 return Ok(x[0]);
178 }
179 if q == F::zero() {
180 return Ok(*x
181 .iter()
182 .min_by(|a, b| a.partial_cmp(b).expect("Operation failed"))
183 .expect("Operation failed"));
184 }
185 if q == F::one() {
186 return Ok(*x
187 .iter()
188 .max_by(|a, b| a.partial_cmp(b).expect("Operation failed"))
189 .expect("Operation failed"));
190 }
191
192 let data = x.as_slice_mut().expect("Operation failed");
194
195 let pos = q * F::from(n - 1).expect("Failed to convert to float");
197 let lower_idx = pos.floor().to_usize().expect("Operation failed");
198 let upper_idx = pos.ceil().to_usize().expect("Operation failed");
199 let fraction = pos - pos.floor();
200
201 if lower_idx == upper_idx {
203 Ok(quickselect_simd(data, lower_idx))
204 } else {
205 let lower_val = quickselect_simd(data, lower_idx);
206 let upper_val = quickselect_simd(data, upper_idx);
207
208 match method {
209 "linear" => Ok(lower_val + fraction * (upper_val - lower_val)),
210 "lower" => Ok(lower_val),
211 "higher" => Ok(upper_val),
212 "midpoint" => Ok((lower_val + upper_val)
213 / F::from(2.0).expect("Failed to convert constant to float")),
214 "nearest" => {
215 if fraction < F::from(0.5).expect("Failed to convert constant to float") {
216 Ok(lower_val)
217 } else {
218 Ok(upper_val)
219 }
220 }
221 _ => Err(StatsError::invalid_argument(format!(
222 "Unknown interpolation method: {}",
223 method
224 ))),
225 }
226 }
227}
228
229#[allow(dead_code)]
243pub fn quantiles_simd<F, D1, D2>(
244 x: &mut ArrayBase<D1, Ix1>,
245 quantiles: &ArrayBase<D2, Ix1>,
246 method: &str,
247) -> StatsResult<scirs2_core::ndarray::Array1<F>>
248where
249 F: Float + NumCast + SimdUnifiedOps + std::fmt::Display,
250 D1: DataMut<Elem = F>,
251 D2: Data<Elem = F>,
252{
253 let n = x.len();
254 if n == 0 {
255 return Err(StatsError::invalid_argument(
256 "Cannot compute quantiles of empty array",
257 ));
258 }
259
260 for &q in quantiles.iter() {
262 if q < F::zero() || q > F::one() {
263 return Err(StatsError::invalid_argument(
264 "All quantiles must be between 0 and 1",
265 ));
266 }
267 }
268
269 let mut results = scirs2_core::ndarray::Array1::zeros(quantiles.len());
270
271 if quantiles.len() > 1 {
273 let data = x.as_slice_mut().expect("Operation failed");
275 simd_sort(data);
276
277 for (i, &q) in quantiles.iter().enumerate() {
279 results[i] = compute_quantile_from_sorted(data, q, method)?;
280 }
281 } else {
282 results[0] = quantile_simd(x, quantiles[0], method)?;
284 }
285
286 Ok(results)
287}
288
289pub(crate) fn simd_sort<F>(data: &mut [F])
293where
294 F: Float + NumCast + SimdUnifiedOps + std::fmt::Display,
295{
296 let n = data.len();
297 let optimizer = AutoOptimizer::new();
298
299 if n <= 1 {
300 return;
301 }
302
303 if n <= 32 {
305 insertion_sort(data);
306 return;
307 }
308
309 let max_depth = (n.ilog2() * 2) as usize;
311 introsort_simd(data, 0, n - 1, max_depth, &optimizer);
312}
313
314#[allow(dead_code)]
316fn insertion_sort<F: Float>(data: &mut [F]) {
317 for i in 1..data.len() {
318 let key = data[i];
319 let mut j = i;
320
321 while j > 0 && data[j - 1] > key {
322 data[j] = data[j - 1];
323 j -= 1;
324 }
325
326 data[j] = key;
327 }
328}
329
330#[allow(dead_code)]
332fn introsort_simd<F>(
333 data: &mut [F],
334 left: usize,
335 right: usize,
336 depth_limit: usize,
337 optimizer: &AutoOptimizer,
338) where
339 F: Float + NumCast + SimdUnifiedOps + std::fmt::Display,
340{
341 if right <= left {
342 return;
343 }
344
345 let size = right - left + 1;
346
347 if size <= 16 {
349 insertion_sort(&mut data[left..=right]);
350 return;
351 }
352
353 if depth_limit == 0 {
355 heapsort(&mut data[left..=right]);
356 return;
357 }
358
359 let pivot_idx = partition_simd(data, left, right, optimizer);
361
362 if pivot_idx > left {
363 introsort_simd(data, left, pivot_idx - 1, depth_limit - 1, optimizer);
364 }
365 if pivot_idx < right {
366 introsort_simd(data, pivot_idx + 1, right, depth_limit - 1, optimizer);
367 }
368}
369
370#[allow(dead_code)]
372fn heapsort<F: Float>(data: &mut [F]) {
373 let n = data.len();
374
375 for i in (0..n / 2).rev() {
377 heapify(data, n, i);
378 }
379
380 for i in (1..n).rev() {
382 data.swap(0, i);
383 heapify(data, i, 0);
384 }
385}
386
387#[allow(dead_code)]
388fn heapify<F: Float>(data: &mut [F], n: usize, i: usize) {
389 let mut largest = i;
390 let left = 2 * i + 1;
391 let right = 2 * i + 2;
392
393 if left < n && data[left] > data[largest] {
394 largest = left;
395 }
396
397 if right < n && data[right] > data[largest] {
398 largest = right;
399 }
400
401 if largest != i {
402 data.swap(i, largest);
403 heapify(data, n, largest);
404 }
405}
406
407#[allow(dead_code)]
409fn compute_quantile_from_sorted<F>(sorteddata: &[F], q: F, method: &str) -> StatsResult<F>
410where
411 F: Float + NumCast + std::fmt::Display,
412{
413 let n = sorteddata.len();
414
415 if q == F::zero() {
416 return Ok(sorteddata[0]);
417 }
418 if q == F::one() {
419 return Ok(sorteddata[n - 1]);
420 }
421
422 let pos = q * F::from(n - 1).expect("Failed to convert to float");
423 let lower_idx = pos.floor().to_usize().expect("Operation failed");
424 let upper_idx = pos.ceil().to_usize().expect("Operation failed");
425 let fraction = pos - pos.floor();
426
427 if lower_idx == upper_idx {
428 Ok(sorteddata[lower_idx])
429 } else {
430 let lower_val = sorteddata[lower_idx];
431 let upper_val = sorteddata[upper_idx];
432
433 match method {
434 "linear" => Ok(lower_val + fraction * (upper_val - lower_val)),
435 "lower" => Ok(lower_val),
436 "higher" => Ok(upper_val),
437 "midpoint" => Ok((lower_val + upper_val)
438 / F::from(2.0).expect("Failed to convert constant to float")),
439 "nearest" => {
440 if fraction < F::from(0.5).expect("Failed to convert constant to float") {
441 Ok(lower_val)
442 } else {
443 Ok(upper_val)
444 }
445 }
446 _ => Err(StatsError::invalid_argument(format!(
447 "Unknown interpolation method: {}",
448 method
449 ))),
450 }
451 }
452}
453
454#[allow(dead_code)]
458pub fn median_simd<F, D>(x: &mut ArrayBase<D, Ix1>) -> StatsResult<F>
459where
460 F: Float + NumCast + SimdUnifiedOps + std::fmt::Display,
461 D: DataMut<Elem = F>,
462{
463 quantile_simd(
464 x,
465 F::from(0.5).expect("Failed to convert constant to float"),
466 "linear",
467 )
468}
469
470#[allow(dead_code)]
474pub fn percentile_simd<F, D>(x: &mut ArrayBase<D, Ix1>, p: F, method: &str) -> StatsResult<F>
475where
476 F: Float + NumCast + SimdUnifiedOps + std::fmt::Display,
477 D: DataMut<Elem = F>,
478{
479 if p < F::zero() || p > F::from(100.0).expect("Failed to convert constant to float") {
480 return Err(StatsError::invalid_argument(
481 "Percentile must be between 0 and 100",
482 ));
483 }
484
485 quantile_simd(
486 x,
487 p / F::from(100.0).expect("Failed to convert constant to float"),
488 method,
489 )
490}
491
492#[cfg(test)]
493mod tests {
494 use super::*;
495 use approx::assert_relative_eq;
496 use scirs2_core::ndarray::array;
497
498 #[test]
499 fn test_quickselect_simd() {
500 let mut data = vec![5.0, 3.0, 7.0, 1.0, 9.0, 2.0, 8.0, 4.0, 6.0];
501 let result = quickselect_simd(&mut data, 4); assert_relative_eq!(result, 5.0, epsilon = 1e-10);
503 }
504
505 #[test]
506 fn test_quantile_simd() {
507 let mut data = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
508
509 let median = quantile_simd(&mut data.view_mut(), 0.5, "linear").expect("Operation failed");
511 assert_relative_eq!(median, 5.0, epsilon = 1e-10);
512
513 let q1 = quantile_simd(&mut data.view_mut(), 0.25, "linear").expect("Operation failed");
515 assert_relative_eq!(q1, 3.0, epsilon = 1e-10);
516
517 let q3 = quantile_simd(&mut data.view_mut(), 0.75, "linear").expect("Operation failed");
518 assert_relative_eq!(q3, 7.0, epsilon = 1e-10);
519 }
520
521 #[test]
522 fn test_quantiles_simd() {
523 let mut data = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
524 let quantiles = array![0.1, 0.25, 0.5, 0.75, 0.9];
525
526 let results = quantiles_simd(&mut data.view_mut(), &quantiles.view(), "linear")
527 .expect("Operation failed");
528
529 assert_relative_eq!(results[0], 1.9, epsilon = 1e-10); assert_relative_eq!(results[1], 3.25, epsilon = 1e-10); assert_relative_eq!(results[2], 5.5, epsilon = 1e-10); assert_relative_eq!(results[3], 7.75, epsilon = 1e-10); assert_relative_eq!(results[4], 9.1, epsilon = 1e-10); }
535
536 #[test]
537 fn test_simd_sort() {
538 let mut data = vec![9.0, 3.0, 7.0, 1.0, 5.0, 8.0, 2.0, 6.0, 4.0];
539 simd_sort(&mut data);
540
541 let expected = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
542 for (a, b) in data.iter().zip(expected.iter()) {
543 assert_relative_eq!(a, b, epsilon = 1e-10);
544 }
545 }
546}