scirs2_stats/
parallel_enhanced_v2.rs

1//! Enhanced parallel processing for v1.0.0
2//!
3//! This module provides improved parallel implementations with:
4//! - Dynamic threshold adjustment
5//! - Better work distribution
6//! - Support for non-contiguous arrays
7//! - Task-based parallelism
8
9use crate::error::{StatsError, StatsResult};
10use scirs2_core::ndarray::{Array1, Array2, ArrayBase, ArrayView1, Data, Ix1, Ix2};
11use scirs2_core::numeric::{Float, NumCast};
12use scirs2_core::parallel_ops::{num_threads, par_chunks, IntoParallelIterator, ParallelIterator};
13use scirs2_core::validation::check_not_empty;
14use std::sync::Arc;
15
16/// Configuration for parallel operations
17#[derive(Debug, Clone)]
18pub struct ParallelConfig {
19    /// Minimum size for parallel execution
20    pub minsize: usize,
21    /// Target chunk size per thread
22    pub chunksize: Option<usize>,
23    /// Maximum number of threads to use
24    pub max_threads: Option<usize>,
25    /// Whether to use adaptive thresholds
26    pub adaptive: bool,
27}
28
29impl Default for ParallelConfig {
30    fn default() -> Self {
31        Self {
32            minsize: 5_000,    // Lower threshold than before
33            chunksize: None,   // Auto-determine
34            max_threads: None, // Use all available
35            adaptive: true,
36        }
37    }
38}
39
40impl ParallelConfig {
41    /// Create config with specific thread count
42    pub fn with_threads(mut self, threads: usize) -> Self {
43        self.max_threads = Some(threads);
44        self
45    }
46
47    /// Create config with specific chunk size
48    pub fn with_chunksize(mut self, size: usize) -> Self {
49        self.chunksize = Some(size);
50        self
51    }
52
53    /// Determine if parallel execution should be used
54    pub fn should_parallelize(&self, n: usize) -> bool {
55        if self.adaptive {
56            // Adaptive threshold based on system load and data size
57            let threads = self.max_threads.unwrap_or_else(num_threads);
58
59            // Dynamic overhead estimation based on available cores
60            let base_overhead = 800;
61            let overhead_factor = base_overhead + (threads.saturating_sub(1) * 200);
62
63            // For very large arrays, always parallelize
64            if n > 100_000 {
65                return true;
66            }
67
68            // For small arrays, prefer sequential
69            if n < 1_000 {
70                return false;
71            }
72
73            // Adaptive decision for medium arrays
74            n > threads * overhead_factor
75        } else {
76            n >= self.minsize
77        }
78    }
79
80    /// Get optimal chunk size for the given data size
81    pub fn get_chunksize(&self, n: usize) -> usize {
82        if let Some(size) = self.chunksize {
83            size
84        } else {
85            // Simple adaptive chunk size: divide data among available threads
86            let threads = self.max_threads.unwrap_or(num_threads());
87            (n / threads).max(1000)
88        }
89    }
90}
91
92/// Enhanced parallel mean computation
93///
94/// Handles non-contiguous arrays and provides better load balancing
95#[allow(dead_code)]
96pub fn mean_parallel_enhanced<F, D>(
97    x: &ArrayBase<D, Ix1>,
98    config: Option<ParallelConfig>,
99) -> StatsResult<F>
100where
101    F: Float + NumCast + Send + Sync + std::iter::Sum<F> + std::fmt::Display,
102    D: Data<Elem = F> + Sync,
103{
104    // Use scirs2-core validation
105    check_not_empty(x, "x")
106        .map_err(|_| StatsError::invalid_argument("Cannot compute mean of empty array"))?;
107
108    let config = config.unwrap_or_default();
109    let n = x.len();
110
111    if !config.should_parallelize(n) {
112        // Sequential computation
113        let sum = x.iter().fold(F::zero(), |acc, &val| acc + val);
114        return Ok(sum / F::from(n).unwrap());
115    }
116
117    // Parallel computation with better handling
118    let sum = if let Some(slice) = x.as_slice() {
119        // Contiguous array - use slice-based parallelism
120        parallel_sum_slice(slice, &config)
121    } else {
122        // Non-contiguous array - use index-based parallelism
123        parallel_sum_indexed(x, &config)
124    };
125
126    Ok(sum / F::from(n).unwrap())
127}
128
129/// Parallel variance with single-pass algorithm
130///
131/// Uses parallel Welford's algorithm for numerical stability
132#[allow(dead_code)]
133pub fn variance_parallel_enhanced<F, D>(
134    x: &ArrayBase<D, Ix1>,
135    ddof: usize,
136    config: Option<ParallelConfig>,
137) -> StatsResult<F>
138where
139    F: Float + NumCast + Send + Sync + std::iter::Sum<F> + std::fmt::Display,
140    D: Data<Elem = F> + Sync,
141{
142    let n = x.len();
143    if n <= ddof {
144        return Err(StatsError::invalid_argument(
145            "Not enough data points for the given degrees of freedom",
146        ));
147    }
148
149    let config = config.unwrap_or_default();
150
151    if !config.should_parallelize(n) {
152        // Use sequential Welford's algorithm
153        return variance_sequential_welford(x, ddof);
154    }
155
156    // Parallel Welford's algorithm
157    let chunksize = config.get_chunksize(n);
158    let n_chunks = n.div_ceil(chunksize);
159
160    // Each chunk computes local mean and M2
161    let chunk_stats: Vec<(F, F, usize)> = (0..n_chunks)
162        .collect::<Vec<_>>()
163        .into_par_iter()
164        .map(|chunk_idx| {
165            let start = chunk_idx * chunksize;
166            let end = (start + chunksize).min(n);
167
168            let mut local_mean = F::zero();
169            let mut local_m2 = F::zero();
170            let mut count = 0;
171
172            for i in start..end {
173                count += 1;
174                let val = x[i];
175                let delta = val - local_mean;
176                local_mean = local_mean + delta / F::from(count).unwrap();
177                let delta2 = val - local_mean;
178                local_m2 = local_m2 + delta * delta2;
179            }
180
181            (local_mean, local_m2, count)
182        })
183        .collect();
184
185    // Combine chunk statistics
186    let (_total_mean, total_m2__, total_count) = combine_welford_stats(&chunk_stats);
187
188    Ok(total_m2__ / F::from(n - ddof).unwrap())
189}
190
191/// Parallel correlation matrix computation
192///
193/// Efficiently computes correlation matrix for multivariate data
194#[allow(dead_code)]
195pub fn corrcoef_parallel_enhanced<F, D>(
196    data: &ArrayBase<D, Ix2>,
197    config: Option<ParallelConfig>,
198) -> StatsResult<Array2<F>>
199where
200    F: Float + NumCast + Send + Sync + std::iter::Sum<F> + std::fmt::Display,
201    D: Data<Elem = F> + Sync,
202{
203    let (n_samples_, n_features) = data.dim();
204
205    if n_samples_ == 0 || n_features == 0 {
206        return Err(StatsError::invalid_argument("Empty data matrix"));
207    }
208
209    let config = config.unwrap_or_default();
210
211    // Compute means for each feature in parallel
212    let means: Vec<F> = (0..n_features)
213        .collect::<Vec<_>>()
214        .into_par_iter()
215        .map(|j| {
216            let col = data.column(j);
217            mean_parallel_enhanced(&col, Some(config.clone())).unwrap_or(F::zero())
218        })
219        .collect();
220
221    // Compute correlation matrix in parallel
222    let mut corr_matrix = Array2::zeros((n_features, n_features));
223
224    // Only compute upper triangle (correlation matrix is symmetric)
225    let indices: Vec<(usize, usize)> = (0..n_features)
226        .flat_map(|i| (i..n_features).map(move |j| (i, j)))
227        .collect();
228
229    let correlations: Vec<((usize, usize), F)> = indices
230        .into_par_iter()
231        .map(|(i, j)| {
232            let corr = if i == j {
233                F::one() // Diagonal is always 1
234            } else {
235                compute_correlation_pair(&data.column(i), &data.column(j), means[i], means[j])
236            };
237            ((i, j), corr)
238        })
239        .collect();
240
241    // Fill the correlation matrix
242    for ((i, j), corr) in correlations {
243        corr_matrix[(i, j)] = corr;
244        if i != j {
245            corr_matrix[(j, i)] = corr; // Symmetric
246        }
247    }
248
249    Ok(corr_matrix)
250}
251
252/// Parallel bootstrap resampling
253///
254/// Generates bootstrap samples in parallel for faster computation
255#[allow(dead_code)]
256pub fn bootstrap_parallel_enhanced<F, D>(
257    data: &ArrayBase<D, Ix1>,
258    n_samples_: usize,
259    statistic_fn: impl Fn(&ArrayView1<F>) -> F + Send + Sync,
260    config: Option<ParallelConfig>,
261) -> StatsResult<Array1<F>>
262where
263    F: Float + NumCast + Send + Sync,
264    D: Data<Elem = F> + Sync,
265{
266    if data.is_empty() {
267        return Err(StatsError::invalid_argument("Cannot bootstrap empty data"));
268    }
269
270    let _config = config.unwrap_or_default();
271    let data_arc = Arc::new(data.to_owned());
272    let n = data.len();
273
274    // Generate bootstrap statistics in parallel
275    let stats: Vec<F> = (0..n_samples_)
276        .collect::<Vec<_>>()
277        .into_par_iter()
278        .map(|sample_idx| {
279            use scirs2_core::random::rngs::StdRng;
280            use scirs2_core::random::{Rng, SeedableRng};
281
282            // Create deterministic RNG for reproducibility
283            let mut rng = StdRng::seed_from_u64(sample_idx as u64);
284            let mut sample = Array1::zeros(n);
285
286            // Generate bootstrap sample
287            for i in 0..n {
288                let idx = rng.gen_range(0..n);
289                sample[i] = data_arc[idx];
290            }
291
292            statistic_fn(&sample.view())
293        })
294        .collect();
295
296    Ok(Array1::from(stats))
297}
298
299/// Helper function for parallel sum on slices
300#[allow(dead_code)]
301fn parallel_sum_slice<F>(slice: &[F], config: &ParallelConfig) -> F
302where
303    F: Float + NumCast + Send + Sync + std::iter::Sum + std::fmt::Display,
304{
305    let chunksize = config.get_chunksize(slice.len());
306
307    par_chunks(slice, chunksize)
308        .map(|chunk| chunk.iter().fold(F::zero(), |acc, &val| acc + val))
309        .reduce(|| F::zero(), |a, b| a + b)
310}
311
312/// Helper function for parallel sum on indexed arrays
313#[allow(dead_code)]
314fn parallel_sum_indexed<F, D>(arr: &ArrayBase<D, Ix1>, config: &ParallelConfig) -> F
315where
316    F: Float + NumCast + Send + Sync + std::iter::Sum<F> + std::fmt::Display,
317    D: Data<Elem = F> + Sync,
318{
319    let n = arr.len();
320    let chunksize = config.get_chunksize(n);
321    let n_chunks = n.div_ceil(chunksize);
322
323    (0..n_chunks)
324        .collect::<Vec<_>>()
325        .into_par_iter()
326        .map(|chunk_idx| {
327            let start = chunk_idx * chunksize;
328            let end = (start + chunksize).min(n);
329
330            (start..end)
331                .map(|i| arr[i])
332                .fold(F::zero(), |acc, val| acc + val)
333        })
334        .reduce(|| F::zero(), |a, b| a + b)
335}
336
337/// Sequential Welford's algorithm (fallback)
338#[allow(dead_code)]
339fn variance_sequential_welford<F, D>(x: &ArrayBase<D, Ix1>, ddof: usize) -> StatsResult<F>
340where
341    F: Float + NumCast,
342    D: Data<Elem = F>,
343{
344    let mut mean = F::zero();
345    let mut m2 = F::zero();
346    let mut count = 0;
347
348    for &val in x.iter() {
349        count += 1;
350        let delta = val - mean;
351        mean = mean + delta / F::from(count).unwrap();
352        let delta2 = val - mean;
353        m2 = m2 + delta * delta2;
354    }
355
356    Ok(m2 / F::from(count - ddof).unwrap())
357}
358
359/// Combine Welford statistics from parallel chunks
360#[allow(dead_code)]
361fn combine_welford_stats<F>(stats: &[(F, F, usize)]) -> (F, F, usize)
362where
363    F: Float + NumCast + std::fmt::Display,
364{
365    stats.iter().fold(
366        (F::zero(), F::zero(), 0),
367        |(mean_a, m2_a, count_a), &(mean_b, m2_b, count_b)| {
368            let count = count_a + count_b;
369            let delta = mean_b - mean_a;
370            let mean = mean_a + delta * F::from(count_b).unwrap() / F::from(count).unwrap();
371            let m2 = m2_a
372                + m2_b
373                + delta * delta * F::from(count_a).unwrap() * F::from(count_b).unwrap()
374                    / F::from(count).unwrap();
375            (mean, m2, count)
376        },
377    )
378}
379
380/// Compute correlation between two vectors
381#[allow(dead_code)]
382fn compute_correlation_pair<F>(x: &ArrayView1<F>, y: &ArrayView1<F>, mean_x: F, meany: F) -> F
383where
384    F: Float + NumCast + std::fmt::Display,
385{
386    let n = x.len();
387    let mut cov = F::zero();
388    let mut var_x = F::zero();
389    let mut var_y = F::zero();
390
391    for i in 0..n {
392        let dx = x[i] - mean_x;
393        let dy = y[i] - meany;
394        cov = cov + dx * dy;
395        var_x = var_x + dx * dx;
396        var_y = var_y + dy * dy;
397    }
398
399    if var_x > F::epsilon() && var_y > F::epsilon() {
400        cov / (var_x * var_y).sqrt()
401    } else {
402        F::zero()
403    }
404}
405
406#[cfg(test)]
407mod tests {
408    use super::*;
409    use scirs2_core::ndarray::array;
410
411    #[test]
412    fn test_parallel_config() {
413        let config = ParallelConfig::default();
414        assert!(config.should_parallelize(100_000));
415        assert!(!config.should_parallelize(100));
416
417        let config_fixed = ParallelConfig::default()
418            .with_threads(4)
419            .with_chunksize(1000);
420        assert_eq!(config_fixed.get_chunksize(10_000), 1000);
421    }
422
423    #[test]
424    fn test_mean_parallel_enhanced() {
425        let data = Array1::from_vec((0..10_000).map(|i| i as f64).collect());
426        let mean = mean_parallel_enhanced(&data.view(), None).unwrap();
427        assert!((mean - 4999.5).abs() < 1e-10);
428    }
429
430    #[test]
431    fn test_variance_parallel_enhanced() {
432        let data = array![1.0, 2.0, 3.0, 4.0, 5.0];
433        let var = variance_parallel_enhanced(&data.view(), 1, None).unwrap();
434        assert!((var - 2.5).abs() < 1e-10);
435    }
436}