scirs2_stats/
simd_optimized_v2.rs

1//! Enhanced SIMD optimizations for v1.0.0
2//!
3//! This module provides improved SIMD implementations that:
4//! - Avoid temporary array allocations
5//! - Use efficient SIMD patterns
6//! - Provide automatic fallback to scalar code
7//! - Support multiple data types
8
9use crate::error::StatsResult;
10use scirs2_core::ndarray::{ArrayBase, Data, Ix1};
11use scirs2_core::numeric::{Float, NumCast};
12use scirs2_core::simd_ops::{PlatformCapabilities, SimdUnifiedOps};
13use scirs2_core::validation::check_not_empty;
14
15/// SIMD configuration for optimal performance
16#[derive(Debug, Clone, Copy)]
17pub struct SimdConfig {
18    /// Minimum array size for SIMD to be beneficial
19    pub minsize: usize,
20    /// Whether to use aligned loads/stores
21    pub use_aligned: bool,
22    /// Maximum unroll factor
23    pub unroll_factor: usize,
24}
25
26impl Default for SimdConfig {
27    fn default() -> Self {
28        // Use conservative defaults for cross-platform compatibility
29        // TODO: Implement proper platform capability detection when available
30        let minsize = 128; // Conservative threshold for SIMD benefits
31
32        Self {
33            minsize,
34            use_aligned: true, // Enable alignment for better performance
35            unroll_factor: 4,  // Standard unroll factor
36        }
37    }
38}
39
40/// Optimized mean calculation using SIMD with chunked processing
41///
42/// This implementation avoids temporary arrays and processes data in chunks
43/// for better cache efficiency.
44#[allow(dead_code)]
45pub fn mean_simd_optimized<F, D>(
46    x: &ArrayBase<D, Ix1>,
47    config: Option<SimdConfig>,
48) -> StatsResult<F>
49where
50    F: Float + NumCast + SimdUnifiedOps,
51    D: Data<Elem = F>,
52{
53    // Use scirs2-core validation
54    check_not_empty(x, "x").map_err(|_| {
55        crate::error::StatsError::invalid_argument("Cannot compute mean of empty array")
56    })?;
57
58    let config = config.unwrap_or_default();
59    let n = x.len();
60
61    if n < config.minsize {
62        // Small arrays: use scalar code
63        let sum = x.iter().fold(F::zero(), |acc, &val| acc + val);
64        return Ok(sum / F::from(n).unwrap());
65    }
66
67    // For larger arrays, use chunked SIMD processing
68    let sum = chunked_simd_sum(x, &config)?;
69    Ok(sum / F::from(n).unwrap())
70}
71
72/// Optimized variance calculation using single-pass SIMD algorithm
73///
74/// Uses Welford's online algorithm adapted for SIMD processing
75#[allow(dead_code)]
76pub fn variance_simd_optimized<F, D>(
77    x: &ArrayBase<D, Ix1>,
78    ddof: usize,
79    config: Option<SimdConfig>,
80) -> StatsResult<F>
81where
82    F: Float + NumCast + SimdUnifiedOps,
83    D: Data<Elem = F>,
84{
85    let n = x.len();
86    if n <= ddof {
87        return Err(crate::error::StatsError::invalid_argument(
88            "Not enough data points for the given degrees of freedom",
89        ));
90    }
91
92    let config = config.unwrap_or_default();
93
94    if n < config.minsize {
95        // Small arrays: use scalar Welford's algorithm
96        return variance_scalar_welford(x, ddof);
97    }
98
99    // Use SIMD-optimized two-pass algorithm for better accuracy
100    let mean = mean_simd_optimized(x, Some(config))?;
101    let sum_sq_dev = chunked_simd_sum_squared_deviations(x, mean, &config)?;
102
103    Ok(sum_sq_dev / F::from(n - ddof).unwrap())
104}
105
106/// Compute all basic statistics in a single SIMD pass
107///
108/// Returns (mean, variance, min, max, skewness, kurtosis)
109#[allow(dead_code)]
110pub fn stats_simd_single_pass<F, D>(
111    x: &ArrayBase<D, Ix1>,
112    config: Option<SimdConfig>,
113) -> StatsResult<(F, F, F, F, F, F)>
114where
115    F: Float + NumCast + SimdUnifiedOps,
116    D: Data<Elem = F>,
117{
118    if x.is_empty() {
119        return Err(crate::error::StatsError::invalid_argument(
120            "Cannot compute statistics of empty array",
121        ));
122    }
123
124    let config = config.unwrap_or_default();
125    let n = x.len();
126    let n_f = F::from(n).unwrap();
127
128    if n < config.minsize {
129        // Use scalar single-pass algorithm
130        return stats_scalar_single_pass(x);
131    }
132
133    // SIMD single-pass algorithm using moments
134    let capabilities = PlatformCapabilities::detect();
135    let simd_width = if capabilities.simd_available { 8 } else { 1 };
136
137    // Process in SIMD chunks
138    let mut m1 = F::zero(); // First moment (sum)
139    let mut m2 = F::zero(); // Second moment
140    let mut m3 = F::zero(); // Third moment
141    let mut m4 = F::zero(); // Fourth moment
142    let mut min = x[0];
143    let mut max = x[0];
144
145    // Main SIMD loop
146    let chunks = x.len() / simd_width;
147    let _remainder = x.len() % simd_width;
148
149    for chunk_idx in 0..chunks {
150        let start = chunk_idx * simd_width;
151        let chunk = x.slice(scirs2_core::ndarray::s![start..start + simd_width]);
152
153        // Process chunk with SIMD
154        let chunk_sum = F::simd_sum(&chunk);
155        m1 = m1 + chunk_sum;
156
157        // Update min/max
158        let chunk_min = F::simd_min_element(&chunk);
159        let chunk_max = F::simd_max_element(&chunk);
160        if chunk_min < min {
161            min = chunk_min;
162        }
163        if chunk_max > max {
164            max = chunk_max;
165        }
166    }
167
168    // Handle remainder with scalar code
169    let remainder_start = chunks * simd_width;
170    for i in remainder_start..x.len() {
171        let val = x[i];
172        m1 = m1 + val;
173        if val < min {
174            min = val;
175        }
176        if val > max {
177            max = val;
178        }
179    }
180
181    // Second pass for central moments (more accurate)
182    let mean = m1 / n_f;
183
184    // Compute central moments in second pass
185    for chunk_idx in 0..chunks {
186        let start = chunk_idx * simd_width;
187        let chunk = x.slice(scirs2_core::ndarray::s![start..start + simd_width]);
188
189        // Compute deviations and powers using SIMD
190        for &val in chunk.iter() {
191            let dev = val - mean;
192            let dev2 = dev * dev;
193            let dev3 = dev2 * dev;
194            let dev4 = dev3 * dev;
195
196            m2 = m2 + dev2;
197            m3 = m3 + dev3;
198            m4 = m4 + dev4;
199        }
200    }
201
202    // Handle remainder
203    for i in remainder_start..x.len() {
204        let dev = x[i] - mean;
205        let dev2 = dev * dev;
206        let dev3 = dev2 * dev;
207        let dev4 = dev3 * dev;
208
209        m2 = m2 + dev2;
210        m3 = m3 + dev3;
211        m4 = m4 + dev4;
212    }
213
214    // Calculate statistics from moments
215    let variance = m2 / F::from(n - 1).unwrap();
216    let std_dev = variance.sqrt();
217
218    let skewness = if std_dev > F::epsilon() {
219        (m3 / n_f) / (std_dev * std_dev * std_dev)
220    } else {
221        F::zero()
222    };
223
224    let kurtosis = if variance > F::epsilon() {
225        (m4 / n_f) / (variance * variance) - F::from(3).unwrap()
226    } else {
227        F::zero()
228    };
229
230    Ok((mean, variance, min, max, skewness, kurtosis))
231}
232
233/// Helper function for chunked SIMD sum
234#[allow(dead_code)]
235fn chunked_simd_sum<F, D>(x: &ArrayBase<D, Ix1>, config: &SimdConfig) -> StatsResult<F>
236where
237    F: Float + NumCast + SimdUnifiedOps,
238    D: Data<Elem = F>,
239{
240    let capabilities = PlatformCapabilities::detect();
241    let _simd_width = if capabilities.simd_available { 8 } else { 1 };
242
243    // Process in chunks for better cache efficiency
244    const CHUNK_SIZE: usize = 1024;
245    let mut total_sum = F::zero();
246
247    for chunk in x.windows(CHUNK_SIZE) {
248        let chunk_sum = F::simd_sum(&chunk.view());
249        total_sum = total_sum + chunk_sum;
250    }
251
252    // Handle any remaining elements
253    let processed = (x.len() / CHUNK_SIZE) * CHUNK_SIZE;
254    if processed < x.len() {
255        let remainder = x.slice(scirs2_core::ndarray::s![processed..]);
256        let remainder_sum = F::simd_sum(&remainder);
257        total_sum = total_sum + remainder_sum;
258    }
259
260    Ok(total_sum)
261}
262
263/// Helper function for chunked SIMD sum of squared deviations
264#[allow(dead_code)]
265fn chunked_simd_sum_squared_deviations<F, D>(
266    x: &ArrayBase<D, Ix1>,
267    mean: F,
268    config: &SimdConfig,
269) -> StatsResult<F>
270where
271    F: Float + NumCast + SimdUnifiedOps,
272    D: Data<Elem = F>,
273{
274    const CHUNK_SIZE: usize = 1024;
275    let mut total_sum = F::zero();
276
277    // Process data in chunks without creating temporary arrays
278    for chunk in x.windows(CHUNK_SIZE) {
279        let chunk_sum = chunk
280            .iter()
281            .map(|&val| {
282                let dev = val - mean;
283                dev * dev
284            })
285            .fold(F::zero(), |acc, val| acc + val);
286        total_sum = total_sum + chunk_sum;
287    }
288
289    // Handle remainder
290    let processed = (x.len() / CHUNK_SIZE) * CHUNK_SIZE;
291    if processed < x.len() {
292        for i in processed..x.len() {
293            let dev = x[i] - mean;
294            total_sum = total_sum + dev * dev;
295        }
296    }
297
298    Ok(total_sum)
299}
300
301/// Scalar Welford's algorithm for variance (fallback)
302#[allow(dead_code)]
303fn variance_scalar_welford<F, D>(x: &ArrayBase<D, Ix1>, ddof: usize) -> StatsResult<F>
304where
305    F: Float + NumCast,
306    D: Data<Elem = F>,
307{
308    let mut mean = F::zero();
309    let mut m2 = F::zero();
310    let mut count = 0;
311
312    for &val in x.iter() {
313        count += 1;
314        let delta = val - mean;
315        mean = mean + delta / F::from(count).unwrap();
316        let delta2 = val - mean;
317        m2 = m2 + delta * delta2;
318    }
319
320    Ok(m2 / F::from(count - ddof).unwrap())
321}
322
323/// Scalar single-pass statistics (fallback)
324#[allow(dead_code)]
325fn stats_scalar_single_pass<F, D>(x: &ArrayBase<D, Ix1>) -> StatsResult<(F, F, F, F, F, F)>
326where
327    F: Float + NumCast,
328    D: Data<Elem = F>,
329{
330    let n = x.len();
331    let n_f = F::from(n).unwrap();
332
333    // First pass: compute mean
334    let mean = x.iter().fold(F::zero(), |acc, &val| acc + val) / n_f;
335
336    // Second pass: compute moments and min/max
337    let mut m2 = F::zero();
338    let mut m3 = F::zero();
339    let mut m4 = F::zero();
340    let mut min = x[0];
341    let mut max = x[0];
342
343    for &val in x.iter() {
344        let dev = val - mean;
345        let dev2 = dev * dev;
346        let dev3 = dev2 * dev;
347        let dev4 = dev3 * dev;
348
349        m2 = m2 + dev2;
350        m3 = m3 + dev3;
351        m4 = m4 + dev4;
352
353        if val < min {
354            min = val;
355        }
356        if val > max {
357            max = val;
358        }
359    }
360
361    let variance = m2 / F::from(n - 1).unwrap();
362    let std_dev = variance.sqrt();
363
364    let skewness = if std_dev > F::epsilon() {
365        (m3 / n_f) / (std_dev * std_dev * std_dev)
366    } else {
367        F::zero()
368    };
369
370    let kurtosis = if variance > F::epsilon() {
371        (m4 / n_f) / (variance * variance) - F::from(3).unwrap()
372    } else {
373        F::zero()
374    };
375
376    Ok((mean, variance, min, max, skewness, kurtosis))
377}
378
379#[cfg(test)]
380mod tests {
381    use super::*;
382    use scirs2_core::ndarray::array;
383
384    #[test]
385    fn test_mean_simd_optimized() {
386        let data = array![1.0f64, 2.0, 3.0, 4.0, 5.0];
387        let mean = mean_simd_optimized(&data.view(), None).unwrap();
388        assert!((mean - 3.0).abs() < 1e-10);
389    }
390
391    #[test]
392    fn test_variance_simd_optimized() {
393        let data = array![1.0f64, 2.0, 3.0, 4.0, 5.0];
394        let var = variance_simd_optimized(&data.view(), 1, None).unwrap();
395        assert!((var - 2.5).abs() < 1e-10);
396    }
397
398    #[test]
399    fn test_stats_single_pass() {
400        let data = array![1.0f64, 2.0, 3.0, 4.0, 5.0];
401        let (mean, var, min, max__, skew, kurt) =
402            stats_simd_single_pass(&data.view(), None).unwrap();
403
404        assert!((mean - 3.0).abs() < 1e-10);
405        assert!((var - 2.5).abs() < 1e-10);
406        assert!((min - 1.0).abs() < 1e-10);
407        assert!((max__ - 5.0).abs() < 1e-10);
408    }
409}