1use crate::error::{StatsError, StatsResult};
7use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
8use scirs2_core::numeric::{Float, NumCast, One, Zero};
9use scirs2_core::random::Rng;
10use scirs2_core::{parallel_ops::*, simd_ops::SimdUnifiedOps, validation::*};
11use std::sync::Arc;
12
13#[derive(Debug, Clone)]
15pub struct EnhancedParallelConfig {
16 pub num_threads: Option<usize>,
18 pub min_chunksize: usize,
20 pub max_chunks: usize,
22 pub numa_aware: bool,
24 pub work_stealing: bool,
26}
27
28impl Default for EnhancedParallelConfig {
29 fn default() -> Self {
30 Self {
31 num_threads: None,
32 min_chunksize: 1000,
33 max_chunks: num_cpus::get() * 4,
34 numa_aware: true,
35 work_stealing: true,
36 }
37 }
38}
39
40pub struct EnhancedParallelProcessor<F> {
42 config: EnhancedParallelConfig,
43 _phantom: std::marker::PhantomData<F>,
44}
45
46impl<F> Default for EnhancedParallelProcessor<F>
47where
48 F: Float
49 + NumCast
50 + SimdUnifiedOps
51 + Zero
52 + One
53 + PartialOrd
54 + Copy
55 + Send
56 + Sync
57 + std::fmt::Display
58 + std::iter::Sum<F>
59 + scirs2_core::numeric::FromPrimitive,
60{
61 fn default() -> Self {
62 Self::new()
63 }
64}
65
66impl<F> EnhancedParallelProcessor<F>
67where
68 F: Float
69 + NumCast
70 + SimdUnifiedOps
71 + Zero
72 + One
73 + PartialOrd
74 + Copy
75 + Send
76 + Sync
77 + std::fmt::Display
78 + std::iter::Sum<F>
79 + scirs2_core::numeric::FromPrimitive,
80{
81 pub fn new() -> Self {
83 Self {
84 config: EnhancedParallelConfig::default(),
85 _phantom: std::marker::PhantomData,
86 }
87 }
88
89 pub fn with_config(config: EnhancedParallelConfig) -> Self {
91 Self {
92 config,
93 _phantom: std::marker::PhantomData,
94 }
95 }
96
97 pub fn mean_parallel_enhanced(&self, data: &ArrayView1<F>) -> StatsResult<F> {
99 checkarray_finite(data, "data")?;
100
101 if data.is_empty() {
102 return Err(StatsError::InvalidArgument(
103 "Data cannot be empty".to_string(),
104 ));
105 }
106
107 let n = data.len();
108
109 if n < self.config.min_chunksize {
110 return Ok(data.mean().unwrap());
112 }
113
114 let chunksize = self.calculate_optimal_chunksize(n);
116 let result = data
117 .as_slice()
118 .unwrap()
119 .par_chunks(chunksize)
120 .map(|chunk| {
121 let sum: F = chunk.iter().copied().sum();
123 let count = chunk.len();
124 (sum, count)
125 })
126 .reduce(
127 || (F::zero(), 0),
128 |(sum1, count1), (sum2, count2)| {
129 (sum1 + sum2, count1 + count2)
131 },
132 );
133
134 let (total_sum, total_count) = result;
135 Ok(total_sum / F::from(total_count).unwrap())
136 }
137
138 pub fn variance_parallel_enhanced(&self, data: &ArrayView1<F>, ddof: usize) -> StatsResult<F> {
140 checkarray_finite(data, "data")?;
141
142 if data.is_empty() {
143 return Err(StatsError::InvalidArgument(
144 "Data cannot be empty".to_string(),
145 ));
146 }
147
148 let n = data.len();
149
150 if n < self.config.min_chunksize {
151 let mean = data.mean().unwrap();
153 let sum_sq_diff: F = data.iter().map(|&x| (x - mean) * (x - mean)).sum();
154 return Ok(sum_sq_diff / F::from(n.saturating_sub(ddof)).unwrap());
155 }
156
157 let mean = self.mean_parallel_enhanced(data)?;
159
160 let chunksize = self.calculate_optimal_chunksize(n);
162 let result = data
163 .as_slice()
164 .unwrap()
165 .par_chunks(chunksize)
166 .map(|chunk| {
167 let sum_sq_diff: F = chunk.iter().map(|&x| (x - mean) * (x - mean)).sum();
169 let count = chunk.len();
170 (sum_sq_diff, count)
171 })
172 .reduce(
173 || (F::zero(), 0),
174 |(sum1, count1), (sum2, count2)| {
175 (sum1 + sum2, count1 + count2)
177 },
178 );
179
180 let (total_sum_sq_diff, total_count) = result;
181 let denominator = total_count.saturating_sub(ddof);
182
183 if denominator == 0 {
184 return Err(StatsError::InvalidArgument(
185 "Insufficient degrees of freedom".to_string(),
186 ));
187 }
188
189 Ok(total_sum_sq_diff / F::from(denominator).unwrap())
190 }
191
192 pub fn correlation_matrix_parallel(&self, matrix: &ArrayView2<F>) -> StatsResult<Array2<F>> {
194 checkarray_finite(matrix, "matrix")?;
195
196 let (_n_samples_, n_features) = matrix.dim();
197
198 if n_features < 2 {
199 return Err(StatsError::InvalidArgument(
200 "At least 2 features required for correlation matrix".to_string(),
201 ));
202 }
203
204 let means = parallel_map_collect(0..n_features, |i| {
206 let col = matrix.column(i);
207 self.mean_parallel_enhanced(&col).unwrap()
208 });
209
210 let mut corr_matrix = Array2::zeros((n_features, n_features));
212 let pairs: Vec<(usize, usize)> = (0..n_features)
213 .flat_map(|i| (i..n_features).map(move |j| (i, j)))
214 .collect();
215
216 let correlations = parallel_map_collect(&pairs, |&(i, j)| {
217 if i == j {
218 (i, j, F::one())
219 } else {
220 let col_i = matrix.column(i);
221 let col_j = matrix.column(j);
222 let corr = self
223 .correlation_coefficient(&col_i, &col_j, means[i], means[j])
224 .unwrap();
225 (i, j, corr)
226 }
227 });
228
229 for (i, j, corr) in correlations {
231 corr_matrix[[i, j]] = corr;
232 if i != j {
233 corr_matrix[[j, i]] = corr;
234 }
235 }
236
237 Ok(corr_matrix)
238 }
239
240 pub fn bootstrap_parallel_enhanced(
242 &self,
243 data: &ArrayView1<F>,
244 n_bootstrap: usize,
245 statistic_fn: impl Fn(&ArrayView1<F>) -> F + Send + Sync,
246 seed: Option<u64>,
247 ) -> StatsResult<Array1<F>> {
248 checkarray_finite(data, "data")?;
249 check_positive(n_bootstrap, "n_bootstrap")?;
250
251 let statistic_fn = Arc::new(statistic_fn);
252 let data_arc = Arc::new(data.to_owned());
253
254 let results = parallel_map_collect(0..n_bootstrap, |i| {
255 use scirs2_core::random::Random;
256 let mut rng = match seed {
257 Some(s) => Random::seed(s.wrapping_add(i as u64)),
258 None => Random::seed(i as u64), };
260
261 let n = data_arc.len();
263 let mut bootstrap_sample = Array1::zeros(n);
264 for j in 0..n {
265 let idx = rng.gen_range(0..n);
266 bootstrap_sample[j] = data_arc[idx];
267 }
268
269 statistic_fn(&bootstrap_sample.view())
271 });
272
273 Ok(Array1::from_vec(results))
274 }
275
276 pub fn matrix_operations_parallel(
278 &self,
279 matrix: &ArrayView2<F>,
280 ) -> StatsResult<MatrixParallelResult<F>> {
281 checkarray_finite(matrix, "matrix")?;
282
283 let (rows, cols) = matrix.dim();
284
285 let row_means = parallel_map_collect(0..rows, |i| {
287 let row = matrix.row(i);
288 self.mean_parallel_enhanced(&row).unwrap()
289 });
290
291 let row_vars = parallel_map_collect(0..rows, |i| {
292 let row = matrix.row(i);
293 self.variance_parallel_enhanced(&row, 1).unwrap()
294 });
295
296 let col_means = parallel_map_collect(0..cols, |j| {
298 let col = matrix.column(j);
299 self.mean_parallel_enhanced(&col).unwrap()
300 });
301
302 let col_vars = parallel_map_collect(0..cols, |j| {
303 let col = matrix.column(j);
304 self.variance_parallel_enhanced(&col, 1).unwrap()
305 });
306
307 let flattened = matrix.iter().copied().collect::<Array1<F>>();
309 let overall_mean = self.mean_parallel_enhanced(&flattened.view())?;
310 let overall_var = self.variance_parallel_enhanced(&flattened.view(), 1)?;
311
312 Ok(MatrixParallelResult {
313 row_means: Array1::from_vec(row_means),
314 row_vars: Array1::from_vec(row_vars),
315 col_means: Array1::from_vec(col_means),
316 col_vars: Array1::from_vec(col_vars),
317 overall_mean,
318 overall_var,
319 shape: (rows, cols),
320 })
321 }
322
323 pub fn quantiles_parallel(
325 &self,
326 data: &ArrayView1<F>,
327 quantiles: &[F],
328 ) -> StatsResult<Array1<F>> {
329 checkarray_finite(data, "data")?;
330
331 if data.is_empty() {
332 return Err(StatsError::InvalidArgument(
333 "Data cannot be empty".to_string(),
334 ));
335 }
336
337 for &q in quantiles {
338 if q < F::zero() || q > F::one() {
339 return Err(StatsError::InvalidArgument(
340 "Quantiles must be in [0, 1]".to_string(),
341 ));
342 }
343 }
344
345 let mut sorteddata = data.to_owned();
347
348 if sorteddata.len() >= self.config.min_chunksize {
350 sorteddata
351 .as_slice_mut()
352 .unwrap()
353 .par_sort_by(|a, b| a.partial_cmp(b).unwrap());
354 } else {
355 sorteddata
356 .as_slice_mut()
357 .unwrap()
358 .sort_by(|a, b| a.partial_cmp(b).unwrap());
359 }
360
361 let n = sorteddata.len();
363 let results = quantiles
364 .iter()
365 .map(|&q| {
366 let index = (q * F::from(n - 1).unwrap()).to_f64().unwrap();
367 let lower = index.floor() as usize;
368 let upper = index.ceil() as usize;
369 let weight = F::from(index - index.floor()).unwrap();
370
371 if lower == upper {
372 sorteddata[lower]
373 } else {
374 sorteddata[lower] * (F::one() - weight) + sorteddata[upper] * weight
375 }
376 })
377 .collect::<Vec<F>>();
378
379 Ok(Array1::from_vec(results))
380 }
381
382 fn calculate_optimal_chunksize(&self, datalen: usize) -> usize {
384 let num_threads = self.config.num_threads.unwrap_or_else(num_cpus::get);
385 let ideal_chunks = num_threads * 2; let chunksize = (datalen / ideal_chunks).max(self.config.min_chunksize);
387 chunksize.min(datalen)
388 }
389
390 fn correlation_coefficient(
392 &self,
393 x: &ArrayView1<F>,
394 y: &ArrayView1<F>,
395 mean_x: F,
396 mean_y: F,
397 ) -> StatsResult<F> {
398 if x.len() != y.len() {
399 return Err(StatsError::DimensionMismatch(
400 "Arrays must have the same length".to_string(),
401 ));
402 }
403
404 let n = x.len();
405 if n < 2 {
406 return Ok(F::zero());
407 }
408
409 let chunksize = self.calculate_optimal_chunksize(n);
410 let result = parallel_map_reduce_indexed(
411 0..n,
412 chunksize,
413 |indices| {
414 let mut sum_xy = F::zero();
415 let mut sum_x2 = F::zero();
416 let mut sum_y2 = F::zero();
417
418 for &i in indices {
419 let dx = x[i] - mean_x;
420 let dy = y[i] - mean_y;
421 sum_xy = sum_xy + dx * dy;
422 sum_x2 = sum_x2 + dx * dx;
423 sum_y2 = sum_y2 + dy * dy;
424 }
425
426 (sum_xy, sum_x2, sum_y2)
427 },
428 |(xy1, x2_1, y2_1), (xy2, x2_2, y2_2)| (xy1 + xy2, x2_1 + x2_2, y2_1 + y2_2),
429 );
430
431 let (sum_xy, sum_x2, sum_y2) = result;
432 let denom = (sum_x2 * sum_y2).sqrt();
433
434 if denom > F::zero() {
435 Ok(sum_xy / denom)
436 } else {
437 Ok(F::zero())
438 }
439 }
440}
441
442#[derive(Debug, Clone)]
444pub struct MatrixParallelResult<F> {
445 pub row_means: Array1<F>,
446 pub row_vars: Array1<F>,
447 pub col_means: Array1<F>,
448 pub col_vars: Array1<F>,
449 pub overall_mean: F,
450 pub overall_var: F,
451 pub shape: (usize, usize),
452}
453
454#[allow(dead_code)]
456pub fn mean_parallel_advanced<F>(data: &ArrayView1<F>) -> StatsResult<F>
457where
458 F: Float
459 + NumCast
460 + SimdUnifiedOps
461 + Zero
462 + One
463 + PartialOrd
464 + Copy
465 + Send
466 + Sync
467 + std::fmt::Display
468 + std::iter::Sum<F>
469 + scirs2_core::numeric::FromPrimitive,
470{
471 let processor = EnhancedParallelProcessor::<F>::new();
472 processor.mean_parallel_enhanced(data)
473}
474
475#[allow(dead_code)]
476pub fn variance_parallel_advanced<F>(data: &ArrayView1<F>, ddof: usize) -> StatsResult<F>
477where
478 F: Float
479 + NumCast
480 + SimdUnifiedOps
481 + Zero
482 + One
483 + PartialOrd
484 + Copy
485 + Send
486 + Sync
487 + std::fmt::Display
488 + std::iter::Sum<F>
489 + scirs2_core::numeric::FromPrimitive,
490{
491 let processor = EnhancedParallelProcessor::<F>::new();
492 processor.variance_parallel_enhanced(data, ddof)
493}
494
495#[allow(dead_code)]
496pub fn correlation_matrix_parallel_advanced<F>(matrix: &ArrayView2<F>) -> StatsResult<Array2<F>>
497where
498 F: Float
499 + NumCast
500 + SimdUnifiedOps
501 + Zero
502 + One
503 + PartialOrd
504 + Copy
505 + Send
506 + Sync
507 + std::fmt::Display
508 + std::iter::Sum<F>
509 + scirs2_core::numeric::FromPrimitive,
510{
511 let processor = EnhancedParallelProcessor::<F>::new();
512 processor.correlation_matrix_parallel(matrix)
513}
514
515#[allow(dead_code)]
516pub fn bootstrap_parallel_advanced<F>(
517 data: &ArrayView1<F>,
518 n_bootstrap: usize,
519 statistic_fn: impl Fn(&ArrayView1<F>) -> F + Send + Sync,
520 seed: Option<u64>,
521) -> StatsResult<Array1<F>>
522where
523 F: Float
524 + NumCast
525 + SimdUnifiedOps
526 + Zero
527 + One
528 + PartialOrd
529 + Copy
530 + Send
531 + Sync
532 + std::fmt::Display
533 + std::iter::Sum<F>
534 + scirs2_core::numeric::FromPrimitive,
535{
536 let processor = EnhancedParallelProcessor::<F>::new();
537 processor.bootstrap_parallel_enhanced(data, n_bootstrap, statistic_fn, seed)
538}