1use crate::error::StatsResult;
10use scirs2_core::ndarray::{ArrayBase, Data, Ix1};
11use scirs2_core::numeric::{Float, NumCast};
12use scirs2_core::simd_ops::{PlatformCapabilities, SimdUnifiedOps};
13use scirs2_core::validation::check_not_empty;
14
15#[derive(Debug, Clone, Copy)]
17pub struct SimdConfig {
18 pub minsize: usize,
20 pub use_aligned: bool,
22 pub unroll_factor: usize,
24}
25
26impl Default for SimdConfig {
27 fn default() -> Self {
28 SimdConfig::detect()
29 }
30}
31
32impl SimdConfig {
33 pub fn detect() -> Self {
39 #[cfg(target_arch = "x86_64")]
40 {
41 if is_x86_feature_detected!("avx512f") {
43 return Self {
45 minsize: 256,
46 use_aligned: true,
47 unroll_factor: 8,
48 };
49 }
50 if is_x86_feature_detected!("avx2") {
51 return Self {
53 minsize: 128,
54 use_aligned: true,
55 unroll_factor: 4,
56 };
57 }
58 if is_x86_feature_detected!("sse4.2") {
59 return Self {
61 minsize: 64,
62 use_aligned: false,
63 unroll_factor: 2,
64 };
65 }
66 }
67
68 #[cfg(target_arch = "aarch64")]
69 {
70 return Self {
73 minsize: 64,
74 use_aligned: false,
75 unroll_factor: 4,
76 };
77 }
78
79 #[allow(unreachable_code)]
81 Self {
82 minsize: 32,
83 use_aligned: false,
84 unroll_factor: 1,
85 }
86 }
87}
88
89#[allow(dead_code)]
94pub fn mean_simd_optimized<F, D>(
95 x: &ArrayBase<D, Ix1>,
96 config: Option<SimdConfig>,
97) -> StatsResult<F>
98where
99 F: Float + NumCast + SimdUnifiedOps,
100 D: Data<Elem = F>,
101{
102 check_not_empty(x, "x").map_err(|_| {
104 crate::error::StatsError::invalid_argument("Cannot compute mean of empty array")
105 })?;
106
107 let config = config.unwrap_or_default();
108 let n = x.len();
109
110 if n < config.minsize {
111 let sum = x.iter().fold(F::zero(), |acc, &val| acc + val);
113 return Ok(sum / F::from(n).expect("Failed to convert to float"));
114 }
115
116 let sum = chunked_simd_sum(x, &config)?;
118 Ok(sum / F::from(n).expect("Failed to convert to float"))
119}
120
121#[allow(dead_code)]
125pub fn variance_simd_optimized<F, D>(
126 x: &ArrayBase<D, Ix1>,
127 ddof: usize,
128 config: Option<SimdConfig>,
129) -> StatsResult<F>
130where
131 F: Float + NumCast + SimdUnifiedOps,
132 D: Data<Elem = F>,
133{
134 let n = x.len();
135 if n <= ddof {
136 return Err(crate::error::StatsError::invalid_argument(
137 "Not enough data points for the given degrees of freedom",
138 ));
139 }
140
141 let config = config.unwrap_or_default();
142
143 if n < config.minsize {
144 return variance_scalar_welford(x, ddof);
146 }
147
148 let mean = mean_simd_optimized(x, Some(config))?;
150 let sum_sq_dev = chunked_simd_sum_squared_deviations(x, mean, &config)?;
151
152 Ok(sum_sq_dev / F::from(n - ddof).expect("Failed to convert to float"))
153}
154
155#[allow(dead_code)]
159pub fn stats_simd_single_pass<F, D>(
160 x: &ArrayBase<D, Ix1>,
161 config: Option<SimdConfig>,
162) -> StatsResult<(F, F, F, F, F, F)>
163where
164 F: Float + NumCast + SimdUnifiedOps,
165 D: Data<Elem = F>,
166{
167 if x.is_empty() {
168 return Err(crate::error::StatsError::invalid_argument(
169 "Cannot compute statistics of empty array",
170 ));
171 }
172
173 let config = config.unwrap_or_default();
174 let n = x.len();
175 let n_f = F::from(n).expect("Failed to convert to float");
176
177 if n < config.minsize {
178 return stats_scalar_single_pass(x);
180 }
181
182 let capabilities = PlatformCapabilities::detect();
184 let simd_width = if capabilities.simd_available { 8 } else { 1 };
185
186 let mut m1 = F::zero(); let mut m2 = F::zero(); let mut m3 = F::zero(); let mut m4 = F::zero(); let mut min = x[0];
192 let mut max = x[0];
193
194 let chunks = x.len() / simd_width;
196 let _remainder = x.len() % simd_width;
197
198 for chunk_idx in 0..chunks {
199 let start = chunk_idx * simd_width;
200 let chunk = x.slice(scirs2_core::ndarray::s![start..start + simd_width]);
201
202 let chunk_sum = F::simd_sum(&chunk);
204 m1 = m1 + chunk_sum;
205
206 let chunk_min = F::simd_min_element(&chunk);
208 let chunk_max = F::simd_max_element(&chunk);
209 if chunk_min < min {
210 min = chunk_min;
211 }
212 if chunk_max > max {
213 max = chunk_max;
214 }
215 }
216
217 let remainder_start = chunks * simd_width;
219 for i in remainder_start..x.len() {
220 let val = x[i];
221 m1 = m1 + val;
222 if val < min {
223 min = val;
224 }
225 if val > max {
226 max = val;
227 }
228 }
229
230 let mean = m1 / n_f;
232
233 for chunk_idx in 0..chunks {
235 let start = chunk_idx * simd_width;
236 let chunk = x.slice(scirs2_core::ndarray::s![start..start + simd_width]);
237
238 for &val in chunk.iter() {
240 let dev = val - mean;
241 let dev2 = dev * dev;
242 let dev3 = dev2 * dev;
243 let dev4 = dev3 * dev;
244
245 m2 = m2 + dev2;
246 m3 = m3 + dev3;
247 m4 = m4 + dev4;
248 }
249 }
250
251 for i in remainder_start..x.len() {
253 let dev = x[i] - mean;
254 let dev2 = dev * dev;
255 let dev3 = dev2 * dev;
256 let dev4 = dev3 * dev;
257
258 m2 = m2 + dev2;
259 m3 = m3 + dev3;
260 m4 = m4 + dev4;
261 }
262
263 let variance = m2 / F::from(n - 1).expect("Failed to convert to float");
265 let std_dev = variance.sqrt();
266
267 let skewness = if std_dev > F::epsilon() {
268 (m3 / n_f) / (std_dev * std_dev * std_dev)
269 } else {
270 F::zero()
271 };
272
273 let kurtosis = if variance > F::epsilon() {
274 (m4 / n_f) / (variance * variance)
275 - F::from(3).expect("Failed to convert constant to float")
276 } else {
277 F::zero()
278 };
279
280 Ok((mean, variance, min, max, skewness, kurtosis))
281}
282
283#[allow(dead_code)]
285fn chunked_simd_sum<F, D>(x: &ArrayBase<D, Ix1>, config: &SimdConfig) -> StatsResult<F>
286where
287 F: Float + NumCast + SimdUnifiedOps,
288 D: Data<Elem = F>,
289{
290 let capabilities = PlatformCapabilities::detect();
291 let _simd_width = if capabilities.simd_available { 8 } else { 1 };
292
293 const CHUNK_SIZE: usize = 1024;
295 let mut total_sum = F::zero();
296
297 for chunk in x.windows(CHUNK_SIZE) {
298 let chunk_sum = F::simd_sum(&chunk.view());
299 total_sum = total_sum + chunk_sum;
300 }
301
302 let processed = (x.len() / CHUNK_SIZE) * CHUNK_SIZE;
304 if processed < x.len() {
305 let remainder = x.slice(scirs2_core::ndarray::s![processed..]);
306 let remainder_sum = F::simd_sum(&remainder);
307 total_sum = total_sum + remainder_sum;
308 }
309
310 Ok(total_sum)
311}
312
313#[allow(dead_code)]
315fn chunked_simd_sum_squared_deviations<F, D>(
316 x: &ArrayBase<D, Ix1>,
317 mean: F,
318 config: &SimdConfig,
319) -> StatsResult<F>
320where
321 F: Float + NumCast + SimdUnifiedOps,
322 D: Data<Elem = F>,
323{
324 const CHUNK_SIZE: usize = 1024;
325 let mut total_sum = F::zero();
326
327 for chunk in x.windows(CHUNK_SIZE) {
329 let chunk_sum = chunk
330 .iter()
331 .map(|&val| {
332 let dev = val - mean;
333 dev * dev
334 })
335 .fold(F::zero(), |acc, val| acc + val);
336 total_sum = total_sum + chunk_sum;
337 }
338
339 let processed = (x.len() / CHUNK_SIZE) * CHUNK_SIZE;
341 if processed < x.len() {
342 for i in processed..x.len() {
343 let dev = x[i] - mean;
344 total_sum = total_sum + dev * dev;
345 }
346 }
347
348 Ok(total_sum)
349}
350
351#[allow(dead_code)]
353fn variance_scalar_welford<F, D>(x: &ArrayBase<D, Ix1>, ddof: usize) -> StatsResult<F>
354where
355 F: Float + NumCast,
356 D: Data<Elem = F>,
357{
358 let mut mean = F::zero();
359 let mut m2 = F::zero();
360 let mut count = 0;
361
362 for &val in x.iter() {
363 count += 1;
364 let delta = val - mean;
365 mean = mean + delta / F::from(count).expect("Failed to convert to float");
366 let delta2 = val - mean;
367 m2 = m2 + delta * delta2;
368 }
369
370 Ok(m2 / F::from(count - ddof).expect("Failed to convert to float"))
371}
372
373#[allow(dead_code)]
375fn stats_scalar_single_pass<F, D>(x: &ArrayBase<D, Ix1>) -> StatsResult<(F, F, F, F, F, F)>
376where
377 F: Float + NumCast,
378 D: Data<Elem = F>,
379{
380 let n = x.len();
381 let n_f = F::from(n).expect("Failed to convert to float");
382
383 let mean = x.iter().fold(F::zero(), |acc, &val| acc + val) / n_f;
385
386 let mut m2 = F::zero();
388 let mut m3 = F::zero();
389 let mut m4 = F::zero();
390 let mut min = x[0];
391 let mut max = x[0];
392
393 for &val in x.iter() {
394 let dev = val - mean;
395 let dev2 = dev * dev;
396 let dev3 = dev2 * dev;
397 let dev4 = dev3 * dev;
398
399 m2 = m2 + dev2;
400 m3 = m3 + dev3;
401 m4 = m4 + dev4;
402
403 if val < min {
404 min = val;
405 }
406 if val > max {
407 max = val;
408 }
409 }
410
411 let variance = m2 / F::from(n - 1).expect("Failed to convert to float");
412 let std_dev = variance.sqrt();
413
414 let skewness = if std_dev > F::epsilon() {
415 (m3 / n_f) / (std_dev * std_dev * std_dev)
416 } else {
417 F::zero()
418 };
419
420 let kurtosis = if variance > F::epsilon() {
421 (m4 / n_f) / (variance * variance)
422 - F::from(3).expect("Failed to convert constant to float")
423 } else {
424 F::zero()
425 };
426
427 Ok((mean, variance, min, max, skewness, kurtosis))
428}
429
430#[cfg(test)]
431mod tests {
432 use super::*;
433 use scirs2_core::ndarray::array;
434
435 #[test]
436 fn test_mean_simd_optimized() {
437 let data = array![1.0f64, 2.0, 3.0, 4.0, 5.0];
438 let mean = mean_simd_optimized(&data.view(), None).expect("Operation failed");
439 assert!((mean - 3.0).abs() < 1e-10);
440 }
441
442 #[test]
443 fn test_variance_simd_optimized() {
444 let data = array![1.0f64, 2.0, 3.0, 4.0, 5.0];
445 let var = variance_simd_optimized(&data.view(), 1, None).expect("Operation failed");
446 assert!((var - 2.5).abs() < 1e-10);
447 }
448
449 #[test]
450 fn test_stats_single_pass() {
451 let data = array![1.0f64, 2.0, 3.0, 4.0, 5.0];
452 let (mean, var, min, max__, skew, kurt) =
453 stats_simd_single_pass(&data.view(), None).expect("Operation failed");
454
455 assert!((mean - 3.0).abs() < 1e-10);
456 assert!((var - 2.5).abs() < 1e-10);
457 assert!((min - 1.0).abs() < 1e-10);
458 assert!((max__ - 5.0).abs() < 1e-10);
459 }
460
461 #[test]
465 fn test_simd_config_detect_no_panic() {
466 let cfg = SimdConfig::detect();
467 let _ = cfg;
469 }
470
471 #[test]
473 fn test_simd_config_unroll_factor_geq_1() {
474 let cfg = SimdConfig::detect();
475 assert!(
476 cfg.unroll_factor >= 1,
477 "unroll_factor must be >= 1, got {}",
478 cfg.unroll_factor
479 );
480 }
481
482 #[test]
484 fn test_simd_config_default_valid() {
485 let cfg = SimdConfig::default();
486 assert!(
487 cfg.unroll_factor >= 1,
488 "default unroll_factor must be >= 1, got {}",
489 cfg.unroll_factor
490 );
491 assert!(cfg.minsize > 0, "minsize must be > 0");
492 }
493}