1use crate::error::{StatsError, StatsResult};
7use crate::{kendall_tau, pearson_r, spearman_r};
8use scirs2_core::ndarray::{s, Array1, Array2, ArrayBase, ArrayView1, ArrayView2, Data, Ix1, Ix2};
9use scirs2_core::numeric::{Float, NumCast, One, Zero};
10use scirs2_core::{
11 simd_ops::{AutoOptimizer, SimdUnifiedOps},
12 validation::*,
13};
14use std::sync::{Arc, Mutex};
15
16#[derive(Debug, Clone)]
18pub struct ParallelCorrelationConfig {
19 pub min_parallelsize: usize,
21 pub chunksize: Option<usize>,
23 pub use_simd: bool,
25 pub work_stealing: bool,
27}
28
29impl Default for ParallelCorrelationConfig {
30 fn default() -> Self {
31 Self {
32 min_parallelsize: 50, chunksize: None, use_simd: true,
35 work_stealing: true,
36 }
37 }
38}
39
40#[allow(dead_code)]
74pub fn corrcoef_parallel_enhanced<F>(
75 data: &ArrayView2<F>,
76 method: &str,
77 config: &ParallelCorrelationConfig,
78) -> StatsResult<Array2<F>>
79where
80 F: Float
81 + NumCast
82 + SimdUnifiedOps
83 + Zero
84 + One
85 + Copy
86 + Send
87 + Sync
88 + std::iter::Sum<F>
89 + std::fmt::Debug
90 + std::fmt::Display,
91{
92 checkarray_finite_2d(data, "data")?;
94
95 match method {
96 "pearson" | "spearman" | "kendall" => {}
97 _ => {
98 return Err(StatsError::InvalidArgument(format!(
99 "Method must be 'pearson', 'spearman', or 'kendall', got {}",
100 method
101 )))
102 }
103 }
104
105 let (n_obs, n_vars) = data.dim();
106
107 if n_obs == 0 || n_vars == 0 {
108 return Err(StatsError::InvalidArgument(
109 "Data array cannot be empty".to_string(),
110 ));
111 }
112
113 let mut corr_mat = Array2::<F>::zeros((n_vars, n_vars));
115
116 for i in 0..n_vars {
118 corr_mat[[i, i]] = F::one();
119 }
120
121 let mut pairs = Vec::new();
123 for i in 0..n_vars {
124 for j in (i + 1)..n_vars {
125 pairs.push((i, j));
126 }
127 }
128
129 let use_parallel = n_vars >= config.min_parallelsize;
131
132 if use_parallel {
133 let chunksize = config
135 .chunksize
136 .unwrap_or(std::cmp::max(1, pairs.len() / 4));
137
138 let results = Arc::new(Mutex::new(Vec::new()));
140
141 pairs.chunks(chunksize).for_each(|chunk| {
142 let mut local_results = Vec::new();
143
144 for &(i, j) in chunk {
145 let var_i = data.slice(s![.., i]);
146 let var_j = data.slice(s![.., j]);
147
148 let corr = match method {
149 "pearson" => {
150 if config.use_simd {
151 match pearson_r_simd_enhanced(&var_i, &var_j) {
152 Ok(val) => val,
153 Err(_) => continue,
154 }
155 } else {
156 match pearson_r(&var_i, &var_j) {
157 Ok(val) => val,
158 Err(_) => continue,
159 }
160 }
161 }
162 "spearman" => match spearman_r(&var_i, &var_j) {
163 Ok(val) => val,
164 Err(_) => continue,
165 },
166 "kendall" => match kendall_tau(&var_i, &var_j, "b") {
167 Ok(val) => val,
168 Err(_) => continue,
169 },
170 _ => unreachable!(),
171 };
172
173 local_results.push((i, j, corr));
174 }
175
176 let mut global_results = results.lock().unwrap();
177 global_results.extend(local_results);
178 });
179
180 let all_results = Arc::try_unwrap(results).unwrap().into_inner().unwrap();
181
182 for (i, j, corr) in all_results {
184 corr_mat[[i, j]] = corr;
185 corr_mat[[j, i]] = corr; }
187 } else {
188 for (i, j) in pairs {
190 let var_i = data.slice(s![.., i]);
191 let var_j = data.slice(s![.., j]);
192
193 let corr = match method {
194 "pearson" => {
195 if config.use_simd {
196 pearson_r_simd_enhanced(&var_i, &var_j)?
197 } else {
198 pearson_r(&var_i, &var_j)?
199 }
200 }
201 "spearman" => spearman_r(&var_i, &var_j)?,
202 "kendall" => kendall_tau(&var_i, &var_j, "b")?,
203 _ => unreachable!(),
204 };
205
206 corr_mat[[i, j]] = corr;
207 corr_mat[[j, i]] = corr; }
209 }
210
211 Ok(corr_mat)
212}
213
214#[allow(dead_code)]
219pub fn pearson_r_simd_enhanced<F, D>(x: &ArrayBase<D, Ix1>, y: &ArrayBase<D, Ix1>) -> StatsResult<F>
220where
221 F: Float + NumCast + SimdUnifiedOps + Zero + One + Copy + std::iter::Sum<F>,
222 D: Data<Elem = F>,
223{
224 if x.len() != y.len() {
226 return Err(StatsError::DimensionMismatch(
227 "Arrays must have the same length".to_string(),
228 ));
229 }
230
231 if x.is_empty() {
232 return Err(StatsError::InvalidArgument(
233 "Arrays cannot be empty".to_string(),
234 ));
235 }
236
237 let n = x.len();
238 let n_f = F::from(n).unwrap();
239 let optimizer = AutoOptimizer::new();
240
241 let (mean_x, mean_y) = if optimizer.should_use_simd(n) {
243 let sum_x = F::simd_sum(&x.view());
244 let sum_y = F::simd_sum(&y.view());
245 (sum_x / n_f, sum_y / n_f)
246 } else {
247 let mean_x = x.iter().fold(F::zero(), |acc, &val| acc + val) / n_f;
248 let mean_y = y.iter().fold(F::zero(), |acc, &val| acc + val) / n_f;
249 (mean_x, mean_y)
250 };
251
252 let (sum_xy, sum_x2, sum_y2) = if optimizer.should_use_simd(n) {
254 let mean_x_array = Array1::from_elem(n, mean_x);
256 let mean_y_array = Array1::from_elem(n, mean_y);
257
258 let x_dev = F::simd_sub(&x.view(), &mean_x_array.view());
260 let y_dev = F::simd_sub(&y.view(), &mean_y_array.view());
261
262 let xy_prod = F::simd_mul(&x_dev.view(), &y_dev.view());
264 let x_sq = F::simd_mul(&x_dev.view(), &x_dev.view());
265 let y_sq = F::simd_mul(&y_dev.view(), &y_dev.view());
266
267 let sum_xy = F::simd_sum(&xy_prod.view());
269 let sum_x2 = F::simd_sum(&x_sq.view());
270 let sum_y2 = F::simd_sum(&y_sq.view());
271
272 (sum_xy, sum_x2, sum_y2)
273 } else {
274 let mut sum_xy = F::zero();
276 let mut sum_x2 = F::zero();
277 let mut sum_y2 = F::zero();
278
279 for i in 0..n {
280 let x_dev = x[i] - mean_x;
281 let y_dev = y[i] - mean_y;
282
283 sum_xy = sum_xy + x_dev * y_dev;
284 sum_x2 = sum_x2 + x_dev * x_dev;
285 sum_y2 = sum_y2 + y_dev * y_dev;
286 }
287
288 (sum_xy, sum_x2, sum_y2)
289 };
290
291 if sum_x2 <= F::epsilon() || sum_y2 <= F::epsilon() {
293 return Err(StatsError::InvalidArgument(
294 "Cannot compute correlation when one or both variables have zero variance".to_string(),
295 ));
296 }
297
298 let corr = sum_xy / (sum_x2 * sum_y2).sqrt();
300
301 let corr = if corr > F::one() {
303 F::one()
304 } else if corr < -F::one() {
305 -F::one()
306 } else {
307 corr
308 };
309
310 Ok(corr)
311}
312
313#[allow(dead_code)]
328pub fn batch_correlations_parallel<'a, F>(
329 pairs: &[(ArrayView1<'a, F>, ArrayView1<'a, F>)],
330 method: &str,
331 config: &ParallelCorrelationConfig,
332) -> StatsResult<Vec<F>>
333where
334 F: Float
335 + NumCast
336 + SimdUnifiedOps
337 + Zero
338 + One
339 + Copy
340 + Send
341 + Sync
342 + std::iter::Sum<F>
343 + std::fmt::Debug
344 + std::fmt::Display,
345{
346 if pairs.is_empty() {
347 return Ok(Vec::new());
348 }
349
350 match method {
352 "pearson" | "spearman" | "kendall" => {}
353 _ => {
354 return Err(StatsError::InvalidArgument(format!(
355 "Method must be 'pearson', 'spearman', or 'kendall', got {}",
356 method
357 )))
358 }
359 }
360
361 let n_pairs = pairs.len();
362 let use_parallel = n_pairs >= config.min_parallelsize.min(10); if use_parallel {
365 let chunksize = config.chunksize.unwrap_or(std::cmp::max(1, n_pairs / 4));
367
368 let results = Arc::new(Mutex::new(Vec::new()));
369 let error_occurred = Arc::new(Mutex::new(false));
370
371 pairs.chunks(chunksize).for_each(|chunk| {
372 let mut local_results = Vec::new();
373 let mut has_error = false;
374
375 for (x, y) in chunk {
376 let corr = match method {
377 "pearson" => {
378 if config.use_simd {
379 pearson_r_simd_enhanced(x, y)
380 } else {
381 pearson_r(x, y)
382 }
383 }
384 "spearman" => spearman_r(x, y),
385 "kendall" => kendall_tau(x, y, "b"),
386 _ => unreachable!(),
387 };
388
389 match corr {
390 Ok(val) => local_results.push(val),
391 Err(_) => {
392 has_error = true;
393 break;
394 }
395 }
396 }
397
398 if has_error {
399 *error_occurred.lock().unwrap() = true;
400 } else {
401 results.lock().unwrap().extend(local_results);
402 }
403 });
404
405 if *error_occurred.lock().unwrap() {
406 return Err(StatsError::InvalidArgument(
407 "Error occurred during batch correlation computation".to_string(),
408 ));
409 }
410
411 let final_results = Arc::try_unwrap(results).unwrap().into_inner().unwrap();
412 Ok(final_results)
413 } else {
414 let mut results = Vec::with_capacity(n_pairs);
416
417 for (x, y) in pairs {
418 let corr = match method {
419 "pearson" => {
420 if config.use_simd {
421 pearson_r_simd_enhanced(x, y)?
422 } else {
423 pearson_r(x, y)?
424 }
425 }
426 "spearman" => spearman_r(x, y)?,
427 "kendall" => kendall_tau(x, y, "b")?,
428 _ => unreachable!(),
429 };
430 results.push(corr);
431 }
432
433 Ok(results)
434 }
435}
436
437#[allow(dead_code)]
442pub fn rolling_correlation_parallel<F>(
443 x: &ArrayView1<F>,
444 y: &ArrayView1<F>,
445 windowsize: usize,
446 method: &str,
447 config: &ParallelCorrelationConfig,
448) -> StatsResult<Array1<F>>
449where
450 F: Float
451 + NumCast
452 + SimdUnifiedOps
453 + Zero
454 + One
455 + Copy
456 + Send
457 + Sync
458 + std::iter::Sum<F>
459 + std::fmt::Debug
460 + std::fmt::Display,
461{
462 if x.len() != y.len() {
463 return Err(StatsError::DimensionMismatch(format!(
464 "x and y must have the same length, got {} and {}",
465 x.len(),
466 y.len()
467 )));
468 }
469 check_positive(windowsize, "windowsize")?;
470
471 if windowsize > x.len() {
472 return Err(StatsError::InvalidArgument(
473 "Window size cannot be larger than data length".to_string(),
474 ));
475 }
476
477 let n_windows = x.len() - windowsize + 1;
478 let mut results = Array1::zeros(n_windows);
479
480 let window_pairs: Vec<_> = (0..n_windows)
482 .map(|i| {
483 let x_window = x.slice(s![i..i + windowsize]);
484 let y_window = y.slice(s![i..i + windowsize]);
485 (x_window, y_window)
486 })
487 .collect();
488
489 let correlations = batch_correlations_parallel(&window_pairs, method, config)?;
491
492 for (i, corr) in correlations.into_iter().enumerate() {
494 results[i] = corr;
495 }
496
497 Ok(results)
498}
499
500#[allow(dead_code)]
502fn checkarray_finite_2d<F, D>(arr: &ArrayBase<D, Ix2>, name: &str) -> StatsResult<()>
503where
504 F: Float,
505 D: Data<Elem = F>,
506{
507 for &val in arr.iter() {
508 if !val.is_finite() {
509 return Err(StatsError::InvalidArgument(format!(
510 "{} contains non-finite values",
511 name
512 )));
513 }
514 }
515 Ok(())
516}
517
518#[cfg(test)]
519mod tests {
520 use super::*;
521 use crate::corrcoef;
522 use scirs2_core::ndarray::array;
523
524 #[test]
525 fn test_corrcoef_parallel_enhanced_consistency() {
526 let data = array![
527 [1.0, 5.0, 10.0],
528 [2.0, 4.0, 9.0],
529 [3.0, 3.0, 8.0],
530 [4.0, 2.0, 7.0],
531 [5.0, 1.0, 6.0]
532 ];
533
534 let config = ParallelCorrelationConfig::default();
535 let parallel_result = corrcoef_parallel_enhanced(&data.view(), "pearson", &config).unwrap();
536 let sequential_result = corrcoef(&data.view(), "pearson").unwrap();
537
538 for i in 0..3 {
539 for j in 0..3 {
540 assert!(
541 (parallel_result[[i, j]] - sequential_result[[i, j]]).abs() < 1e-10,
542 "Mismatch at [{}, {}]: parallel {} vs sequential {}",
543 i,
544 j,
545 parallel_result[[i, j]],
546 sequential_result[[i, j]]
547 );
548 }
549 }
550 }
551
552 #[test]
553 fn test_pearson_r_simd_enhanced_consistency() {
554 let x = array![1.0, 2.0, 3.0, 4.0, 5.0];
555 let y = array![5.0, 4.0, 3.0, 2.0, 1.0];
556
557 let simd_result = pearson_r_simd_enhanced(&x.view(), &y.view()).unwrap();
558 let standard_result = pearson_r(&x.view(), &y.view()).unwrap();
559
560 assert!((simd_result - standard_result).abs() < 1e-10);
561 }
562
563 #[test]
564 fn test_batch_correlations_parallel() {
565 let x1 = array![1.0, 2.0, 3.0, 4.0, 5.0];
566 let y1 = array![5.0, 4.0, 3.0, 2.0, 1.0];
567 let x2 = array![1.0, 2.0, 3.0, 4.0, 5.0];
568 let y2 = array![2.0, 4.0, 6.0, 8.0, 10.0];
569
570 let pairs = vec![(x1.view(), y1.view()), (x2.view(), y2.view())];
571 let config = ParallelCorrelationConfig::default();
572
573 let results = batch_correlations_parallel(&pairs, "pearson", &config).unwrap();
574
575 assert_eq!(results.len(), 2);
576 assert!((results[0] - (-1.0)).abs() < 1e-10); assert!((results[1] - 1.0).abs() < 1e-10); }
579
580 #[test]
581 fn test_rolling_correlation_parallel() {
582 let x = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
583 let y = array![10.0, 9.0, 8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0];
584
585 let config = ParallelCorrelationConfig::default();
586 let rolling_corrs =
587 rolling_correlation_parallel(&x.view(), &y.view(), 3, "pearson", &config).unwrap();
588
589 assert_eq!(rolling_corrs.len(), 8); for corr in rolling_corrs.iter() {
593 assert!(*corr < 0.0);
594 }
595 }
596}