scirs2_stats/
simd_enhanced_v6.rs

1//! Next-generation SIMD optimizations for statistical operations (v6)
2//!
3//! This module provides comprehensive SIMD optimizations that fully leverage
4//! scirs2-core's unified SIMD infrastructure, with advanced vectorization
5//! strategies and platform-specific optimizations.
6
7use crate::error::{StatsError, StatsResult};
8use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
9use scirs2_core::numeric::{Float, NumCast, One, Zero};
10use scirs2_core::Rng;
11use scirs2_core::{
12    simd_ops::{PlatformCapabilities, SimdUnifiedOps},
13    validation::*,
14};
15use std::marker::PhantomData;
16
17/// Advanced SIMD configuration with platform detection
18#[derive(Debug, Clone)]
19pub struct AdvancedSimdConfig {
20    /// Platform capabilities detected at runtime
21    pub capabilities: PlatformCapabilities,
22    /// Chunk size for SIMD operations (auto-determined based on platform)
23    pub chunksize: usize,
24    /// Whether to use parallel SIMD processing
25    pub parallel_enabled: bool,
26    /// Minimum data size for SIMD processing
27    pub simd_threshold: usize,
28}
29
30impl Default for AdvancedSimdConfig {
31    fn default() -> Self {
32        let capabilities = PlatformCapabilities::detect();
33        let chunksize = if capabilities.avx512_available {
34            16 // 512-bit / 32-bit = 16 elements for f32
35        } else if capabilities.avx2_available {
36            8 // 256-bit / 32-bit = 8 elements for f32
37        } else if capabilities.simd_available {
38            4 // 128-bit / 32-bit = 4 elements for f32
39        } else {
40            1 // Scalar fallback
41        };
42
43        Self {
44            capabilities,
45            chunksize,
46            parallel_enabled: true,
47            simd_threshold: 64,
48        }
49    }
50}
51
52/// Advanced-optimized SIMD statistics computer
53pub struct AdvancedSimdStatistics<F> {
54    config: AdvancedSimdConfig,
55    _phantom: PhantomData<F>,
56}
57
58impl<F> AdvancedSimdStatistics<F>
59where
60    F: Float
61        + NumCast
62        + SimdUnifiedOps
63        + Zero
64        + One
65        + PartialOrd
66        + Copy
67        + Send
68        + Sync
69        + std::fmt::Display
70        + std::iter::Sum<F>,
71{
72    /// Create new advanced-optimized SIMD statistics computer
73    pub fn new() -> Self {
74        Self {
75            config: AdvancedSimdConfig::default(),
76            _phantom: PhantomData,
77        }
78    }
79
80    /// Create with custom configuration
81    pub fn with_config(config: AdvancedSimdConfig) -> Self {
82        Self {
83            config,
84            _phantom: PhantomData,
85        }
86    }
87
88    /// Compute comprehensive statistics using advanced SIMD
89    pub fn comprehensive_stats_advanced(
90        &self,
91        data: &ArrayView1<F>,
92    ) -> StatsResult<ComprehensiveStats<F>> {
93        checkarray_finite(data, "data")?;
94
95        if data.is_empty() {
96            return Err(StatsError::InvalidArgument(
97                "Data cannot be empty".to_string(),
98            ));
99        }
100
101        let n = data.len();
102
103        // Use SIMD if data is large enough and SIMD is available
104        if n >= self.config.simd_threshold && self.config.chunksize > 1 {
105            self.compute_simd_comprehensive(data)
106        } else {
107            self.compute_scalar_comprehensive(data)
108        }
109    }
110
111    /// SIMD-optimized comprehensive statistics computation
112    fn compute_simd_comprehensive(
113        &self,
114        data: &ArrayView1<F>,
115    ) -> StatsResult<ComprehensiveStats<F>> {
116        let n = data.len();
117        let chunksize = self.config.chunksize;
118        let n_chunks = n / chunksize;
119        let remainder = n % chunksize;
120
121        // Initialize accumulators
122        let mut sum_acc = F::zero();
123        let mut sum_sq_acc = F::zero();
124        let mut sum_cube_acc = F::zero();
125        let mut sum_quad_acc = F::zero();
126        let mut min_val = F::infinity();
127        let mut max_val = F::neg_infinity();
128
129        // Process chunks with SIMD
130        for i in 0..n_chunks {
131            let start = i * chunksize;
132            let end = start + chunksize;
133            let chunk = data.slice(scirs2_core::ndarray::s![start..end]);
134
135            // Use SIMD operations from scirs2-core
136            let chunk_sum = F::simd_sum(&chunk);
137            let chunk_sq = F::simd_mul(&chunk, &chunk);
138            let chunk_sum_sq = F::simd_sum(&chunk_sq.view());
139            let chunk_cube = F::simd_mul(&chunk_sq.view(), &chunk);
140            let chunk_sum_cube = F::simd_sum(&chunk_cube.view());
141            let chunk_quad = F::simd_mul(&chunk_sq.view(), &chunk_sq.view());
142            let chunk_sum_quad = F::simd_sum(&chunk_quad.view());
143            let chunk_min = F::simd_min_element(&chunk);
144            let chunk_max = F::simd_max_element(&chunk);
145
146            sum_acc = sum_acc + chunk_sum;
147            sum_sq_acc = sum_sq_acc + chunk_sum_sq;
148            sum_cube_acc = sum_cube_acc + chunk_sum_cube;
149            sum_quad_acc = sum_quad_acc + chunk_sum_quad;
150            min_val = if chunk_min < min_val {
151                chunk_min
152            } else {
153                min_val
154            };
155            max_val = if chunk_max > max_val {
156                chunk_max
157            } else {
158                max_val
159            };
160        }
161
162        // Handle remainder with scalar operations
163        if remainder > 0 {
164            let start = n_chunks * chunksize;
165            for i in start..n {
166                let val = data[i];
167                sum_acc = sum_acc + val;
168                sum_sq_acc = sum_sq_acc + val * val;
169                sum_cube_acc = sum_cube_acc + val * val * val;
170                sum_quad_acc = sum_quad_acc + val * val * val * val;
171                min_val = if val < min_val { val } else { min_val };
172                max_val = if val > max_val { val } else { max_val };
173            }
174        }
175
176        // Compute final statistics
177        let n_f = F::from(n).unwrap();
178        let mean = sum_acc / n_f;
179        let variance = (sum_sq_acc / n_f) - (mean * mean);
180        let std_dev = variance.sqrt();
181
182        // Compute higher moments
183        let m2 = sum_sq_acc / n_f - mean * mean;
184        let m3 = sum_cube_acc / n_f - F::from(3).unwrap() * mean * m2 - mean * mean * mean;
185        let m4 = sum_quad_acc / n_f
186            - F::from(4).unwrap() * mean * m3
187            - F::from(6).unwrap() * mean * mean * m2
188            - mean * mean * mean * mean;
189
190        let skewness = if m2 > F::zero() {
191            m3 / (m2 * m2.sqrt())
192        } else {
193            F::zero()
194        };
195
196        let kurtosis = if m2 > F::zero() {
197            m4 / (m2 * m2) - F::from(3).unwrap()
198        } else {
199            F::zero()
200        };
201
202        Ok(ComprehensiveStats {
203            mean,
204            variance,
205            std_dev,
206            skewness,
207            kurtosis,
208            min: min_val,
209            max: max_val,
210            range: max_val - min_val,
211            count: n,
212        })
213    }
214
215    /// Scalar fallback for comprehensive statistics
216    fn compute_scalar_comprehensive(
217        &self,
218        data: &ArrayView1<F>,
219    ) -> StatsResult<ComprehensiveStats<F>> {
220        let n = data.len();
221        let n_f = F::from(n).unwrap();
222
223        let sum: F = data.iter().copied().sum();
224        let mean = sum / n_f;
225
226        let mut sum_sq = F::zero();
227        let mut sum_cube = F::zero();
228        let mut sum_quad = F::zero();
229        let mut min_val = F::infinity();
230        let mut max_val = F::neg_infinity();
231
232        for &val in data.iter() {
233            let diff = val - mean;
234            sum_sq = sum_sq + diff * diff;
235            sum_cube = sum_cube + diff * diff * diff;
236            sum_quad = sum_quad + diff * diff * diff * diff;
237            min_val = if val < min_val { val } else { min_val };
238            max_val = if val > max_val { val } else { max_val };
239        }
240
241        let variance = sum_sq / n_f;
242        let std_dev = variance.sqrt();
243
244        let m2 = variance;
245        let m3 = sum_cube / n_f;
246        let m4 = sum_quad / n_f;
247
248        let skewness = if m2 > F::zero() {
249            m3 / (m2 * m2.sqrt())
250        } else {
251            F::zero()
252        };
253
254        let kurtosis = if m2 > F::zero() {
255            m4 / (m2 * m2) - F::from(3).unwrap()
256        } else {
257            F::zero()
258        };
259
260        Ok(ComprehensiveStats {
261            mean,
262            variance,
263            std_dev,
264            skewness,
265            kurtosis,
266            min: min_val,
267            max: max_val,
268            range: max_val - min_val,
269            count: n,
270        })
271    }
272
273    /// Optimized SIMD-optimized matrix operations
274    pub fn matrix_stats_advanced(
275        &self,
276        matrix: &ArrayView2<F>,
277    ) -> StatsResult<MatrixStatsResult<F>> {
278        checkarray_finite(matrix, "matrix")?;
279
280        if matrix.is_empty() {
281            return Err(StatsError::InvalidArgument(
282                "Matrix cannot be empty".to_string(),
283            ));
284        }
285
286        let (rows, cols) = matrix.dim();
287
288        // Compute row-wise statistics using SIMD
289        let mut row_stats = Vec::with_capacity(rows);
290        for i in 0..rows {
291            let row = matrix.row(i);
292            let stats = self.comprehensive_stats_advanced(&row)?;
293            row_stats.push(stats);
294        }
295
296        // Compute column-wise statistics using SIMD
297        let mut col_stats = Vec::with_capacity(cols);
298        for j in 0..cols {
299            let col = matrix.column(j);
300            let stats = self.comprehensive_stats_advanced(&col)?;
301            col_stats.push(stats);
302        }
303
304        // Compute overall matrix statistics
305        let flattened = matrix.iter().copied().collect::<Array1<F>>();
306        let overall_stats = self.comprehensive_stats_advanced(&flattened.view())?;
307
308        Ok(MatrixStatsResult {
309            row_stats,
310            col_stats,
311            overall_stats,
312            shape: (rows, cols),
313        })
314    }
315
316    /// SIMD-optimized correlation matrix computation
317    pub fn correlation_matrix_advanced(&self, matrix: &ArrayView2<F>) -> StatsResult<Array2<F>> {
318        checkarray_finite(matrix, "matrix")?;
319
320        let (_n_samples_, n_features) = matrix.dim();
321
322        if n_features < 2 {
323            return Err(StatsError::InvalidArgument(
324                "At least 2 features required for correlation matrix".to_string(),
325            ));
326        }
327
328        let mut corr_matrix = Array2::zeros((n_features, n_features));
329
330        // Compute means using SIMD
331        let mut means = Array1::zeros(n_features);
332        for j in 0..n_features {
333            let col = matrix.column(j);
334            means[j] = F::simd_mean(&col);
335        }
336
337        // Compute correlation coefficients using SIMD
338        for i in 0..n_features {
339            for j in i..n_features {
340                if i == j {
341                    corr_matrix[[i, j]] = F::one();
342                } else {
343                    let col_i = matrix.column(i);
344                    let col_j = matrix.column(j);
345
346                    // Compute correlation using SIMD operations
347                    let _n = F::from(col_i.len()).unwrap();
348                    let mean_i_vec = Array1::from_elem(col_i.len(), means[i]);
349                    let mean_j_vec = Array1::from_elem(col_j.len(), means[j]);
350
351                    let dev_i = F::simd_sub(&col_i, &mean_i_vec.view());
352                    let dev_j = F::simd_sub(&col_j, &mean_j_vec.view());
353
354                    let numerator = F::simd_sum(&F::simd_mul(&dev_i.view(), &dev_j.view()).view());
355                    let sum_sq_i = F::simd_sum(&F::simd_mul(&dev_i.view(), &dev_i.view()).view());
356                    let sum_sq_j = F::simd_sum(&F::simd_mul(&dev_j.view(), &dev_j.view()).view());
357
358                    let denominator = (sum_sq_i * sum_sq_j).sqrt();
359                    let corr = if denominator > F::zero() {
360                        numerator / denominator
361                    } else {
362                        F::zero()
363                    };
364
365                    corr_matrix[[i, j]] = corr;
366                    corr_matrix[[j, i]] = corr;
367                }
368            }
369        }
370
371        Ok(corr_matrix)
372    }
373
374    /// SIMD-optimized bootstrap sampling with statistics
375    pub fn bootstrap_stats_advanced(
376        &self,
377        data: &ArrayView1<F>,
378        n_bootstrap: usize,
379        seed: Option<u64>,
380    ) -> StatsResult<BootstrapResult<F>> {
381        checkarray_finite(data, "data")?;
382        check_positive(n_bootstrap, "n_bootstrap")?;
383
384        let n = data.len();
385        let mut rng = create_rng(seed);
386
387        let mut bootstrap_means = Array1::zeros(n_bootstrap);
388        let mut bootstrap_vars = Array1::zeros(n_bootstrap);
389        let mut bootstrap_stds = Array1::zeros(n_bootstrap);
390
391        // Perform _bootstrap sampling with SIMD statistics
392        for i in 0..n_bootstrap {
393            // Generate _bootstrap sample
394            let mut bootstrap_sample = Array1::zeros(n);
395            for j in 0..n {
396                let idx = rng.gen_range(0..n);
397                bootstrap_sample[j] = data[idx];
398            }
399
400            // Compute statistics using SIMD
401            let stats = self.comprehensive_stats_advanced(&bootstrap_sample.view())?;
402            bootstrap_means[i] = stats.mean;
403            bootstrap_vars[i] = stats.variance;
404            bootstrap_stds[i] = stats.std_dev;
405        }
406
407        // Compute confidence intervals
408        let mut sorted_means = bootstrap_means.to_owned();
409        sorted_means
410            .as_slice_mut()
411            .unwrap()
412            .sort_by(|a, b| a.partial_cmp(b).unwrap());
413
414        let alpha = F::from(0.05).unwrap(); // 95% confidence
415        let lower_idx = ((alpha / F::from(2).unwrap()) * F::from(n_bootstrap).unwrap())
416            .to_usize()
417            .unwrap();
418        let upper_idx = ((F::one() - alpha / F::from(2).unwrap()) * F::from(n_bootstrap).unwrap())
419            .to_usize()
420            .unwrap();
421
422        let mean_ci = (
423            sorted_means[lower_idx],
424            sorted_means[upper_idx.min(n_bootstrap - 1)],
425        );
426
427        Ok(BootstrapResult {
428            original_stats: self.comprehensive_stats_advanced(data)?,
429            bootstrap_means,
430            bootstrap_vars,
431            bootstrap_stds,
432            mean_ci,
433            n_bootstrap,
434        })
435    }
436}
437
438/// Comprehensive statistics result
439#[derive(Debug, Clone)]
440pub struct ComprehensiveStats<F> {
441    pub mean: F,
442    pub variance: F,
443    pub std_dev: F,
444    pub skewness: F,
445    pub kurtosis: F,
446    pub min: F,
447    pub max: F,
448    pub range: F,
449    pub count: usize,
450}
451
452/// Matrix statistics result
453#[derive(Debug, Clone)]
454pub struct MatrixStatsResult<F> {
455    pub row_stats: Vec<ComprehensiveStats<F>>,
456    pub col_stats: Vec<ComprehensiveStats<F>>,
457    pub overall_stats: ComprehensiveStats<F>,
458    pub shape: (usize, usize),
459}
460
461/// Bootstrap analysis result
462#[derive(Debug, Clone)]
463pub struct BootstrapResult<F> {
464    pub original_stats: ComprehensiveStats<F>,
465    pub bootstrap_means: Array1<F>,
466    pub bootstrap_vars: Array1<F>,
467    pub bootstrap_stds: Array1<F>,
468    pub mean_ci: (F, F),
469    pub n_bootstrap: usize,
470}
471
472/// Specialized SIMD operations for advanced statistics
473pub trait AdvancedSimdOps<F>: SimdUnifiedOps
474where
475    F: Float
476        + NumCast
477        + Zero
478        + One
479        + PartialOrd
480        + Copy
481        + Send
482        + Sync
483        + std::fmt::Display
484        + std::iter::Sum<F>,
485{
486    /// SIMD-optimized sum of cubes
487    fn simd_sum_cubes(data: &ArrayView1<F>) -> F {
488        data.iter().map(|&x| x * x * x).sum()
489    }
490
491    /// SIMD-optimized sum of fourth powers
492    fn simd_sum_quads(data: &ArrayView1<F>) -> F {
493        data.iter().map(|&x| x * x * x * x).sum()
494    }
495
496    /// SIMD-optimized correlation coefficient
497    fn simd_correlation(x: &ArrayView1<F>, y: &ArrayView1<F>, mean_x: F, meany: F) -> F {
498        let n = x.len();
499        if n != y.len() {
500            return F::zero();
501        }
502
503        let _n_f = F::from(n).unwrap();
504        let mut sum_xy = F::zero();
505        let mut sum_x2 = F::zero();
506        let mut sum_y2 = F::zero();
507
508        for i in 0..n {
509            let dx = x[i] - mean_x;
510            let dy = y[i] - meany;
511            sum_xy = sum_xy + dx * dy;
512            sum_x2 = sum_x2 + dx * dx;
513            sum_y2 = sum_y2 + dy * dy;
514        }
515
516        let denom = (sum_x2 * sum_y2).sqrt();
517        if denom > F::zero() {
518            sum_xy / denom
519        } else {
520            F::zero()
521        }
522    }
523}
524
525// Implement the advanced SIMD operations for supported types
526impl AdvancedSimdOps<f32> for f32 {}
527impl AdvancedSimdOps<f64> for f64 {}
528
529/// High-level convenience functions
530#[allow(dead_code)]
531pub fn advanced_mean_simd<F>(data: &ArrayView1<F>) -> StatsResult<F>
532where
533    F: Float
534        + NumCast
535        + SimdUnifiedOps
536        + Zero
537        + One
538        + PartialOrd
539        + Copy
540        + Send
541        + Sync
542        + std::fmt::Display
543        + std::iter::Sum<F>,
544{
545    let computer = AdvancedSimdStatistics::<F>::new();
546    let stats = computer.comprehensive_stats_advanced(data)?;
547    Ok(stats.mean)
548}
549
550#[allow(dead_code)]
551pub fn advanced_std_simd<F>(data: &ArrayView1<F>) -> StatsResult<F>
552where
553    F: Float
554        + NumCast
555        + SimdUnifiedOps
556        + Zero
557        + One
558        + PartialOrd
559        + Copy
560        + Send
561        + Sync
562        + std::fmt::Display
563        + std::iter::Sum<F>,
564{
565    let computer = AdvancedSimdStatistics::<F>::new();
566    let stats = computer.comprehensive_stats_advanced(data)?;
567    Ok(stats.std_dev)
568}
569
570#[allow(dead_code)]
571pub fn advanced_comprehensive_simd<F>(data: &ArrayView1<F>) -> StatsResult<ComprehensiveStats<F>>
572where
573    F: Float
574        + NumCast
575        + SimdUnifiedOps
576        + Zero
577        + One
578        + PartialOrd
579        + Copy
580        + Send
581        + Sync
582        + std::fmt::Display
583        + std::iter::Sum<F>,
584{
585    let computer = AdvancedSimdStatistics::<F>::new();
586    computer.comprehensive_stats_advanced(data)
587}
588
589/// Create RNG with optional seed
590#[allow(dead_code)]
591fn create_rng(seed: Option<u64>) -> impl Rng {
592    use scirs2_core::random::{rngs::StdRng, SeedableRng};
593    match seed {
594        Some(s) => StdRng::seed_from_u64(s),
595        None => {
596            use std::time::{SystemTime, UNIX_EPOCH};
597            let s = SystemTime::now()
598                .duration_since(UNIX_EPOCH)
599                .unwrap_or_default()
600                .as_secs();
601            StdRng::seed_from_u64(s)
602        }
603    }
604}