Skip to main content

scirs2_stats/
correlation_parallel_enhanced.rs

1//! Enhanced parallel correlation computations
2//!
3//! This module provides SIMD and parallel-accelerated implementations of correlation
4//! operations using scirs2-core's unified optimization framework.
5
6use crate::error::{StatsError, StatsResult};
7use crate::{kendall_tau, pearson_r, spearman_r};
8use scirs2_core::ndarray::{s, Array1, Array2, ArrayBase, ArrayView1, ArrayView2, Data, Ix1, Ix2};
9use scirs2_core::numeric::{Float, NumCast, One, Zero};
10use scirs2_core::{
11    simd_ops::{AutoOptimizer, SimdUnifiedOps},
12    validation::*,
13};
14use std::sync::{Arc, Mutex};
15
16/// Parallel configuration for correlation computations
17#[derive(Debug, Clone)]
18pub struct ParallelCorrelationConfig {
19    /// Minimum matrix size to trigger parallel processing
20    pub min_parallelsize: usize,
21    /// Chunk size for parallel processing
22    pub chunksize: Option<usize>,
23    /// Enable SIMD optimizations
24    pub use_simd: bool,
25    /// Use work stealing for load balancing
26    pub work_stealing: bool,
27}
28
29impl Default for ParallelCorrelationConfig {
30    fn default() -> Self {
31        Self {
32            min_parallelsize: 50, // 50x50 matrix threshold
33            chunksize: None,      // Auto-determine
34            use_simd: true,
35            work_stealing: true,
36        }
37    }
38}
39
40/// Parallel and SIMD-optimized correlation matrix computation
41///
42/// Computes pairwise correlations between all variables in a matrix using
43/// parallel processing for the correlation pairs and SIMD for individual
44/// correlation calculations.
45///
46/// # Arguments
47///
48/// * `data` - Input data matrix (observations × variables)
49/// * `method` - Correlation method ("pearson", "spearman", "kendall")
50/// * `config` - Parallel processing configuration
51///
52/// # Returns
53///
54/// * Correlation matrix (variables × variables)
55///
56/// # Examples
57///
58/// ```
59/// use scirs2_core::ndarray::array;
60/// use scirs2_stats::{corrcoef_parallel_enhanced, ParallelCorrelationConfig};
61///
62/// let data = array![
63///     [1.0, 5.0, 10.0],
64///     [2.0, 4.0, 9.0],
65///     [3.0, 3.0, 8.0],
66///     [4.0, 2.0, 7.0],
67///     [5.0, 1.0, 6.0]
68/// ];
69///
70/// let config = ParallelCorrelationConfig::default();
71/// let corr_matrix = corrcoef_parallel_enhanced(&data.view(), "pearson", &config).expect("Operation failed");
72/// ```
73#[allow(dead_code)]
74pub fn corrcoef_parallel_enhanced<F>(
75    data: &ArrayView2<F>,
76    method: &str,
77    config: &ParallelCorrelationConfig,
78) -> StatsResult<Array2<F>>
79where
80    F: Float
81        + NumCast
82        + SimdUnifiedOps
83        + Zero
84        + One
85        + Copy
86        + Send
87        + Sync
88        + std::iter::Sum<F>
89        + std::fmt::Debug
90        + std::fmt::Display
91        + 'static,
92{
93    // Validate inputs
94    checkarray_finite_2d(data, "data")?;
95
96    match method {
97        "pearson" | "spearman" | "kendall" => {}
98        _ => {
99            return Err(StatsError::InvalidArgument(format!(
100                "Method must be 'pearson', 'spearman', or 'kendall', got {}",
101                method
102            )))
103        }
104    }
105
106    let (n_obs, n_vars) = data.dim();
107
108    if n_obs == 0 || n_vars == 0 {
109        return Err(StatsError::InvalidArgument(
110            "Data array cannot be empty".to_string(),
111        ));
112    }
113
114    // Initialize correlation matrix
115    let mut corr_mat = Array2::<F>::zeros((n_vars, n_vars));
116
117    // Set diagonal elements to 1
118    for i in 0..n_vars {
119        corr_mat[[i, i]] = F::one();
120    }
121
122    // Generate upper triangular pairs for parallel processing
123    let mut pairs = Vec::new();
124    for i in 0..n_vars {
125        for j in (i + 1)..n_vars {
126            pairs.push((i, j));
127        }
128    }
129
130    // Decide whether to use parallel processing
131    let use_parallel = n_vars >= config.min_parallelsize;
132
133    if use_parallel {
134        // Parallel processing with result collection
135        let chunksize = config
136            .chunksize
137            .unwrap_or(std::cmp::max(1, pairs.len() / 4));
138
139        // Process pairs in parallel and collect results
140        let results = Arc::new(Mutex::new(Vec::new()));
141
142        pairs.chunks(chunksize).for_each(|chunk| {
143            let mut local_results = Vec::new();
144
145            for &(i, j) in chunk {
146                let var_i = data.slice(s![.., i]);
147                let var_j = data.slice(s![.., j]);
148
149                let corr = match method {
150                    "pearson" => {
151                        if config.use_simd {
152                            match pearson_r_simd_enhanced(&var_i, &var_j) {
153                                Ok(val) => val,
154                                Err(_) => continue,
155                            }
156                        } else {
157                            match pearson_r(&var_i, &var_j) {
158                                Ok(val) => val,
159                                Err(_) => continue,
160                            }
161                        }
162                    }
163                    "spearman" => match spearman_r(&var_i, &var_j) {
164                        Ok(val) => val,
165                        Err(_) => continue,
166                    },
167                    "kendall" => match kendall_tau(&var_i, &var_j, "b") {
168                        Ok(val) => val,
169                        Err(_) => continue,
170                    },
171                    _ => unreachable!(),
172                };
173
174                local_results.push((i, j, corr));
175            }
176
177            let mut global_results = results.lock().expect("Operation failed");
178            global_results.extend(local_results);
179        });
180
181        let all_results = Arc::try_unwrap(results)
182            .expect("Operation failed")
183            .into_inner()
184            .expect("Operation failed");
185
186        // Write results back to matrix
187        for (i, j, corr) in all_results {
188            corr_mat[[i, j]] = corr;
189            corr_mat[[j, i]] = corr; // Symmetric
190        }
191    } else {
192        // Sequential processing for smaller matrices
193        for (i, j) in pairs {
194            let var_i = data.slice(s![.., i]);
195            let var_j = data.slice(s![.., j]);
196
197            let corr = match method {
198                "pearson" => {
199                    if config.use_simd {
200                        pearson_r_simd_enhanced(&var_i, &var_j)?
201                    } else {
202                        pearson_r(&var_i, &var_j)?
203                    }
204                }
205                "spearman" => spearman_r(&var_i, &var_j)?,
206                "kendall" => kendall_tau(&var_i, &var_j, "b")?,
207                _ => unreachable!(),
208            };
209
210            corr_mat[[i, j]] = corr;
211            corr_mat[[j, i]] = corr; // Symmetric
212        }
213    }
214
215    Ok(corr_mat)
216}
217
218/// SIMD-enhanced Pearson correlation computation
219///
220/// Optimized version of Pearson correlation using SIMD operations
221/// for improved performance on large datasets.
222#[allow(dead_code)]
223pub fn pearson_r_simd_enhanced<F, D>(x: &ArrayBase<D, Ix1>, y: &ArrayBase<D, Ix1>) -> StatsResult<F>
224where
225    F: Float + NumCast + SimdUnifiedOps + Zero + One + Copy + std::iter::Sum<F>,
226    D: Data<Elem = F>,
227{
228    // Check dimensions
229    if x.len() != y.len() {
230        return Err(StatsError::DimensionMismatch(
231            "Arrays must have the same length".to_string(),
232        ));
233    }
234
235    if x.is_empty() {
236        return Err(StatsError::InvalidArgument(
237            "Arrays cannot be empty".to_string(),
238        ));
239    }
240
241    let n = x.len();
242    let n_f = F::from(n).expect("Failed to convert to float");
243    let optimizer = AutoOptimizer::new();
244
245    // Use SIMD for mean calculations if beneficial
246    let (mean_x, mean_y) = if optimizer.should_use_simd(n) {
247        let sum_x = F::simd_sum(&x.view());
248        let sum_y = F::simd_sum(&y.view());
249        (sum_x / n_f, sum_y / n_f)
250    } else {
251        let mean_x = x.iter().fold(F::zero(), |acc, &val| acc + val) / n_f;
252        let mean_y = y.iter().fold(F::zero(), |acc, &val| acc + val) / n_f;
253        (mean_x, mean_y)
254    };
255
256    // SIMD-optimized correlation calculation
257    let (sum_xy, sum_x2, sum_y2) = if optimizer.should_use_simd(n) {
258        // Create arrays with means for SIMD subtraction
259        let mean_x_array = Array1::from_elem(n, mean_x);
260        let mean_y_array = Array1::from_elem(n, mean_y);
261
262        // Compute deviations
263        let x_dev = F::simd_sub(&x.view(), &mean_x_array.view());
264        let y_dev = F::simd_sub(&y.view(), &mean_y_array.view());
265
266        // Compute products and squares
267        let xy_prod = F::simd_mul(&x_dev.view(), &y_dev.view());
268        let x_sq = F::simd_mul(&x_dev.view(), &x_dev.view());
269        let y_sq = F::simd_mul(&y_dev.view(), &y_dev.view());
270
271        // Sum the results
272        let sum_xy = F::simd_sum(&xy_prod.view());
273        let sum_x2 = F::simd_sum(&x_sq.view());
274        let sum_y2 = F::simd_sum(&y_sq.view());
275
276        (sum_xy, sum_x2, sum_y2)
277    } else {
278        // Scalar fallback
279        let mut sum_xy = F::zero();
280        let mut sum_x2 = F::zero();
281        let mut sum_y2 = F::zero();
282
283        for i in 0..n {
284            let x_dev = x[i] - mean_x;
285            let y_dev = y[i] - mean_y;
286
287            sum_xy = sum_xy + x_dev * y_dev;
288            sum_x2 = sum_x2 + x_dev * x_dev;
289            sum_y2 = sum_y2 + y_dev * y_dev;
290        }
291
292        (sum_xy, sum_x2, sum_y2)
293    };
294
295    // Check for zero variances
296    if sum_x2 <= F::epsilon() || sum_y2 <= F::epsilon() {
297        return Err(StatsError::InvalidArgument(
298            "Cannot compute correlation when one or both variables have zero variance".to_string(),
299        ));
300    }
301
302    // Calculate correlation coefficient
303    let corr = sum_xy / (sum_x2 * sum_y2).sqrt();
304
305    // Clamp to valid range [-1, 1]
306    let corr = if corr > F::one() {
307        F::one()
308    } else if corr < -F::one() {
309        -F::one()
310    } else {
311        corr
312    };
313
314    Ok(corr)
315}
316
317/// Parallel batch correlation computation
318///
319/// Computes correlations between multiple pairs of arrays in parallel,
320/// useful for large-scale correlation analysis.
321///
322/// # Arguments
323///
324/// * `pairs` - Vector of array pairs to correlate
325/// * `method` - Correlation method
326/// * `config` - Parallel processing configuration
327///
328/// # Returns
329///
330/// * Vector of correlation coefficients in the same order as input pairs
331#[allow(dead_code)]
332pub fn batch_correlations_parallel<'a, F>(
333    pairs: &[(ArrayView1<'a, F>, ArrayView1<'a, F>)],
334    method: &str,
335    config: &ParallelCorrelationConfig,
336) -> StatsResult<Vec<F>>
337where
338    F: Float
339        + NumCast
340        + SimdUnifiedOps
341        + Zero
342        + One
343        + Copy
344        + Send
345        + Sync
346        + std::iter::Sum<F>
347        + std::fmt::Debug
348        + std::fmt::Display
349        + 'static,
350{
351    if pairs.is_empty() {
352        return Ok(Vec::new());
353    }
354
355    // Validate method
356    match method {
357        "pearson" | "spearman" | "kendall" => {}
358        _ => {
359            return Err(StatsError::InvalidArgument(format!(
360                "Method must be 'pearson', 'spearman', or 'kendall', got {}",
361                method
362            )))
363        }
364    }
365
366    let n_pairs = pairs.len();
367    let use_parallel = n_pairs >= config.min_parallelsize.min(10); // Lower threshold for batch operations
368
369    if use_parallel {
370        // Parallel processing with chunking
371        let chunksize = config.chunksize.unwrap_or(std::cmp::max(1, n_pairs / 4));
372
373        let results = Arc::new(Mutex::new(Vec::new()));
374        let error_occurred = Arc::new(Mutex::new(false));
375
376        pairs.chunks(chunksize).for_each(|chunk| {
377            let mut local_results = Vec::new();
378            let mut has_error = false;
379
380            for (x, y) in chunk {
381                let corr = match method {
382                    "pearson" => {
383                        if config.use_simd {
384                            pearson_r_simd_enhanced(x, y)
385                        } else {
386                            pearson_r(x, y)
387                        }
388                    }
389                    "spearman" => spearman_r(x, y),
390                    "kendall" => kendall_tau(x, y, "b"),
391                    _ => unreachable!(),
392                };
393
394                match corr {
395                    Ok(val) => local_results.push(val),
396                    Err(_) => {
397                        has_error = true;
398                        break;
399                    }
400                }
401            }
402
403            if has_error {
404                *error_occurred.lock().expect("Operation failed") = true;
405            } else {
406                results
407                    .lock()
408                    .expect("Operation failed")
409                    .extend(local_results);
410            }
411        });
412
413        if *error_occurred.lock().expect("Operation failed") {
414            return Err(StatsError::InvalidArgument(
415                "Error occurred during batch correlation computation".to_string(),
416            ));
417        }
418
419        let final_results = Arc::try_unwrap(results)
420            .expect("Operation failed")
421            .into_inner()
422            .expect("Operation failed");
423        Ok(final_results)
424    } else {
425        // Sequential processing
426        let mut results = Vec::with_capacity(n_pairs);
427
428        for (x, y) in pairs {
429            let corr = match method {
430                "pearson" => {
431                    if config.use_simd {
432                        pearson_r_simd_enhanced(x, y)?
433                    } else {
434                        pearson_r(x, y)?
435                    }
436                }
437                "spearman" => spearman_r(x, y)?,
438                "kendall" => kendall_tau(x, y, "b")?,
439                _ => unreachable!(),
440            };
441            results.push(corr);
442        }
443
444        Ok(results)
445    }
446}
447
448/// Rolling correlation computation with parallel processing
449///
450/// Computes rolling correlations between two time series using
451/// parallel processing for multiple windows.
452#[allow(dead_code)]
453pub fn rolling_correlation_parallel<F>(
454    x: &ArrayView1<F>,
455    y: &ArrayView1<F>,
456    windowsize: usize,
457    method: &str,
458    config: &ParallelCorrelationConfig,
459) -> StatsResult<Array1<F>>
460where
461    F: Float
462        + NumCast
463        + SimdUnifiedOps
464        + Zero
465        + One
466        + Copy
467        + Send
468        + Sync
469        + std::iter::Sum<F>
470        + std::fmt::Debug
471        + std::fmt::Display
472        + 'static,
473{
474    if x.len() != y.len() {
475        return Err(StatsError::DimensionMismatch(format!(
476            "x and y must have the same length, got {} and {}",
477            x.len(),
478            y.len()
479        )));
480    }
481    check_positive(windowsize, "windowsize")?;
482
483    if windowsize > x.len() {
484        return Err(StatsError::InvalidArgument(
485            "Window size cannot be larger than data length".to_string(),
486        ));
487    }
488
489    let n_windows = x.len() - windowsize + 1;
490    let mut results = Array1::zeros(n_windows);
491
492    // Generate window pairs
493    let window_pairs: Vec<_> = (0..n_windows)
494        .map(|i| {
495            let x_window = x.slice(s![i..i + windowsize]);
496            let y_window = y.slice(s![i..i + windowsize]);
497            (x_window, y_window)
498        })
499        .collect();
500
501    // Compute correlations in parallel
502    let correlations = batch_correlations_parallel(&window_pairs, method, config)?;
503
504    // Copy results
505    for (i, corr) in correlations.into_iter().enumerate() {
506        results[i] = corr;
507    }
508
509    Ok(results)
510}
511
512// Helper function for 2D array validation
513#[allow(dead_code)]
514fn checkarray_finite_2d<F, D>(arr: &ArrayBase<D, Ix2>, name: &str) -> StatsResult<()>
515where
516    F: Float,
517    D: Data<Elem = F>,
518{
519    for &val in arr.iter() {
520        if !val.is_finite() {
521            return Err(StatsError::InvalidArgument(format!(
522                "{} contains non-finite values",
523                name
524            )));
525        }
526    }
527    Ok(())
528}
529
530#[cfg(test)]
531mod tests {
532    use super::*;
533    use crate::corrcoef;
534    use scirs2_core::ndarray::array;
535
536    #[test]
537    fn test_corrcoef_parallel_enhanced_consistency() {
538        let data = array![
539            [1.0, 5.0, 10.0],
540            [2.0, 4.0, 9.0],
541            [3.0, 3.0, 8.0],
542            [4.0, 2.0, 7.0],
543            [5.0, 1.0, 6.0]
544        ];
545
546        let config = ParallelCorrelationConfig::default();
547        let parallel_result =
548            corrcoef_parallel_enhanced(&data.view(), "pearson", &config).expect("Operation failed");
549        let sequential_result = corrcoef(&data.view(), "pearson").expect("Operation failed");
550
551        for i in 0..3 {
552            for j in 0..3 {
553                assert!(
554                    (parallel_result[[i, j]] - sequential_result[[i, j]]).abs() < 1e-10,
555                    "Mismatch at [{}, {}]: parallel {} vs sequential {}",
556                    i,
557                    j,
558                    parallel_result[[i, j]],
559                    sequential_result[[i, j]]
560                );
561            }
562        }
563    }
564
565    #[test]
566    fn test_pearson_r_simd_enhanced_consistency() {
567        let x = array![1.0, 2.0, 3.0, 4.0, 5.0];
568        let y = array![5.0, 4.0, 3.0, 2.0, 1.0];
569
570        let simd_result = pearson_r_simd_enhanced(&x.view(), &y.view()).expect("Operation failed");
571        let standard_result = pearson_r(&x.view(), &y.view()).expect("Operation failed");
572
573        assert!((simd_result - standard_result).abs() < 1e-10);
574    }
575
576    #[test]
577    fn test_batch_correlations_parallel() {
578        let x1 = array![1.0, 2.0, 3.0, 4.0, 5.0];
579        let y1 = array![5.0, 4.0, 3.0, 2.0, 1.0];
580        let x2 = array![1.0, 2.0, 3.0, 4.0, 5.0];
581        let y2 = array![2.0, 4.0, 6.0, 8.0, 10.0];
582
583        let pairs = vec![(x1.view(), y1.view()), (x2.view(), y2.view())];
584        let config = ParallelCorrelationConfig::default();
585
586        let results =
587            batch_correlations_parallel(&pairs, "pearson", &config).expect("Operation failed");
588
589        assert_eq!(results.len(), 2);
590        assert!((results[0] - (-1.0)).abs() < 1e-10); // Perfect negative correlation
591        assert!((results[1] - 1.0).abs() < 1e-10); // Perfect positive correlation
592    }
593
594    #[test]
595    fn test_rolling_correlation_parallel() {
596        let x = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
597        let y = array![10.0, 9.0, 8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0];
598
599        let config = ParallelCorrelationConfig::default();
600        let rolling_corrs =
601            rolling_correlation_parallel(&x.view(), &y.view(), 3, "pearson", &config)
602                .expect("Operation failed");
603
604        assert_eq!(rolling_corrs.len(), 8); // 10 - 3 + 1
605
606        // All rolling correlations should be negative (x increases, y decreases)
607        for corr in rolling_corrs.iter() {
608            assert!(*corr < 0.0);
609        }
610    }
611}