1use crate::error::{StatsError, StatsResult};
8use crate::error_standardization::ErrorMessages;
9use crate::simd_enhanced_core::{mean_enhanced, variance_enhanced, ComprehensiveStats};
10use crossbeam;
11use scirs2_core::ndarray::{Array1, Array2, ArrayBase, ArrayView1, Data, Ix1, Ix2};
12use scirs2_core::numeric::{Float, NumCast, One, Zero};
13use scirs2_core::{
14 parallel_ops::*,
15 simd_ops::{PlatformCapabilities, SimdUnifiedOps},
16};
17use std::collections::VecDeque;
18use std::sync::{atomic::AtomicUsize, Arc, Mutex};
19use std::thread;
20
21#[derive(Debug, Clone)]
23pub struct AdvancedParallelConfig {
24 pub parallel_threshold: usize,
26 pub num_threads: Option<usize>,
28 pub numa_aware: bool,
30 pub work_stealing: bool,
32 pub chunk_strategy: ChunkStrategy,
34 pub max_memory_usage: usize,
36}
37
38impl Default for AdvancedParallelConfig {
39 fn default() -> Self {
40 Self {
41 parallel_threshold: 10_000,
42 num_threads: None,
43 numa_aware: true,
44 work_stealing: true,
45 chunk_strategy: ChunkStrategy::Adaptive,
46 max_memory_usage: 1024 * 1024 * 1024, }
48 }
49}
50
51#[derive(Debug, Clone, Copy)]
53pub enum ChunkStrategy {
54 Fixed(usize),
56 CacheOptimal,
58 Adaptive,
60 WorkStealing,
62}
63
64pub struct AdvancedParallelProcessor<F: Float + std::fmt::Display> {
66 config: AdvancedParallelConfig,
67 capabilities: PlatformCapabilities,
68 #[allow(dead_code)]
69 thread_pool: Option<ThreadPool>,
70 #[allow(dead_code)]
71 work_queue: Arc<Mutex<VecDeque<ParallelTask<F>>>>,
72 #[allow(dead_code)]
73 active_workers: Arc<AtomicUsize>,
74}
75
76enum ParallelTask<F: Float + std::fmt::Display> {
78 Mean(Vec<F>),
79 Variance(Vec<F>, F, usize), Correlation(Vec<F>, Vec<F>),
81 Histogram(Vec<F>, usize),
82}
83
84pub enum ParallelResult<F: Float + std::fmt::Display> {
86 Mean(F),
87 Variance(F),
88 Correlation(F),
89 Histogram(Vec<usize>),
90}
91
92impl<F> AdvancedParallelProcessor<F>
93where
94 F: Float
95 + NumCast
96 + Send
97 + Sync
98 + SimdUnifiedOps
99 + Copy
100 + 'static
101 + Zero
102 + One
103 + std::fmt::Debug
104 + std::fmt::Display
105 + std::iter::Sum<F>,
106{
107 pub fn new(config: AdvancedParallelConfig) -> Self {
109 let capabilities = PlatformCapabilities::detect();
110
111 Self {
112 config,
113 capabilities,
114 thread_pool: None,
115 work_queue: Arc::new(Mutex::new(VecDeque::new())),
116 active_workers: Arc::new(AtomicUsize::new(0)),
117 }
118 }
119
120 pub fn initialize(&mut self) -> StatsResult<()> {
122 let num_threads = self
123 .config
124 .num_threads
125 .unwrap_or_else(|| self.optimal_thread_count());
126
127 self.thread_pool = Some(ThreadPool::new(num_threads, self.config.clone())?);
128 Ok(())
129 }
130
131 pub fn mean_parallel_advanced<D>(&self, x: &ArrayBase<D, Ix1>) -> StatsResult<F>
133 where
134 D: Data<Elem = F> + Sync + Send,
135 {
136 if x.is_empty() {
137 return Err(ErrorMessages::empty_array("x"));
138 }
139
140 let n = x.len();
141
142 if n < self.config.parallel_threshold {
144 return mean_enhanced(x);
145 }
146
147 match self.config.chunk_strategy {
149 ChunkStrategy::WorkStealing => self.mean_work_stealing(x),
150 ChunkStrategy::Adaptive => self.mean_adaptive_chunking(x),
151 ChunkStrategy::CacheOptimal => self.mean_cache_optimal(x),
152 ChunkStrategy::Fixed(chunksize) => self.mean_fixed_chunks(x, chunksize),
153 }
154 }
155
156 pub fn variance_parallel_advanced<D>(
158 &self,
159 x: &ArrayBase<D, Ix1>,
160 ddof: usize,
161 ) -> StatsResult<F>
162 where
163 D: Data<Elem = F> + Sync + Send,
164 {
165 let n = x.len();
166 if n == 0 {
167 return Err(ErrorMessages::empty_array("x"));
168 }
169 if n <= ddof {
170 return Err(ErrorMessages::insufficientdata(
171 "variance calculation",
172 ddof + 1,
173 n,
174 ));
175 }
176
177 if n < self.config.parallel_threshold {
178 return variance_enhanced(x, ddof);
179 }
180
181 self.variance_welford_parallel(x, ddof)
183 }
184
185 pub fn correlation_matrix_parallel<D>(&self, data: &ArrayBase<D, Ix2>) -> StatsResult<Array2<F>>
187 where
188 D: Data<Elem = F> + Sync + Send,
189 {
190 let (n_samples_, n_features) = data.dim();
191
192 if n_samples_ == 0 {
193 return Err(ErrorMessages::empty_array("data"));
194 }
195 if n_features == 0 {
196 return Err(ErrorMessages::insufficientdata(
197 "correlation matrix",
198 2,
199 n_features,
200 ));
201 }
202
203 let mut correlation_matrix = Array2::eye(n_features);
204
205 if n_features > 4 && n_samples_ > self.config.parallel_threshold {
207 self.correlation_matrix_parallel_upper_triangle(data, &mut correlation_matrix)?;
208 } else {
209 self.correlation_matrix_sequential(data, &mut correlation_matrix)?;
210 }
211
212 for i in 0..n_features {
214 for j in 0..i {
215 correlation_matrix[[i, j]] = correlation_matrix[[j, i]];
216 }
217 }
218
219 Ok(correlation_matrix)
220 }
221
222 pub fn batch_statistics_parallel<D>(
224 &self,
225 x: &ArrayBase<D, Ix1>,
226 ddof: usize,
227 ) -> StatsResult<ComprehensiveStats<F>>
228 where
229 D: Data<Elem = F> + Sync + Send,
230 {
231 let n = x.len();
232 if n == 0 {
233 return Err(ErrorMessages::empty_array("x"));
234 }
235 if n <= ddof {
236 return Err(ErrorMessages::insufficientdata(
237 "comprehensive statistics",
238 ddof + 1,
239 n,
240 ));
241 }
242
243 if n < self.config.parallel_threshold {
244 return crate::simd_enhanced_core::comprehensive_stats_simd(x, ddof);
246 }
247
248 self.comprehensive_stats_single_pass_parallel(x, ddof)
250 }
251
252 pub fn bootstrap_parallel<D>(
254 &self,
255 x: &ArrayBase<D, Ix1>,
256 n_samples_: usize,
257 statistic_fn: impl Fn(&ArrayView1<F>) -> F + Send + Sync + Clone,
258 seed: Option<u64>,
259 ) -> StatsResult<Array1<F>>
260 where
261 D: Data<Elem = F> + Sync + Send,
262 {
263 if x.is_empty() {
264 return Err(ErrorMessages::empty_array("x"));
265 }
266 if n_samples_ == 0 {
267 return Err(ErrorMessages::insufficientdata("bootstrap", 1, 0));
268 }
269
270 let num_threads = self
271 .config
272 .num_threads
273 .unwrap_or_else(|| self.optimal_thread_count());
274 let samples_per_thread = n_samples_.div_ceil(num_threads);
275
276 self.bootstrap_work_stealing(x, n_samples_, samples_per_thread, statistic_fn, seed)
278 }
279
280 fn optimal_thread_count(&self) -> usize {
283 let logical_cores = std::thread::available_parallelism()
284 .map(|n| n.get())
285 .unwrap_or(4);
286
287 if logical_cores > 2 {
294 logical_cores / 2
295 } else {
296 logical_cores
297 }
298 }
299
300 fn mean_work_stealing<D>(&self, x: &ArrayBase<D, Ix1>) -> StatsResult<F>
301 where
302 D: Data<Elem = F> + Sync + Send,
303 {
304 let n = x.len();
305 let num_threads = self
306 .config
307 .num_threads
308 .unwrap_or_else(|| self.optimal_thread_count());
309 let initial_chunksize = n.div_ceil(num_threads);
310
311 let work_queue: Arc<Mutex<VecDeque<(usize, usize)>>> =
313 Arc::new(Mutex::new(VecDeque::new()));
314
315 for i in 0..num_threads {
316 let start = i * initial_chunksize;
317 let end = ((i + 1) * initial_chunksize).min(n);
318 if start < end {
319 work_queue
320 .lock()
321 .expect("Operation failed")
322 .push_back((start, end));
323 }
324 }
325
326 let partial_sums: Arc<Mutex<Vec<F>>> = Arc::new(Mutex::new(Vec::new()));
327 let data_slice = x
328 .as_slice()
329 .ok_or(StatsError::InvalidInput("Data not contiguous".to_string()))?;
330
331 crossbeam::scope(|s| {
332 for _ in 0..num_threads {
333 let work_queue = Arc::clone(&work_queue);
334 let partial_sums = Arc::clone(&partial_sums);
335
336 s.spawn(move |_| {
337 let mut local_sum = F::zero();
338
339 while let Some((start, end)) =
340 work_queue.lock().expect("Operation failed").pop_front()
341 {
342 for &val in &data_slice[start..end] {
344 local_sum = local_sum + val;
345 }
346
347 if end - start > 1000 {
349 let mid = (start + end) / 2;
350 if mid > start {
351 work_queue
352 .lock()
353 .expect("Operation failed")
354 .push_back((mid, end));
355 }
356 }
357 }
358
359 partial_sums
360 .lock()
361 .expect("Operation failed")
362 .push(local_sum);
363 });
364 }
365 })
366 .expect("Operation failed");
367
368 let total_sum = partial_sums
369 .lock()
370 .expect("Operation failed")
371 .iter()
372 .fold(F::zero(), |acc, &val| acc + val);
373 Ok(total_sum / F::from(n).expect("Failed to convert to float"))
374 }
375
376 fn mean_adaptive_chunking<D>(&self, x: &ArrayBase<D, Ix1>) -> StatsResult<F>
377 where
378 D: Data<Elem = F> + Sync + Send,
379 {
380 let n = x.len();
381 let elementsize = std::mem::size_of::<F>();
382
383 let l1_cache = 32 * 1024; let l2_cache = 256 * 1024; let chunksize = if n * elementsize <= l1_cache {
388 n } else if n * elementsize <= l2_cache {
390 l1_cache / elementsize } else {
392 l2_cache / elementsize };
394
395 let num_chunks = n.div_ceil(chunksize);
396 let _num_threads = self
397 .config
398 .num_threads
399 .unwrap_or_else(|| self.optimal_thread_count());
400
401 let chunks: Vec<_> = (0..num_chunks)
403 .map(|i| {
404 let start = i * chunksize;
405 let end = ((i + 1) * chunksize).min(n);
406 x.slice(scirs2_core::ndarray::s![start..end])
407 })
408 .collect();
409
410 let partial_sums: Vec<F> = chunks
411 .into_par_iter()
412 .map(|chunk| {
413 if self.capabilities.simd_available && chunk.len() > 64 {
414 F::simd_sum(&chunk)
415 } else {
416 chunk.iter().fold(F::zero(), |acc, &val| acc + val)
417 }
418 })
419 .collect();
420
421 let total_sum = partial_sums
422 .into_iter()
423 .fold(F::zero(), |acc, val| acc + val);
424 Ok(total_sum / F::from(n).expect("Failed to convert to float"))
425 }
426
427 fn mean_cache_optimal<D>(&self, x: &ArrayBase<D, Ix1>) -> StatsResult<F>
428 where
429 D: Data<Elem = F> + Sync + Send,
430 {
431 Self::mean_cache_oblivious_static(x, 0, x.len())
433 }
434
435 #[allow(dead_code)]
436 fn mean_cache_oblivious<D>(
437 &self,
438 x: &ArrayBase<D, Ix1>,
439 start: usize,
440 len: usize,
441 ) -> StatsResult<F>
442 where
443 D: Data<Elem = F> + Sync + Send,
444 {
445 Self::mean_cache_oblivious_static(x, start, len)
446 }
447
448 fn mean_cache_oblivious_static<D>(
450 x: &ArrayBase<D, Ix1>,
451 start: usize,
452 len: usize,
453 ) -> StatsResult<F>
454 where
455 D: Data<Elem = F> + Sync + Send,
456 F: Float + Send + Sync + 'static + std::fmt::Display,
457 {
458 const CACHE_THRESHOLD: usize = 1024; if len <= CACHE_THRESHOLD {
461 let slice = x.slice(scirs2_core::ndarray::s![start..start + len]);
463 let sum = slice.iter().fold(F::zero(), |acc, &val| acc + val);
464 Ok(sum / F::from(len).expect("Failed to convert to float"))
465 } else {
466 let mid = len / 2;
468 let left_result = Self::mean_cache_oblivious_static(x, start, mid)?;
469 let right_result = Self::mean_cache_oblivious_static(x, start + mid, len - mid)?;
470
471 let left_weight = F::from(mid).expect("Failed to convert to float");
473 let right_weight = F::from(len - mid).expect("Failed to convert to float");
474 let total_weight = F::from(len).expect("Failed to convert to float");
475
476 Ok((left_result * left_weight + right_result * right_weight) / total_weight)
477 }
478 }
479
480 fn mean_fixed_chunks<D>(&self, x: &ArrayBase<D, Ix1>, chunksize: usize) -> StatsResult<F>
481 where
482 D: Data<Elem = F> + Sync + Send,
483 {
484 let n = x.len();
485 let chunks: Vec<_> = x
486 .exact_chunks(chunksize)
487 .into_iter()
488 .chain(if !n.is_multiple_of(chunksize) {
489 vec![x.slice(scirs2_core::ndarray::s![n - (n % chunksize)..])]
490 } else {
491 vec![]
492 })
493 .collect();
494
495 let partial_sums: Vec<F> = chunks
496 .into_par_iter()
497 .map(|chunk| chunk.iter().fold(F::zero(), |acc, &val| acc + val))
498 .collect();
499
500 let total_sum = partial_sums
501 .into_iter()
502 .fold(F::zero(), |acc, val| acc + val);
503 Ok(total_sum / F::from(n).expect("Failed to convert to float"))
504 }
505
506 fn variance_welford_parallel<D>(&self, x: &ArrayBase<D, Ix1>, ddof: usize) -> StatsResult<F>
507 where
508 D: Data<Elem = F> + Sync + Send,
509 {
510 let n = x.len();
512 let num_threads = self
513 .config
514 .num_threads
515 .unwrap_or_else(|| self.optimal_thread_count());
516 let chunksize = n.div_ceil(num_threads);
517
518 let results: Vec<(F, F, usize)> = (0..num_threads)
519 .into_par_iter()
520 .map(|i| {
521 let start = i * chunksize;
522 let end = ((i + 1) * chunksize).min(n);
523
524 if start >= end {
525 return (F::zero(), F::zero(), 0);
526 }
527
528 let chunk = x.slice(scirs2_core::ndarray::s![start..end]);
529 let mut mean = F::zero();
530 let mut m2 = F::zero();
531 let count = chunk.len();
532
533 for (j, &val) in chunk.iter().enumerate() {
534 let n = F::from(j + 1).expect("Failed to convert to float");
535 let delta = val - mean;
536 mean = mean + delta / n;
537 let delta2 = val - mean;
538 m2 = m2 + delta * delta2;
539 }
540
541 (mean, m2, count)
542 })
543 .collect();
544
545 let (_final_mean, final_m2, final_count) = results.into_iter().fold(
547 (F::zero(), F::zero(), 0),
548 |(mean_a, m2_a, count_a), (mean_b, m2_b, count_b)| {
549 if count_b == 0 {
550 return (mean_a, m2_a, count_a);
551 }
552 if count_a == 0 {
553 return (mean_b, m2_b, count_b);
554 }
555
556 let total_count = count_a + count_b;
557 let count_a_f = F::from(count_a).expect("Failed to convert to float");
558 let count_b_f = F::from(count_b).expect("Failed to convert to float");
559 let total_count_f = F::from(total_count).expect("Failed to convert to float");
560
561 let delta = mean_b - mean_a;
562 let combined_mean = (mean_a * count_a_f + mean_b * count_b_f) / total_count_f;
563 let combined_m2 =
564 m2_a + m2_b + delta * delta * count_a_f * count_b_f / total_count_f;
565
566 (combined_mean, combined_m2, total_count)
567 },
568 );
569
570 Ok(final_m2 / F::from(n - ddof).expect("Failed to convert to float"))
571 }
572
573 fn correlation_matrix_parallel_upper_triangle<D>(
574 &self,
575 data: &ArrayBase<D, Ix2>,
576 correlation_matrix: &mut Array2<F>,
577 ) -> StatsResult<()>
578 where
579 D: Data<Elem = F> + Sync + Send,
580 {
581 let (_, n_features) = data.dim();
582
583 let pairs: Vec<(usize, usize)> = (0..n_features)
585 .flat_map(|i| (i + 1..n_features).map(move |j| (i, j)))
586 .collect();
587
588 let results: Vec<((usize, usize), F)> = pairs
589 .into_par_iter()
590 .map(|(i, j)| {
591 let x = data.column(i);
592 let y = data.column(j);
593 let corr = crate::simd_enhanced_core::correlation_simd_enhanced(&x, &y)
594 .unwrap_or(F::zero());
595 ((i, j), corr)
596 })
597 .collect();
598
599 for ((i, j), corr) in results {
601 correlation_matrix[[i, j]] = corr;
602 }
603
604 Ok(())
605 }
606
607 fn correlation_matrix_sequential<D>(
608 &self,
609 data: &ArrayBase<D, Ix2>,
610 correlation_matrix: &mut Array2<F>,
611 ) -> StatsResult<()>
612 where
613 D: Data<Elem = F> + Sync + Send,
614 {
615 let (_, n_features) = data.dim();
616
617 for i in 0..n_features {
618 for j in i + 1..n_features {
619 let x = data.column(i);
620 let y = data.column(j);
621 let corr = crate::simd_enhanced_core::correlation_simd_enhanced(&x, &y)?;
622 correlation_matrix[[i, j]] = corr;
623 }
624 }
625
626 Ok(())
627 }
628
629 fn comprehensive_stats_single_pass_parallel<D>(
630 &self,
631 x: &ArrayBase<D, Ix1>,
632 ddof: usize,
633 ) -> StatsResult<ComprehensiveStats<F>>
634 where
635 D: Data<Elem = F> + Sync + Send,
636 {
637 let n = x.len();
638 let num_threads = self
639 .config
640 .num_threads
641 .unwrap_or_else(|| self.optimal_thread_count());
642 let chunksize = n.div_ceil(num_threads);
643
644 let results: Vec<(F, F, F, F, usize)> = (0..num_threads)
646 .into_par_iter()
647 .map(|i| {
648 let start = i * chunksize;
649 let end = ((i + 1) * chunksize).min(n);
650
651 if start >= end {
652 return (F::zero(), F::zero(), F::zero(), F::zero(), 0);
653 }
654
655 let chunk = x.slice(scirs2_core::ndarray::s![start..end]);
656 let count = chunk.len();
657 let count_f = F::from(count).expect("Failed to convert to float");
658
659 let mean = chunk.iter().fold(F::zero(), |acc, &val| acc + val) / count_f;
661
662 let (m2, m3, m4) =
663 chunk
664 .iter()
665 .fold((F::zero(), F::zero(), F::zero()), |(m2, m3, m4), &val| {
666 let dev = val - mean;
667 let dev2 = dev * dev;
668 let dev3 = dev2 * dev;
669 let dev4 = dev2 * dev2;
670 (m2 + dev2, m3 + dev3, m4 + dev4)
671 });
672
673 (mean, m2, m3, m4, count)
674 })
675 .collect();
676
677 let (total_mean, total_m2_, total_m3, total_m4, total_count) = results.into_iter().fold(
679 (F::zero(), F::zero(), F::zero(), F::zero(), 0),
680 |(mean_acc, m2_acc, m3_acc, m4_acc, count_acc), (mean, m2, m3, m4, count)| {
681 if count == 0 {
682 return (mean_acc, m2_acc, m3_acc, m4_acc, count_acc);
683 }
684 if count_acc == 0 {
685 return (mean, m2, m3, m4, count);
686 }
687
688 let total_count = count_acc + count;
690 let count_f = F::from(count).expect("Failed to convert to float");
691 let count_acc_f = F::from(count_acc).expect("Failed to convert to float");
692 let total_count_f = F::from(total_count).expect("Failed to convert to float");
693
694 let combined_mean = (mean_acc * count_acc_f + mean * count_f) / total_count_f;
695
696 (
698 combined_mean,
699 m2_acc + m2,
700 m3_acc + m3,
701 m4_acc + m4,
702 total_count,
703 )
704 },
705 );
706
707 let variance = total_m2_ / F::from(n - ddof).expect("Failed to convert to float");
708 let std = variance.sqrt();
709
710 let skewness = if variance > F::epsilon() {
711 (total_m3 / F::from(n).expect("Failed to convert to float"))
712 / variance.powf(F::from(1.5).expect("Failed to convert constant to float"))
713 } else {
714 F::zero()
715 };
716
717 let kurtosis = if variance > F::epsilon() {
718 (total_m4 / F::from(n).expect("Failed to convert to float")) / (variance * variance)
719 - F::from(3.0).expect("Failed to convert constant to float")
720 } else {
721 F::zero()
722 };
723
724 Ok(ComprehensiveStats {
725 mean: total_mean,
726 variance,
727 std,
728 skewness,
729 kurtosis,
730 count: n,
731 })
732 }
733
734 fn bootstrap_work_stealing<D>(
735 &self,
736 x: &ArrayBase<D, Ix1>,
737 n_samples_: usize,
738 samples_per_thread: usize,
739 statistic_fn: impl Fn(&ArrayView1<F>) -> F + Send + Sync + Clone,
740 seed: Option<u64>,
741 ) -> StatsResult<Array1<F>>
742 where
743 D: Data<Elem = F> + Sync + Send,
744 {
745 use scirs2_core::random::ChaCha8Rng;
746 use scirs2_core::random::{Rng, SeedableRng};
747
748 let num_threads = self
749 .config
750 .num_threads
751 .unwrap_or_else(|| self.optimal_thread_count());
752 let _results: Vec<F> = Vec::with_capacity(n_samples_);
753
754 let data_vec: Vec<F> = x.iter().cloned().collect();
755 let data_arc = Arc::new(data_vec);
756
757 let partial_results: Arc<Mutex<Vec<F>>> = Arc::new(Mutex::new(Vec::new()));
758
759 crossbeam::scope(|s| {
760 for thread_id in 0..num_threads {
761 let data_arc = Arc::clone(&data_arc);
762 let partial_results = Arc::clone(&partial_results);
763 let statistic_fn = statistic_fn.clone();
764
765 s.spawn(move |_| {
766 let mut rng = if let Some(seed) = seed {
767 ChaCha8Rng::seed_from_u64(seed + thread_id as u64)
768 } else {
769 ChaCha8Rng::from_rng(&mut scirs2_core::random::thread_rng())
770 };
771
772 let mut local_results = Vec::with_capacity(samples_per_thread);
773 let ndata = data_arc.len();
774
775 for _ in 0..samples_per_thread {
776 let bootstrap_indices: Vec<usize> =
778 (0..ndata).map(|_| rng.random_range(0..ndata)).collect();
779
780 let bootstrap_sample: Vec<F> =
781 bootstrap_indices.into_iter().map(|i| data_arc[i]).collect();
782
783 let sample_array = Array1::from(bootstrap_sample);
784 let statistic = statistic_fn(&sample_array.view());
785 local_results.push(statistic);
786 }
787
788 partial_results
789 .lock()
790 .expect("Operation failed")
791 .extend(local_results);
792 });
793 }
794 })
795 .expect("Operation failed");
796
797 let mut all_results = partial_results.lock().expect("Operation failed");
798 all_results.truncate(n_samples_); Ok(Array1::from(all_results.clone()))
801 }
802}
803
804struct ThreadPool {
806 workers: Vec<thread::JoinHandle<()>>,
807 sender: std::sync::mpsc::Sender<Message>,
808}
809
810type Job = Box<dyn FnOnce() + Send + 'static>;
811
812enum Message {
813 NewJob(Job),
814 Terminate,
815}
816
817impl ThreadPool {
818 fn new(size: usize, config: AdvancedParallelConfig) -> StatsResult<ThreadPool> {
819 if size == 0 {
820 return Err(ErrorMessages::invalid_probability("thread count", 0.0));
821 }
822
823 let (sender, receiver) = std::sync::mpsc::channel();
824 let receiver = Arc::new(Mutex::new(receiver));
825 let mut workers = Vec::with_capacity(size);
826
827 for _id in 0..size {
828 let receiver = Arc::clone(&receiver);
829
830 let worker = thread::spawn(move || loop {
831 let message = receiver
832 .lock()
833 .expect("Operation failed")
834 .recv()
835 .expect("Operation failed");
836
837 match message {
838 Message::NewJob(job) => {
839 job();
840 }
841 Message::Terminate => {
842 break;
843 }
844 }
845 });
846
847 workers.push(worker);
848 }
849
850 Ok(ThreadPool { workers, sender })
851 }
852
853 #[allow(dead_code)]
854 fn execute<F>(&self, f: F)
855 where
856 F: FnOnce() + Send + 'static,
857 {
858 let job = Box::new(f);
859 self.sender
860 .send(Message::NewJob(job))
861 .expect("Operation failed");
862 }
863}
864
865impl Drop for ThreadPool {
866 fn drop(&mut self) {
867 for _ in &self.workers {
868 self.sender
869 .send(Message::Terminate)
870 .expect("Operation failed");
871 }
872
873 for worker in &mut self.workers {
874 if let Some(handle) = worker.thread().name() {
875 println!("Shutting down worker {}", handle);
876 }
877 }
878 }
879}
880
881#[allow(dead_code)]
883pub fn create_advanced_parallel_processor<F>() -> AdvancedParallelProcessor<F>
884where
885 F: Float
886 + NumCast
887 + Send
888 + Sync
889 + SimdUnifiedOps
890 + Copy
891 + 'static
892 + Zero
893 + One
894 + std::fmt::Debug
895 + std::fmt::Display
896 + std::iter::Sum<F>,
897{
898 AdvancedParallelProcessor::new(AdvancedParallelConfig::default())
899}
900
901#[allow(dead_code)]
903pub fn create_configured_parallel_processor<F>(
904 config: AdvancedParallelConfig,
905) -> AdvancedParallelProcessor<F>
906where
907 F: Float
908 + NumCast
909 + Send
910 + Sync
911 + SimdUnifiedOps
912 + Copy
913 + 'static
914 + Zero
915 + One
916 + std::fmt::Debug
917 + std::fmt::Display
918 + std::iter::Sum<F>,
919{
920 AdvancedParallelProcessor::new(config)
921}