1use crate::error::{StatsError, StatsResult};
8use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
9use scirs2_core::numeric::{Float, NumCast, One, Zero};
10use scirs2_core::Rng;
11use scirs2_core::{
12 simd_ops::{PlatformCapabilities, SimdUnifiedOps},
13 validation::*,
14};
15use std::marker::PhantomData;
16
17#[derive(Debug, Clone)]
19pub struct AdvancedSimdConfig {
20 pub capabilities: PlatformCapabilities,
22 pub chunksize: usize,
24 pub parallel_enabled: bool,
26 pub simd_threshold: usize,
28}
29
30impl Default for AdvancedSimdConfig {
31 fn default() -> Self {
32 let capabilities = PlatformCapabilities::detect();
33 let chunksize = if capabilities.avx512_available {
34 16 } else if capabilities.avx2_available {
36 8 } else if capabilities.simd_available {
38 4 } else {
40 1 };
42
43 Self {
44 capabilities,
45 chunksize,
46 parallel_enabled: true,
47 simd_threshold: 64,
48 }
49 }
50}
51
52pub struct AdvancedSimdStatistics<F> {
54 config: AdvancedSimdConfig,
55 _phantom: PhantomData<F>,
56}
57
58impl<F> AdvancedSimdStatistics<F>
59where
60 F: Float
61 + NumCast
62 + SimdUnifiedOps
63 + Zero
64 + One
65 + PartialOrd
66 + Copy
67 + Send
68 + Sync
69 + std::fmt::Display
70 + std::iter::Sum<F>,
71{
72 pub fn new() -> Self {
74 Self {
75 config: AdvancedSimdConfig::default(),
76 _phantom: PhantomData,
77 }
78 }
79
80 pub fn with_config(config: AdvancedSimdConfig) -> Self {
82 Self {
83 config,
84 _phantom: PhantomData,
85 }
86 }
87
88 pub fn comprehensive_stats_advanced(
90 &self,
91 data: &ArrayView1<F>,
92 ) -> StatsResult<ComprehensiveStats<F>> {
93 checkarray_finite(data, "data")?;
94
95 if data.is_empty() {
96 return Err(StatsError::InvalidArgument(
97 "Data cannot be empty".to_string(),
98 ));
99 }
100
101 let n = data.len();
102
103 if n >= self.config.simd_threshold && self.config.chunksize > 1 {
105 self.compute_simd_comprehensive(data)
106 } else {
107 self.compute_scalar_comprehensive(data)
108 }
109 }
110
111 fn compute_simd_comprehensive(
113 &self,
114 data: &ArrayView1<F>,
115 ) -> StatsResult<ComprehensiveStats<F>> {
116 let n = data.len();
117 let chunksize = self.config.chunksize;
118 let n_chunks = n / chunksize;
119 let remainder = n % chunksize;
120
121 let mut sum_acc = F::zero();
123 let mut sum_sq_acc = F::zero();
124 let mut sum_cube_acc = F::zero();
125 let mut sum_quad_acc = F::zero();
126 let mut min_val = F::infinity();
127 let mut max_val = F::neg_infinity();
128
129 for i in 0..n_chunks {
131 let start = i * chunksize;
132 let end = start + chunksize;
133 let chunk = data.slice(scirs2_core::ndarray::s![start..end]);
134
135 let chunk_sum = F::simd_sum(&chunk);
137 let chunk_sq = F::simd_mul(&chunk, &chunk);
138 let chunk_sum_sq = F::simd_sum(&chunk_sq.view());
139 let chunk_cube = F::simd_mul(&chunk_sq.view(), &chunk);
140 let chunk_sum_cube = F::simd_sum(&chunk_cube.view());
141 let chunk_quad = F::simd_mul(&chunk_sq.view(), &chunk_sq.view());
142 let chunk_sum_quad = F::simd_sum(&chunk_quad.view());
143 let chunk_min = F::simd_min_element(&chunk);
144 let chunk_max = F::simd_max_element(&chunk);
145
146 sum_acc = sum_acc + chunk_sum;
147 sum_sq_acc = sum_sq_acc + chunk_sum_sq;
148 sum_cube_acc = sum_cube_acc + chunk_sum_cube;
149 sum_quad_acc = sum_quad_acc + chunk_sum_quad;
150 min_val = if chunk_min < min_val {
151 chunk_min
152 } else {
153 min_val
154 };
155 max_val = if chunk_max > max_val {
156 chunk_max
157 } else {
158 max_val
159 };
160 }
161
162 if remainder > 0 {
164 let start = n_chunks * chunksize;
165 for i in start..n {
166 let val = data[i];
167 sum_acc = sum_acc + val;
168 sum_sq_acc = sum_sq_acc + val * val;
169 sum_cube_acc = sum_cube_acc + val * val * val;
170 sum_quad_acc = sum_quad_acc + val * val * val * val;
171 min_val = if val < min_val { val } else { min_val };
172 max_val = if val > max_val { val } else { max_val };
173 }
174 }
175
176 let n_f = F::from(n).expect("Failed to convert to float");
178 let mean = sum_acc / n_f;
179 let variance = (sum_sq_acc / n_f) - (mean * mean);
180 let std_dev = variance.sqrt();
181
182 let m2 = sum_sq_acc / n_f - mean * mean;
184 let m3 = sum_cube_acc / n_f
185 - F::from(3).expect("Failed to convert constant to float") * mean * m2
186 - mean * mean * mean;
187 let m4 = sum_quad_acc / n_f
188 - F::from(4).expect("Failed to convert constant to float") * mean * m3
189 - F::from(6).expect("Failed to convert constant to float") * mean * mean * m2
190 - mean * mean * mean * mean;
191
192 let skewness = if m2 > F::zero() {
193 m3 / (m2 * m2.sqrt())
194 } else {
195 F::zero()
196 };
197
198 let kurtosis = if m2 > F::zero() {
199 m4 / (m2 * m2) - F::from(3).expect("Failed to convert constant to float")
200 } else {
201 F::zero()
202 };
203
204 Ok(ComprehensiveStats {
205 mean,
206 variance,
207 std_dev,
208 skewness,
209 kurtosis,
210 min: min_val,
211 max: max_val,
212 range: max_val - min_val,
213 count: n,
214 })
215 }
216
217 fn compute_scalar_comprehensive(
219 &self,
220 data: &ArrayView1<F>,
221 ) -> StatsResult<ComprehensiveStats<F>> {
222 let n = data.len();
223 let n_f = F::from(n).expect("Failed to convert to float");
224
225 let sum: F = data.iter().copied().sum();
226 let mean = sum / n_f;
227
228 let mut sum_sq = F::zero();
229 let mut sum_cube = F::zero();
230 let mut sum_quad = F::zero();
231 let mut min_val = F::infinity();
232 let mut max_val = F::neg_infinity();
233
234 for &val in data.iter() {
235 let diff = val - mean;
236 sum_sq = sum_sq + diff * diff;
237 sum_cube = sum_cube + diff * diff * diff;
238 sum_quad = sum_quad + diff * diff * diff * diff;
239 min_val = if val < min_val { val } else { min_val };
240 max_val = if val > max_val { val } else { max_val };
241 }
242
243 let variance = sum_sq / n_f;
244 let std_dev = variance.sqrt();
245
246 let m2 = variance;
247 let m3 = sum_cube / n_f;
248 let m4 = sum_quad / n_f;
249
250 let skewness = if m2 > F::zero() {
251 m3 / (m2 * m2.sqrt())
252 } else {
253 F::zero()
254 };
255
256 let kurtosis = if m2 > F::zero() {
257 m4 / (m2 * m2) - F::from(3).expect("Failed to convert constant to float")
258 } else {
259 F::zero()
260 };
261
262 Ok(ComprehensiveStats {
263 mean,
264 variance,
265 std_dev,
266 skewness,
267 kurtosis,
268 min: min_val,
269 max: max_val,
270 range: max_val - min_val,
271 count: n,
272 })
273 }
274
275 pub fn matrix_stats_advanced(
277 &self,
278 matrix: &ArrayView2<F>,
279 ) -> StatsResult<MatrixStatsResult<F>> {
280 checkarray_finite(matrix, "matrix")?;
281
282 if matrix.is_empty() {
283 return Err(StatsError::InvalidArgument(
284 "Matrix cannot be empty".to_string(),
285 ));
286 }
287
288 let (rows, cols) = matrix.dim();
289
290 let mut row_stats = Vec::with_capacity(rows);
292 for i in 0..rows {
293 let row = matrix.row(i);
294 let stats = self.comprehensive_stats_advanced(&row)?;
295 row_stats.push(stats);
296 }
297
298 let mut col_stats = Vec::with_capacity(cols);
300 for j in 0..cols {
301 let col = matrix.column(j);
302 let stats = self.comprehensive_stats_advanced(&col)?;
303 col_stats.push(stats);
304 }
305
306 let flattened = matrix.iter().copied().collect::<Array1<F>>();
308 let overall_stats = self.comprehensive_stats_advanced(&flattened.view())?;
309
310 Ok(MatrixStatsResult {
311 row_stats,
312 col_stats,
313 overall_stats,
314 shape: (rows, cols),
315 })
316 }
317
318 pub fn correlation_matrix_advanced(&self, matrix: &ArrayView2<F>) -> StatsResult<Array2<F>> {
320 checkarray_finite(matrix, "matrix")?;
321
322 let (_n_samples_, n_features) = matrix.dim();
323
324 if n_features < 2 {
325 return Err(StatsError::InvalidArgument(
326 "At least 2 features required for correlation matrix".to_string(),
327 ));
328 }
329
330 let mut corr_matrix = Array2::zeros((n_features, n_features));
331
332 let mut means = Array1::zeros(n_features);
334 for j in 0..n_features {
335 let col = matrix.column(j);
336 means[j] = F::simd_mean(&col);
337 }
338
339 for i in 0..n_features {
341 for j in i..n_features {
342 if i == j {
343 corr_matrix[[i, j]] = F::one();
344 } else {
345 let col_i = matrix.column(i);
346 let col_j = matrix.column(j);
347
348 let _n = F::from(col_i.len()).expect("Operation failed");
350 let mean_i_vec = Array1::from_elem(col_i.len(), means[i]);
351 let mean_j_vec = Array1::from_elem(col_j.len(), means[j]);
352
353 let dev_i = F::simd_sub(&col_i, &mean_i_vec.view());
354 let dev_j = F::simd_sub(&col_j, &mean_j_vec.view());
355
356 let numerator = F::simd_sum(&F::simd_mul(&dev_i.view(), &dev_j.view()).view());
357 let sum_sq_i = F::simd_sum(&F::simd_mul(&dev_i.view(), &dev_i.view()).view());
358 let sum_sq_j = F::simd_sum(&F::simd_mul(&dev_j.view(), &dev_j.view()).view());
359
360 let denominator = (sum_sq_i * sum_sq_j).sqrt();
361 let corr = if denominator > F::zero() {
362 numerator / denominator
363 } else {
364 F::zero()
365 };
366
367 corr_matrix[[i, j]] = corr;
368 corr_matrix[[j, i]] = corr;
369 }
370 }
371 }
372
373 Ok(corr_matrix)
374 }
375
376 pub fn bootstrap_stats_advanced(
378 &self,
379 data: &ArrayView1<F>,
380 n_bootstrap: usize,
381 seed: Option<u64>,
382 ) -> StatsResult<BootstrapResult<F>> {
383 checkarray_finite(data, "data")?;
384 check_positive(n_bootstrap, "n_bootstrap")?;
385
386 let n = data.len();
387 let mut rng = create_rng(seed);
388
389 let mut bootstrap_means = Array1::zeros(n_bootstrap);
390 let mut bootstrap_vars = Array1::zeros(n_bootstrap);
391 let mut bootstrap_stds = Array1::zeros(n_bootstrap);
392
393 for i in 0..n_bootstrap {
395 let mut bootstrap_sample = Array1::zeros(n);
397 for j in 0..n {
398 let idx = rng.random_range(0..n);
399 bootstrap_sample[j] = data[idx];
400 }
401
402 let stats = self.comprehensive_stats_advanced(&bootstrap_sample.view())?;
404 bootstrap_means[i] = stats.mean;
405 bootstrap_vars[i] = stats.variance;
406 bootstrap_stds[i] = stats.std_dev;
407 }
408
409 let mut sorted_means = bootstrap_means.to_owned();
411 sorted_means
412 .as_slice_mut()
413 .expect("Operation failed")
414 .sort_by(|a, b| a.partial_cmp(b).expect("Operation failed"));
415
416 let alpha = F::from(0.05).expect("Failed to convert constant to float"); let lower_idx = ((alpha / F::from(2).expect("Failed to convert constant to float"))
418 * F::from(n_bootstrap).expect("Failed to convert to float"))
419 .to_usize()
420 .expect("Operation failed");
421 let upper_idx = ((F::one()
422 - alpha / F::from(2).expect("Failed to convert constant to float"))
423 * F::from(n_bootstrap).expect("Failed to convert to float"))
424 .to_usize()
425 .expect("Operation failed");
426
427 let mean_ci = (
428 sorted_means[lower_idx],
429 sorted_means[upper_idx.min(n_bootstrap - 1)],
430 );
431
432 Ok(BootstrapResult {
433 original_stats: self.comprehensive_stats_advanced(data)?,
434 bootstrap_means,
435 bootstrap_vars,
436 bootstrap_stds,
437 mean_ci,
438 n_bootstrap,
439 })
440 }
441}
442
443#[derive(Debug, Clone)]
445pub struct ComprehensiveStats<F> {
446 pub mean: F,
447 pub variance: F,
448 pub std_dev: F,
449 pub skewness: F,
450 pub kurtosis: F,
451 pub min: F,
452 pub max: F,
453 pub range: F,
454 pub count: usize,
455}
456
457#[derive(Debug, Clone)]
459pub struct MatrixStatsResult<F> {
460 pub row_stats: Vec<ComprehensiveStats<F>>,
461 pub col_stats: Vec<ComprehensiveStats<F>>,
462 pub overall_stats: ComprehensiveStats<F>,
463 pub shape: (usize, usize),
464}
465
466#[derive(Debug, Clone)]
468pub struct BootstrapResult<F> {
469 pub original_stats: ComprehensiveStats<F>,
470 pub bootstrap_means: Array1<F>,
471 pub bootstrap_vars: Array1<F>,
472 pub bootstrap_stds: Array1<F>,
473 pub mean_ci: (F, F),
474 pub n_bootstrap: usize,
475}
476
477pub trait AdvancedSimdOps<F>: SimdUnifiedOps
479where
480 F: Float
481 + NumCast
482 + Zero
483 + One
484 + PartialOrd
485 + Copy
486 + Send
487 + Sync
488 + std::fmt::Display
489 + std::iter::Sum<F>,
490{
491 fn simd_sum_cubes(data: &ArrayView1<F>) -> F {
493 data.iter().map(|&x| x * x * x).sum()
494 }
495
496 fn simd_sum_quads(data: &ArrayView1<F>) -> F {
498 data.iter().map(|&x| x * x * x * x).sum()
499 }
500
501 fn simd_correlation(x: &ArrayView1<F>, y: &ArrayView1<F>, mean_x: F, meany: F) -> F {
503 let n = x.len();
504 if n != y.len() {
505 return F::zero();
506 }
507
508 let _n_f = F::from(n).expect("Failed to convert to float");
509 let mut sum_xy = F::zero();
510 let mut sum_x2 = F::zero();
511 let mut sum_y2 = F::zero();
512
513 for i in 0..n {
514 let dx = x[i] - mean_x;
515 let dy = y[i] - meany;
516 sum_xy = sum_xy + dx * dy;
517 sum_x2 = sum_x2 + dx * dx;
518 sum_y2 = sum_y2 + dy * dy;
519 }
520
521 let denom = (sum_x2 * sum_y2).sqrt();
522 if denom > F::zero() {
523 sum_xy / denom
524 } else {
525 F::zero()
526 }
527 }
528}
529
530impl AdvancedSimdOps<f32> for f32 {}
532impl AdvancedSimdOps<f64> for f64 {}
533
534#[allow(dead_code)]
536pub fn advanced_mean_simd<F>(data: &ArrayView1<F>) -> StatsResult<F>
537where
538 F: Float
539 + NumCast
540 + SimdUnifiedOps
541 + Zero
542 + One
543 + PartialOrd
544 + Copy
545 + Send
546 + Sync
547 + std::fmt::Display
548 + std::iter::Sum<F>,
549{
550 let computer = AdvancedSimdStatistics::<F>::new();
551 let stats = computer.comprehensive_stats_advanced(data)?;
552 Ok(stats.mean)
553}
554
555#[allow(dead_code)]
556pub fn advanced_std_simd<F>(data: &ArrayView1<F>) -> StatsResult<F>
557where
558 F: Float
559 + NumCast
560 + SimdUnifiedOps
561 + Zero
562 + One
563 + PartialOrd
564 + Copy
565 + Send
566 + Sync
567 + std::fmt::Display
568 + std::iter::Sum<F>,
569{
570 let computer = AdvancedSimdStatistics::<F>::new();
571 let stats = computer.comprehensive_stats_advanced(data)?;
572 Ok(stats.std_dev)
573}
574
575#[allow(dead_code)]
576pub fn advanced_comprehensive_simd<F>(data: &ArrayView1<F>) -> StatsResult<ComprehensiveStats<F>>
577where
578 F: Float
579 + NumCast
580 + SimdUnifiedOps
581 + Zero
582 + One
583 + PartialOrd
584 + Copy
585 + Send
586 + Sync
587 + std::fmt::Display
588 + std::iter::Sum<F>,
589{
590 let computer = AdvancedSimdStatistics::<F>::new();
591 computer.comprehensive_stats_advanced(data)
592}
593
594#[allow(dead_code)]
596fn create_rng(seed: Option<u64>) -> impl Rng {
597 use scirs2_core::random::{rngs::StdRng, SeedableRng};
598 match seed {
599 Some(s) => StdRng::seed_from_u64(s),
600 None => {
601 use std::time::{SystemTime, UNIX_EPOCH};
602 let s = SystemTime::now()
603 .duration_since(UNIX_EPOCH)
604 .unwrap_or_default()
605 .as_secs();
606 StdRng::seed_from_u64(s)
607 }
608 }
609}