1use crate::error::{StatsError, StatsResult};
8use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
9use scirs2_core::numeric::{Float, NumCast, One, Zero};
10use scirs2_core::Rng;
11use scirs2_core::{
12 simd_ops::{PlatformCapabilities, SimdUnifiedOps},
13 validation::*,
14};
15use std::marker::PhantomData;
16
17#[derive(Debug, Clone)]
19pub struct AdvancedSimdConfig {
20 pub capabilities: PlatformCapabilities,
22 pub chunksize: usize,
24 pub parallel_enabled: bool,
26 pub simd_threshold: usize,
28}
29
30impl Default for AdvancedSimdConfig {
31 fn default() -> Self {
32 let capabilities = PlatformCapabilities::detect();
33 let chunksize = if capabilities.avx512_available {
34 16 } else if capabilities.avx2_available {
36 8 } else if capabilities.simd_available {
38 4 } else {
40 1 };
42
43 Self {
44 capabilities,
45 chunksize,
46 parallel_enabled: true,
47 simd_threshold: 64,
48 }
49 }
50}
51
52pub struct AdvancedSimdStatistics<F> {
54 config: AdvancedSimdConfig,
55 _phantom: PhantomData<F>,
56}
57
58impl<F> AdvancedSimdStatistics<F>
59where
60 F: Float
61 + NumCast
62 + SimdUnifiedOps
63 + Zero
64 + One
65 + PartialOrd
66 + Copy
67 + Send
68 + Sync
69 + std::fmt::Display
70 + std::iter::Sum<F>,
71{
72 pub fn new() -> Self {
74 Self {
75 config: AdvancedSimdConfig::default(),
76 _phantom: PhantomData,
77 }
78 }
79
80 pub fn with_config(config: AdvancedSimdConfig) -> Self {
82 Self {
83 config,
84 _phantom: PhantomData,
85 }
86 }
87
88 pub fn comprehensive_stats_advanced(
90 &self,
91 data: &ArrayView1<F>,
92 ) -> StatsResult<ComprehensiveStats<F>> {
93 checkarray_finite(data, "data")?;
94
95 if data.is_empty() {
96 return Err(StatsError::InvalidArgument(
97 "Data cannot be empty".to_string(),
98 ));
99 }
100
101 let n = data.len();
102
103 if n >= self.config.simd_threshold && self.config.chunksize > 1 {
105 self.compute_simd_comprehensive(data)
106 } else {
107 self.compute_scalar_comprehensive(data)
108 }
109 }
110
111 fn compute_simd_comprehensive(
113 &self,
114 data: &ArrayView1<F>,
115 ) -> StatsResult<ComprehensiveStats<F>> {
116 let n = data.len();
117 let chunksize = self.config.chunksize;
118 let n_chunks = n / chunksize;
119 let remainder = n % chunksize;
120
121 let mut sum_acc = F::zero();
123 let mut sum_sq_acc = F::zero();
124 let mut sum_cube_acc = F::zero();
125 let mut sum_quad_acc = F::zero();
126 let mut min_val = F::infinity();
127 let mut max_val = F::neg_infinity();
128
129 for i in 0..n_chunks {
131 let start = i * chunksize;
132 let end = start + chunksize;
133 let chunk = data.slice(scirs2_core::ndarray::s![start..end]);
134
135 let chunk_sum = F::simd_sum(&chunk);
137 let chunk_sq = F::simd_mul(&chunk, &chunk);
138 let chunk_sum_sq = F::simd_sum(&chunk_sq.view());
139 let chunk_cube = F::simd_mul(&chunk_sq.view(), &chunk);
140 let chunk_sum_cube = F::simd_sum(&chunk_cube.view());
141 let chunk_quad = F::simd_mul(&chunk_sq.view(), &chunk_sq.view());
142 let chunk_sum_quad = F::simd_sum(&chunk_quad.view());
143 let chunk_min = F::simd_min_element(&chunk);
144 let chunk_max = F::simd_max_element(&chunk);
145
146 sum_acc = sum_acc + chunk_sum;
147 sum_sq_acc = sum_sq_acc + chunk_sum_sq;
148 sum_cube_acc = sum_cube_acc + chunk_sum_cube;
149 sum_quad_acc = sum_quad_acc + chunk_sum_quad;
150 min_val = if chunk_min < min_val {
151 chunk_min
152 } else {
153 min_val
154 };
155 max_val = if chunk_max > max_val {
156 chunk_max
157 } else {
158 max_val
159 };
160 }
161
162 if remainder > 0 {
164 let start = n_chunks * chunksize;
165 for i in start..n {
166 let val = data[i];
167 sum_acc = sum_acc + val;
168 sum_sq_acc = sum_sq_acc + val * val;
169 sum_cube_acc = sum_cube_acc + val * val * val;
170 sum_quad_acc = sum_quad_acc + val * val * val * val;
171 min_val = if val < min_val { val } else { min_val };
172 max_val = if val > max_val { val } else { max_val };
173 }
174 }
175
176 let n_f = F::from(n).unwrap();
178 let mean = sum_acc / n_f;
179 let variance = (sum_sq_acc / n_f) - (mean * mean);
180 let std_dev = variance.sqrt();
181
182 let m2 = sum_sq_acc / n_f - mean * mean;
184 let m3 = sum_cube_acc / n_f - F::from(3).unwrap() * mean * m2 - mean * mean * mean;
185 let m4 = sum_quad_acc / n_f
186 - F::from(4).unwrap() * mean * m3
187 - F::from(6).unwrap() * mean * mean * m2
188 - mean * mean * mean * mean;
189
190 let skewness = if m2 > F::zero() {
191 m3 / (m2 * m2.sqrt())
192 } else {
193 F::zero()
194 };
195
196 let kurtosis = if m2 > F::zero() {
197 m4 / (m2 * m2) - F::from(3).unwrap()
198 } else {
199 F::zero()
200 };
201
202 Ok(ComprehensiveStats {
203 mean,
204 variance,
205 std_dev,
206 skewness,
207 kurtosis,
208 min: min_val,
209 max: max_val,
210 range: max_val - min_val,
211 count: n,
212 })
213 }
214
215 fn compute_scalar_comprehensive(
217 &self,
218 data: &ArrayView1<F>,
219 ) -> StatsResult<ComprehensiveStats<F>> {
220 let n = data.len();
221 let n_f = F::from(n).unwrap();
222
223 let sum: F = data.iter().copied().sum();
224 let mean = sum / n_f;
225
226 let mut sum_sq = F::zero();
227 let mut sum_cube = F::zero();
228 let mut sum_quad = F::zero();
229 let mut min_val = F::infinity();
230 let mut max_val = F::neg_infinity();
231
232 for &val in data.iter() {
233 let diff = val - mean;
234 sum_sq = sum_sq + diff * diff;
235 sum_cube = sum_cube + diff * diff * diff;
236 sum_quad = sum_quad + diff * diff * diff * diff;
237 min_val = if val < min_val { val } else { min_val };
238 max_val = if val > max_val { val } else { max_val };
239 }
240
241 let variance = sum_sq / n_f;
242 let std_dev = variance.sqrt();
243
244 let m2 = variance;
245 let m3 = sum_cube / n_f;
246 let m4 = sum_quad / n_f;
247
248 let skewness = if m2 > F::zero() {
249 m3 / (m2 * m2.sqrt())
250 } else {
251 F::zero()
252 };
253
254 let kurtosis = if m2 > F::zero() {
255 m4 / (m2 * m2) - F::from(3).unwrap()
256 } else {
257 F::zero()
258 };
259
260 Ok(ComprehensiveStats {
261 mean,
262 variance,
263 std_dev,
264 skewness,
265 kurtosis,
266 min: min_val,
267 max: max_val,
268 range: max_val - min_val,
269 count: n,
270 })
271 }
272
273 pub fn matrix_stats_advanced(
275 &self,
276 matrix: &ArrayView2<F>,
277 ) -> StatsResult<MatrixStatsResult<F>> {
278 checkarray_finite(matrix, "matrix")?;
279
280 if matrix.is_empty() {
281 return Err(StatsError::InvalidArgument(
282 "Matrix cannot be empty".to_string(),
283 ));
284 }
285
286 let (rows, cols) = matrix.dim();
287
288 let mut row_stats = Vec::with_capacity(rows);
290 for i in 0..rows {
291 let row = matrix.row(i);
292 let stats = self.comprehensive_stats_advanced(&row)?;
293 row_stats.push(stats);
294 }
295
296 let mut col_stats = Vec::with_capacity(cols);
298 for j in 0..cols {
299 let col = matrix.column(j);
300 let stats = self.comprehensive_stats_advanced(&col)?;
301 col_stats.push(stats);
302 }
303
304 let flattened = matrix.iter().copied().collect::<Array1<F>>();
306 let overall_stats = self.comprehensive_stats_advanced(&flattened.view())?;
307
308 Ok(MatrixStatsResult {
309 row_stats,
310 col_stats,
311 overall_stats,
312 shape: (rows, cols),
313 })
314 }
315
316 pub fn correlation_matrix_advanced(&self, matrix: &ArrayView2<F>) -> StatsResult<Array2<F>> {
318 checkarray_finite(matrix, "matrix")?;
319
320 let (_n_samples_, n_features) = matrix.dim();
321
322 if n_features < 2 {
323 return Err(StatsError::InvalidArgument(
324 "At least 2 features required for correlation matrix".to_string(),
325 ));
326 }
327
328 let mut corr_matrix = Array2::zeros((n_features, n_features));
329
330 let mut means = Array1::zeros(n_features);
332 for j in 0..n_features {
333 let col = matrix.column(j);
334 means[j] = F::simd_mean(&col);
335 }
336
337 for i in 0..n_features {
339 for j in i..n_features {
340 if i == j {
341 corr_matrix[[i, j]] = F::one();
342 } else {
343 let col_i = matrix.column(i);
344 let col_j = matrix.column(j);
345
346 let _n = F::from(col_i.len()).unwrap();
348 let mean_i_vec = Array1::from_elem(col_i.len(), means[i]);
349 let mean_j_vec = Array1::from_elem(col_j.len(), means[j]);
350
351 let dev_i = F::simd_sub(&col_i, &mean_i_vec.view());
352 let dev_j = F::simd_sub(&col_j, &mean_j_vec.view());
353
354 let numerator = F::simd_sum(&F::simd_mul(&dev_i.view(), &dev_j.view()).view());
355 let sum_sq_i = F::simd_sum(&F::simd_mul(&dev_i.view(), &dev_i.view()).view());
356 let sum_sq_j = F::simd_sum(&F::simd_mul(&dev_j.view(), &dev_j.view()).view());
357
358 let denominator = (sum_sq_i * sum_sq_j).sqrt();
359 let corr = if denominator > F::zero() {
360 numerator / denominator
361 } else {
362 F::zero()
363 };
364
365 corr_matrix[[i, j]] = corr;
366 corr_matrix[[j, i]] = corr;
367 }
368 }
369 }
370
371 Ok(corr_matrix)
372 }
373
374 pub fn bootstrap_stats_advanced(
376 &self,
377 data: &ArrayView1<F>,
378 n_bootstrap: usize,
379 seed: Option<u64>,
380 ) -> StatsResult<BootstrapResult<F>> {
381 checkarray_finite(data, "data")?;
382 check_positive(n_bootstrap, "n_bootstrap")?;
383
384 let n = data.len();
385 let mut rng = create_rng(seed);
386
387 let mut bootstrap_means = Array1::zeros(n_bootstrap);
388 let mut bootstrap_vars = Array1::zeros(n_bootstrap);
389 let mut bootstrap_stds = Array1::zeros(n_bootstrap);
390
391 for i in 0..n_bootstrap {
393 let mut bootstrap_sample = Array1::zeros(n);
395 for j in 0..n {
396 let idx = rng.gen_range(0..n);
397 bootstrap_sample[j] = data[idx];
398 }
399
400 let stats = self.comprehensive_stats_advanced(&bootstrap_sample.view())?;
402 bootstrap_means[i] = stats.mean;
403 bootstrap_vars[i] = stats.variance;
404 bootstrap_stds[i] = stats.std_dev;
405 }
406
407 let mut sorted_means = bootstrap_means.to_owned();
409 sorted_means
410 .as_slice_mut()
411 .unwrap()
412 .sort_by(|a, b| a.partial_cmp(b).unwrap());
413
414 let alpha = F::from(0.05).unwrap(); let lower_idx = ((alpha / F::from(2).unwrap()) * F::from(n_bootstrap).unwrap())
416 .to_usize()
417 .unwrap();
418 let upper_idx = ((F::one() - alpha / F::from(2).unwrap()) * F::from(n_bootstrap).unwrap())
419 .to_usize()
420 .unwrap();
421
422 let mean_ci = (
423 sorted_means[lower_idx],
424 sorted_means[upper_idx.min(n_bootstrap - 1)],
425 );
426
427 Ok(BootstrapResult {
428 original_stats: self.comprehensive_stats_advanced(data)?,
429 bootstrap_means,
430 bootstrap_vars,
431 bootstrap_stds,
432 mean_ci,
433 n_bootstrap,
434 })
435 }
436}
437
438#[derive(Debug, Clone)]
440pub struct ComprehensiveStats<F> {
441 pub mean: F,
442 pub variance: F,
443 pub std_dev: F,
444 pub skewness: F,
445 pub kurtosis: F,
446 pub min: F,
447 pub max: F,
448 pub range: F,
449 pub count: usize,
450}
451
452#[derive(Debug, Clone)]
454pub struct MatrixStatsResult<F> {
455 pub row_stats: Vec<ComprehensiveStats<F>>,
456 pub col_stats: Vec<ComprehensiveStats<F>>,
457 pub overall_stats: ComprehensiveStats<F>,
458 pub shape: (usize, usize),
459}
460
461#[derive(Debug, Clone)]
463pub struct BootstrapResult<F> {
464 pub original_stats: ComprehensiveStats<F>,
465 pub bootstrap_means: Array1<F>,
466 pub bootstrap_vars: Array1<F>,
467 pub bootstrap_stds: Array1<F>,
468 pub mean_ci: (F, F),
469 pub n_bootstrap: usize,
470}
471
472pub trait AdvancedSimdOps<F>: SimdUnifiedOps
474where
475 F: Float
476 + NumCast
477 + Zero
478 + One
479 + PartialOrd
480 + Copy
481 + Send
482 + Sync
483 + std::fmt::Display
484 + std::iter::Sum<F>,
485{
486 fn simd_sum_cubes(data: &ArrayView1<F>) -> F {
488 data.iter().map(|&x| x * x * x).sum()
489 }
490
491 fn simd_sum_quads(data: &ArrayView1<F>) -> F {
493 data.iter().map(|&x| x * x * x * x).sum()
494 }
495
496 fn simd_correlation(x: &ArrayView1<F>, y: &ArrayView1<F>, mean_x: F, meany: F) -> F {
498 let n = x.len();
499 if n != y.len() {
500 return F::zero();
501 }
502
503 let _n_f = F::from(n).unwrap();
504 let mut sum_xy = F::zero();
505 let mut sum_x2 = F::zero();
506 let mut sum_y2 = F::zero();
507
508 for i in 0..n {
509 let dx = x[i] - mean_x;
510 let dy = y[i] - meany;
511 sum_xy = sum_xy + dx * dy;
512 sum_x2 = sum_x2 + dx * dx;
513 sum_y2 = sum_y2 + dy * dy;
514 }
515
516 let denom = (sum_x2 * sum_y2).sqrt();
517 if denom > F::zero() {
518 sum_xy / denom
519 } else {
520 F::zero()
521 }
522 }
523}
524
525impl AdvancedSimdOps<f32> for f32 {}
527impl AdvancedSimdOps<f64> for f64 {}
528
529#[allow(dead_code)]
531pub fn advanced_mean_simd<F>(data: &ArrayView1<F>) -> StatsResult<F>
532where
533 F: Float
534 + NumCast
535 + SimdUnifiedOps
536 + Zero
537 + One
538 + PartialOrd
539 + Copy
540 + Send
541 + Sync
542 + std::fmt::Display
543 + std::iter::Sum<F>,
544{
545 let computer = AdvancedSimdStatistics::<F>::new();
546 let stats = computer.comprehensive_stats_advanced(data)?;
547 Ok(stats.mean)
548}
549
550#[allow(dead_code)]
551pub fn advanced_std_simd<F>(data: &ArrayView1<F>) -> StatsResult<F>
552where
553 F: Float
554 + NumCast
555 + SimdUnifiedOps
556 + Zero
557 + One
558 + PartialOrd
559 + Copy
560 + Send
561 + Sync
562 + std::fmt::Display
563 + std::iter::Sum<F>,
564{
565 let computer = AdvancedSimdStatistics::<F>::new();
566 let stats = computer.comprehensive_stats_advanced(data)?;
567 Ok(stats.std_dev)
568}
569
570#[allow(dead_code)]
571pub fn advanced_comprehensive_simd<F>(data: &ArrayView1<F>) -> StatsResult<ComprehensiveStats<F>>
572where
573 F: Float
574 + NumCast
575 + SimdUnifiedOps
576 + Zero
577 + One
578 + PartialOrd
579 + Copy
580 + Send
581 + Sync
582 + std::fmt::Display
583 + std::iter::Sum<F>,
584{
585 let computer = AdvancedSimdStatistics::<F>::new();
586 computer.comprehensive_stats_advanced(data)
587}
588
589#[allow(dead_code)]
591fn create_rng(seed: Option<u64>) -> impl Rng {
592 use scirs2_core::random::{rngs::StdRng, SeedableRng};
593 match seed {
594 Some(s) => StdRng::seed_from_u64(s),
595 None => {
596 use std::time::{SystemTime, UNIX_EPOCH};
597 let s = SystemTime::now()
598 .duration_since(UNIX_EPOCH)
599 .unwrap_or_default()
600 .as_secs();
601 StdRng::seed_from_u64(s)
602 }
603 }
604}