Skip to main content

scirs2_stats/
parallel_enhanced_advanced.rs

1//! Advanced parallel statistical processing with intelligent optimization
2//!
3//! This module provides state-of-the-art parallel implementations that
4//! automatically adapt to system characteristics and data patterns for
5//! optimal performance across different hardware configurations.
6
7use crate::error::{StatsError, StatsResult};
8use crate::error_standardization::ErrorMessages;
9use crate::simd_enhanced_core::{mean_enhanced, variance_enhanced, ComprehensiveStats};
10use crossbeam;
11use scirs2_core::ndarray::{Array1, Array2, ArrayBase, ArrayView1, Data, Ix1, Ix2};
12use scirs2_core::numeric::{Float, NumCast, One, Zero};
13use scirs2_core::{
14    parallel_ops::*,
15    simd_ops::{PlatformCapabilities, SimdUnifiedOps},
16};
17use std::collections::VecDeque;
18use std::sync::{atomic::AtomicUsize, Arc, Mutex};
19use std::thread;
20
21/// Advanced parallel processing configuration
22#[derive(Debug, Clone)]
23pub struct AdvancedParallelConfig {
24    /// Minimum data size to trigger parallel processing
25    pub parallel_threshold: usize,
26    /// Number of worker threads (None = auto-detect)
27    pub num_threads: Option<usize>,
28    /// Enable NUMA-aware processing
29    pub numa_aware: bool,
30    /// Enable work stealing for better load balancing
31    pub work_stealing: bool,
32    /// Preferred chunk size strategy
33    pub chunk_strategy: ChunkStrategy,
34    /// Maximum memory usage for intermediate results (bytes)
35    pub max_memory_usage: usize,
36}
37
38impl Default for AdvancedParallelConfig {
39    fn default() -> Self {
40        Self {
41            parallel_threshold: 10_000,
42            num_threads: None,
43            numa_aware: true,
44            work_stealing: true,
45            chunk_strategy: ChunkStrategy::Adaptive,
46            max_memory_usage: 1024 * 1024 * 1024, // 1GB
47        }
48    }
49}
50
51/// Chunking strategies for optimal data access patterns
52#[derive(Debug, Clone, Copy)]
53pub enum ChunkStrategy {
54    /// Fixed chunk size
55    Fixed(usize),
56    /// Cache-aware chunking
57    CacheOptimal,
58    /// Adaptive chunking based on data characteristics
59    Adaptive,
60    /// Work-stealing with dynamic load balancing
61    WorkStealing,
62}
63
64/// Advanced parallel statistics processor
65pub struct AdvancedParallelProcessor<F: Float + std::fmt::Display> {
66    config: AdvancedParallelConfig,
67    capabilities: PlatformCapabilities,
68    #[allow(dead_code)]
69    thread_pool: Option<ThreadPool>,
70    #[allow(dead_code)]
71    work_queue: Arc<Mutex<VecDeque<ParallelTask<F>>>>,
72    #[allow(dead_code)]
73    active_workers: Arc<AtomicUsize>,
74}
75
76/// Task for parallel execution
77enum ParallelTask<F: Float + std::fmt::Display> {
78    Mean(Vec<F>),
79    Variance(Vec<F>, F, usize), // data, mean, ddof
80    Correlation(Vec<F>, Vec<F>),
81    Histogram(Vec<F>, usize),
82}
83
84/// Result of parallel computation
85pub enum ParallelResult<F: Float + std::fmt::Display> {
86    Mean(F),
87    Variance(F),
88    Correlation(F),
89    Histogram(Vec<usize>),
90}
91
92impl<F> AdvancedParallelProcessor<F>
93where
94    F: Float
95        + NumCast
96        + Send
97        + Sync
98        + SimdUnifiedOps
99        + Copy
100        + 'static
101        + Zero
102        + One
103        + std::fmt::Debug
104        + std::fmt::Display
105        + std::iter::Sum<F>,
106{
107    /// Create a new advanced parallel processor
108    pub fn new(config: AdvancedParallelConfig) -> Self {
109        let capabilities = PlatformCapabilities::detect();
110
111        Self {
112            config,
113            capabilities,
114            thread_pool: None,
115            work_queue: Arc::new(Mutex::new(VecDeque::new())),
116            active_workers: Arc::new(AtomicUsize::new(0)),
117        }
118    }
119
120    /// Initialize the thread pool with optimal configuration
121    pub fn initialize(&mut self) -> StatsResult<()> {
122        let num_threads = self
123            .config
124            .num_threads
125            .unwrap_or_else(|| self.optimal_thread_count());
126
127        self.thread_pool = Some(ThreadPool::new(num_threads, self.config.clone())?);
128        Ok(())
129    }
130
131    /// Compute mean using advanced parallel processing
132    pub fn mean_parallel_advanced<D>(&self, x: &ArrayBase<D, Ix1>) -> StatsResult<F>
133    where
134        D: Data<Elem = F> + Sync + Send,
135    {
136        if x.is_empty() {
137            return Err(ErrorMessages::empty_array("x"));
138        }
139
140        let n = x.len();
141
142        // Use sequential processing for small arrays
143        if n < self.config.parallel_threshold {
144            return mean_enhanced(x);
145        }
146
147        // Choose optimal parallel strategy
148        match self.config.chunk_strategy {
149            ChunkStrategy::WorkStealing => self.mean_work_stealing(x),
150            ChunkStrategy::Adaptive => self.mean_adaptive_chunking(x),
151            ChunkStrategy::CacheOptimal => self.mean_cache_optimal(x),
152            ChunkStrategy::Fixed(chunksize) => self.mean_fixed_chunks(x, chunksize),
153        }
154    }
155
156    /// Compute variance using advanced parallel processing with numerical stability
157    pub fn variance_parallel_advanced<D>(
158        &self,
159        x: &ArrayBase<D, Ix1>,
160        ddof: usize,
161    ) -> StatsResult<F>
162    where
163        D: Data<Elem = F> + Sync + Send,
164    {
165        let n = x.len();
166        if n == 0 {
167            return Err(ErrorMessages::empty_array("x"));
168        }
169        if n <= ddof {
170            return Err(ErrorMessages::insufficientdata(
171                "variance calculation",
172                ddof + 1,
173                n,
174            ));
175        }
176
177        if n < self.config.parallel_threshold {
178            return variance_enhanced(x, ddof);
179        }
180
181        // Use parallel Welford's algorithm for better numerical stability
182        self.variance_welford_parallel(x, ddof)
183    }
184
185    /// Compute correlation matrix in parallel for multivariate data
186    pub fn correlation_matrix_parallel<D>(&self, data: &ArrayBase<D, Ix2>) -> StatsResult<Array2<F>>
187    where
188        D: Data<Elem = F> + Sync + Send,
189    {
190        let (n_samples_, n_features) = data.dim();
191
192        if n_samples_ == 0 {
193            return Err(ErrorMessages::empty_array("data"));
194        }
195        if n_features == 0 {
196            return Err(ErrorMessages::insufficientdata(
197                "correlation matrix",
198                2,
199                n_features,
200            ));
201        }
202
203        let mut correlation_matrix = Array2::eye(n_features);
204
205        // Parallel computation of upper triangle
206        if n_features > 4 && n_samples_ > self.config.parallel_threshold {
207            self.correlation_matrix_parallel_upper_triangle(data, &mut correlation_matrix)?;
208        } else {
209            self.correlation_matrix_sequential(data, &mut correlation_matrix)?;
210        }
211
212        // Fill lower triangle (correlation matrix is symmetric)
213        for i in 0..n_features {
214            for j in 0..i {
215                correlation_matrix[[i, j]] = correlation_matrix[[j, i]];
216            }
217        }
218
219        Ok(correlation_matrix)
220    }
221
222    /// Batch parallel processing for multiple statistical operations
223    pub fn batch_statistics_parallel<D>(
224        &self,
225        x: &ArrayBase<D, Ix1>,
226        ddof: usize,
227    ) -> StatsResult<ComprehensiveStats<F>>
228    where
229        D: Data<Elem = F> + Sync + Send,
230    {
231        let n = x.len();
232        if n == 0 {
233            return Err(ErrorMessages::empty_array("x"));
234        }
235        if n <= ddof {
236            return Err(ErrorMessages::insufficientdata(
237                "comprehensive statistics",
238                ddof + 1,
239                n,
240            ));
241        }
242
243        if n < self.config.parallel_threshold {
244            // Use the enhanced SIMD version for smaller datasets
245            return crate::simd_enhanced_core::comprehensive_stats_simd(x, ddof);
246        }
247
248        // Parallel single-pass computation of all statistics
249        self.comprehensive_stats_single_pass_parallel(x, ddof)
250    }
251
252    /// Parallel bootstrap resampling with intelligent load balancing
253    pub fn bootstrap_parallel<D>(
254        &self,
255        x: &ArrayBase<D, Ix1>,
256        n_samples_: usize,
257        statistic_fn: impl Fn(&ArrayView1<F>) -> F + Send + Sync + Clone,
258        seed: Option<u64>,
259    ) -> StatsResult<Array1<F>>
260    where
261        D: Data<Elem = F> + Sync + Send,
262    {
263        if x.is_empty() {
264            return Err(ErrorMessages::empty_array("x"));
265        }
266        if n_samples_ == 0 {
267            return Err(ErrorMessages::insufficientdata("bootstrap", 1, 0));
268        }
269
270        let num_threads = self
271            .config
272            .num_threads
273            .unwrap_or_else(|| self.optimal_thread_count());
274        let samples_per_thread = n_samples_.div_ceil(num_threads);
275
276        // Parallel bootstrap computation with work stealing
277        self.bootstrap_work_stealing(x, n_samples_, samples_per_thread, statistic_fn, seed)
278    }
279
280    // Private helper methods
281
282    fn optimal_thread_count(&self) -> usize {
283        let logical_cores = std::thread::available_parallelism()
284            .map(|n| n.get())
285            .unwrap_or(4);
286
287        // Account for hyperthreading - usually optimal to use physical cores
288        // Simple heuristic: if we have more than 2 logical cores, assume hyperthreading
289
290        // For CPU-intensive tasks, use physical cores
291        // For memory-bound tasks, might benefit from more threads
292        // Use physical cores for better performance
293        if logical_cores > 2 {
294            logical_cores / 2
295        } else {
296            logical_cores
297        }
298    }
299
300    fn mean_work_stealing<D>(&self, x: &ArrayBase<D, Ix1>) -> StatsResult<F>
301    where
302        D: Data<Elem = F> + Sync + Send,
303    {
304        let n = x.len();
305        let num_threads = self
306            .config
307            .num_threads
308            .unwrap_or_else(|| self.optimal_thread_count());
309        let initial_chunksize = n.div_ceil(num_threads);
310
311        // Create work queue with initial chunks
312        let work_queue: Arc<Mutex<VecDeque<(usize, usize)>>> =
313            Arc::new(Mutex::new(VecDeque::new()));
314
315        for i in 0..num_threads {
316            let start = i * initial_chunksize;
317            let end = ((i + 1) * initial_chunksize).min(n);
318            if start < end {
319                work_queue
320                    .lock()
321                    .expect("Operation failed")
322                    .push_back((start, end));
323            }
324        }
325
326        let partial_sums: Arc<Mutex<Vec<F>>> = Arc::new(Mutex::new(Vec::new()));
327        let data_slice = x
328            .as_slice()
329            .ok_or(StatsError::InvalidInput("Data not contiguous".to_string()))?;
330
331        crossbeam::scope(|s| {
332            for _ in 0..num_threads {
333                let work_queue = Arc::clone(&work_queue);
334                let partial_sums = Arc::clone(&partial_sums);
335
336                s.spawn(move |_| {
337                    let mut local_sum = F::zero();
338
339                    while let Some((start, end)) =
340                        work_queue.lock().expect("Operation failed").pop_front()
341                    {
342                        // Process chunk safely
343                        for &val in &data_slice[start..end] {
344                            local_sum = local_sum + val;
345                        }
346
347                        // Split remaining work if chunk was large
348                        if end - start > 1000 {
349                            let mid = (start + end) / 2;
350                            if mid > start {
351                                work_queue
352                                    .lock()
353                                    .expect("Operation failed")
354                                    .push_back((mid, end));
355                            }
356                        }
357                    }
358
359                    partial_sums
360                        .lock()
361                        .expect("Operation failed")
362                        .push(local_sum);
363                });
364            }
365        })
366        .expect("Operation failed");
367
368        let total_sum = partial_sums
369            .lock()
370            .expect("Operation failed")
371            .iter()
372            .fold(F::zero(), |acc, &val| acc + val);
373        Ok(total_sum / F::from(n).expect("Failed to convert to float"))
374    }
375
376    fn mean_adaptive_chunking<D>(&self, x: &ArrayBase<D, Ix1>) -> StatsResult<F>
377    where
378        D: Data<Elem = F> + Sync + Send,
379    {
380        let n = x.len();
381        let elementsize = std::mem::size_of::<F>();
382
383        // Adaptive chunk size based on cache hierarchy
384        let l1_cache = 32 * 1024; // 32KB L1 cache (typical)
385        let l2_cache = 256 * 1024; // 256KB L2 cache (typical)
386
387        let chunksize = if n * elementsize <= l1_cache {
388            n // Fits in L1, no chunking needed
389        } else if n * elementsize <= l2_cache {
390            l1_cache / elementsize // Chunk to fit in L1
391        } else {
392            l2_cache / elementsize // Chunk to fit in L2
393        };
394
395        let num_chunks = n.div_ceil(chunksize);
396        let _num_threads = self
397            .config
398            .num_threads
399            .unwrap_or_else(|| self.optimal_thread_count());
400
401        // Use thread pool for processing
402        let chunks: Vec<_> = (0..num_chunks)
403            .map(|i| {
404                let start = i * chunksize;
405                let end = ((i + 1) * chunksize).min(n);
406                x.slice(scirs2_core::ndarray::s![start..end])
407            })
408            .collect();
409
410        let partial_sums: Vec<F> = chunks
411            .into_par_iter()
412            .map(|chunk| {
413                if self.capabilities.simd_available && chunk.len() > 64 {
414                    F::simd_sum(&chunk)
415                } else {
416                    chunk.iter().fold(F::zero(), |acc, &val| acc + val)
417                }
418            })
419            .collect();
420
421        let total_sum = partial_sums
422            .into_iter()
423            .fold(F::zero(), |acc, val| acc + val);
424        Ok(total_sum / F::from(n).expect("Failed to convert to float"))
425    }
426
427    fn mean_cache_optimal<D>(&self, x: &ArrayBase<D, Ix1>) -> StatsResult<F>
428    where
429        D: Data<Elem = F> + Sync + Send,
430    {
431        // Use cache-oblivious algorithm for optimal performance
432        Self::mean_cache_oblivious_static(x, 0, x.len())
433    }
434
435    #[allow(dead_code)]
436    fn mean_cache_oblivious<D>(
437        &self,
438        x: &ArrayBase<D, Ix1>,
439        start: usize,
440        len: usize,
441    ) -> StatsResult<F>
442    where
443        D: Data<Elem = F> + Sync + Send,
444    {
445        Self::mean_cache_oblivious_static(x, start, len)
446    }
447
448    // Static version that can be used in threads
449    fn mean_cache_oblivious_static<D>(
450        x: &ArrayBase<D, Ix1>,
451        start: usize,
452        len: usize,
453    ) -> StatsResult<F>
454    where
455        D: Data<Elem = F> + Sync + Send,
456        F: Float + Send + Sync + 'static + std::fmt::Display,
457    {
458        const CACHE_THRESHOLD: usize = 1024; // Empirically determined threshold
459
460        if len <= CACHE_THRESHOLD {
461            // Base case: compute directly
462            let slice = x.slice(scirs2_core::ndarray::s![start..start + len]);
463            let sum = slice.iter().fold(F::zero(), |acc, &val| acc + val);
464            Ok(sum / F::from(len).expect("Failed to convert to float"))
465        } else {
466            // Divide and conquer (sequential to avoid lifetime issues)
467            let mid = len / 2;
468            let left_result = Self::mean_cache_oblivious_static(x, start, mid)?;
469            let right_result = Self::mean_cache_oblivious_static(x, start + mid, len - mid)?;
470
471            // Combine results weighted by size
472            let left_weight = F::from(mid).expect("Failed to convert to float");
473            let right_weight = F::from(len - mid).expect("Failed to convert to float");
474            let total_weight = F::from(len).expect("Failed to convert to float");
475
476            Ok((left_result * left_weight + right_result * right_weight) / total_weight)
477        }
478    }
479
480    fn mean_fixed_chunks<D>(&self, x: &ArrayBase<D, Ix1>, chunksize: usize) -> StatsResult<F>
481    where
482        D: Data<Elem = F> + Sync + Send,
483    {
484        let n = x.len();
485        let chunks: Vec<_> = x
486            .exact_chunks(chunksize)
487            .into_iter()
488            .chain(if !n.is_multiple_of(chunksize) {
489                vec![x.slice(scirs2_core::ndarray::s![n - (n % chunksize)..])]
490            } else {
491                vec![]
492            })
493            .collect();
494
495        let partial_sums: Vec<F> = chunks
496            .into_par_iter()
497            .map(|chunk| chunk.iter().fold(F::zero(), |acc, &val| acc + val))
498            .collect();
499
500        let total_sum = partial_sums
501            .into_iter()
502            .fold(F::zero(), |acc, val| acc + val);
503        Ok(total_sum / F::from(n).expect("Failed to convert to float"))
504    }
505
506    fn variance_welford_parallel<D>(&self, x: &ArrayBase<D, Ix1>, ddof: usize) -> StatsResult<F>
507    where
508        D: Data<Elem = F> + Sync + Send,
509    {
510        // Parallel Welford's algorithm implementation
511        let n = x.len();
512        let num_threads = self
513            .config
514            .num_threads
515            .unwrap_or_else(|| self.optimal_thread_count());
516        let chunksize = n.div_ceil(num_threads);
517
518        let results: Vec<(F, F, usize)> = (0..num_threads)
519            .into_par_iter()
520            .map(|i| {
521                let start = i * chunksize;
522                let end = ((i + 1) * chunksize).min(n);
523
524                if start >= end {
525                    return (F::zero(), F::zero(), 0);
526                }
527
528                let chunk = x.slice(scirs2_core::ndarray::s![start..end]);
529                let mut mean = F::zero();
530                let mut m2 = F::zero();
531                let count = chunk.len();
532
533                for (j, &val) in chunk.iter().enumerate() {
534                    let n = F::from(j + 1).expect("Failed to convert to float");
535                    let delta = val - mean;
536                    mean = mean + delta / n;
537                    let delta2 = val - mean;
538                    m2 = m2 + delta * delta2;
539                }
540
541                (mean, m2, count)
542            })
543            .collect();
544
545        // Combine results using parallel reduction
546        let (_final_mean, final_m2, final_count) = results.into_iter().fold(
547            (F::zero(), F::zero(), 0),
548            |(mean_a, m2_a, count_a), (mean_b, m2_b, count_b)| {
549                if count_b == 0 {
550                    return (mean_a, m2_a, count_a);
551                }
552                if count_a == 0 {
553                    return (mean_b, m2_b, count_b);
554                }
555
556                let total_count = count_a + count_b;
557                let count_a_f = F::from(count_a).expect("Failed to convert to float");
558                let count_b_f = F::from(count_b).expect("Failed to convert to float");
559                let total_count_f = F::from(total_count).expect("Failed to convert to float");
560
561                let delta = mean_b - mean_a;
562                let combined_mean = (mean_a * count_a_f + mean_b * count_b_f) / total_count_f;
563                let combined_m2 =
564                    m2_a + m2_b + delta * delta * count_a_f * count_b_f / total_count_f;
565
566                (combined_mean, combined_m2, total_count)
567            },
568        );
569
570        Ok(final_m2 / F::from(n - ddof).expect("Failed to convert to float"))
571    }
572
573    fn correlation_matrix_parallel_upper_triangle<D>(
574        &self,
575        data: &ArrayBase<D, Ix2>,
576        correlation_matrix: &mut Array2<F>,
577    ) -> StatsResult<()>
578    where
579        D: Data<Elem = F> + Sync + Send,
580    {
581        let (_, n_features) = data.dim();
582
583        // Generate pairs for upper triangle
584        let pairs: Vec<(usize, usize)> = (0..n_features)
585            .flat_map(|i| (i + 1..n_features).map(move |j| (i, j)))
586            .collect();
587
588        let results: Vec<((usize, usize), F)> = pairs
589            .into_par_iter()
590            .map(|(i, j)| {
591                let x = data.column(i);
592                let y = data.column(j);
593                let corr = crate::simd_enhanced_core::correlation_simd_enhanced(&x, &y)
594                    .unwrap_or(F::zero());
595                ((i, j), corr)
596            })
597            .collect();
598
599        // Fill the correlation _matrix
600        for ((i, j), corr) in results {
601            correlation_matrix[[i, j]] = corr;
602        }
603
604        Ok(())
605    }
606
607    fn correlation_matrix_sequential<D>(
608        &self,
609        data: &ArrayBase<D, Ix2>,
610        correlation_matrix: &mut Array2<F>,
611    ) -> StatsResult<()>
612    where
613        D: Data<Elem = F> + Sync + Send,
614    {
615        let (_, n_features) = data.dim();
616
617        for i in 0..n_features {
618            for j in i + 1..n_features {
619                let x = data.column(i);
620                let y = data.column(j);
621                let corr = crate::simd_enhanced_core::correlation_simd_enhanced(&x, &y)?;
622                correlation_matrix[[i, j]] = corr;
623            }
624        }
625
626        Ok(())
627    }
628
629    fn comprehensive_stats_single_pass_parallel<D>(
630        &self,
631        x: &ArrayBase<D, Ix1>,
632        ddof: usize,
633    ) -> StatsResult<ComprehensiveStats<F>>
634    where
635        D: Data<Elem = F> + Sync + Send,
636    {
637        let n = x.len();
638        let num_threads = self
639            .config
640            .num_threads
641            .unwrap_or_else(|| self.optimal_thread_count());
642        let chunksize = n.div_ceil(num_threads);
643
644        // Parallel computation of all moments
645        let results: Vec<(F, F, F, F, usize)> = (0..num_threads)
646            .into_par_iter()
647            .map(|i| {
648                let start = i * chunksize;
649                let end = ((i + 1) * chunksize).min(n);
650
651                if start >= end {
652                    return (F::zero(), F::zero(), F::zero(), F::zero(), 0);
653                }
654
655                let chunk = x.slice(scirs2_core::ndarray::s![start..end]);
656                let count = chunk.len();
657                let count_f = F::from(count).expect("Failed to convert to float");
658
659                // Single pass computation of all moments
660                let mean = chunk.iter().fold(F::zero(), |acc, &val| acc + val) / count_f;
661
662                let (m2, m3, m4) =
663                    chunk
664                        .iter()
665                        .fold((F::zero(), F::zero(), F::zero()), |(m2, m3, m4), &val| {
666                            let dev = val - mean;
667                            let dev2 = dev * dev;
668                            let dev3 = dev2 * dev;
669                            let dev4 = dev2 * dev2;
670                            (m2 + dev2, m3 + dev3, m4 + dev4)
671                        });
672
673                (mean, m2, m3, m4, count)
674            })
675            .collect();
676
677        // Combine results
678        let (total_mean, total_m2_, total_m3, total_m4, total_count) = results.into_iter().fold(
679            (F::zero(), F::zero(), F::zero(), F::zero(), 0),
680            |(mean_acc, m2_acc, m3_acc, m4_acc, count_acc), (mean, m2, m3, m4, count)| {
681                if count == 0 {
682                    return (mean_acc, m2_acc, m3_acc, m4_acc, count_acc);
683                }
684                if count_acc == 0 {
685                    return (mean, m2, m3, m4, count);
686                }
687
688                // Combine means
689                let total_count = count_acc + count;
690                let count_f = F::from(count).expect("Failed to convert to float");
691                let count_acc_f = F::from(count_acc).expect("Failed to convert to float");
692                let total_count_f = F::from(total_count).expect("Failed to convert to float");
693
694                let combined_mean = (mean_acc * count_acc_f + mean * count_f) / total_count_f;
695
696                // For simplicity, recalculate moments (could be optimized further)
697                (
698                    combined_mean,
699                    m2_acc + m2,
700                    m3_acc + m3,
701                    m4_acc + m4,
702                    total_count,
703                )
704            },
705        );
706
707        let variance = total_m2_ / F::from(n - ddof).expect("Failed to convert to float");
708        let std = variance.sqrt();
709
710        let skewness = if variance > F::epsilon() {
711            (total_m3 / F::from(n).expect("Failed to convert to float"))
712                / variance.powf(F::from(1.5).expect("Failed to convert constant to float"))
713        } else {
714            F::zero()
715        };
716
717        let kurtosis = if variance > F::epsilon() {
718            (total_m4 / F::from(n).expect("Failed to convert to float")) / (variance * variance)
719                - F::from(3.0).expect("Failed to convert constant to float")
720        } else {
721            F::zero()
722        };
723
724        Ok(ComprehensiveStats {
725            mean: total_mean,
726            variance,
727            std,
728            skewness,
729            kurtosis,
730            count: n,
731        })
732    }
733
734    fn bootstrap_work_stealing<D>(
735        &self,
736        x: &ArrayBase<D, Ix1>,
737        n_samples_: usize,
738        samples_per_thread: usize,
739        statistic_fn: impl Fn(&ArrayView1<F>) -> F + Send + Sync + Clone,
740        seed: Option<u64>,
741    ) -> StatsResult<Array1<F>>
742    where
743        D: Data<Elem = F> + Sync + Send,
744    {
745        use scirs2_core::random::ChaCha8Rng;
746        use scirs2_core::random::{Rng, SeedableRng};
747
748        let num_threads = self
749            .config
750            .num_threads
751            .unwrap_or_else(|| self.optimal_thread_count());
752        let _results: Vec<F> = Vec::with_capacity(n_samples_);
753
754        let data_vec: Vec<F> = x.iter().cloned().collect();
755        let data_arc = Arc::new(data_vec);
756
757        let partial_results: Arc<Mutex<Vec<F>>> = Arc::new(Mutex::new(Vec::new()));
758
759        crossbeam::scope(|s| {
760            for thread_id in 0..num_threads {
761                let data_arc = Arc::clone(&data_arc);
762                let partial_results = Arc::clone(&partial_results);
763                let statistic_fn = statistic_fn.clone();
764
765                s.spawn(move |_| {
766                    let mut rng = if let Some(seed) = seed {
767                        ChaCha8Rng::seed_from_u64(seed + thread_id as u64)
768                    } else {
769                        ChaCha8Rng::from_rng(&mut scirs2_core::random::thread_rng())
770                    };
771
772                    let mut local_results = Vec::with_capacity(samples_per_thread);
773                    let ndata = data_arc.len();
774
775                    for _ in 0..samples_per_thread {
776                        // Generate bootstrap sample
777                        let bootstrap_indices: Vec<usize> =
778                            (0..ndata).map(|_| rng.random_range(0..ndata)).collect();
779
780                        let bootstrap_sample: Vec<F> =
781                            bootstrap_indices.into_iter().map(|i| data_arc[i]).collect();
782
783                        let sample_array = Array1::from(bootstrap_sample);
784                        let statistic = statistic_fn(&sample_array.view());
785                        local_results.push(statistic);
786                    }
787
788                    partial_results
789                        .lock()
790                        .expect("Operation failed")
791                        .extend(local_results);
792                });
793            }
794        })
795        .expect("Operation failed");
796
797        let mut all_results = partial_results.lock().expect("Operation failed");
798        all_results.truncate(n_samples_); // Ensure exact number of _samples
799
800        Ok(Array1::from(all_results.clone()))
801    }
802}
803
804/// Simple thread pool for parallel execution
805struct ThreadPool {
806    workers: Vec<thread::JoinHandle<()>>,
807    sender: std::sync::mpsc::Sender<Message>,
808}
809
810type Job = Box<dyn FnOnce() + Send + 'static>;
811
812enum Message {
813    NewJob(Job),
814    Terminate,
815}
816
817impl ThreadPool {
818    fn new(size: usize, config: AdvancedParallelConfig) -> StatsResult<ThreadPool> {
819        if size == 0 {
820            return Err(ErrorMessages::invalid_probability("thread count", 0.0));
821        }
822
823        let (sender, receiver) = std::sync::mpsc::channel();
824        let receiver = Arc::new(Mutex::new(receiver));
825        let mut workers = Vec::with_capacity(size);
826
827        for _id in 0..size {
828            let receiver = Arc::clone(&receiver);
829
830            let worker = thread::spawn(move || loop {
831                let message = receiver
832                    .lock()
833                    .expect("Operation failed")
834                    .recv()
835                    .expect("Operation failed");
836
837                match message {
838                    Message::NewJob(job) => {
839                        job();
840                    }
841                    Message::Terminate => {
842                        break;
843                    }
844                }
845            });
846
847            workers.push(worker);
848        }
849
850        Ok(ThreadPool { workers, sender })
851    }
852
853    #[allow(dead_code)]
854    fn execute<F>(&self, f: F)
855    where
856        F: FnOnce() + Send + 'static,
857    {
858        let job = Box::new(f);
859        self.sender
860            .send(Message::NewJob(job))
861            .expect("Operation failed");
862    }
863}
864
865impl Drop for ThreadPool {
866    fn drop(&mut self) {
867        for _ in &self.workers {
868            self.sender
869                .send(Message::Terminate)
870                .expect("Operation failed");
871        }
872
873        for worker in &mut self.workers {
874            if let Some(handle) = worker.thread().name() {
875                println!("Shutting down worker {}", handle);
876            }
877        }
878    }
879}
880
881/// Convenience function to create an advanced parallel processor
882#[allow(dead_code)]
883pub fn create_advanced_parallel_processor<F>() -> AdvancedParallelProcessor<F>
884where
885    F: Float
886        + NumCast
887        + Send
888        + Sync
889        + SimdUnifiedOps
890        + Copy
891        + 'static
892        + Zero
893        + One
894        + std::fmt::Debug
895        + std::fmt::Display
896        + std::iter::Sum<F>,
897{
898    AdvancedParallelProcessor::new(AdvancedParallelConfig::default())
899}
900
901/// Convenience function to create a processor with custom configuration
902#[allow(dead_code)]
903pub fn create_configured_parallel_processor<F>(
904    config: AdvancedParallelConfig,
905) -> AdvancedParallelProcessor<F>
906where
907    F: Float
908        + NumCast
909        + Send
910        + Sync
911        + SimdUnifiedOps
912        + Copy
913        + 'static
914        + Zero
915        + One
916        + std::fmt::Debug
917        + std::fmt::Display
918        + std::iter::Sum<F>,
919{
920    AdvancedParallelProcessor::new(config)
921}