scirs2_stats/
correlation_parallel_enhanced.rs

1//! Enhanced parallel correlation computations
2//!
3//! This module provides SIMD and parallel-accelerated implementations of correlation
4//! operations using scirs2-core's unified optimization framework.
5
6use crate::error::{StatsError, StatsResult};
7use crate::{kendall_tau, pearson_r, spearman_r};
8use scirs2_core::ndarray::{s, Array1, Array2, ArrayBase, ArrayView1, ArrayView2, Data, Ix1, Ix2};
9use scirs2_core::numeric::{Float, NumCast, One, Zero};
10use scirs2_core::{
11    simd_ops::{AutoOptimizer, SimdUnifiedOps},
12    validation::*,
13};
14use std::sync::{Arc, Mutex};
15
16/// Parallel configuration for correlation computations
17#[derive(Debug, Clone)]
18pub struct ParallelCorrelationConfig {
19    /// Minimum matrix size to trigger parallel processing
20    pub min_parallelsize: usize,
21    /// Chunk size for parallel processing
22    pub chunksize: Option<usize>,
23    /// Enable SIMD optimizations
24    pub use_simd: bool,
25    /// Use work stealing for load balancing
26    pub work_stealing: bool,
27}
28
29impl Default for ParallelCorrelationConfig {
30    fn default() -> Self {
31        Self {
32            min_parallelsize: 50, // 50x50 matrix threshold
33            chunksize: None,      // Auto-determine
34            use_simd: true,
35            work_stealing: true,
36        }
37    }
38}
39
40/// Parallel and SIMD-optimized correlation matrix computation
41///
42/// Computes pairwise correlations between all variables in a matrix using
43/// parallel processing for the correlation pairs and SIMD for individual
44/// correlation calculations.
45///
46/// # Arguments
47///
48/// * `data` - Input data matrix (observations × variables)
49/// * `method` - Correlation method ("pearson", "spearman", "kendall")
50/// * `config` - Parallel processing configuration
51///
52/// # Returns
53///
54/// * Correlation matrix (variables × variables)
55///
56/// # Examples
57///
58/// ```
59/// use scirs2_core::ndarray::array;
60/// use scirs2_stats::{corrcoef_parallel_enhanced, ParallelCorrelationConfig};
61///
62/// let data = array![
63///     [1.0, 5.0, 10.0],
64///     [2.0, 4.0, 9.0],
65///     [3.0, 3.0, 8.0],
66///     [4.0, 2.0, 7.0],
67///     [5.0, 1.0, 6.0]
68/// ];
69///
70/// let config = ParallelCorrelationConfig::default();
71/// let corr_matrix = corrcoef_parallel_enhanced(&data.view(), "pearson", &config).unwrap();
72/// ```
73#[allow(dead_code)]
74pub fn corrcoef_parallel_enhanced<F>(
75    data: &ArrayView2<F>,
76    method: &str,
77    config: &ParallelCorrelationConfig,
78) -> StatsResult<Array2<F>>
79where
80    F: Float
81        + NumCast
82        + SimdUnifiedOps
83        + Zero
84        + One
85        + Copy
86        + Send
87        + Sync
88        + std::iter::Sum<F>
89        + std::fmt::Debug
90        + std::fmt::Display,
91{
92    // Validate inputs
93    checkarray_finite_2d(data, "data")?;
94
95    match method {
96        "pearson" | "spearman" | "kendall" => {}
97        _ => {
98            return Err(StatsError::InvalidArgument(format!(
99                "Method must be 'pearson', 'spearman', or 'kendall', got {}",
100                method
101            )))
102        }
103    }
104
105    let (n_obs, n_vars) = data.dim();
106
107    if n_obs == 0 || n_vars == 0 {
108        return Err(StatsError::InvalidArgument(
109            "Data array cannot be empty".to_string(),
110        ));
111    }
112
113    // Initialize correlation matrix
114    let mut corr_mat = Array2::<F>::zeros((n_vars, n_vars));
115
116    // Set diagonal elements to 1
117    for i in 0..n_vars {
118        corr_mat[[i, i]] = F::one();
119    }
120
121    // Generate upper triangular pairs for parallel processing
122    let mut pairs = Vec::new();
123    for i in 0..n_vars {
124        for j in (i + 1)..n_vars {
125            pairs.push((i, j));
126        }
127    }
128
129    // Decide whether to use parallel processing
130    let use_parallel = n_vars >= config.min_parallelsize;
131
132    if use_parallel {
133        // Parallel processing with result collection
134        let chunksize = config
135            .chunksize
136            .unwrap_or(std::cmp::max(1, pairs.len() / 4));
137
138        // Process pairs in parallel and collect results
139        let results = Arc::new(Mutex::new(Vec::new()));
140
141        pairs.chunks(chunksize).for_each(|chunk| {
142            let mut local_results = Vec::new();
143
144            for &(i, j) in chunk {
145                let var_i = data.slice(s![.., i]);
146                let var_j = data.slice(s![.., j]);
147
148                let corr = match method {
149                    "pearson" => {
150                        if config.use_simd {
151                            match pearson_r_simd_enhanced(&var_i, &var_j) {
152                                Ok(val) => val,
153                                Err(_) => continue,
154                            }
155                        } else {
156                            match pearson_r(&var_i, &var_j) {
157                                Ok(val) => val,
158                                Err(_) => continue,
159                            }
160                        }
161                    }
162                    "spearman" => match spearman_r(&var_i, &var_j) {
163                        Ok(val) => val,
164                        Err(_) => continue,
165                    },
166                    "kendall" => match kendall_tau(&var_i, &var_j, "b") {
167                        Ok(val) => val,
168                        Err(_) => continue,
169                    },
170                    _ => unreachable!(),
171                };
172
173                local_results.push((i, j, corr));
174            }
175
176            let mut global_results = results.lock().unwrap();
177            global_results.extend(local_results);
178        });
179
180        let all_results = Arc::try_unwrap(results).unwrap().into_inner().unwrap();
181
182        // Write results back to matrix
183        for (i, j, corr) in all_results {
184            corr_mat[[i, j]] = corr;
185            corr_mat[[j, i]] = corr; // Symmetric
186        }
187    } else {
188        // Sequential processing for smaller matrices
189        for (i, j) in pairs {
190            let var_i = data.slice(s![.., i]);
191            let var_j = data.slice(s![.., j]);
192
193            let corr = match method {
194                "pearson" => {
195                    if config.use_simd {
196                        pearson_r_simd_enhanced(&var_i, &var_j)?
197                    } else {
198                        pearson_r(&var_i, &var_j)?
199                    }
200                }
201                "spearman" => spearman_r(&var_i, &var_j)?,
202                "kendall" => kendall_tau(&var_i, &var_j, "b")?,
203                _ => unreachable!(),
204            };
205
206            corr_mat[[i, j]] = corr;
207            corr_mat[[j, i]] = corr; // Symmetric
208        }
209    }
210
211    Ok(corr_mat)
212}
213
214/// SIMD-enhanced Pearson correlation computation
215///
216/// Optimized version of Pearson correlation using SIMD operations
217/// for improved performance on large datasets.
218#[allow(dead_code)]
219pub fn pearson_r_simd_enhanced<F, D>(x: &ArrayBase<D, Ix1>, y: &ArrayBase<D, Ix1>) -> StatsResult<F>
220where
221    F: Float + NumCast + SimdUnifiedOps + Zero + One + Copy + std::iter::Sum<F>,
222    D: Data<Elem = F>,
223{
224    // Check dimensions
225    if x.len() != y.len() {
226        return Err(StatsError::DimensionMismatch(
227            "Arrays must have the same length".to_string(),
228        ));
229    }
230
231    if x.is_empty() {
232        return Err(StatsError::InvalidArgument(
233            "Arrays cannot be empty".to_string(),
234        ));
235    }
236
237    let n = x.len();
238    let n_f = F::from(n).unwrap();
239    let optimizer = AutoOptimizer::new();
240
241    // Use SIMD for mean calculations if beneficial
242    let (mean_x, mean_y) = if optimizer.should_use_simd(n) {
243        let sum_x = F::simd_sum(&x.view());
244        let sum_y = F::simd_sum(&y.view());
245        (sum_x / n_f, sum_y / n_f)
246    } else {
247        let mean_x = x.iter().fold(F::zero(), |acc, &val| acc + val) / n_f;
248        let mean_y = y.iter().fold(F::zero(), |acc, &val| acc + val) / n_f;
249        (mean_x, mean_y)
250    };
251
252    // SIMD-optimized correlation calculation
253    let (sum_xy, sum_x2, sum_y2) = if optimizer.should_use_simd(n) {
254        // Create arrays with means for SIMD subtraction
255        let mean_x_array = Array1::from_elem(n, mean_x);
256        let mean_y_array = Array1::from_elem(n, mean_y);
257
258        // Compute deviations
259        let x_dev = F::simd_sub(&x.view(), &mean_x_array.view());
260        let y_dev = F::simd_sub(&y.view(), &mean_y_array.view());
261
262        // Compute products and squares
263        let xy_prod = F::simd_mul(&x_dev.view(), &y_dev.view());
264        let x_sq = F::simd_mul(&x_dev.view(), &x_dev.view());
265        let y_sq = F::simd_mul(&y_dev.view(), &y_dev.view());
266
267        // Sum the results
268        let sum_xy = F::simd_sum(&xy_prod.view());
269        let sum_x2 = F::simd_sum(&x_sq.view());
270        let sum_y2 = F::simd_sum(&y_sq.view());
271
272        (sum_xy, sum_x2, sum_y2)
273    } else {
274        // Scalar fallback
275        let mut sum_xy = F::zero();
276        let mut sum_x2 = F::zero();
277        let mut sum_y2 = F::zero();
278
279        for i in 0..n {
280            let x_dev = x[i] - mean_x;
281            let y_dev = y[i] - mean_y;
282
283            sum_xy = sum_xy + x_dev * y_dev;
284            sum_x2 = sum_x2 + x_dev * x_dev;
285            sum_y2 = sum_y2 + y_dev * y_dev;
286        }
287
288        (sum_xy, sum_x2, sum_y2)
289    };
290
291    // Check for zero variances
292    if sum_x2 <= F::epsilon() || sum_y2 <= F::epsilon() {
293        return Err(StatsError::InvalidArgument(
294            "Cannot compute correlation when one or both variables have zero variance".to_string(),
295        ));
296    }
297
298    // Calculate correlation coefficient
299    let corr = sum_xy / (sum_x2 * sum_y2).sqrt();
300
301    // Clamp to valid range [-1, 1]
302    let corr = if corr > F::one() {
303        F::one()
304    } else if corr < -F::one() {
305        -F::one()
306    } else {
307        corr
308    };
309
310    Ok(corr)
311}
312
313/// Parallel batch correlation computation
314///
315/// Computes correlations between multiple pairs of arrays in parallel,
316/// useful for large-scale correlation analysis.
317///
318/// # Arguments
319///
320/// * `pairs` - Vector of array pairs to correlate
321/// * `method` - Correlation method
322/// * `config` - Parallel processing configuration
323///
324/// # Returns
325///
326/// * Vector of correlation coefficients in the same order as input pairs
327#[allow(dead_code)]
328pub fn batch_correlations_parallel<'a, F>(
329    pairs: &[(ArrayView1<'a, F>, ArrayView1<'a, F>)],
330    method: &str,
331    config: &ParallelCorrelationConfig,
332) -> StatsResult<Vec<F>>
333where
334    F: Float
335        + NumCast
336        + SimdUnifiedOps
337        + Zero
338        + One
339        + Copy
340        + Send
341        + Sync
342        + std::iter::Sum<F>
343        + std::fmt::Debug
344        + std::fmt::Display,
345{
346    if pairs.is_empty() {
347        return Ok(Vec::new());
348    }
349
350    // Validate method
351    match method {
352        "pearson" | "spearman" | "kendall" => {}
353        _ => {
354            return Err(StatsError::InvalidArgument(format!(
355                "Method must be 'pearson', 'spearman', or 'kendall', got {}",
356                method
357            )))
358        }
359    }
360
361    let n_pairs = pairs.len();
362    let use_parallel = n_pairs >= config.min_parallelsize.min(10); // Lower threshold for batch operations
363
364    if use_parallel {
365        // Parallel processing with chunking
366        let chunksize = config.chunksize.unwrap_or(std::cmp::max(1, n_pairs / 4));
367
368        let results = Arc::new(Mutex::new(Vec::new()));
369        let error_occurred = Arc::new(Mutex::new(false));
370
371        pairs.chunks(chunksize).for_each(|chunk| {
372            let mut local_results = Vec::new();
373            let mut has_error = false;
374
375            for (x, y) in chunk {
376                let corr = match method {
377                    "pearson" => {
378                        if config.use_simd {
379                            pearson_r_simd_enhanced(x, y)
380                        } else {
381                            pearson_r(x, y)
382                        }
383                    }
384                    "spearman" => spearman_r(x, y),
385                    "kendall" => kendall_tau(x, y, "b"),
386                    _ => unreachable!(),
387                };
388
389                match corr {
390                    Ok(val) => local_results.push(val),
391                    Err(_) => {
392                        has_error = true;
393                        break;
394                    }
395                }
396            }
397
398            if has_error {
399                *error_occurred.lock().unwrap() = true;
400            } else {
401                results.lock().unwrap().extend(local_results);
402            }
403        });
404
405        if *error_occurred.lock().unwrap() {
406            return Err(StatsError::InvalidArgument(
407                "Error occurred during batch correlation computation".to_string(),
408            ));
409        }
410
411        let final_results = Arc::try_unwrap(results).unwrap().into_inner().unwrap();
412        Ok(final_results)
413    } else {
414        // Sequential processing
415        let mut results = Vec::with_capacity(n_pairs);
416
417        for (x, y) in pairs {
418            let corr = match method {
419                "pearson" => {
420                    if config.use_simd {
421                        pearson_r_simd_enhanced(x, y)?
422                    } else {
423                        pearson_r(x, y)?
424                    }
425                }
426                "spearman" => spearman_r(x, y)?,
427                "kendall" => kendall_tau(x, y, "b")?,
428                _ => unreachable!(),
429            };
430            results.push(corr);
431        }
432
433        Ok(results)
434    }
435}
436
437/// Rolling correlation computation with parallel processing
438///
439/// Computes rolling correlations between two time series using
440/// parallel processing for multiple windows.
441#[allow(dead_code)]
442pub fn rolling_correlation_parallel<F>(
443    x: &ArrayView1<F>,
444    y: &ArrayView1<F>,
445    windowsize: usize,
446    method: &str,
447    config: &ParallelCorrelationConfig,
448) -> StatsResult<Array1<F>>
449where
450    F: Float
451        + NumCast
452        + SimdUnifiedOps
453        + Zero
454        + One
455        + Copy
456        + Send
457        + Sync
458        + std::iter::Sum<F>
459        + std::fmt::Debug
460        + std::fmt::Display,
461{
462    if x.len() != y.len() {
463        return Err(StatsError::DimensionMismatch(format!(
464            "x and y must have the same length, got {} and {}",
465            x.len(),
466            y.len()
467        )));
468    }
469    check_positive(windowsize, "windowsize")?;
470
471    if windowsize > x.len() {
472        return Err(StatsError::InvalidArgument(
473            "Window size cannot be larger than data length".to_string(),
474        ));
475    }
476
477    let n_windows = x.len() - windowsize + 1;
478    let mut results = Array1::zeros(n_windows);
479
480    // Generate window pairs
481    let window_pairs: Vec<_> = (0..n_windows)
482        .map(|i| {
483            let x_window = x.slice(s![i..i + windowsize]);
484            let y_window = y.slice(s![i..i + windowsize]);
485            (x_window, y_window)
486        })
487        .collect();
488
489    // Compute correlations in parallel
490    let correlations = batch_correlations_parallel(&window_pairs, method, config)?;
491
492    // Copy results
493    for (i, corr) in correlations.into_iter().enumerate() {
494        results[i] = corr;
495    }
496
497    Ok(results)
498}
499
500// Helper function for 2D array validation
501#[allow(dead_code)]
502fn checkarray_finite_2d<F, D>(arr: &ArrayBase<D, Ix2>, name: &str) -> StatsResult<()>
503where
504    F: Float,
505    D: Data<Elem = F>,
506{
507    for &val in arr.iter() {
508        if !val.is_finite() {
509            return Err(StatsError::InvalidArgument(format!(
510                "{} contains non-finite values",
511                name
512            )));
513        }
514    }
515    Ok(())
516}
517
518#[cfg(test)]
519mod tests {
520    use super::*;
521    use crate::corrcoef;
522    use scirs2_core::ndarray::array;
523
524    #[test]
525    fn test_corrcoef_parallel_enhanced_consistency() {
526        let data = array![
527            [1.0, 5.0, 10.0],
528            [2.0, 4.0, 9.0],
529            [3.0, 3.0, 8.0],
530            [4.0, 2.0, 7.0],
531            [5.0, 1.0, 6.0]
532        ];
533
534        let config = ParallelCorrelationConfig::default();
535        let parallel_result = corrcoef_parallel_enhanced(&data.view(), "pearson", &config).unwrap();
536        let sequential_result = corrcoef(&data.view(), "pearson").unwrap();
537
538        for i in 0..3 {
539            for j in 0..3 {
540                assert!(
541                    (parallel_result[[i, j]] - sequential_result[[i, j]]).abs() < 1e-10,
542                    "Mismatch at [{}, {}]: parallel {} vs sequential {}",
543                    i,
544                    j,
545                    parallel_result[[i, j]],
546                    sequential_result[[i, j]]
547                );
548            }
549        }
550    }
551
552    #[test]
553    fn test_pearson_r_simd_enhanced_consistency() {
554        let x = array![1.0, 2.0, 3.0, 4.0, 5.0];
555        let y = array![5.0, 4.0, 3.0, 2.0, 1.0];
556
557        let simd_result = pearson_r_simd_enhanced(&x.view(), &y.view()).unwrap();
558        let standard_result = pearson_r(&x.view(), &y.view()).unwrap();
559
560        assert!((simd_result - standard_result).abs() < 1e-10);
561    }
562
563    #[test]
564    fn test_batch_correlations_parallel() {
565        let x1 = array![1.0, 2.0, 3.0, 4.0, 5.0];
566        let y1 = array![5.0, 4.0, 3.0, 2.0, 1.0];
567        let x2 = array![1.0, 2.0, 3.0, 4.0, 5.0];
568        let y2 = array![2.0, 4.0, 6.0, 8.0, 10.0];
569
570        let pairs = vec![(x1.view(), y1.view()), (x2.view(), y2.view())];
571        let config = ParallelCorrelationConfig::default();
572
573        let results = batch_correlations_parallel(&pairs, "pearson", &config).unwrap();
574
575        assert_eq!(results.len(), 2);
576        assert!((results[0] - (-1.0)).abs() < 1e-10); // Perfect negative correlation
577        assert!((results[1] - 1.0).abs() < 1e-10); // Perfect positive correlation
578    }
579
580    #[test]
581    fn test_rolling_correlation_parallel() {
582        let x = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
583        let y = array![10.0, 9.0, 8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0];
584
585        let config = ParallelCorrelationConfig::default();
586        let rolling_corrs =
587            rolling_correlation_parallel(&x.view(), &y.view(), 3, "pearson", &config).unwrap();
588
589        assert_eq!(rolling_corrs.len(), 8); // 10 - 3 + 1
590
591        // All rolling correlations should be negative (x increases, y decreases)
592        for corr in rolling_corrs.iter() {
593            assert!(*corr < 0.0);
594        }
595    }
596}