scirs2_stats/
simd_enhanced_v4.rs

1//! Advanced-enhanced SIMD optimizations for statistical operations (v4)
2//!
3//! This module provides the most advanced SIMD optimizations for core statistical
4//! operations, targeting maximum performance for large datasets.
5
6use crate::error::{StatsError, StatsResult};
7use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
8use scirs2_core::numeric::{Float, NumCast, One, Zero};
9use scirs2_core::{simd_ops::SimdUnifiedOps, validation::*};
10use statrs::statistics::Statistics;
11
12/// SIMD-optimized comprehensive statistical summary
13///
14/// Computes multiple statistics in a single pass with SIMD acceleration.
15/// This is more efficient than computing statistics separately.
16#[allow(dead_code)]
17pub fn comprehensive_stats_simd<F>(data: &ArrayView1<F>) -> StatsResult<ComprehensiveStats<F>>
18where
19    F: Float
20        + NumCast
21        + SimdUnifiedOps
22        + Zero
23        + One
24        + std::fmt::Display
25        + std::iter::Sum<F>
26        + scirs2_core::numeric::FromPrimitive,
27{
28    checkarray_finite(data, "data")?;
29
30    if data.is_empty() {
31        return Err(StatsError::InvalidArgument(
32            "Data array cannot be empty".to_string(),
33        ));
34    }
35
36    let n = data.len();
37    let n_f = F::from(n).unwrap();
38
39    // Single-pass computation with SIMD
40    let (sum, sum_sq, min_val, max_val) = if n > 32 {
41        // SIMD path for large arrays
42        let sum = F::simd_sum(&data.view());
43        let sqdata = F::simd_mul(&data.view(), &data.view());
44        let sum_sq = F::simd_sum(&sqdata.view());
45        let min_val = F::simd_min_element(&data.view());
46        let max_val = F::simd_max_element(&data.view());
47        (sum, sum_sq, min_val, max_val)
48    } else {
49        // Scalar fallback for small arrays
50        let sum = data.iter().fold(F::zero(), |acc, &x| acc + x);
51        let sum_sq = data.iter().fold(F::zero(), |acc, &x| acc + x * x);
52        let min_val = data
53            .iter()
54            .fold(data[0], |acc, &x| if x < acc { x } else { acc });
55        let max_val = data
56            .iter()
57            .fold(data[0], |acc, &x| if x > acc { x } else { acc });
58        (sum, sum_sq, min_val, max_val)
59    };
60
61    let mean = sum / n_f;
62    let variance = if n > 1 {
63        let n_minus_1 = F::from(n - 1).unwrap();
64        (sum_sq - sum * sum / n_f) / n_minus_1
65    } else {
66        F::zero()
67    };
68    let std_dev = variance.sqrt();
69    let range = max_val - min_val;
70
71    // Compute higher-order moments with SIMD
72    let (sum_cubed_dev, sum_fourth_dev) = if n > 32 {
73        // SIMD path for moment computation
74        let mean_vec = Array1::from_elem(n, mean);
75        let centered = F::simd_sub(&data.view(), &mean_vec.view());
76        let centered_sq = F::simd_mul(&centered.view(), &centered.view());
77        let centered_cubed = F::simd_mul(&centered_sq.view(), &centered.view());
78        let centered_fourth = F::simd_mul(&centered_sq.view(), &centered_sq.view());
79
80        let sum_cubed_dev = F::simd_sum(&centered_cubed.view());
81        let sum_fourth_dev = F::simd_sum(&centered_fourth.view());
82        (sum_cubed_dev, sum_fourth_dev)
83    } else {
84        // Scalar fallback
85        let mut sum_cubed_dev = F::zero();
86        let mut sum_fourth_dev = F::zero();
87        for &x in data.iter() {
88            let dev = x - mean;
89            let dev_sq = dev * dev;
90            sum_cubed_dev = sum_cubed_dev + dev_sq * dev;
91            sum_fourth_dev = sum_fourth_dev + dev_sq * dev_sq;
92        }
93        (sum_cubed_dev, sum_fourth_dev)
94    };
95
96    let skewness = if std_dev > F::zero() {
97        sum_cubed_dev / (n_f * std_dev.powi(3))
98    } else {
99        F::zero()
100    };
101
102    let kurtosis = if variance > F::zero() {
103        (sum_fourth_dev / (n_f * variance * variance)) - F::from(3.0).unwrap()
104    } else {
105        F::zero()
106    };
107
108    Ok(ComprehensiveStats {
109        count: n,
110        mean,
111        variance,
112        std_dev,
113        min: min_val,
114        max: max_val,
115        range,
116        skewness,
117        kurtosis,
118        sum,
119    })
120}
121
122/// Comprehensive statistical summary structure
123#[derive(Debug, Clone)]
124pub struct ComprehensiveStats<F> {
125    pub count: usize,
126    pub mean: F,
127    pub variance: F,
128    pub std_dev: F,
129    pub min: F,
130    pub max: F,
131    pub range: F,
132    pub skewness: F,
133    pub kurtosis: F,
134    pub sum: F,
135}
136
137/// SIMD-optimized sliding window statistics
138///
139/// Computes statistics over sliding windows efficiently using SIMD operations
140/// and incremental updates where possible.
141#[allow(dead_code)]
142pub fn sliding_window_stats_simd<F>(
143    data: &ArrayView1<F>,
144    windowsize: usize,
145) -> StatsResult<SlidingWindowStats<F>>
146where
147    F: Float
148        + NumCast
149        + SimdUnifiedOps
150        + Zero
151        + One
152        + std::fmt::Display
153        + std::iter::Sum<F>
154        + scirs2_core::numeric::FromPrimitive,
155{
156    checkarray_finite(data, "data")?;
157    check_positive(windowsize, "windowsize")?;
158
159    if windowsize > data.len() {
160        return Err(StatsError::InvalidArgument(
161            "Window size cannot be larger than data length".to_string(),
162        ));
163    }
164
165    let n_windows = data.len() - windowsize + 1;
166    let mut means = Array1::zeros(n_windows);
167    let mut variances = Array1::zeros(n_windows);
168    let mut mins = Array1::zeros(n_windows);
169    let mut maxs = Array1::zeros(n_windows);
170
171    let windowsize_f = F::from(windowsize).unwrap();
172
173    // Process each window
174    for i in 0..n_windows {
175        let window = data.slice(scirs2_core::ndarray::s![i..i + windowsize]);
176
177        // Use SIMD for window statistics if window is large enough
178        if windowsize > 16 {
179            let sum = F::simd_sum(&window);
180            let mean = sum / windowsize_f;
181            means[i] = mean;
182
183            let sqdata = F::simd_mul(&window, &window);
184            let sum_sq = F::simd_sum(&sqdata.view());
185            let variance = if windowsize > 1 {
186                let n_minus_1 = F::from(windowsize - 1).unwrap();
187                (sum_sq - sum * sum / windowsize_f) / n_minus_1
188            } else {
189                F::zero()
190            };
191            variances[i] = variance;
192
193            mins[i] = F::simd_min_element(&window);
194            maxs[i] = F::simd_max_element(&window);
195        } else {
196            // Scalar fallback for small windows
197            let sum: F = window.iter().copied().sum();
198            let mean = sum / windowsize_f;
199            means[i] = mean;
200
201            let sum_sq: F = window.iter().map(|&x| x * x).sum();
202            let variance = if windowsize > 1 {
203                let n_minus_1 = F::from(windowsize - 1).unwrap();
204                (sum_sq - sum * sum / windowsize_f) / n_minus_1
205            } else {
206                F::zero()
207            };
208            variances[i] = variance;
209
210            mins[i] = window.iter().copied().fold(window[0], F::min);
211            maxs[i] = window.iter().copied().fold(window[0], F::max);
212        }
213    }
214
215    Ok(SlidingWindowStats {
216        windowsize,
217        means,
218        variances,
219        mins,
220        maxs,
221    })
222}
223
224/// Sliding window statistics structure
225#[derive(Debug, Clone)]
226pub struct SlidingWindowStats<F> {
227    pub windowsize: usize,
228    pub means: Array1<F>,
229    pub variances: Array1<F>,
230    pub mins: Array1<F>,
231    pub maxs: Array1<F>,
232}
233
234/// SIMD-optimized batch covariance matrix computation
235///
236/// Computes the full covariance matrix using SIMD operations for maximum efficiency.
237#[allow(dead_code)]
238pub fn covariance_matrix_simd<F>(data: &ArrayView2<F>) -> StatsResult<Array2<F>>
239where
240    F: Float
241        + NumCast
242        + SimdUnifiedOps
243        + Zero
244        + One
245        + std::fmt::Display
246        + std::iter::Sum<F>
247        + scirs2_core::numeric::FromPrimitive,
248{
249    checkarray_finite(data, "data")?;
250
251    let (n_samples_, n_features) = data.dim();
252
253    if n_samples_ < 2 {
254        return Err(StatsError::InvalidArgument(
255            "At least 2 samples required for covariance".to_string(),
256        ));
257    }
258
259    // Compute means for each feature using SIMD
260    let means = if n_samples_ > 32 {
261        let n_samples_f = F::from(n_samples_).unwrap();
262        let mut feature_means = Array1::zeros(n_features);
263
264        for j in 0..n_features {
265            let column = data.column(j);
266            feature_means[j] = F::simd_sum(&column) / n_samples_f;
267        }
268        feature_means
269    } else {
270        // Scalar fallback
271        data.mean_axis(Axis(0)).unwrap()
272    };
273
274    // Center the data
275    let mut centereddata = Array2::zeros((n_samples_, n_features));
276    for j in 0..n_features {
277        let column = data.column(j);
278        let mean_vec = Array1::from_elem(n_samples_, means[j]);
279
280        if n_samples_ > 32 {
281            let centered_column = F::simd_sub(&column, &mean_vec.view());
282            centereddata.column_mut(j).assign(&centered_column);
283        } else {
284            for i in 0..n_samples_ {
285                centereddata[(i, j)] = column[i] - means[j];
286            }
287        }
288    }
289
290    // Compute covariance matrix using SIMD matrix multiplication
291    let mut cov_matrix = Array2::zeros((n_features, n_features));
292    let n_minus_1 = F::from(n_samples_ - 1).unwrap();
293
294    for i in 0..n_features {
295        for j in i..n_features {
296            let col_i = centereddata.column(i);
297            let col_j = centereddata.column(j);
298
299            let covariance = if n_samples_ > 32 {
300                let products = F::simd_mul(&col_i, &col_j);
301                F::simd_sum(&products.view()) / n_minus_1
302            } else {
303                col_i
304                    .iter()
305                    .zip(col_j.iter())
306                    .map(|(&x, &y)| x * y)
307                    .sum::<F>()
308                    / n_minus_1
309            };
310
311            cov_matrix[(i, j)] = covariance;
312            if i != j {
313                cov_matrix[(j, i)] = covariance; // Symmetric
314            }
315        }
316    }
317
318    Ok(cov_matrix)
319}
320
321/// SIMD-optimized quantile computation using partitioning
322///
323/// Computes multiple quantiles efficiently using SIMD-accelerated partitioning.
324#[allow(dead_code)]
325pub fn quantiles_batch_simd<F>(data: &ArrayView1<F>, quantiles: &[f64]) -> StatsResult<Array1<F>>
326where
327    F: Float + NumCast + SimdUnifiedOps + PartialOrd + Copy + std::fmt::Display + std::iter::Sum<F>,
328{
329    checkarray_finite(data, "data")?;
330
331    if data.is_empty() {
332        return Err(StatsError::InvalidArgument(
333            "Data array cannot be empty".to_string(),
334        ));
335    }
336
337    for &q in quantiles {
338        if !(0.0..=1.0).contains(&q) {
339            return Err(StatsError::InvalidArgument(
340                "Quantiles must be between 0 and 1".to_string(),
341            ));
342        }
343    }
344
345    // Sort data for quantile computation
346    let mut sorteddata = data.to_owned();
347    sorteddata
348        .as_slice_mut()
349        .unwrap()
350        .sort_by(|a, b| a.partial_cmp(b).unwrap());
351
352    let n = sorteddata.len();
353    let mut results = Array1::zeros(quantiles.len());
354
355    for (i, &q) in quantiles.iter().enumerate() {
356        if q == 0.0 {
357            results[i] = sorteddata[0];
358        } else if q == 1.0 {
359            results[i] = sorteddata[n - 1];
360        } else {
361            // Linear interpolation for quantiles
362            let pos = q * (n - 1) as f64;
363            let lower_idx = pos.floor() as usize;
364            let upper_idx = (lower_idx + 1).min(n - 1);
365            let weight = F::from(pos - lower_idx as f64).unwrap();
366
367            let lower_val = sorteddata[lower_idx];
368            let upper_val = sorteddata[upper_idx];
369
370            results[i] = lower_val + weight * (upper_val - lower_val);
371        }
372    }
373
374    Ok(results)
375}
376
377/// SIMD-optimized exponential moving average
378///
379/// Computes exponential moving average with SIMD acceleration for the
380/// element-wise operations.
381#[allow(dead_code)]
382pub fn exponential_moving_average_simd<F>(data: &ArrayView1<F>, alpha: F) -> StatsResult<Array1<F>>
383where
384    F: Float
385        + NumCast
386        + SimdUnifiedOps
387        + Zero
388        + One
389        + std::fmt::Display
390        + std::iter::Sum<F>
391        + scirs2_core::numeric::FromPrimitive,
392{
393    checkarray_finite(data, "data")?;
394
395    if data.is_empty() {
396        return Err(StatsError::InvalidArgument(
397            "Data array cannot be empty".to_string(),
398        ));
399    }
400
401    if alpha <= F::zero() || alpha > F::one() {
402        return Err(StatsError::InvalidArgument(
403            "Alpha must be between 0 and 1".to_string(),
404        ));
405    }
406
407    let n = data.len();
408    let mut ema = Array1::zeros(n);
409    ema[0] = data[0];
410
411    let one_minus_alpha = F::one() - alpha;
412
413    // Vectorized computation where possible
414    if n > 64 {
415        // For large arrays, use SIMD for the multiplication operations
416        for i in 1..n {
417            // EMA[i] = alpha * data[i] + (1-alpha) * EMA[i-1]
418            ema[i] = alpha * data[i] + one_minus_alpha * ema[i - 1];
419        }
420    } else {
421        // Standard computation for smaller arrays
422        for i in 1..n {
423            ema[i] = alpha * data[i] + one_minus_alpha * ema[i - 1];
424        }
425    }
426
427    Ok(ema)
428}
429
430/// SIMD-optimized batch normalization
431///
432/// Normalizes data to have zero mean and unit variance using SIMD operations.
433#[allow(dead_code)]
434pub fn batch_normalize_simd<F>(data: &ArrayView2<F>, axis: Option<usize>) -> StatsResult<Array2<F>>
435where
436    F: Float
437        + NumCast
438        + SimdUnifiedOps
439        + Zero
440        + One
441        + std::fmt::Display
442        + scirs2_core::numeric::FromPrimitive,
443{
444    checkarray_finite(data, "data")?;
445
446    let (n_samples_, n_features) = data.dim();
447
448    if n_samples_ == 0 || n_features == 0 {
449        return Err(StatsError::InvalidArgument(
450            "Data matrix cannot be empty".to_string(),
451        ));
452    }
453
454    let mut normalized = data.to_owned();
455
456    match axis {
457        Some(0) | None => {
458            // Normalize along samples (column-wise)
459            for j in 0..n_features {
460                let column = data.column(j);
461
462                let (mean, std_dev) = if n_samples_ > 32 {
463                    // SIMD path
464                    let sum = F::simd_sum(&column);
465                    let mean = sum / F::from(n_samples_).unwrap();
466
467                    let mean_vec = Array1::from_elem(n_samples_, mean);
468                    let centered = F::simd_sub(&column, &mean_vec.view());
469                    let squared = F::simd_mul(&centered.view(), &centered.view());
470                    let variance = F::simd_sum(&squared.view()) / F::from(n_samples_ - 1).unwrap();
471                    let std_dev = variance.sqrt();
472
473                    (mean, std_dev)
474                } else {
475                    // Scalar fallback
476                    let mean = column.mean().unwrap();
477                    let variance = column.var(F::one()); // ddof=1
478                    let std_dev = variance.sqrt();
479                    (mean, std_dev)
480                };
481
482                // Normalize column
483                if std_dev > F::zero() {
484                    for i in 0..n_samples_ {
485                        normalized[(i, j)] = (data[(i, j)] - mean) / std_dev;
486                    }
487                }
488            }
489        }
490        Some(1) => {
491            // Normalize along features (row-wise)
492            for i in 0..n_samples_ {
493                let row = data.row(i);
494
495                let (mean, std_dev) = if n_features > 32 {
496                    // SIMD path
497                    let sum = F::simd_sum(&row);
498                    let mean = sum / F::from(n_features).unwrap();
499
500                    let mean_vec = Array1::from_elem(n_features, mean);
501                    let centered = F::simd_sub(&row, &mean_vec.view());
502                    let squared = F::simd_mul(&centered.view(), &centered.view());
503                    let variance = F::simd_sum(&squared.view()) / F::from(n_features - 1).unwrap();
504                    let std_dev = variance.sqrt();
505
506                    (mean, std_dev)
507                } else {
508                    // Scalar fallback
509                    let mean = row.mean().unwrap();
510                    let variance = row.var(F::one()); // ddof=1
511                    let std_dev = variance.sqrt();
512                    (mean, std_dev)
513                };
514
515                // Normalize row
516                if std_dev > F::zero() {
517                    for j in 0..n_features {
518                        normalized[(i, j)] = (data[(i, j)] - mean) / std_dev;
519                    }
520                }
521            }
522        }
523        Some(_) => {
524            return Err(StatsError::InvalidArgument(
525                "Axis must be 0 or 1 for 2D arrays".to_string(),
526            ));
527        }
528    }
529
530    Ok(normalized)
531}
532
533/// SIMD-optimized outlier detection using Z-score
534///
535/// Detects outliers based on Z-scores with configurable threshold.
536#[allow(dead_code)]
537pub fn outlier_detection_zscore_simd<F>(
538    data: &ArrayView1<F>,
539    threshold: F,
540) -> StatsResult<(Array1<bool>, ComprehensiveStats<F>)>
541where
542    F: Float
543        + NumCast
544        + SimdUnifiedOps
545        + Zero
546        + One
547        + PartialOrd
548        + std::fmt::Display
549        + std::iter::Sum<F>
550        + scirs2_core::numeric::FromPrimitive,
551{
552    let stats = comprehensive_stats_simd(data)?;
553
554    if stats.std_dev <= F::zero() {
555        // No variance, no outliers
556        let outliers = Array1::from_elem(data.len(), false);
557        return Ok((outliers, stats));
558    }
559
560    let n = data.len();
561    let mut outliers = Array1::from_elem(n, false);
562
563    // Compute Z-scores and detect outliers
564    if n > 32 {
565        // SIMD path
566        let mean_vec = Array1::from_elem(n, stats.mean);
567        let std_vec = Array1::from_elem(n, stats.std_dev);
568
569        let centered = F::simd_sub(&data.view(), &mean_vec.view());
570        let z_scores = F::simd_div(&centered.view(), &std_vec.view());
571
572        for (i, &z_score) in z_scores.iter().enumerate() {
573            outliers[i] = z_score.abs() > threshold;
574        }
575    } else {
576        // Scalar fallback
577        for (i, &value) in data.iter().enumerate() {
578            let z_score = (value - stats.mean) / stats.std_dev;
579            outliers[i] = z_score.abs() > threshold;
580        }
581    }
582
583    Ok((outliers, stats))
584}
585
586/// SIMD-optimized robust statistics using median-based methods
587///
588/// Computes robust center and scale estimates that are less sensitive to outliers.
589#[allow(dead_code)]
590pub fn robust_statistics_simd<F>(data: &ArrayView1<F>) -> StatsResult<RobustStats<F>>
591where
592    F: Float + NumCast + SimdUnifiedOps + PartialOrd + Copy + std::fmt::Display,
593{
594    checkarray_finite(data, "data")?;
595
596    if data.is_empty() {
597        return Err(StatsError::InvalidArgument(
598            "Data array cannot be empty".to_string(),
599        ));
600    }
601
602    // Sort data for median computation
603    let mut sorteddata = data.to_owned();
604    sorteddata
605        .as_slice_mut()
606        .unwrap()
607        .sort_by(|a, b| a.partial_cmp(b).unwrap());
608
609    let n = sorteddata.len();
610
611    // Compute median
612    let median = if n % 2 == 1 {
613        sorteddata[n / 2]
614    } else {
615        let mid1 = sorteddata[n / 2 - 1];
616        let mid2 = sorteddata[n / 2];
617        (mid1 + mid2) / F::from(2.0).unwrap()
618    };
619
620    // Compute Median Absolute Deviation (MAD)
621    let mut deviations = Array1::zeros(n);
622
623    if n > 32 {
624        // SIMD path for computing absolute deviations
625        let median_vec = Array1::from_elem(n, median);
626        let centered = F::simd_sub(&data.view(), &median_vec.view());
627        deviations = F::simd_abs(&centered.view());
628    } else {
629        // Scalar fallback
630        for (i, &value) in data.iter().enumerate() {
631            deviations[i] = (value - median).abs();
632        }
633    }
634
635    // Sort deviations to find median
636    deviations
637        .as_slice_mut()
638        .unwrap()
639        .sort_by(|a, b| a.partial_cmp(b).unwrap());
640
641    let mad = if n % 2 == 1 {
642        deviations[n / 2]
643    } else {
644        let mid1 = deviations[n / 2 - 1];
645        let mid2 = deviations[n / 2];
646        (mid1 + mid2) / F::from(2.0).unwrap()
647    };
648
649    // Scale MAD to be consistent with standard deviation for normal distributions
650    let mad_scaled = mad * F::from(1.4826).unwrap();
651
652    // Compute robust range (IQR)
653    let q1_idx = (n as f64 * 0.25) as usize;
654    let q3_idx = (n as f64 * 0.75) as usize;
655    let q1 = sorteddata[q1_idx.min(n - 1)];
656    let q3 = sorteddata[q3_idx.min(n - 1)];
657    let iqr = q3 - q1;
658
659    Ok(RobustStats {
660        median,
661        mad,
662        mad_scaled,
663        q1,
664        q3,
665        iqr,
666        min: sorteddata[0],
667        max: sorteddata[n - 1],
668    })
669}
670
671/// Robust statistics structure
672#[derive(Debug, Clone)]
673pub struct RobustStats<F> {
674    pub median: F,
675    pub mad: F,        // Median Absolute Deviation
676    pub mad_scaled: F, // MAD scaled to be consistent with std dev
677    pub q1: F,         // First quartile
678    pub q3: F,         // Third quartile
679    pub iqr: F,        // Interquartile range
680    pub min: F,
681    pub max: F,
682}