1use crate::error::{StatsError, StatsResult};
7use crate::{kendall_tau, pearson_r, spearman_r};
8use scirs2_core::ndarray::{s, Array1, Array2, ArrayBase, ArrayView1, ArrayView2, Data, Ix1, Ix2};
9use scirs2_core::numeric::{Float, NumCast, One, Zero};
10use scirs2_core::{
11 simd_ops::{AutoOptimizer, SimdUnifiedOps},
12 validation::*,
13};
14use std::sync::{Arc, Mutex};
15
16#[derive(Debug, Clone)]
18pub struct ParallelCorrelationConfig {
19 pub min_parallelsize: usize,
21 pub chunksize: Option<usize>,
23 pub use_simd: bool,
25 pub work_stealing: bool,
27}
28
29impl Default for ParallelCorrelationConfig {
30 fn default() -> Self {
31 Self {
32 min_parallelsize: 50, chunksize: None, use_simd: true,
35 work_stealing: true,
36 }
37 }
38}
39
40#[allow(dead_code)]
74pub fn corrcoef_parallel_enhanced<F>(
75 data: &ArrayView2<F>,
76 method: &str,
77 config: &ParallelCorrelationConfig,
78) -> StatsResult<Array2<F>>
79where
80 F: Float
81 + NumCast
82 + SimdUnifiedOps
83 + Zero
84 + One
85 + Copy
86 + Send
87 + Sync
88 + std::iter::Sum<F>
89 + std::fmt::Debug
90 + std::fmt::Display
91 + 'static,
92{
93 checkarray_finite_2d(data, "data")?;
95
96 match method {
97 "pearson" | "spearman" | "kendall" => {}
98 _ => {
99 return Err(StatsError::InvalidArgument(format!(
100 "Method must be 'pearson', 'spearman', or 'kendall', got {}",
101 method
102 )))
103 }
104 }
105
106 let (n_obs, n_vars) = data.dim();
107
108 if n_obs == 0 || n_vars == 0 {
109 return Err(StatsError::InvalidArgument(
110 "Data array cannot be empty".to_string(),
111 ));
112 }
113
114 let mut corr_mat = Array2::<F>::zeros((n_vars, n_vars));
116
117 for i in 0..n_vars {
119 corr_mat[[i, i]] = F::one();
120 }
121
122 let mut pairs = Vec::new();
124 for i in 0..n_vars {
125 for j in (i + 1)..n_vars {
126 pairs.push((i, j));
127 }
128 }
129
130 let use_parallel = n_vars >= config.min_parallelsize;
132
133 if use_parallel {
134 let chunksize = config
136 .chunksize
137 .unwrap_or(std::cmp::max(1, pairs.len() / 4));
138
139 let results = Arc::new(Mutex::new(Vec::new()));
141
142 pairs.chunks(chunksize).for_each(|chunk| {
143 let mut local_results = Vec::new();
144
145 for &(i, j) in chunk {
146 let var_i = data.slice(s![.., i]);
147 let var_j = data.slice(s![.., j]);
148
149 let corr = match method {
150 "pearson" => {
151 if config.use_simd {
152 match pearson_r_simd_enhanced(&var_i, &var_j) {
153 Ok(val) => val,
154 Err(_) => continue,
155 }
156 } else {
157 match pearson_r(&var_i, &var_j) {
158 Ok(val) => val,
159 Err(_) => continue,
160 }
161 }
162 }
163 "spearman" => match spearman_r(&var_i, &var_j) {
164 Ok(val) => val,
165 Err(_) => continue,
166 },
167 "kendall" => match kendall_tau(&var_i, &var_j, "b") {
168 Ok(val) => val,
169 Err(_) => continue,
170 },
171 _ => unreachable!(),
172 };
173
174 local_results.push((i, j, corr));
175 }
176
177 let mut global_results = results.lock().expect("Operation failed");
178 global_results.extend(local_results);
179 });
180
181 let all_results = Arc::try_unwrap(results)
182 .expect("Operation failed")
183 .into_inner()
184 .expect("Operation failed");
185
186 for (i, j, corr) in all_results {
188 corr_mat[[i, j]] = corr;
189 corr_mat[[j, i]] = corr; }
191 } else {
192 for (i, j) in pairs {
194 let var_i = data.slice(s![.., i]);
195 let var_j = data.slice(s![.., j]);
196
197 let corr = match method {
198 "pearson" => {
199 if config.use_simd {
200 pearson_r_simd_enhanced(&var_i, &var_j)?
201 } else {
202 pearson_r(&var_i, &var_j)?
203 }
204 }
205 "spearman" => spearman_r(&var_i, &var_j)?,
206 "kendall" => kendall_tau(&var_i, &var_j, "b")?,
207 _ => unreachable!(),
208 };
209
210 corr_mat[[i, j]] = corr;
211 corr_mat[[j, i]] = corr; }
213 }
214
215 Ok(corr_mat)
216}
217
218#[allow(dead_code)]
223pub fn pearson_r_simd_enhanced<F, D>(x: &ArrayBase<D, Ix1>, y: &ArrayBase<D, Ix1>) -> StatsResult<F>
224where
225 F: Float + NumCast + SimdUnifiedOps + Zero + One + Copy + std::iter::Sum<F>,
226 D: Data<Elem = F>,
227{
228 if x.len() != y.len() {
230 return Err(StatsError::DimensionMismatch(
231 "Arrays must have the same length".to_string(),
232 ));
233 }
234
235 if x.is_empty() {
236 return Err(StatsError::InvalidArgument(
237 "Arrays cannot be empty".to_string(),
238 ));
239 }
240
241 let n = x.len();
242 let n_f = F::from(n).expect("Failed to convert to float");
243 let optimizer = AutoOptimizer::new();
244
245 let (mean_x, mean_y) = if optimizer.should_use_simd(n) {
247 let sum_x = F::simd_sum(&x.view());
248 let sum_y = F::simd_sum(&y.view());
249 (sum_x / n_f, sum_y / n_f)
250 } else {
251 let mean_x = x.iter().fold(F::zero(), |acc, &val| acc + val) / n_f;
252 let mean_y = y.iter().fold(F::zero(), |acc, &val| acc + val) / n_f;
253 (mean_x, mean_y)
254 };
255
256 let (sum_xy, sum_x2, sum_y2) = if optimizer.should_use_simd(n) {
258 let mean_x_array = Array1::from_elem(n, mean_x);
260 let mean_y_array = Array1::from_elem(n, mean_y);
261
262 let x_dev = F::simd_sub(&x.view(), &mean_x_array.view());
264 let y_dev = F::simd_sub(&y.view(), &mean_y_array.view());
265
266 let xy_prod = F::simd_mul(&x_dev.view(), &y_dev.view());
268 let x_sq = F::simd_mul(&x_dev.view(), &x_dev.view());
269 let y_sq = F::simd_mul(&y_dev.view(), &y_dev.view());
270
271 let sum_xy = F::simd_sum(&xy_prod.view());
273 let sum_x2 = F::simd_sum(&x_sq.view());
274 let sum_y2 = F::simd_sum(&y_sq.view());
275
276 (sum_xy, sum_x2, sum_y2)
277 } else {
278 let mut sum_xy = F::zero();
280 let mut sum_x2 = F::zero();
281 let mut sum_y2 = F::zero();
282
283 for i in 0..n {
284 let x_dev = x[i] - mean_x;
285 let y_dev = y[i] - mean_y;
286
287 sum_xy = sum_xy + x_dev * y_dev;
288 sum_x2 = sum_x2 + x_dev * x_dev;
289 sum_y2 = sum_y2 + y_dev * y_dev;
290 }
291
292 (sum_xy, sum_x2, sum_y2)
293 };
294
295 if sum_x2 <= F::epsilon() || sum_y2 <= F::epsilon() {
297 return Err(StatsError::InvalidArgument(
298 "Cannot compute correlation when one or both variables have zero variance".to_string(),
299 ));
300 }
301
302 let corr = sum_xy / (sum_x2 * sum_y2).sqrt();
304
305 let corr = if corr > F::one() {
307 F::one()
308 } else if corr < -F::one() {
309 -F::one()
310 } else {
311 corr
312 };
313
314 Ok(corr)
315}
316
317#[allow(dead_code)]
332pub fn batch_correlations_parallel<'a, F>(
333 pairs: &[(ArrayView1<'a, F>, ArrayView1<'a, F>)],
334 method: &str,
335 config: &ParallelCorrelationConfig,
336) -> StatsResult<Vec<F>>
337where
338 F: Float
339 + NumCast
340 + SimdUnifiedOps
341 + Zero
342 + One
343 + Copy
344 + Send
345 + Sync
346 + std::iter::Sum<F>
347 + std::fmt::Debug
348 + std::fmt::Display
349 + 'static,
350{
351 if pairs.is_empty() {
352 return Ok(Vec::new());
353 }
354
355 match method {
357 "pearson" | "spearman" | "kendall" => {}
358 _ => {
359 return Err(StatsError::InvalidArgument(format!(
360 "Method must be 'pearson', 'spearman', or 'kendall', got {}",
361 method
362 )))
363 }
364 }
365
366 let n_pairs = pairs.len();
367 let use_parallel = n_pairs >= config.min_parallelsize.min(10); if use_parallel {
370 let chunksize = config.chunksize.unwrap_or(std::cmp::max(1, n_pairs / 4));
372
373 let results = Arc::new(Mutex::new(Vec::new()));
374 let error_occurred = Arc::new(Mutex::new(false));
375
376 pairs.chunks(chunksize).for_each(|chunk| {
377 let mut local_results = Vec::new();
378 let mut has_error = false;
379
380 for (x, y) in chunk {
381 let corr = match method {
382 "pearson" => {
383 if config.use_simd {
384 pearson_r_simd_enhanced(x, y)
385 } else {
386 pearson_r(x, y)
387 }
388 }
389 "spearman" => spearman_r(x, y),
390 "kendall" => kendall_tau(x, y, "b"),
391 _ => unreachable!(),
392 };
393
394 match corr {
395 Ok(val) => local_results.push(val),
396 Err(_) => {
397 has_error = true;
398 break;
399 }
400 }
401 }
402
403 if has_error {
404 *error_occurred.lock().expect("Operation failed") = true;
405 } else {
406 results
407 .lock()
408 .expect("Operation failed")
409 .extend(local_results);
410 }
411 });
412
413 if *error_occurred.lock().expect("Operation failed") {
414 return Err(StatsError::InvalidArgument(
415 "Error occurred during batch correlation computation".to_string(),
416 ));
417 }
418
419 let final_results = Arc::try_unwrap(results)
420 .expect("Operation failed")
421 .into_inner()
422 .expect("Operation failed");
423 Ok(final_results)
424 } else {
425 let mut results = Vec::with_capacity(n_pairs);
427
428 for (x, y) in pairs {
429 let corr = match method {
430 "pearson" => {
431 if config.use_simd {
432 pearson_r_simd_enhanced(x, y)?
433 } else {
434 pearson_r(x, y)?
435 }
436 }
437 "spearman" => spearman_r(x, y)?,
438 "kendall" => kendall_tau(x, y, "b")?,
439 _ => unreachable!(),
440 };
441 results.push(corr);
442 }
443
444 Ok(results)
445 }
446}
447
448#[allow(dead_code)]
453pub fn rolling_correlation_parallel<F>(
454 x: &ArrayView1<F>,
455 y: &ArrayView1<F>,
456 windowsize: usize,
457 method: &str,
458 config: &ParallelCorrelationConfig,
459) -> StatsResult<Array1<F>>
460where
461 F: Float
462 + NumCast
463 + SimdUnifiedOps
464 + Zero
465 + One
466 + Copy
467 + Send
468 + Sync
469 + std::iter::Sum<F>
470 + std::fmt::Debug
471 + std::fmt::Display
472 + 'static,
473{
474 if x.len() != y.len() {
475 return Err(StatsError::DimensionMismatch(format!(
476 "x and y must have the same length, got {} and {}",
477 x.len(),
478 y.len()
479 )));
480 }
481 check_positive(windowsize, "windowsize")?;
482
483 if windowsize > x.len() {
484 return Err(StatsError::InvalidArgument(
485 "Window size cannot be larger than data length".to_string(),
486 ));
487 }
488
489 let n_windows = x.len() - windowsize + 1;
490 let mut results = Array1::zeros(n_windows);
491
492 let window_pairs: Vec<_> = (0..n_windows)
494 .map(|i| {
495 let x_window = x.slice(s![i..i + windowsize]);
496 let y_window = y.slice(s![i..i + windowsize]);
497 (x_window, y_window)
498 })
499 .collect();
500
501 let correlations = batch_correlations_parallel(&window_pairs, method, config)?;
503
504 for (i, corr) in correlations.into_iter().enumerate() {
506 results[i] = corr;
507 }
508
509 Ok(results)
510}
511
512#[allow(dead_code)]
514fn checkarray_finite_2d<F, D>(arr: &ArrayBase<D, Ix2>, name: &str) -> StatsResult<()>
515where
516 F: Float,
517 D: Data<Elem = F>,
518{
519 for &val in arr.iter() {
520 if !val.is_finite() {
521 return Err(StatsError::InvalidArgument(format!(
522 "{} contains non-finite values",
523 name
524 )));
525 }
526 }
527 Ok(())
528}
529
530#[cfg(test)]
531mod tests {
532 use super::*;
533 use crate::corrcoef;
534 use scirs2_core::ndarray::array;
535
536 #[test]
537 fn test_corrcoef_parallel_enhanced_consistency() {
538 let data = array![
539 [1.0, 5.0, 10.0],
540 [2.0, 4.0, 9.0],
541 [3.0, 3.0, 8.0],
542 [4.0, 2.0, 7.0],
543 [5.0, 1.0, 6.0]
544 ];
545
546 let config = ParallelCorrelationConfig::default();
547 let parallel_result =
548 corrcoef_parallel_enhanced(&data.view(), "pearson", &config).expect("Operation failed");
549 let sequential_result = corrcoef(&data.view(), "pearson").expect("Operation failed");
550
551 for i in 0..3 {
552 for j in 0..3 {
553 assert!(
554 (parallel_result[[i, j]] - sequential_result[[i, j]]).abs() < 1e-10,
555 "Mismatch at [{}, {}]: parallel {} vs sequential {}",
556 i,
557 j,
558 parallel_result[[i, j]],
559 sequential_result[[i, j]]
560 );
561 }
562 }
563 }
564
565 #[test]
566 fn test_pearson_r_simd_enhanced_consistency() {
567 let x = array![1.0, 2.0, 3.0, 4.0, 5.0];
568 let y = array![5.0, 4.0, 3.0, 2.0, 1.0];
569
570 let simd_result = pearson_r_simd_enhanced(&x.view(), &y.view()).expect("Operation failed");
571 let standard_result = pearson_r(&x.view(), &y.view()).expect("Operation failed");
572
573 assert!((simd_result - standard_result).abs() < 1e-10);
574 }
575
576 #[test]
577 fn test_batch_correlations_parallel() {
578 let x1 = array![1.0, 2.0, 3.0, 4.0, 5.0];
579 let y1 = array![5.0, 4.0, 3.0, 2.0, 1.0];
580 let x2 = array![1.0, 2.0, 3.0, 4.0, 5.0];
581 let y2 = array![2.0, 4.0, 6.0, 8.0, 10.0];
582
583 let pairs = vec![(x1.view(), y1.view()), (x2.view(), y2.view())];
584 let config = ParallelCorrelationConfig::default();
585
586 let results =
587 batch_correlations_parallel(&pairs, "pearson", &config).expect("Operation failed");
588
589 assert_eq!(results.len(), 2);
590 assert!((results[0] - (-1.0)).abs() < 1e-10); assert!((results[1] - 1.0).abs() < 1e-10); }
593
594 #[test]
595 fn test_rolling_correlation_parallel() {
596 let x = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
597 let y = array![10.0, 9.0, 8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0];
598
599 let config = ParallelCorrelationConfig::default();
600 let rolling_corrs =
601 rolling_correlation_parallel(&x.view(), &y.view(), 3, "pearson", &config)
602 .expect("Operation failed");
603
604 assert_eq!(rolling_corrs.len(), 8); for corr in rolling_corrs.iter() {
608 assert!(*corr < 0.0);
609 }
610 }
611}