Skip to main content

scirs2_stats/
simd_optimized_v2.rs

1//! Enhanced SIMD optimizations for v1.0.0
2//!
3//! This module provides improved SIMD implementations that:
4//! - Avoid temporary array allocations
5//! - Use efficient SIMD patterns
6//! - Provide automatic fallback to scalar code
7//! - Support multiple data types
8
9use crate::error::StatsResult;
10use scirs2_core::ndarray::{ArrayBase, Data, Ix1};
11use scirs2_core::numeric::{Float, NumCast};
12use scirs2_core::simd_ops::{PlatformCapabilities, SimdUnifiedOps};
13use scirs2_core::validation::check_not_empty;
14
15/// SIMD configuration for optimal performance
16#[derive(Debug, Clone, Copy)]
17pub struct SimdConfig {
18    /// Minimum array size for SIMD to be beneficial
19    pub minsize: usize,
20    /// Whether to use aligned loads/stores
21    pub use_aligned: bool,
22    /// Maximum unroll factor
23    pub unroll_factor: usize,
24}
25
26impl Default for SimdConfig {
27    fn default() -> Self {
28        SimdConfig::detect()
29    }
30}
31
32impl SimdConfig {
33    /// Build a `SimdConfig` by detecting CPU capabilities at runtime.
34    ///
35    /// On x86-64 this uses `is_x86_feature_detected!` (a runtime CPUID check).
36    /// On AArch64 NEON is always present so a fixed unroll-4 is used.
37    /// All other platforms fall back to unroll-1 (scalar path).
38    pub fn detect() -> Self {
39        #[cfg(target_arch = "x86_64")]
40        {
41            // Runtime detection — does NOT require -C target-cpu=native.
42            if is_x86_feature_detected!("avx512f") {
43                // AVX-512: 512-bit registers → 8×f64 per lane
44                return Self {
45                    minsize: 256,
46                    use_aligned: true,
47                    unroll_factor: 8,
48                };
49            }
50            if is_x86_feature_detected!("avx2") {
51                // AVX2: 256-bit registers → 4×f64 per lane
52                return Self {
53                    minsize: 128,
54                    use_aligned: true,
55                    unroll_factor: 4,
56                };
57            }
58            if is_x86_feature_detected!("sse4.2") {
59                // SSE4.2: 128-bit registers → 2×f64 per lane
60                return Self {
61                    minsize: 64,
62                    use_aligned: false,
63                    unroll_factor: 2,
64                };
65            }
66        }
67
68        #[cfg(target_arch = "aarch64")]
69        {
70            // NEON is mandatory on AArch64; 128-bit registers, 2×f64 per lane,
71            // but 4-wide unroll is effective due to out-of-order execution.
72            return Self {
73                minsize: 64,
74                use_aligned: false,
75                unroll_factor: 4,
76            };
77        }
78
79        // Scalar fallback for all other architectures
80        #[allow(unreachable_code)]
81        Self {
82            minsize: 32,
83            use_aligned: false,
84            unroll_factor: 1,
85        }
86    }
87}
88
89/// Optimized mean calculation using SIMD with chunked processing
90///
91/// This implementation avoids temporary arrays and processes data in chunks
92/// for better cache efficiency.
93#[allow(dead_code)]
94pub fn mean_simd_optimized<F, D>(
95    x: &ArrayBase<D, Ix1>,
96    config: Option<SimdConfig>,
97) -> StatsResult<F>
98where
99    F: Float + NumCast + SimdUnifiedOps,
100    D: Data<Elem = F>,
101{
102    // Use scirs2-core validation
103    check_not_empty(x, "x").map_err(|_| {
104        crate::error::StatsError::invalid_argument("Cannot compute mean of empty array")
105    })?;
106
107    let config = config.unwrap_or_default();
108    let n = x.len();
109
110    if n < config.minsize {
111        // Small arrays: use scalar code
112        let sum = x.iter().fold(F::zero(), |acc, &val| acc + val);
113        return Ok(sum / F::from(n).expect("Failed to convert to float"));
114    }
115
116    // For larger arrays, use chunked SIMD processing
117    let sum = chunked_simd_sum(x, &config)?;
118    Ok(sum / F::from(n).expect("Failed to convert to float"))
119}
120
121/// Optimized variance calculation using single-pass SIMD algorithm
122///
123/// Uses Welford's online algorithm adapted for SIMD processing
124#[allow(dead_code)]
125pub fn variance_simd_optimized<F, D>(
126    x: &ArrayBase<D, Ix1>,
127    ddof: usize,
128    config: Option<SimdConfig>,
129) -> StatsResult<F>
130where
131    F: Float + NumCast + SimdUnifiedOps,
132    D: Data<Elem = F>,
133{
134    let n = x.len();
135    if n <= ddof {
136        return Err(crate::error::StatsError::invalid_argument(
137            "Not enough data points for the given degrees of freedom",
138        ));
139    }
140
141    let config = config.unwrap_or_default();
142
143    if n < config.minsize {
144        // Small arrays: use scalar Welford's algorithm
145        return variance_scalar_welford(x, ddof);
146    }
147
148    // Use SIMD-optimized two-pass algorithm for better accuracy
149    let mean = mean_simd_optimized(x, Some(config))?;
150    let sum_sq_dev = chunked_simd_sum_squared_deviations(x, mean, &config)?;
151
152    Ok(sum_sq_dev / F::from(n - ddof).expect("Failed to convert to float"))
153}
154
155/// Compute all basic statistics in a single SIMD pass
156///
157/// Returns (mean, variance, min, max, skewness, kurtosis)
158#[allow(dead_code)]
159pub fn stats_simd_single_pass<F, D>(
160    x: &ArrayBase<D, Ix1>,
161    config: Option<SimdConfig>,
162) -> StatsResult<(F, F, F, F, F, F)>
163where
164    F: Float + NumCast + SimdUnifiedOps,
165    D: Data<Elem = F>,
166{
167    if x.is_empty() {
168        return Err(crate::error::StatsError::invalid_argument(
169            "Cannot compute statistics of empty array",
170        ));
171    }
172
173    let config = config.unwrap_or_default();
174    let n = x.len();
175    let n_f = F::from(n).expect("Failed to convert to float");
176
177    if n < config.minsize {
178        // Use scalar single-pass algorithm
179        return stats_scalar_single_pass(x);
180    }
181
182    // SIMD single-pass algorithm using moments
183    let capabilities = PlatformCapabilities::detect();
184    let simd_width = if capabilities.simd_available { 8 } else { 1 };
185
186    // Process in SIMD chunks
187    let mut m1 = F::zero(); // First moment (sum)
188    let mut m2 = F::zero(); // Second moment
189    let mut m3 = F::zero(); // Third moment
190    let mut m4 = F::zero(); // Fourth moment
191    let mut min = x[0];
192    let mut max = x[0];
193
194    // Main SIMD loop
195    let chunks = x.len() / simd_width;
196    let _remainder = x.len() % simd_width;
197
198    for chunk_idx in 0..chunks {
199        let start = chunk_idx * simd_width;
200        let chunk = x.slice(scirs2_core::ndarray::s![start..start + simd_width]);
201
202        // Process chunk with SIMD
203        let chunk_sum = F::simd_sum(&chunk);
204        m1 = m1 + chunk_sum;
205
206        // Update min/max
207        let chunk_min = F::simd_min_element(&chunk);
208        let chunk_max = F::simd_max_element(&chunk);
209        if chunk_min < min {
210            min = chunk_min;
211        }
212        if chunk_max > max {
213            max = chunk_max;
214        }
215    }
216
217    // Handle remainder with scalar code
218    let remainder_start = chunks * simd_width;
219    for i in remainder_start..x.len() {
220        let val = x[i];
221        m1 = m1 + val;
222        if val < min {
223            min = val;
224        }
225        if val > max {
226            max = val;
227        }
228    }
229
230    // Second pass for central moments (more accurate)
231    let mean = m1 / n_f;
232
233    // Compute central moments in second pass
234    for chunk_idx in 0..chunks {
235        let start = chunk_idx * simd_width;
236        let chunk = x.slice(scirs2_core::ndarray::s![start..start + simd_width]);
237
238        // Compute deviations and powers using SIMD
239        for &val in chunk.iter() {
240            let dev = val - mean;
241            let dev2 = dev * dev;
242            let dev3 = dev2 * dev;
243            let dev4 = dev3 * dev;
244
245            m2 = m2 + dev2;
246            m3 = m3 + dev3;
247            m4 = m4 + dev4;
248        }
249    }
250
251    // Handle remainder
252    for i in remainder_start..x.len() {
253        let dev = x[i] - mean;
254        let dev2 = dev * dev;
255        let dev3 = dev2 * dev;
256        let dev4 = dev3 * dev;
257
258        m2 = m2 + dev2;
259        m3 = m3 + dev3;
260        m4 = m4 + dev4;
261    }
262
263    // Calculate statistics from moments
264    let variance = m2 / F::from(n - 1).expect("Failed to convert to float");
265    let std_dev = variance.sqrt();
266
267    let skewness = if std_dev > F::epsilon() {
268        (m3 / n_f) / (std_dev * std_dev * std_dev)
269    } else {
270        F::zero()
271    };
272
273    let kurtosis = if variance > F::epsilon() {
274        (m4 / n_f) / (variance * variance)
275            - F::from(3).expect("Failed to convert constant to float")
276    } else {
277        F::zero()
278    };
279
280    Ok((mean, variance, min, max, skewness, kurtosis))
281}
282
283/// Helper function for chunked SIMD sum
284#[allow(dead_code)]
285fn chunked_simd_sum<F, D>(x: &ArrayBase<D, Ix1>, config: &SimdConfig) -> StatsResult<F>
286where
287    F: Float + NumCast + SimdUnifiedOps,
288    D: Data<Elem = F>,
289{
290    let capabilities = PlatformCapabilities::detect();
291    let _simd_width = if capabilities.simd_available { 8 } else { 1 };
292
293    // Process in chunks for better cache efficiency
294    const CHUNK_SIZE: usize = 1024;
295    let mut total_sum = F::zero();
296
297    for chunk in x.windows(CHUNK_SIZE) {
298        let chunk_sum = F::simd_sum(&chunk.view());
299        total_sum = total_sum + chunk_sum;
300    }
301
302    // Handle any remaining elements
303    let processed = (x.len() / CHUNK_SIZE) * CHUNK_SIZE;
304    if processed < x.len() {
305        let remainder = x.slice(scirs2_core::ndarray::s![processed..]);
306        let remainder_sum = F::simd_sum(&remainder);
307        total_sum = total_sum + remainder_sum;
308    }
309
310    Ok(total_sum)
311}
312
313/// Helper function for chunked SIMD sum of squared deviations
314#[allow(dead_code)]
315fn chunked_simd_sum_squared_deviations<F, D>(
316    x: &ArrayBase<D, Ix1>,
317    mean: F,
318    config: &SimdConfig,
319) -> StatsResult<F>
320where
321    F: Float + NumCast + SimdUnifiedOps,
322    D: Data<Elem = F>,
323{
324    const CHUNK_SIZE: usize = 1024;
325    let mut total_sum = F::zero();
326
327    // Process data in chunks without creating temporary arrays
328    for chunk in x.windows(CHUNK_SIZE) {
329        let chunk_sum = chunk
330            .iter()
331            .map(|&val| {
332                let dev = val - mean;
333                dev * dev
334            })
335            .fold(F::zero(), |acc, val| acc + val);
336        total_sum = total_sum + chunk_sum;
337    }
338
339    // Handle remainder
340    let processed = (x.len() / CHUNK_SIZE) * CHUNK_SIZE;
341    if processed < x.len() {
342        for i in processed..x.len() {
343            let dev = x[i] - mean;
344            total_sum = total_sum + dev * dev;
345        }
346    }
347
348    Ok(total_sum)
349}
350
351/// Scalar Welford's algorithm for variance (fallback)
352#[allow(dead_code)]
353fn variance_scalar_welford<F, D>(x: &ArrayBase<D, Ix1>, ddof: usize) -> StatsResult<F>
354where
355    F: Float + NumCast,
356    D: Data<Elem = F>,
357{
358    let mut mean = F::zero();
359    let mut m2 = F::zero();
360    let mut count = 0;
361
362    for &val in x.iter() {
363        count += 1;
364        let delta = val - mean;
365        mean = mean + delta / F::from(count).expect("Failed to convert to float");
366        let delta2 = val - mean;
367        m2 = m2 + delta * delta2;
368    }
369
370    Ok(m2 / F::from(count - ddof).expect("Failed to convert to float"))
371}
372
373/// Scalar single-pass statistics (fallback)
374#[allow(dead_code)]
375fn stats_scalar_single_pass<F, D>(x: &ArrayBase<D, Ix1>) -> StatsResult<(F, F, F, F, F, F)>
376where
377    F: Float + NumCast,
378    D: Data<Elem = F>,
379{
380    let n = x.len();
381    let n_f = F::from(n).expect("Failed to convert to float");
382
383    // First pass: compute mean
384    let mean = x.iter().fold(F::zero(), |acc, &val| acc + val) / n_f;
385
386    // Second pass: compute moments and min/max
387    let mut m2 = F::zero();
388    let mut m3 = F::zero();
389    let mut m4 = F::zero();
390    let mut min = x[0];
391    let mut max = x[0];
392
393    for &val in x.iter() {
394        let dev = val - mean;
395        let dev2 = dev * dev;
396        let dev3 = dev2 * dev;
397        let dev4 = dev3 * dev;
398
399        m2 = m2 + dev2;
400        m3 = m3 + dev3;
401        m4 = m4 + dev4;
402
403        if val < min {
404            min = val;
405        }
406        if val > max {
407            max = val;
408        }
409    }
410
411    let variance = m2 / F::from(n - 1).expect("Failed to convert to float");
412    let std_dev = variance.sqrt();
413
414    let skewness = if std_dev > F::epsilon() {
415        (m3 / n_f) / (std_dev * std_dev * std_dev)
416    } else {
417        F::zero()
418    };
419
420    let kurtosis = if variance > F::epsilon() {
421        (m4 / n_f) / (variance * variance)
422            - F::from(3).expect("Failed to convert constant to float")
423    } else {
424        F::zero()
425    };
426
427    Ok((mean, variance, min, max, skewness, kurtosis))
428}
429
430#[cfg(test)]
431mod tests {
432    use super::*;
433    use scirs2_core::ndarray::array;
434
435    #[test]
436    fn test_mean_simd_optimized() {
437        let data = array![1.0f64, 2.0, 3.0, 4.0, 5.0];
438        let mean = mean_simd_optimized(&data.view(), None).expect("Operation failed");
439        assert!((mean - 3.0).abs() < 1e-10);
440    }
441
442    #[test]
443    fn test_variance_simd_optimized() {
444        let data = array![1.0f64, 2.0, 3.0, 4.0, 5.0];
445        let var = variance_simd_optimized(&data.view(), 1, None).expect("Operation failed");
446        assert!((var - 2.5).abs() < 1e-10);
447    }
448
449    #[test]
450    fn test_stats_single_pass() {
451        let data = array![1.0f64, 2.0, 3.0, 4.0, 5.0];
452        let (mean, var, min, max__, skew, kurt) =
453            stats_simd_single_pass(&data.view(), None).expect("Operation failed");
454
455        assert!((mean - 3.0).abs() < 1e-10);
456        assert!((var - 2.5).abs() < 1e-10);
457        assert!((min - 1.0).abs() < 1e-10);
458        assert!((max__ - 5.0).abs() < 1e-10);
459    }
460
461    // --- Tests for SimdConfig::detect() ---
462
463    /// detect() must not panic on any supported platform.
464    #[test]
465    fn test_simd_config_detect_no_panic() {
466        let cfg = SimdConfig::detect();
467        // Just verify we got a value without panicking.
468        let _ = cfg;
469    }
470
471    /// On every platform the unroll_factor must be at least 1 (scalar fallback).
472    #[test]
473    fn test_simd_config_unroll_factor_geq_1() {
474        let cfg = SimdConfig::detect();
475        assert!(
476            cfg.unroll_factor >= 1,
477            "unroll_factor must be >= 1, got {}",
478            cfg.unroll_factor
479        );
480    }
481
482    /// The default SimdConfig must also not panic and have unroll_factor >= 1.
483    #[test]
484    fn test_simd_config_default_valid() {
485        let cfg = SimdConfig::default();
486        assert!(
487            cfg.unroll_factor >= 1,
488            "default unroll_factor must be >= 1, got {}",
489            cfg.unroll_factor
490        );
491        assert!(cfg.minsize > 0, "minsize must be > 0");
492    }
493}