1use crate::error::{StatsError, StatsResult};
7use scirs2_core::ndarray::{Array1, Array2, ArrayBase, ArrayView1, Axis, Data, Ix1, Ix2};
8use scirs2_core::numeric::{Float, NumCast};
9use scirs2_core::parallel_ops::*;
10use std::sync::{Arc, Mutex};
11
12#[derive(Debug, Clone)]
14pub struct AdvancedParallelConfig {
15 pub minsize: usize,
17 pub chunksize: Option<usize>,
19 pub max_threads: Option<usize>,
21 pub work_stealing: bool,
23 pub dynamic_chunks: bool,
25}
26
27impl Default for AdvancedParallelConfig {
28 fn default() -> Self {
29 Self {
30 minsize: 2_000, chunksize: None, max_threads: None, work_stealing: true,
34 dynamic_chunks: true,
35 }
36 }
37}
38
39impl AdvancedParallelConfig {
40 pub fn get_optimal_chunksize(&self, n: usize) -> usize {
42 if let Some(size) = self.chunksize {
43 return size;
44 }
45
46 let threads = self.max_threads.unwrap_or_else(num_cpus::get);
47
48 if self.dynamic_chunks {
49 let base_chunk = n / (threads * 4); base_chunk.clamp(100, 10_000) } else {
53 n / threads
54 }
55 }
56}
57
58pub struct ParallelBatchProcessor<F> {
63 config: AdvancedParallelConfig,
64 _phantom: std::marker::PhantomData<F>,
65}
66
67impl<F> ParallelBatchProcessor<F>
68where
69 F: Float + NumCast + Send + Sync + std::iter::Sum + std::fmt::Display,
70{
71 pub fn new(config: AdvancedParallelConfig) -> Self {
72 Self {
73 config,
74 _phantom: std::marker::PhantomData,
75 }
76 }
77
78 pub fn batch_descriptive_stats<D>(
80 &self,
81 datasets: &[ArrayBase<D, Ix1>],
82 ) -> StatsResult<Vec<(F, F, F, F)>>
83 where
85 D: Data<Elem = F> + Sync,
86 {
87 if datasets.is_empty() {
88 return Ok(Vec::new());
89 }
90
91 let results: Vec<StatsResult<(F, F, F, F)>> = datasets
92 .iter()
93 .map(|dataset| self.compute_singledataset_stats(dataset))
94 .collect();
95
96 results.into_iter().collect()
97 }
98
99 fn compute_singledataset_stats<D>(&self, data: &ArrayBase<D, Ix1>) -> StatsResult<(F, F, F, F)>
100 where
101 D: Data<Elem = F>,
102 {
103 let n = data.len();
104 if n == 0 {
105 return Err(StatsError::InvalidArgument(
106 "Dataset cannot be empty".to_string(),
107 ));
108 }
109
110 if n < self.config.minsize {
111 let mean = data.iter().fold(F::zero(), |acc, &x| acc + x)
113 / F::from(n).expect("Failed to convert to float");
114 let variance = data
115 .iter()
116 .map(|&x| {
117 let diff = x - mean;
118 diff * diff
119 })
120 .fold(F::zero(), |acc, x| acc + x)
121 / F::from(n - 1).expect("Failed to convert to float");
122 let min = data
123 .iter()
124 .fold(data[0], |min_val, &x| if x < min_val { x } else { min_val });
125 let max = data
126 .iter()
127 .fold(data[0], |max_val, &x| if x > max_val { x } else { max_val });
128
129 return Ok((mean, variance, min, max));
130 }
131
132 let chunksize = self.config.get_optimal_chunksize(n);
134
135 let results: Vec<(F, F, F, F, usize)> = data
137 .as_slice()
138 .expect("Operation failed")
139 .par_chunks(chunksize)
140 .map(|chunk| {
141 let len = chunk.len();
142 let sum = chunk.iter().fold(F::zero(), |acc, &x| acc + x);
143 let min = chunk.iter().fold(
144 chunk[0],
145 |min_val, &x| if x < min_val { x } else { min_val },
146 );
147 let max = chunk.iter().fold(
148 chunk[0],
149 |max_val, &x| if x > max_val { x } else { max_val },
150 );
151
152 let local_mean = sum / F::from(len).expect("Failed to convert to float");
154 let sum_sq_dev = chunk
155 .iter()
156 .map(|&x| {
157 let diff = x - local_mean;
158 diff * diff
159 })
160 .fold(F::zero(), |acc, x| acc + x);
161
162 (sum, sum_sq_dev, min, max, len)
163 })
164 .collect();
165
166 let total_sum = results
168 .iter()
169 .map(|(sum__, _, _, _, _)| *sum__)
170 .fold(F::zero(), |acc, x| acc + x);
171 let total_len = results.iter().map(|(_, _, _, _, len)| *len).sum::<usize>();
172 let global_mean = total_sum / F::from(total_len).expect("Failed to convert to float");
173
174 let global_min =
175 results
176 .iter()
177 .map(|(_, _, min__, _, _)| *min__)
178 .fold(
179 results[0].2,
180 |min_val, x| if x < min_val { x } else { min_val },
181 );
182 let global_max =
183 results
184 .iter()
185 .map(|(_, _, _, max_, _)| *max_)
186 .fold(
187 results[0].3,
188 |max_val, x| if x > max_val { x } else { max_val },
189 );
190
191 let global_variance = par_chunks(data.as_slice().expect("Operation failed"), chunksize)
193 .map(|chunk| {
194 chunk
195 .iter()
196 .map(|&x| {
197 let diff = x - global_mean;
198 diff * diff
199 })
200 .fold(F::zero(), |acc, x| acc + x)
201 })
202 .reduce(|| F::zero(), |a, b| a + b)
203 / F::from(total_len - 1).expect("Failed to convert to float");
204
205 Ok((global_mean, global_variance, global_min, global_max))
206 }
207}
208
209pub struct ParallelCrossValidator<F> {
214 k_folds: usize,
215 config: AdvancedParallelConfig,
216 _phantom: std::marker::PhantomData<F>,
217}
218
219impl<F> ParallelCrossValidator<F>
220where
221 F: Float + NumCast + Send + Sync + std::fmt::Display,
222{
223 pub fn new(_kfolds: usize, config: AdvancedParallelConfig) -> Self {
224 Self {
225 k_folds: _kfolds,
226 config,
227 _phantom: std::marker::PhantomData,
228 }
229 }
230
231 pub fn cross_validate_correlation<D1, D2>(
233 &self,
234 x: &ArrayBase<D1, Ix1>,
235 y: &ArrayBase<D2, Ix1>,
236 ) -> StatsResult<(F, F)>
237 where
239 D1: Data<Elem = F> + Sync,
240 D2: Data<Elem = F> + Sync,
241 {
242 if x.len() != y.len() {
243 return Err(StatsError::DimensionMismatch(
244 "Arrays must have same length".to_string(),
245 ));
246 }
247
248 let n = x.len();
249 if n < self.k_folds {
250 return Err(StatsError::InvalidArgument(
251 "Not enough data for k-fold validation".to_string(),
252 ));
253 }
254
255 let foldsize = n / self.k_folds;
256 let x_arc = Arc::new(x.to_owned());
257 let y_arc = Arc::new(y.to_owned());
258
259 let correlations: Vec<F> = (0..self.k_folds)
261 .map(|fold| {
262 let start = fold * foldsize;
263 let end = if fold == self.k_folds - 1 {
264 n
265 } else {
266 (fold + 1) * foldsize
267 };
268
269 let mut train_x = Vec::new();
271 let mut train_y = Vec::new();
272
273 for i in 0..n {
274 if i < start || i >= end {
275 train_x.push(x_arc[i]);
276 train_y.push(y_arc[i]);
277 }
278 }
279
280 let x_train = Array1::from(train_x);
282 let y_train = Array1::from(train_y);
283
284 self.compute_pearson_correlation(&x_train.view(), &y_train.view())
285 .unwrap_or(F::zero())
286 })
287 .collect();
288
289 let mean_corr = correlations.iter().fold(F::zero(), |acc, &x| acc + x)
291 / F::from(self.k_folds).expect("Failed to convert to float");
292 let var_corr = correlations
293 .iter()
294 .map(|&corr| {
295 let diff = corr - mean_corr;
296 diff * diff
297 })
298 .fold(F::zero(), |acc, x| acc + x)
299 / F::from(self.k_folds - 1).expect("Failed to convert to float");
300 let std_corr = var_corr.sqrt();
301
302 Ok((mean_corr, std_corr))
303 }
304
305 fn compute_pearson_correlation(&self, x: &ArrayView1<F>, y: &ArrayView1<F>) -> StatsResult<F> {
306 let n = x.len();
307 let mean_x = x.iter().fold(F::zero(), |acc, &val| acc + val)
308 / F::from(n).expect("Failed to convert to float");
309 let mean_y = y.iter().fold(F::zero(), |acc, &val| acc + val)
310 / F::from(n).expect("Failed to convert to float");
311
312 let mut sum_xy = F::zero();
313 let mut sum_x2 = F::zero();
314 let mut sum_y2 = F::zero();
315
316 for (&xi, &yi) in x.iter().zip(y.iter()) {
317 let x_dev = xi - mean_x;
318 let y_dev = yi - mean_y;
319 sum_xy = sum_xy + x_dev * y_dev;
320 sum_x2 = sum_x2 + x_dev * x_dev;
321 sum_y2 = sum_y2 + y_dev * y_dev;
322 }
323
324 let epsilon = F::from(1e-15)
325 .unwrap_or_else(|| F::from(0.0).expect("Failed to convert constant to float"));
326 if sum_x2 <= epsilon || sum_y2 <= epsilon {
327 return Ok(F::zero());
328 }
329
330 Ok(sum_xy / (sum_x2 * sum_y2).sqrt())
331 }
332}
333
334pub struct ParallelMonteCarlo<F> {
338 n_simulations: usize,
339 config: AdvancedParallelConfig,
340 _phantom: std::marker::PhantomData<F>,
341}
342
343impl<F> ParallelMonteCarlo<F>
344where
345 F: Float + NumCast + Send + Sync + std::fmt::Display,
346{
347 pub fn new(_nsimulations: usize, config: AdvancedParallelConfig) -> Self {
348 Self {
349 n_simulations: _nsimulations,
350 config,
351 _phantom: std::marker::PhantomData,
352 }
353 }
354
355 pub fn bootstrap_confidence_interval<D>(
357 &self,
358 data: &ArrayBase<D, Ix1>,
359 statistic_fn: impl Fn(&ArrayView1<F>) -> F + Send + Sync,
360 confidence_level: F,
361 ) -> StatsResult<(F, F, F)>
362 where
364 D: Data<Elem = F> + Sync,
365 {
366 if data.is_empty() {
367 return Err(StatsError::InvalidArgument(
368 "Data cannot be empty".to_string(),
369 ));
370 }
371
372 if confidence_level <= F::zero() || confidence_level >= F::one() {
373 return Err(StatsError::InvalidArgument(
374 "Confidence _level must be between 0 and 1".to_string(),
375 ));
376 }
377
378 let data_arc = Arc::new(data.to_owned());
379 let n = data.len();
380
381 let bootstrap_stats: Vec<F> = (0..self.n_simulations)
383 .map(|seed| {
384 use scirs2_core::random::rngs::StdRng;
385 use scirs2_core::random::SeedableRng;
386
387 let mut rng = StdRng::seed_from_u64(seed as u64);
388 let mut bootstrap_sample = Array1::zeros(n);
389
390 for i in 0..n {
391 use scirs2_core::random::Rng;
392 let idx = rng.gen_range(0..n);
393 bootstrap_sample[i] = data_arc[idx];
394 }
395
396 statistic_fn(&bootstrap_sample.view())
397 })
398 .collect();
399
400 let mut sorted_stats = bootstrap_stats;
402 sorted_stats.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
403
404 let alpha = F::one() - confidence_level;
406 let lower_percentile = alpha / F::from(2.0).expect("Failed to convert constant to float");
407 let upper_percentile = F::one() - lower_percentile;
408
409 let lower_idx = (lower_percentile
410 * F::from(self.n_simulations - 1).expect("Failed to convert to float"))
411 .floor()
412 .to_usize()
413 .expect("Operation failed");
414 let upper_idx = (upper_percentile
415 * F::from(self.n_simulations - 1).expect("Failed to convert to float"))
416 .ceil()
417 .to_usize()
418 .expect("Operation failed");
419
420 let original_estimate = statistic_fn(&data.view());
421 let lower_bound = sorted_stats[lower_idx];
422 let upper_bound = sorted_stats[upper_idx];
423
424 Ok((original_estimate, lower_bound, upper_bound))
425 }
426
427 pub fn permutation_test<D1, D2>(
429 &self,
430 group1: &ArrayBase<D1, Ix1>,
431 group2: &ArrayBase<D2, Ix1>,
432 test_statistic: impl Fn(&ArrayView1<F>, &ArrayView1<F>) -> F + Send + Sync,
433 ) -> StatsResult<F>
434 where
436 D1: Data<Elem = F> + Sync,
437 D2: Data<Elem = F> + Sync,
438 {
439 if group1.is_empty() || group2.is_empty() {
440 return Err(StatsError::InvalidArgument(
441 "Groups cannot be empty".to_string(),
442 ));
443 }
444
445 let combined: Vec<F> = group1.iter().chain(group2.iter()).cloned().collect();
447 let n1 = group1.len();
448 let n2 = group2.len();
449 let _total_n = n1 + n2;
450
451 let observed_stat = test_statistic(&group1.view(), &group2.view());
453
454 let combined_arc = Arc::new(combined);
456 let count_extreme = Arc::new(Mutex::new(0usize));
457
458 (0..self.n_simulations).for_each(|seed| {
459 use scirs2_core::random::rngs::StdRng;
460 use scirs2_core::random::{SeedableRng, SliceRandom};
461
462 let mut rng = StdRng::seed_from_u64(seed as u64);
463 let mut permuted = combined_arc.as_ref().clone();
464 permuted.shuffle(&mut rng);
465
466 let perm_group1 = Array1::from_vec(permuted[0..n1].to_vec());
468 let perm_group2 = Array1::from_vec(permuted[n1..].to_vec());
469
470 let perm_stat = test_statistic(&perm_group1.view(), &perm_group2.view());
471
472 if perm_stat.abs() >= observed_stat.abs() {
474 let mut count = count_extreme.lock().expect("Operation failed");
475 *count += 1;
476 }
477 });
478
479 let extreme_count = *count_extreme.lock().expect("Operation failed");
480 let p_value = F::from(extreme_count).expect("Failed to convert to float")
481 / F::from(self.n_simulations).expect("Failed to convert to float");
482
483 Ok(p_value)
484 }
485}
486
487pub struct ParallelMatrixOps;
489
490impl ParallelMatrixOps {
491 pub fn matvec_parallel<F, D1, D2>(
493 matrix: &ArrayBase<D1, Ix2>,
494 vector: &ArrayBase<D2, Ix1>,
495 config: Option<AdvancedParallelConfig>,
496 ) -> StatsResult<Array1<F>>
497 where
498 F: Float + NumCast + Send + Sync + std::iter::Sum,
499 D1: Data<Elem = F> + Sync,
500 D2: Data<Elem = F> + Sync,
501 {
502 let (m, n) = matrix.dim();
503 if n != vector.len() {
504 return Err(StatsError::DimensionMismatch(
505 "Matrix columns must match vector length".to_string(),
506 ));
507 }
508
509 let config = config.unwrap_or_default();
510 let mut result = Array1::zeros(m);
511
512 if m < config.minsize {
513 for i in 0..m {
515 let row = matrix.row(i);
516 result[i] = row.iter().zip(vector.iter()).map(|(&a, &b)| a * b).sum();
517 }
518 } else {
519 let chunksize = config.get_optimal_chunksize(m);
521
522 result
523 .axis_chunks_iter_mut(Axis(0), chunksize)
524 .enumerate()
525 .for_each(|(chunk_idx, mut result_chunk)| {
526 let start_row = chunk_idx * chunksize;
527 let end_row = (start_row + result_chunk.len()).min(m);
528
529 for (local_idx, i) in (start_row..end_row).enumerate() {
530 let row = matrix.row(i);
531 result_chunk[local_idx] =
532 row.iter().zip(vector.iter()).map(|(&a, &b)| a * b).sum();
533 }
534 });
535 }
536
537 Ok(result)
538 }
539
540 pub fn outer_product_parallel<F, D1, D2>(
542 a: &ArrayBase<D1, Ix1>,
543 b: &ArrayBase<D2, Ix1>,
544 config: Option<AdvancedParallelConfig>,
545 ) -> Array2<F>
546 where
547 F: Float + NumCast + Send + Sync,
548 D1: Data<Elem = F> + Sync,
549 D2: Data<Elem = F> + Sync,
550 {
551 let m = a.len();
552 let n = b.len();
553 let mut result = Array2::zeros((m, n));
554
555 let config = config.unwrap_or_default();
556
557 if m * n < config.minsize {
558 for i in 0..m {
560 for j in 0..n {
561 result[(i, j)] = a[i] * b[j];
562 }
563 }
564 } else {
565 result
567 .axis_iter_mut(Axis(0))
568 .enumerate()
569 .par_bridge()
570 .for_each(|(i, mut row)| {
571 for (j, elem) in row.iter_mut().enumerate() {
572 *elem = a[i] * b[j];
573 }
574 });
575 }
576
577 result
578 }
579}
580
581#[cfg(test)]
582mod tests {
583 use super::*;
584 use approx::assert_relative_eq;
585 use scirs2_core::ndarray::array;
586
587 #[test]
588 fn test_parallel_batch_processor() {
589 let datasets = vec![
590 array![1.0, 2.0, 3.0, 4.0, 5.0],
591 array![2.0, 4.0, 6.0, 8.0, 10.0],
592 array![1.0, 1.0, 1.0, 1.0, 1.0],
593 ];
594
595 let processor = ParallelBatchProcessor::new(AdvancedParallelConfig::default());
596 let results = processor
597 .batch_descriptive_stats(&datasets)
598 .expect("Operation failed");
599
600 assert_eq!(results.len(), 3);
601 assert_relative_eq!(results[0].0, 3.0, epsilon = 1e-10); assert_relative_eq!(results[1].0, 6.0, epsilon = 1e-10); assert_relative_eq!(results[2].0, 1.0, epsilon = 1e-10); }
605
606 #[test]
607 fn test_parallel_cross_validator() {
608 let x = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
609 let y = array![2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0, 18.0, 20.0]; let validator = ParallelCrossValidator::new(5, AdvancedParallelConfig::default());
612 let (mean_corr, std_corr) = validator
613 .cross_validate_correlation(&x.view(), &y.view())
614 .expect("Operation failed");
615
616 assert!(mean_corr > 0.9); assert!(std_corr < 0.1); }
619
620 #[test]
621 fn test_parallel_matrix_ops() {
622 let matrix = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
623 let vector = array![1.0, 2.0, 3.0];
624
625 let result = ParallelMatrixOps::matvec_parallel(&matrix.view(), &vector.view(), None)
626 .expect("Operation failed");
627
628 assert_relative_eq!(result[0], 14.0, epsilon = 1e-10); assert_relative_eq!(result[1], 32.0, epsilon = 1e-10); }
631}