Skip to main content

scirs2_stats/
simd_optimized_v2.rs

1//! Enhanced SIMD optimizations for v1.0.0
2//!
3//! This module provides improved SIMD implementations that:
4//! - Avoid temporary array allocations
5//! - Use efficient SIMD patterns
6//! - Provide automatic fallback to scalar code
7//! - Support multiple data types
8
9use crate::error::StatsResult;
10use scirs2_core::ndarray::{ArrayBase, Data, Ix1};
11use scirs2_core::numeric::{Float, NumCast};
12use scirs2_core::simd_ops::{PlatformCapabilities, SimdUnifiedOps};
13use scirs2_core::validation::check_not_empty;
14
15/// SIMD configuration for optimal performance
16#[derive(Debug, Clone, Copy)]
17pub struct SimdConfig {
18    /// Minimum array size for SIMD to be beneficial
19    pub minsize: usize,
20    /// Whether to use aligned loads/stores
21    pub use_aligned: bool,
22    /// Maximum unroll factor
23    pub unroll_factor: usize,
24}
25
26impl Default for SimdConfig {
27    fn default() -> Self {
28        // Use conservative defaults for cross-platform compatibility
29        // TODO: Implement proper platform capability detection when available
30        let minsize = 128; // Conservative threshold for SIMD benefits
31
32        Self {
33            minsize,
34            use_aligned: true, // Enable alignment for better performance
35            unroll_factor: 4,  // Standard unroll factor
36        }
37    }
38}
39
40/// Optimized mean calculation using SIMD with chunked processing
41///
42/// This implementation avoids temporary arrays and processes data in chunks
43/// for better cache efficiency.
44#[allow(dead_code)]
45pub fn mean_simd_optimized<F, D>(
46    x: &ArrayBase<D, Ix1>,
47    config: Option<SimdConfig>,
48) -> StatsResult<F>
49where
50    F: Float + NumCast + SimdUnifiedOps,
51    D: Data<Elem = F>,
52{
53    // Use scirs2-core validation
54    check_not_empty(x, "x").map_err(|_| {
55        crate::error::StatsError::invalid_argument("Cannot compute mean of empty array")
56    })?;
57
58    let config = config.unwrap_or_default();
59    let n = x.len();
60
61    if n < config.minsize {
62        // Small arrays: use scalar code
63        let sum = x.iter().fold(F::zero(), |acc, &val| acc + val);
64        return Ok(sum / F::from(n).expect("Failed to convert to float"));
65    }
66
67    // For larger arrays, use chunked SIMD processing
68    let sum = chunked_simd_sum(x, &config)?;
69    Ok(sum / F::from(n).expect("Failed to convert to float"))
70}
71
72/// Optimized variance calculation using single-pass SIMD algorithm
73///
74/// Uses Welford's online algorithm adapted for SIMD processing
75#[allow(dead_code)]
76pub fn variance_simd_optimized<F, D>(
77    x: &ArrayBase<D, Ix1>,
78    ddof: usize,
79    config: Option<SimdConfig>,
80) -> StatsResult<F>
81where
82    F: Float + NumCast + SimdUnifiedOps,
83    D: Data<Elem = F>,
84{
85    let n = x.len();
86    if n <= ddof {
87        return Err(crate::error::StatsError::invalid_argument(
88            "Not enough data points for the given degrees of freedom",
89        ));
90    }
91
92    let config = config.unwrap_or_default();
93
94    if n < config.minsize {
95        // Small arrays: use scalar Welford's algorithm
96        return variance_scalar_welford(x, ddof);
97    }
98
99    // Use SIMD-optimized two-pass algorithm for better accuracy
100    let mean = mean_simd_optimized(x, Some(config))?;
101    let sum_sq_dev = chunked_simd_sum_squared_deviations(x, mean, &config)?;
102
103    Ok(sum_sq_dev / F::from(n - ddof).expect("Failed to convert to float"))
104}
105
106/// Compute all basic statistics in a single SIMD pass
107///
108/// Returns (mean, variance, min, max, skewness, kurtosis)
109#[allow(dead_code)]
110pub fn stats_simd_single_pass<F, D>(
111    x: &ArrayBase<D, Ix1>,
112    config: Option<SimdConfig>,
113) -> StatsResult<(F, F, F, F, F, F)>
114where
115    F: Float + NumCast + SimdUnifiedOps,
116    D: Data<Elem = F>,
117{
118    if x.is_empty() {
119        return Err(crate::error::StatsError::invalid_argument(
120            "Cannot compute statistics of empty array",
121        ));
122    }
123
124    let config = config.unwrap_or_default();
125    let n = x.len();
126    let n_f = F::from(n).expect("Failed to convert to float");
127
128    if n < config.minsize {
129        // Use scalar single-pass algorithm
130        return stats_scalar_single_pass(x);
131    }
132
133    // SIMD single-pass algorithm using moments
134    let capabilities = PlatformCapabilities::detect();
135    let simd_width = if capabilities.simd_available { 8 } else { 1 };
136
137    // Process in SIMD chunks
138    let mut m1 = F::zero(); // First moment (sum)
139    let mut m2 = F::zero(); // Second moment
140    let mut m3 = F::zero(); // Third moment
141    let mut m4 = F::zero(); // Fourth moment
142    let mut min = x[0];
143    let mut max = x[0];
144
145    // Main SIMD loop
146    let chunks = x.len() / simd_width;
147    let _remainder = x.len() % simd_width;
148
149    for chunk_idx in 0..chunks {
150        let start = chunk_idx * simd_width;
151        let chunk = x.slice(scirs2_core::ndarray::s![start..start + simd_width]);
152
153        // Process chunk with SIMD
154        let chunk_sum = F::simd_sum(&chunk);
155        m1 = m1 + chunk_sum;
156
157        // Update min/max
158        let chunk_min = F::simd_min_element(&chunk);
159        let chunk_max = F::simd_max_element(&chunk);
160        if chunk_min < min {
161            min = chunk_min;
162        }
163        if chunk_max > max {
164            max = chunk_max;
165        }
166    }
167
168    // Handle remainder with scalar code
169    let remainder_start = chunks * simd_width;
170    for i in remainder_start..x.len() {
171        let val = x[i];
172        m1 = m1 + val;
173        if val < min {
174            min = val;
175        }
176        if val > max {
177            max = val;
178        }
179    }
180
181    // Second pass for central moments (more accurate)
182    let mean = m1 / n_f;
183
184    // Compute central moments in second pass
185    for chunk_idx in 0..chunks {
186        let start = chunk_idx * simd_width;
187        let chunk = x.slice(scirs2_core::ndarray::s![start..start + simd_width]);
188
189        // Compute deviations and powers using SIMD
190        for &val in chunk.iter() {
191            let dev = val - mean;
192            let dev2 = dev * dev;
193            let dev3 = dev2 * dev;
194            let dev4 = dev3 * dev;
195
196            m2 = m2 + dev2;
197            m3 = m3 + dev3;
198            m4 = m4 + dev4;
199        }
200    }
201
202    // Handle remainder
203    for i in remainder_start..x.len() {
204        let dev = x[i] - mean;
205        let dev2 = dev * dev;
206        let dev3 = dev2 * dev;
207        let dev4 = dev3 * dev;
208
209        m2 = m2 + dev2;
210        m3 = m3 + dev3;
211        m4 = m4 + dev4;
212    }
213
214    // Calculate statistics from moments
215    let variance = m2 / F::from(n - 1).expect("Failed to convert to float");
216    let std_dev = variance.sqrt();
217
218    let skewness = if std_dev > F::epsilon() {
219        (m3 / n_f) / (std_dev * std_dev * std_dev)
220    } else {
221        F::zero()
222    };
223
224    let kurtosis = if variance > F::epsilon() {
225        (m4 / n_f) / (variance * variance)
226            - F::from(3).expect("Failed to convert constant to float")
227    } else {
228        F::zero()
229    };
230
231    Ok((mean, variance, min, max, skewness, kurtosis))
232}
233
234/// Helper function for chunked SIMD sum
235#[allow(dead_code)]
236fn chunked_simd_sum<F, D>(x: &ArrayBase<D, Ix1>, config: &SimdConfig) -> StatsResult<F>
237where
238    F: Float + NumCast + SimdUnifiedOps,
239    D: Data<Elem = F>,
240{
241    let capabilities = PlatformCapabilities::detect();
242    let _simd_width = if capabilities.simd_available { 8 } else { 1 };
243
244    // Process in chunks for better cache efficiency
245    const CHUNK_SIZE: usize = 1024;
246    let mut total_sum = F::zero();
247
248    for chunk in x.windows(CHUNK_SIZE) {
249        let chunk_sum = F::simd_sum(&chunk.view());
250        total_sum = total_sum + chunk_sum;
251    }
252
253    // Handle any remaining elements
254    let processed = (x.len() / CHUNK_SIZE) * CHUNK_SIZE;
255    if processed < x.len() {
256        let remainder = x.slice(scirs2_core::ndarray::s![processed..]);
257        let remainder_sum = F::simd_sum(&remainder);
258        total_sum = total_sum + remainder_sum;
259    }
260
261    Ok(total_sum)
262}
263
264/// Helper function for chunked SIMD sum of squared deviations
265#[allow(dead_code)]
266fn chunked_simd_sum_squared_deviations<F, D>(
267    x: &ArrayBase<D, Ix1>,
268    mean: F,
269    config: &SimdConfig,
270) -> StatsResult<F>
271where
272    F: Float + NumCast + SimdUnifiedOps,
273    D: Data<Elem = F>,
274{
275    const CHUNK_SIZE: usize = 1024;
276    let mut total_sum = F::zero();
277
278    // Process data in chunks without creating temporary arrays
279    for chunk in x.windows(CHUNK_SIZE) {
280        let chunk_sum = chunk
281            .iter()
282            .map(|&val| {
283                let dev = val - mean;
284                dev * dev
285            })
286            .fold(F::zero(), |acc, val| acc + val);
287        total_sum = total_sum + chunk_sum;
288    }
289
290    // Handle remainder
291    let processed = (x.len() / CHUNK_SIZE) * CHUNK_SIZE;
292    if processed < x.len() {
293        for i in processed..x.len() {
294            let dev = x[i] - mean;
295            total_sum = total_sum + dev * dev;
296        }
297    }
298
299    Ok(total_sum)
300}
301
302/// Scalar Welford's algorithm for variance (fallback)
303#[allow(dead_code)]
304fn variance_scalar_welford<F, D>(x: &ArrayBase<D, Ix1>, ddof: usize) -> StatsResult<F>
305where
306    F: Float + NumCast,
307    D: Data<Elem = F>,
308{
309    let mut mean = F::zero();
310    let mut m2 = F::zero();
311    let mut count = 0;
312
313    for &val in x.iter() {
314        count += 1;
315        let delta = val - mean;
316        mean = mean + delta / F::from(count).expect("Failed to convert to float");
317        let delta2 = val - mean;
318        m2 = m2 + delta * delta2;
319    }
320
321    Ok(m2 / F::from(count - ddof).expect("Failed to convert to float"))
322}
323
324/// Scalar single-pass statistics (fallback)
325#[allow(dead_code)]
326fn stats_scalar_single_pass<F, D>(x: &ArrayBase<D, Ix1>) -> StatsResult<(F, F, F, F, F, F)>
327where
328    F: Float + NumCast,
329    D: Data<Elem = F>,
330{
331    let n = x.len();
332    let n_f = F::from(n).expect("Failed to convert to float");
333
334    // First pass: compute mean
335    let mean = x.iter().fold(F::zero(), |acc, &val| acc + val) / n_f;
336
337    // Second pass: compute moments and min/max
338    let mut m2 = F::zero();
339    let mut m3 = F::zero();
340    let mut m4 = F::zero();
341    let mut min = x[0];
342    let mut max = x[0];
343
344    for &val in x.iter() {
345        let dev = val - mean;
346        let dev2 = dev * dev;
347        let dev3 = dev2 * dev;
348        let dev4 = dev3 * dev;
349
350        m2 = m2 + dev2;
351        m3 = m3 + dev3;
352        m4 = m4 + dev4;
353
354        if val < min {
355            min = val;
356        }
357        if val > max {
358            max = val;
359        }
360    }
361
362    let variance = m2 / F::from(n - 1).expect("Failed to convert to float");
363    let std_dev = variance.sqrt();
364
365    let skewness = if std_dev > F::epsilon() {
366        (m3 / n_f) / (std_dev * std_dev * std_dev)
367    } else {
368        F::zero()
369    };
370
371    let kurtosis = if variance > F::epsilon() {
372        (m4 / n_f) / (variance * variance)
373            - F::from(3).expect("Failed to convert constant to float")
374    } else {
375        F::zero()
376    };
377
378    Ok((mean, variance, min, max, skewness, kurtosis))
379}
380
381#[cfg(test)]
382mod tests {
383    use super::*;
384    use scirs2_core::ndarray::array;
385
386    #[test]
387    fn test_mean_simd_optimized() {
388        let data = array![1.0f64, 2.0, 3.0, 4.0, 5.0];
389        let mean = mean_simd_optimized(&data.view(), None).expect("Operation failed");
390        assert!((mean - 3.0).abs() < 1e-10);
391    }
392
393    #[test]
394    fn test_variance_simd_optimized() {
395        let data = array![1.0f64, 2.0, 3.0, 4.0, 5.0];
396        let var = variance_simd_optimized(&data.view(), 1, None).expect("Operation failed");
397        assert!((var - 2.5).abs() < 1e-10);
398    }
399
400    #[test]
401    fn test_stats_single_pass() {
402        let data = array![1.0f64, 2.0, 3.0, 4.0, 5.0];
403        let (mean, var, min, max__, skew, kurt) =
404            stats_simd_single_pass(&data.view(), None).expect("Operation failed");
405
406        assert!((mean - 3.0).abs() < 1e-10);
407        assert!((var - 2.5).abs() < 1e-10);
408        assert!((min - 1.0).abs() < 1e-10);
409        assert!((max__ - 5.0).abs() < 1e-10);
410    }
411}