1use crate::error::{StatsError, StatsResult};
7use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
8use scirs2_core::numeric::{Float, NumCast, One, Zero};
9use scirs2_core::random::Rng;
10use scirs2_core::{parallel_ops::*, simd_ops::SimdUnifiedOps, validation::*};
11use std::sync::Arc;
12
13#[derive(Debug, Clone)]
15pub struct EnhancedParallelConfig {
16 pub num_threads: Option<usize>,
18 pub min_chunksize: usize,
20 pub max_chunks: usize,
22 pub numa_aware: bool,
24 pub work_stealing: bool,
26}
27
28impl Default for EnhancedParallelConfig {
29 fn default() -> Self {
30 Self {
31 num_threads: None,
32 min_chunksize: 1000,
33 max_chunks: num_cpus::get() * 4,
34 numa_aware: true,
35 work_stealing: true,
36 }
37 }
38}
39
40pub struct EnhancedParallelProcessor<F> {
42 config: EnhancedParallelConfig,
43 _phantom: std::marker::PhantomData<F>,
44}
45
46impl<F> Default for EnhancedParallelProcessor<F>
47where
48 F: Float
49 + NumCast
50 + SimdUnifiedOps
51 + Zero
52 + One
53 + PartialOrd
54 + Copy
55 + Send
56 + Sync
57 + std::fmt::Display
58 + std::iter::Sum<F>
59 + scirs2_core::numeric::FromPrimitive,
60{
61 fn default() -> Self {
62 Self::new()
63 }
64}
65
66impl<F> EnhancedParallelProcessor<F>
67where
68 F: Float
69 + NumCast
70 + SimdUnifiedOps
71 + Zero
72 + One
73 + PartialOrd
74 + Copy
75 + Send
76 + Sync
77 + std::fmt::Display
78 + std::iter::Sum<F>
79 + scirs2_core::numeric::FromPrimitive,
80{
81 pub fn new() -> Self {
83 Self {
84 config: EnhancedParallelConfig::default(),
85 _phantom: std::marker::PhantomData,
86 }
87 }
88
89 pub fn with_config(config: EnhancedParallelConfig) -> Self {
91 Self {
92 config,
93 _phantom: std::marker::PhantomData,
94 }
95 }
96
97 pub fn mean_parallel_enhanced(&self, data: &ArrayView1<F>) -> StatsResult<F> {
99 checkarray_finite(data, "data")?;
100
101 if data.is_empty() {
102 return Err(StatsError::InvalidArgument(
103 "Data cannot be empty".to_string(),
104 ));
105 }
106
107 let n = data.len();
108
109 if n < self.config.min_chunksize {
110 return Ok(data.mean().expect("Operation failed"));
112 }
113
114 let chunksize = self.calculate_optimal_chunksize(n);
116 let result = data
117 .as_slice()
118 .expect("Operation failed")
119 .par_chunks(chunksize)
120 .map(|chunk| {
121 let sum: F = chunk.iter().copied().sum();
123 let count = chunk.len();
124 (sum, count)
125 })
126 .reduce(
127 || (F::zero(), 0),
128 |(sum1, count1), (sum2, count2)| {
129 (sum1 + sum2, count1 + count2)
131 },
132 );
133
134 let (total_sum, total_count) = result;
135 Ok(total_sum / F::from(total_count).expect("Failed to convert to float"))
136 }
137
138 pub fn variance_parallel_enhanced(&self, data: &ArrayView1<F>, ddof: usize) -> StatsResult<F> {
140 checkarray_finite(data, "data")?;
141
142 if data.is_empty() {
143 return Err(StatsError::InvalidArgument(
144 "Data cannot be empty".to_string(),
145 ));
146 }
147
148 let n = data.len();
149
150 if n < self.config.min_chunksize {
151 let mean = data.mean().expect("Operation failed");
153 let sum_sq_diff: F = data.iter().map(|&x| (x - mean) * (x - mean)).sum();
154 return Ok(sum_sq_diff / F::from(n.saturating_sub(ddof)).expect("Operation failed"));
155 }
156
157 let mean = self.mean_parallel_enhanced(data)?;
159
160 let chunksize = self.calculate_optimal_chunksize(n);
162 let result = data
163 .as_slice()
164 .expect("Operation failed")
165 .par_chunks(chunksize)
166 .map(|chunk| {
167 let sum_sq_diff: F = chunk.iter().map(|&x| (x - mean) * (x - mean)).sum();
169 let count = chunk.len();
170 (sum_sq_diff, count)
171 })
172 .reduce(
173 || (F::zero(), 0),
174 |(sum1, count1), (sum2, count2)| {
175 (sum1 + sum2, count1 + count2)
177 },
178 );
179
180 let (total_sum_sq_diff, total_count) = result;
181 let denominator = total_count.saturating_sub(ddof);
182
183 if denominator == 0 {
184 return Err(StatsError::InvalidArgument(
185 "Insufficient degrees of freedom".to_string(),
186 ));
187 }
188
189 Ok(total_sum_sq_diff / F::from(denominator).expect("Failed to convert to float"))
190 }
191
192 pub fn correlation_matrix_parallel(&self, matrix: &ArrayView2<F>) -> StatsResult<Array2<F>> {
194 checkarray_finite(matrix, "matrix")?;
195
196 let (_n_samples_, n_features) = matrix.dim();
197
198 if n_features < 2 {
199 return Err(StatsError::InvalidArgument(
200 "At least 2 features required for correlation matrix".to_string(),
201 ));
202 }
203
204 let means = parallel_map_collect(0..n_features, |i| {
206 let col = matrix.column(i);
207 self.mean_parallel_enhanced(&col).expect("Operation failed")
208 });
209
210 let mut corr_matrix = Array2::zeros((n_features, n_features));
212 let pairs: Vec<(usize, usize)> = (0..n_features)
213 .flat_map(|i| (i..n_features).map(move |j| (i, j)))
214 .collect();
215
216 let correlations = parallel_map_collect(&pairs, |&(i, j)| {
217 if i == j {
218 (i, j, F::one())
219 } else {
220 let col_i = matrix.column(i);
221 let col_j = matrix.column(j);
222 let corr = self
223 .correlation_coefficient(&col_i, &col_j, means[i], means[j])
224 .expect("Operation failed");
225 (i, j, corr)
226 }
227 });
228
229 for (i, j, corr) in correlations {
231 corr_matrix[[i, j]] = corr;
232 if i != j {
233 corr_matrix[[j, i]] = corr;
234 }
235 }
236
237 Ok(corr_matrix)
238 }
239
240 pub fn bootstrap_parallel_enhanced(
242 &self,
243 data: &ArrayView1<F>,
244 n_bootstrap: usize,
245 statistic_fn: impl Fn(&ArrayView1<F>) -> F + Send + Sync,
246 seed: Option<u64>,
247 ) -> StatsResult<Array1<F>> {
248 checkarray_finite(data, "data")?;
249 check_positive(n_bootstrap, "n_bootstrap")?;
250
251 let statistic_fn = Arc::new(statistic_fn);
252 let data_arc = Arc::new(data.to_owned());
253
254 let results = parallel_map_collect(0..n_bootstrap, |i| {
255 use scirs2_core::random::Random;
256 let mut rng = match seed {
257 Some(s) => Random::seed(s.wrapping_add(i as u64)),
258 None => Random::seed(i as u64), };
260
261 let n = data_arc.len();
263 let mut bootstrap_sample = Array1::zeros(n);
264 for j in 0..n {
265 let idx = rng.random_range(0..n);
266 bootstrap_sample[j] = data_arc[idx];
267 }
268
269 statistic_fn(&bootstrap_sample.view())
271 });
272
273 Ok(Array1::from_vec(results))
274 }
275
276 pub fn matrix_operations_parallel(
278 &self,
279 matrix: &ArrayView2<F>,
280 ) -> StatsResult<MatrixParallelResult<F>> {
281 checkarray_finite(matrix, "matrix")?;
282
283 let (rows, cols) = matrix.dim();
284
285 let row_means = parallel_map_collect(0..rows, |i| {
287 let row = matrix.row(i);
288 self.mean_parallel_enhanced(&row).expect("Operation failed")
289 });
290
291 let row_vars = parallel_map_collect(0..rows, |i| {
292 let row = matrix.row(i);
293 self.variance_parallel_enhanced(&row, 1)
294 .expect("Operation failed")
295 });
296
297 let col_means = parallel_map_collect(0..cols, |j| {
299 let col = matrix.column(j);
300 self.mean_parallel_enhanced(&col).expect("Operation failed")
301 });
302
303 let col_vars = parallel_map_collect(0..cols, |j| {
304 let col = matrix.column(j);
305 self.variance_parallel_enhanced(&col, 1)
306 .expect("Operation failed")
307 });
308
309 let flattened = matrix.iter().copied().collect::<Array1<F>>();
311 let overall_mean = self.mean_parallel_enhanced(&flattened.view())?;
312 let overall_var = self.variance_parallel_enhanced(&flattened.view(), 1)?;
313
314 Ok(MatrixParallelResult {
315 row_means: Array1::from_vec(row_means),
316 row_vars: Array1::from_vec(row_vars),
317 col_means: Array1::from_vec(col_means),
318 col_vars: Array1::from_vec(col_vars),
319 overall_mean,
320 overall_var,
321 shape: (rows, cols),
322 })
323 }
324
325 pub fn quantiles_parallel(
327 &self,
328 data: &ArrayView1<F>,
329 quantiles: &[F],
330 ) -> StatsResult<Array1<F>> {
331 checkarray_finite(data, "data")?;
332
333 if data.is_empty() {
334 return Err(StatsError::InvalidArgument(
335 "Data cannot be empty".to_string(),
336 ));
337 }
338
339 for &q in quantiles {
340 if q < F::zero() || q > F::one() {
341 return Err(StatsError::InvalidArgument(
342 "Quantiles must be in [0, 1]".to_string(),
343 ));
344 }
345 }
346
347 let mut sorteddata = data.to_owned();
349
350 if sorteddata.len() >= self.config.min_chunksize {
352 sorteddata
353 .as_slice_mut()
354 .expect("Operation failed")
355 .par_sort_by(|a, b| a.partial_cmp(b).expect("Operation failed"));
356 } else {
357 sorteddata
358 .as_slice_mut()
359 .expect("Operation failed")
360 .sort_by(|a, b| a.partial_cmp(b).expect("Operation failed"));
361 }
362
363 let n = sorteddata.len();
365 let results = quantiles
366 .iter()
367 .map(|&q| {
368 let index = (q * F::from(n - 1).expect("Failed to convert to float"))
369 .to_f64()
370 .expect("Operation failed");
371 let lower = index.floor() as usize;
372 let upper = index.ceil() as usize;
373 let weight = F::from(index - index.floor()).expect("Operation failed");
374
375 if lower == upper {
376 sorteddata[lower]
377 } else {
378 sorteddata[lower] * (F::one() - weight) + sorteddata[upper] * weight
379 }
380 })
381 .collect::<Vec<F>>();
382
383 Ok(Array1::from_vec(results))
384 }
385
386 fn calculate_optimal_chunksize(&self, datalen: usize) -> usize {
388 let num_threads = self.config.num_threads.unwrap_or_else(num_cpus::get);
389 let ideal_chunks = num_threads * 2; let chunksize = (datalen / ideal_chunks).max(self.config.min_chunksize);
391 chunksize.min(datalen)
392 }
393
394 fn correlation_coefficient(
396 &self,
397 x: &ArrayView1<F>,
398 y: &ArrayView1<F>,
399 mean_x: F,
400 mean_y: F,
401 ) -> StatsResult<F> {
402 if x.len() != y.len() {
403 return Err(StatsError::DimensionMismatch(
404 "Arrays must have the same length".to_string(),
405 ));
406 }
407
408 let n = x.len();
409 if n < 2 {
410 return Ok(F::zero());
411 }
412
413 let chunksize = self.calculate_optimal_chunksize(n);
414 let result = parallel_map_reduce_indexed(
415 0..n,
416 chunksize,
417 |indices| {
418 let mut sum_xy = F::zero();
419 let mut sum_x2 = F::zero();
420 let mut sum_y2 = F::zero();
421
422 for &i in indices {
423 let dx = x[i] - mean_x;
424 let dy = y[i] - mean_y;
425 sum_xy = sum_xy + dx * dy;
426 sum_x2 = sum_x2 + dx * dx;
427 sum_y2 = sum_y2 + dy * dy;
428 }
429
430 (sum_xy, sum_x2, sum_y2)
431 },
432 |(xy1, x2_1, y2_1), (xy2, x2_2, y2_2)| (xy1 + xy2, x2_1 + x2_2, y2_1 + y2_2),
433 );
434
435 let (sum_xy, sum_x2, sum_y2) = result;
436 let denom = (sum_x2 * sum_y2).sqrt();
437
438 if denom > F::zero() {
439 Ok(sum_xy / denom)
440 } else {
441 Ok(F::zero())
442 }
443 }
444}
445
446#[derive(Debug, Clone)]
448pub struct MatrixParallelResult<F> {
449 pub row_means: Array1<F>,
450 pub row_vars: Array1<F>,
451 pub col_means: Array1<F>,
452 pub col_vars: Array1<F>,
453 pub overall_mean: F,
454 pub overall_var: F,
455 pub shape: (usize, usize),
456}
457
458#[allow(dead_code)]
460pub fn mean_parallel_advanced<F>(data: &ArrayView1<F>) -> StatsResult<F>
461where
462 F: Float
463 + NumCast
464 + SimdUnifiedOps
465 + Zero
466 + One
467 + PartialOrd
468 + Copy
469 + Send
470 + Sync
471 + std::fmt::Display
472 + std::iter::Sum<F>
473 + scirs2_core::numeric::FromPrimitive,
474{
475 let processor = EnhancedParallelProcessor::<F>::new();
476 processor.mean_parallel_enhanced(data)
477}
478
479#[allow(dead_code)]
480pub fn variance_parallel_advanced<F>(data: &ArrayView1<F>, ddof: usize) -> StatsResult<F>
481where
482 F: Float
483 + NumCast
484 + SimdUnifiedOps
485 + Zero
486 + One
487 + PartialOrd
488 + Copy
489 + Send
490 + Sync
491 + std::fmt::Display
492 + std::iter::Sum<F>
493 + scirs2_core::numeric::FromPrimitive,
494{
495 let processor = EnhancedParallelProcessor::<F>::new();
496 processor.variance_parallel_enhanced(data, ddof)
497}
498
499#[allow(dead_code)]
500pub fn correlation_matrix_parallel_advanced<F>(matrix: &ArrayView2<F>) -> StatsResult<Array2<F>>
501where
502 F: Float
503 + NumCast
504 + SimdUnifiedOps
505 + Zero
506 + One
507 + PartialOrd
508 + Copy
509 + Send
510 + Sync
511 + std::fmt::Display
512 + std::iter::Sum<F>
513 + scirs2_core::numeric::FromPrimitive,
514{
515 let processor = EnhancedParallelProcessor::<F>::new();
516 processor.correlation_matrix_parallel(matrix)
517}
518
519#[allow(dead_code)]
520pub fn bootstrap_parallel_advanced<F>(
521 data: &ArrayView1<F>,
522 n_bootstrap: usize,
523 statistic_fn: impl Fn(&ArrayView1<F>) -> F + Send + Sync,
524 seed: Option<u64>,
525) -> StatsResult<Array1<F>>
526where
527 F: Float
528 + NumCast
529 + SimdUnifiedOps
530 + Zero
531 + One
532 + PartialOrd
533 + Copy
534 + Send
535 + Sync
536 + std::fmt::Display
537 + std::iter::Sum<F>
538 + scirs2_core::numeric::FromPrimitive,
539{
540 let processor = EnhancedParallelProcessor::<F>::new();
541 processor.bootstrap_parallel_enhanced(data, n_bootstrap, statistic_fn, seed)
542}