Skip to main content

scirs2_stats/
simd_enhanced_v6.rs

1//! Next-generation SIMD optimizations for statistical operations (v6)
2//!
3//! This module provides comprehensive SIMD optimizations that fully leverage
4//! scirs2-core's unified SIMD infrastructure, with advanced vectorization
5//! strategies and platform-specific optimizations.
6
7use crate::error::{StatsError, StatsResult};
8use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
9use scirs2_core::numeric::{Float, NumCast, One, Zero};
10use scirs2_core::Rng;
11use scirs2_core::{
12    simd_ops::{PlatformCapabilities, SimdUnifiedOps},
13    validation::*,
14};
15use std::marker::PhantomData;
16
17/// Advanced SIMD configuration with platform detection
18#[derive(Debug, Clone)]
19pub struct AdvancedSimdConfig {
20    /// Platform capabilities detected at runtime
21    pub capabilities: PlatformCapabilities,
22    /// Chunk size for SIMD operations (auto-determined based on platform)
23    pub chunksize: usize,
24    /// Whether to use parallel SIMD processing
25    pub parallel_enabled: bool,
26    /// Minimum data size for SIMD processing
27    pub simd_threshold: usize,
28}
29
30impl Default for AdvancedSimdConfig {
31    fn default() -> Self {
32        let capabilities = PlatformCapabilities::detect();
33        let chunksize = if capabilities.avx512_available {
34            16 // 512-bit / 32-bit = 16 elements for f32
35        } else if capabilities.avx2_available {
36            8 // 256-bit / 32-bit = 8 elements for f32
37        } else if capabilities.simd_available {
38            4 // 128-bit / 32-bit = 4 elements for f32
39        } else {
40            1 // Scalar fallback
41        };
42
43        Self {
44            capabilities,
45            chunksize,
46            parallel_enabled: true,
47            simd_threshold: 64,
48        }
49    }
50}
51
52/// Advanced-optimized SIMD statistics computer
53pub struct AdvancedSimdStatistics<F> {
54    config: AdvancedSimdConfig,
55    _phantom: PhantomData<F>,
56}
57
58impl<F> AdvancedSimdStatistics<F>
59where
60    F: Float
61        + NumCast
62        + SimdUnifiedOps
63        + Zero
64        + One
65        + PartialOrd
66        + Copy
67        + Send
68        + Sync
69        + std::fmt::Display
70        + std::iter::Sum<F>,
71{
72    /// Create new advanced-optimized SIMD statistics computer
73    pub fn new() -> Self {
74        Self {
75            config: AdvancedSimdConfig::default(),
76            _phantom: PhantomData,
77        }
78    }
79
80    /// Create with custom configuration
81    pub fn with_config(config: AdvancedSimdConfig) -> Self {
82        Self {
83            config,
84            _phantom: PhantomData,
85        }
86    }
87
88    /// Compute comprehensive statistics using advanced SIMD
89    pub fn comprehensive_stats_advanced(
90        &self,
91        data: &ArrayView1<F>,
92    ) -> StatsResult<ComprehensiveStats<F>> {
93        checkarray_finite(data, "data")?;
94
95        if data.is_empty() {
96            return Err(StatsError::InvalidArgument(
97                "Data cannot be empty".to_string(),
98            ));
99        }
100
101        let n = data.len();
102
103        // Use SIMD if data is large enough and SIMD is available
104        if n >= self.config.simd_threshold && self.config.chunksize > 1 {
105            self.compute_simd_comprehensive(data)
106        } else {
107            self.compute_scalar_comprehensive(data)
108        }
109    }
110
111    /// SIMD-optimized comprehensive statistics computation
112    fn compute_simd_comprehensive(
113        &self,
114        data: &ArrayView1<F>,
115    ) -> StatsResult<ComprehensiveStats<F>> {
116        let n = data.len();
117        let chunksize = self.config.chunksize;
118        let n_chunks = n / chunksize;
119        let remainder = n % chunksize;
120
121        // Initialize accumulators
122        let mut sum_acc = F::zero();
123        let mut sum_sq_acc = F::zero();
124        let mut sum_cube_acc = F::zero();
125        let mut sum_quad_acc = F::zero();
126        let mut min_val = F::infinity();
127        let mut max_val = F::neg_infinity();
128
129        // Process chunks with SIMD
130        for i in 0..n_chunks {
131            let start = i * chunksize;
132            let end = start + chunksize;
133            let chunk = data.slice(scirs2_core::ndarray::s![start..end]);
134
135            // Use SIMD operations from scirs2-core
136            let chunk_sum = F::simd_sum(&chunk);
137            let chunk_sq = F::simd_mul(&chunk, &chunk);
138            let chunk_sum_sq = F::simd_sum(&chunk_sq.view());
139            let chunk_cube = F::simd_mul(&chunk_sq.view(), &chunk);
140            let chunk_sum_cube = F::simd_sum(&chunk_cube.view());
141            let chunk_quad = F::simd_mul(&chunk_sq.view(), &chunk_sq.view());
142            let chunk_sum_quad = F::simd_sum(&chunk_quad.view());
143            let chunk_min = F::simd_min_element(&chunk);
144            let chunk_max = F::simd_max_element(&chunk);
145
146            sum_acc = sum_acc + chunk_sum;
147            sum_sq_acc = sum_sq_acc + chunk_sum_sq;
148            sum_cube_acc = sum_cube_acc + chunk_sum_cube;
149            sum_quad_acc = sum_quad_acc + chunk_sum_quad;
150            min_val = if chunk_min < min_val {
151                chunk_min
152            } else {
153                min_val
154            };
155            max_val = if chunk_max > max_val {
156                chunk_max
157            } else {
158                max_val
159            };
160        }
161
162        // Handle remainder with scalar operations
163        if remainder > 0 {
164            let start = n_chunks * chunksize;
165            for i in start..n {
166                let val = data[i];
167                sum_acc = sum_acc + val;
168                sum_sq_acc = sum_sq_acc + val * val;
169                sum_cube_acc = sum_cube_acc + val * val * val;
170                sum_quad_acc = sum_quad_acc + val * val * val * val;
171                min_val = if val < min_val { val } else { min_val };
172                max_val = if val > max_val { val } else { max_val };
173            }
174        }
175
176        // Compute final statistics
177        let n_f = F::from(n).expect("Failed to convert to float");
178        let mean = sum_acc / n_f;
179        let variance = (sum_sq_acc / n_f) - (mean * mean);
180        let std_dev = variance.sqrt();
181
182        // Compute higher moments
183        let m2 = sum_sq_acc / n_f - mean * mean;
184        let m3 = sum_cube_acc / n_f
185            - F::from(3).expect("Failed to convert constant to float") * mean * m2
186            - mean * mean * mean;
187        let m4 = sum_quad_acc / n_f
188            - F::from(4).expect("Failed to convert constant to float") * mean * m3
189            - F::from(6).expect("Failed to convert constant to float") * mean * mean * m2
190            - mean * mean * mean * mean;
191
192        let skewness = if m2 > F::zero() {
193            m3 / (m2 * m2.sqrt())
194        } else {
195            F::zero()
196        };
197
198        let kurtosis = if m2 > F::zero() {
199            m4 / (m2 * m2) - F::from(3).expect("Failed to convert constant to float")
200        } else {
201            F::zero()
202        };
203
204        Ok(ComprehensiveStats {
205            mean,
206            variance,
207            std_dev,
208            skewness,
209            kurtosis,
210            min: min_val,
211            max: max_val,
212            range: max_val - min_val,
213            count: n,
214        })
215    }
216
217    /// Scalar fallback for comprehensive statistics
218    fn compute_scalar_comprehensive(
219        &self,
220        data: &ArrayView1<F>,
221    ) -> StatsResult<ComprehensiveStats<F>> {
222        let n = data.len();
223        let n_f = F::from(n).expect("Failed to convert to float");
224
225        let sum: F = data.iter().copied().sum();
226        let mean = sum / n_f;
227
228        let mut sum_sq = F::zero();
229        let mut sum_cube = F::zero();
230        let mut sum_quad = F::zero();
231        let mut min_val = F::infinity();
232        let mut max_val = F::neg_infinity();
233
234        for &val in data.iter() {
235            let diff = val - mean;
236            sum_sq = sum_sq + diff * diff;
237            sum_cube = sum_cube + diff * diff * diff;
238            sum_quad = sum_quad + diff * diff * diff * diff;
239            min_val = if val < min_val { val } else { min_val };
240            max_val = if val > max_val { val } else { max_val };
241        }
242
243        let variance = sum_sq / n_f;
244        let std_dev = variance.sqrt();
245
246        let m2 = variance;
247        let m3 = sum_cube / n_f;
248        let m4 = sum_quad / n_f;
249
250        let skewness = if m2 > F::zero() {
251            m3 / (m2 * m2.sqrt())
252        } else {
253            F::zero()
254        };
255
256        let kurtosis = if m2 > F::zero() {
257            m4 / (m2 * m2) - F::from(3).expect("Failed to convert constant to float")
258        } else {
259            F::zero()
260        };
261
262        Ok(ComprehensiveStats {
263            mean,
264            variance,
265            std_dev,
266            skewness,
267            kurtosis,
268            min: min_val,
269            max: max_val,
270            range: max_val - min_val,
271            count: n,
272        })
273    }
274
275    /// Optimized SIMD-optimized matrix operations
276    pub fn matrix_stats_advanced(
277        &self,
278        matrix: &ArrayView2<F>,
279    ) -> StatsResult<MatrixStatsResult<F>> {
280        checkarray_finite(matrix, "matrix")?;
281
282        if matrix.is_empty() {
283            return Err(StatsError::InvalidArgument(
284                "Matrix cannot be empty".to_string(),
285            ));
286        }
287
288        let (rows, cols) = matrix.dim();
289
290        // Compute row-wise statistics using SIMD
291        let mut row_stats = Vec::with_capacity(rows);
292        for i in 0..rows {
293            let row = matrix.row(i);
294            let stats = self.comprehensive_stats_advanced(&row)?;
295            row_stats.push(stats);
296        }
297
298        // Compute column-wise statistics using SIMD
299        let mut col_stats = Vec::with_capacity(cols);
300        for j in 0..cols {
301            let col = matrix.column(j);
302            let stats = self.comprehensive_stats_advanced(&col)?;
303            col_stats.push(stats);
304        }
305
306        // Compute overall matrix statistics
307        let flattened = matrix.iter().copied().collect::<Array1<F>>();
308        let overall_stats = self.comprehensive_stats_advanced(&flattened.view())?;
309
310        Ok(MatrixStatsResult {
311            row_stats,
312            col_stats,
313            overall_stats,
314            shape: (rows, cols),
315        })
316    }
317
318    /// SIMD-optimized correlation matrix computation
319    pub fn correlation_matrix_advanced(&self, matrix: &ArrayView2<F>) -> StatsResult<Array2<F>> {
320        checkarray_finite(matrix, "matrix")?;
321
322        let (_n_samples_, n_features) = matrix.dim();
323
324        if n_features < 2 {
325            return Err(StatsError::InvalidArgument(
326                "At least 2 features required for correlation matrix".to_string(),
327            ));
328        }
329
330        let mut corr_matrix = Array2::zeros((n_features, n_features));
331
332        // Compute means using SIMD
333        let mut means = Array1::zeros(n_features);
334        for j in 0..n_features {
335            let col = matrix.column(j);
336            means[j] = F::simd_mean(&col);
337        }
338
339        // Compute correlation coefficients using SIMD
340        for i in 0..n_features {
341            for j in i..n_features {
342                if i == j {
343                    corr_matrix[[i, j]] = F::one();
344                } else {
345                    let col_i = matrix.column(i);
346                    let col_j = matrix.column(j);
347
348                    // Compute correlation using SIMD operations
349                    let _n = F::from(col_i.len()).expect("Operation failed");
350                    let mean_i_vec = Array1::from_elem(col_i.len(), means[i]);
351                    let mean_j_vec = Array1::from_elem(col_j.len(), means[j]);
352
353                    let dev_i = F::simd_sub(&col_i, &mean_i_vec.view());
354                    let dev_j = F::simd_sub(&col_j, &mean_j_vec.view());
355
356                    let numerator = F::simd_sum(&F::simd_mul(&dev_i.view(), &dev_j.view()).view());
357                    let sum_sq_i = F::simd_sum(&F::simd_mul(&dev_i.view(), &dev_i.view()).view());
358                    let sum_sq_j = F::simd_sum(&F::simd_mul(&dev_j.view(), &dev_j.view()).view());
359
360                    let denominator = (sum_sq_i * sum_sq_j).sqrt();
361                    let corr = if denominator > F::zero() {
362                        numerator / denominator
363                    } else {
364                        F::zero()
365                    };
366
367                    corr_matrix[[i, j]] = corr;
368                    corr_matrix[[j, i]] = corr;
369                }
370            }
371        }
372
373        Ok(corr_matrix)
374    }
375
376    /// SIMD-optimized bootstrap sampling with statistics
377    pub fn bootstrap_stats_advanced(
378        &self,
379        data: &ArrayView1<F>,
380        n_bootstrap: usize,
381        seed: Option<u64>,
382    ) -> StatsResult<BootstrapResult<F>> {
383        checkarray_finite(data, "data")?;
384        check_positive(n_bootstrap, "n_bootstrap")?;
385
386        let n = data.len();
387        let mut rng = create_rng(seed);
388
389        let mut bootstrap_means = Array1::zeros(n_bootstrap);
390        let mut bootstrap_vars = Array1::zeros(n_bootstrap);
391        let mut bootstrap_stds = Array1::zeros(n_bootstrap);
392
393        // Perform _bootstrap sampling with SIMD statistics
394        for i in 0..n_bootstrap {
395            // Generate _bootstrap sample
396            let mut bootstrap_sample = Array1::zeros(n);
397            for j in 0..n {
398                let idx = rng.random_range(0..n);
399                bootstrap_sample[j] = data[idx];
400            }
401
402            // Compute statistics using SIMD
403            let stats = self.comprehensive_stats_advanced(&bootstrap_sample.view())?;
404            bootstrap_means[i] = stats.mean;
405            bootstrap_vars[i] = stats.variance;
406            bootstrap_stds[i] = stats.std_dev;
407        }
408
409        // Compute confidence intervals
410        let mut sorted_means = bootstrap_means.to_owned();
411        sorted_means
412            .as_slice_mut()
413            .expect("Operation failed")
414            .sort_by(|a, b| a.partial_cmp(b).expect("Operation failed"));
415
416        let alpha = F::from(0.05).expect("Failed to convert constant to float"); // 95% confidence
417        let lower_idx = ((alpha / F::from(2).expect("Failed to convert constant to float"))
418            * F::from(n_bootstrap).expect("Failed to convert to float"))
419        .to_usize()
420        .expect("Operation failed");
421        let upper_idx = ((F::one()
422            - alpha / F::from(2).expect("Failed to convert constant to float"))
423            * F::from(n_bootstrap).expect("Failed to convert to float"))
424        .to_usize()
425        .expect("Operation failed");
426
427        let mean_ci = (
428            sorted_means[lower_idx],
429            sorted_means[upper_idx.min(n_bootstrap - 1)],
430        );
431
432        Ok(BootstrapResult {
433            original_stats: self.comprehensive_stats_advanced(data)?,
434            bootstrap_means,
435            bootstrap_vars,
436            bootstrap_stds,
437            mean_ci,
438            n_bootstrap,
439        })
440    }
441}
442
443/// Comprehensive statistics result
444#[derive(Debug, Clone)]
445pub struct ComprehensiveStats<F> {
446    pub mean: F,
447    pub variance: F,
448    pub std_dev: F,
449    pub skewness: F,
450    pub kurtosis: F,
451    pub min: F,
452    pub max: F,
453    pub range: F,
454    pub count: usize,
455}
456
457/// Matrix statistics result
458#[derive(Debug, Clone)]
459pub struct MatrixStatsResult<F> {
460    pub row_stats: Vec<ComprehensiveStats<F>>,
461    pub col_stats: Vec<ComprehensiveStats<F>>,
462    pub overall_stats: ComprehensiveStats<F>,
463    pub shape: (usize, usize),
464}
465
466/// Bootstrap analysis result
467#[derive(Debug, Clone)]
468pub struct BootstrapResult<F> {
469    pub original_stats: ComprehensiveStats<F>,
470    pub bootstrap_means: Array1<F>,
471    pub bootstrap_vars: Array1<F>,
472    pub bootstrap_stds: Array1<F>,
473    pub mean_ci: (F, F),
474    pub n_bootstrap: usize,
475}
476
477/// Specialized SIMD operations for advanced statistics
478pub trait AdvancedSimdOps<F>: SimdUnifiedOps
479where
480    F: Float
481        + NumCast
482        + Zero
483        + One
484        + PartialOrd
485        + Copy
486        + Send
487        + Sync
488        + std::fmt::Display
489        + std::iter::Sum<F>,
490{
491    /// SIMD-optimized sum of cubes
492    fn simd_sum_cubes(data: &ArrayView1<F>) -> F {
493        data.iter().map(|&x| x * x * x).sum()
494    }
495
496    /// SIMD-optimized sum of fourth powers
497    fn simd_sum_quads(data: &ArrayView1<F>) -> F {
498        data.iter().map(|&x| x * x * x * x).sum()
499    }
500
501    /// SIMD-optimized correlation coefficient
502    fn simd_correlation(x: &ArrayView1<F>, y: &ArrayView1<F>, mean_x: F, meany: F) -> F {
503        let n = x.len();
504        if n != y.len() {
505            return F::zero();
506        }
507
508        let _n_f = F::from(n).expect("Failed to convert to float");
509        let mut sum_xy = F::zero();
510        let mut sum_x2 = F::zero();
511        let mut sum_y2 = F::zero();
512
513        for i in 0..n {
514            let dx = x[i] - mean_x;
515            let dy = y[i] - meany;
516            sum_xy = sum_xy + dx * dy;
517            sum_x2 = sum_x2 + dx * dx;
518            sum_y2 = sum_y2 + dy * dy;
519        }
520
521        let denom = (sum_x2 * sum_y2).sqrt();
522        if denom > F::zero() {
523            sum_xy / denom
524        } else {
525            F::zero()
526        }
527    }
528}
529
530// Implement the advanced SIMD operations for supported types
531impl AdvancedSimdOps<f32> for f32 {}
532impl AdvancedSimdOps<f64> for f64 {}
533
534/// High-level convenience functions
535#[allow(dead_code)]
536pub fn advanced_mean_simd<F>(data: &ArrayView1<F>) -> StatsResult<F>
537where
538    F: Float
539        + NumCast
540        + SimdUnifiedOps
541        + Zero
542        + One
543        + PartialOrd
544        + Copy
545        + Send
546        + Sync
547        + std::fmt::Display
548        + std::iter::Sum<F>,
549{
550    let computer = AdvancedSimdStatistics::<F>::new();
551    let stats = computer.comprehensive_stats_advanced(data)?;
552    Ok(stats.mean)
553}
554
555#[allow(dead_code)]
556pub fn advanced_std_simd<F>(data: &ArrayView1<F>) -> StatsResult<F>
557where
558    F: Float
559        + NumCast
560        + SimdUnifiedOps
561        + Zero
562        + One
563        + PartialOrd
564        + Copy
565        + Send
566        + Sync
567        + std::fmt::Display
568        + std::iter::Sum<F>,
569{
570    let computer = AdvancedSimdStatistics::<F>::new();
571    let stats = computer.comprehensive_stats_advanced(data)?;
572    Ok(stats.std_dev)
573}
574
575#[allow(dead_code)]
576pub fn advanced_comprehensive_simd<F>(data: &ArrayView1<F>) -> StatsResult<ComprehensiveStats<F>>
577where
578    F: Float
579        + NumCast
580        + SimdUnifiedOps
581        + Zero
582        + One
583        + PartialOrd
584        + Copy
585        + Send
586        + Sync
587        + std::fmt::Display
588        + std::iter::Sum<F>,
589{
590    let computer = AdvancedSimdStatistics::<F>::new();
591    computer.comprehensive_stats_advanced(data)
592}
593
594/// Create RNG with optional seed
595#[allow(dead_code)]
596fn create_rng(seed: Option<u64>) -> impl Rng {
597    use scirs2_core::random::{rngs::StdRng, SeedableRng};
598    match seed {
599        Some(s) => StdRng::seed_from_u64(s),
600        None => {
601            use std::time::{SystemTime, UNIX_EPOCH};
602            let s = SystemTime::now()
603                .duration_since(UNIX_EPOCH)
604                .unwrap_or_default()
605                .as_secs();
606            StdRng::seed_from_u64(s)
607        }
608    }
609}