1use super::{init_thread_pool, ThreadPoolError};
7use crate::Float;
8use scirs2_core::ndarray::{Array, Axis, IxDyn, Zip};
9use scirs2_core::parallel_ops::*;
10
11#[derive(Debug, Clone)]
13pub struct ParallelConfig {
14 pub min_parallel_size: usize,
16 pub num_chunks: Option<usize>,
18 pub adaptive_chunking: bool,
20 pub preferred_chunk_size: usize,
22}
23
24impl Default for ParallelConfig {
25 fn default() -> Self {
26 Self {
27 min_parallel_size: 1000,
28 num_chunks: None, adaptive_chunking: true,
30 preferred_chunk_size: 10000,
31 }
32 }
33}
34
35pub struct ParallelElementWise;
37
38impl ParallelElementWise {
39 pub fn add<F: Float>(
41 left: &Array<F, IxDyn>,
42 right: &Array<F, IxDyn>,
43 config: &ParallelConfig,
44 ) -> Result<Array<F, IxDyn>, ThreadPoolError> {
45 if left.len() < config.min_parallel_size {
46 return Ok(left + right);
48 }
49
50 let mut result = Array::zeros(left.raw_dim());
51
52 if config.adaptive_chunking {
53 Zip::from(&mut result)
55 .and(left)
56 .and(right)
57 .par_for_each(|r, &l, &r_val| {
58 *r = l + r_val;
59 });
60 } else {
61 let result_slice = result.as_slice_mut().expect("Operation failed");
63 let left_slice = left.as_slice().expect("Operation failed");
64 let right_slice = right.as_slice().expect("Operation failed");
65
66 for i in 0..left.len() {
67 result_slice[i] = left_slice[i] + right_slice[i];
68 }
69 }
70
71 Ok(result)
72 }
73
74 pub fn mul<F: Float>(
76 left: &Array<F, IxDyn>,
77 right: &Array<F, IxDyn>,
78 config: &ParallelConfig,
79 ) -> Result<Array<F, IxDyn>, ThreadPoolError> {
80 if left.len() < config.min_parallel_size {
81 return Ok(left * right);
82 }
83
84 let mut result = Array::zeros(left.raw_dim());
85
86 Zip::from(&mut result)
87 .and(left)
88 .and(right)
89 .par_for_each(|r, &l, &r_val| {
90 *r = l * r_val;
91 });
92
93 Ok(result)
94 }
95
96 pub fn map<F: Float, Func>(
98 array: &Array<F, IxDyn>,
99 func: Func,
100 config: &ParallelConfig,
101 ) -> Result<Array<F, IxDyn>, ThreadPoolError>
102 where
103 Func: Fn(F) -> F + Sync + Send,
104 {
105 if array.len() < config.min_parallel_size {
106 return Ok(array.mapv(func));
107 }
108
109 let mut result = Array::zeros(array.raw_dim());
110
111 Zip::from(&mut result).and(array).par_for_each(|r, &val| {
112 *r = func(val);
113 });
114
115 Ok(result)
116 }
117
118 #[allow(dead_code)]
120 fn calculate_chunk_size(_totalsize: usize, config: &ParallelConfig) -> usize {
121 if let Some(num_chunks) = config.num_chunks {
122 _totalsize.div_ceil(num_chunks)
123 } else {
124 let num_threads = std::thread::available_parallelism()
126 .map(|n| n.get())
127 .unwrap_or(4);
128
129 let chunk_size = _totalsize.div_ceil(num_threads);
130 chunk_size.max(config.preferred_chunk_size / num_threads)
131 }
132 }
133}
134
135pub struct ParallelReduction;
137
138impl ParallelReduction {
139 pub fn sum<F: Float>(
141 array: &Array<F, IxDyn>,
142 config: &ParallelConfig,
143 ) -> Result<F, ThreadPoolError> {
144 if array.len() < config.min_parallel_size {
145 return Ok(array.sum());
146 }
147
148 let result = array.par_iter().cloned().reduce(|| F::zero(), |a, b| a + b);
150 Ok(result)
151 }
152
153 pub fn sum_axis<F: Float>(
155 array: &Array<F, IxDyn>,
156 axis: usize,
157 config: &ParallelConfig,
158 ) -> Result<Array<F, IxDyn>, ThreadPoolError> {
159 if array.len() < config.min_parallel_size {
160 return Ok(array.sum_axis(Axis(axis)));
161 }
162
163 let result = array.sum_axis(Axis(axis));
165 Ok(result)
166 }
167
168 pub fn max<F: Float>(
170 array: &Array<F, IxDyn>,
171 config: &ParallelConfig,
172 ) -> Result<F, ThreadPoolError> {
173 if array.len() < config.min_parallel_size || array.is_empty() {
174 return array
175 .iter()
176 .cloned()
177 .fold(None, |acc, x| {
178 Some(match acc {
179 None => x,
180 Some(y) => {
181 if x > y {
182 x
183 } else {
184 y
185 }
186 }
187 })
188 })
189 .ok_or(ThreadPoolError::ExecutionFailed);
190 }
191
192 let result = array
193 .par_iter()
194 .cloned()
195 .reduce(|| F::neg_infinity(), |a, b| if a > b { a } else { b });
196 Ok(result)
197 }
198
199 pub fn min<F: Float>(
201 array: &Array<F, IxDyn>,
202 config: &ParallelConfig,
203 ) -> Result<F, ThreadPoolError> {
204 if array.len() < config.min_parallel_size || array.is_empty() {
205 return array
206 .iter()
207 .cloned()
208 .fold(None, |acc, x| {
209 Some(match acc {
210 None => x,
211 Some(y) => {
212 if x < y {
213 x
214 } else {
215 y
216 }
217 }
218 })
219 })
220 .ok_or(ThreadPoolError::ExecutionFailed);
221 }
222
223 let result = array
224 .par_iter()
225 .cloned()
226 .reduce(|| F::infinity(), |a, b| if a < b { a } else { b });
227 Ok(result)
228 }
229
230 pub fn mean<F: Float>(
232 array: &Array<F, IxDyn>,
233 config: &ParallelConfig,
234 ) -> Result<F, ThreadPoolError> {
235 if array.is_empty() {
236 return Err(ThreadPoolError::ExecutionFailed);
237 }
238
239 let sum = Self::sum(array, config)?;
240 let count = F::from(array.len()).expect("Operation failed");
241 Ok(sum / count)
242 }
243
244 pub fn variance<F: Float>(
246 array: &Array<F, IxDyn>,
247 config: &ParallelConfig,
248 ) -> Result<F, ThreadPoolError> {
249 if array.is_empty() {
250 return Err(ThreadPoolError::ExecutionFailed);
251 }
252
253 let mean = Self::mean(array, config)?;
254 let variance = if array.len() < config.min_parallel_size {
255 array
256 .iter()
257 .map(|&x| {
258 let diff = x - mean;
259 diff * diff
260 })
261 .fold(F::zero(), |acc, x| acc + x)
262 } else {
263 array
264 .par_iter()
265 .map(|&x| {
266 let diff = x - mean;
267 diff * diff
268 })
269 .fold(|| F::zero(), |acc, x| acc + x)
270 .reduce(|| F::zero(), |a, b| a + b)
271 };
272
273 let count = F::from(array.len()).expect("Operation failed");
274 Ok(variance / count)
275 }
276}
277
278pub struct ParallelMatrix;
280
281impl ParallelMatrix {
282 pub fn matmul<F: Float>(
284 left: &Array<F, IxDyn>,
285 right: &Array<F, IxDyn>,
286 config: &ParallelConfig,
287 ) -> Result<Array<F, IxDyn>, ThreadPoolError> {
288 if left.ndim() != 2 || right.ndim() != 2 {
290 return Err(ThreadPoolError::ExecutionFailed);
291 }
292
293 let (m, k) = (left.shape()[0], left.shape()[1]);
294 let (k2, n) = (right.shape()[0], right.shape()[1]);
295
296 if k != k2 {
297 return Err(ThreadPoolError::ExecutionFailed);
298 }
299
300 let total_ops = m * n * k;
301 if total_ops < config.min_parallel_size {
302 let (m, k) = (left.shape()[0], left.shape()[1]);
305 let (k2, n) = (right.shape()[0], right.shape()[1]);
306 if k != k2 {
307 return Err(ThreadPoolError::ExecutionFailed);
308 }
309
310 let mut result = Array::zeros(IxDyn(&[m, n]));
311 for i in 0..m {
312 for j in 0..n {
313 let mut sum = F::zero();
314 for k_idx in 0..k {
315 sum += left[[i, k_idx]] * right[[k_idx, j]];
316 }
317 result[[i, j]] = sum;
318 }
319 }
320 return Ok(result);
321 }
322
323 let mut result = Array::zeros(IxDyn(&[m, n]));
325 let block_size = Self::calculate_block_size(m, n, k, config);
326
327 for i_start in (0..m).step_by(block_size) {
329 let i_end = (i_start + block_size).min(m);
330
331 for j_start in (0..n).step_by(block_size) {
332 let j_end = (j_start + block_size).min(n);
333
334 for k_start in (0..k).step_by(block_size) {
335 let k_end = (k_start + block_size).min(k);
336
337 Self::multiply_block(
339 left,
340 right,
341 &mut result,
342 i_start,
343 i_end,
344 j_start,
345 j_end,
346 k_start,
347 k_end,
348 );
349 }
350 }
351 }
352
353 Ok(result)
354 }
355
356 fn calculate_block_size(m: usize, n: usize, k: usize, config: &ParallelConfig) -> usize {
358 let cache_size = 32 * 1024; let element_size = std::mem::size_of::<f64>(); let max_block_elements = cache_size / (3 * element_size); let suggested_block_size = (max_block_elements as f64).sqrt() as usize;
364
365 suggested_block_size.clamp(32, 512).min(m).min(n).min(k)
367 }
368
369 fn multiply_block<F: Float>(
371 left: &Array<F, IxDyn>,
372 right: &Array<F, IxDyn>,
373 result: &mut Array<F, IxDyn>,
374 i_start: usize,
375 i_end: usize,
376 j_start: usize,
377 j_end: usize,
378 k_start: usize,
379 k_end: usize,
380 ) {
381 for i in i_start..i_end {
382 for j in j_start..j_end {
383 let mut sum = F::zero();
384 for k in k_start..k_end {
385 sum += left[[i, k]] * right[[k, j]];
386 }
387 result[[i, j]] += sum;
388 }
389 }
390 }
391
392 pub fn transpose<F: Float>(
394 array: &Array<F, IxDyn>,
395 config: &ParallelConfig,
396 ) -> Result<Array<F, IxDyn>, ThreadPoolError> {
397 if array.ndim() != 2 {
398 return Err(ThreadPoolError::ExecutionFailed);
399 }
400
401 let (rows, cols) = (array.shape()[0], array.shape()[1]);
402
403 if rows * cols < config.min_parallel_size {
404 return Ok(array.t().to_owned());
405 }
406
407 let mut result = Array::zeros(IxDyn(&[cols, rows]));
408
409 let block_size = 64; for i_start in (0..rows).step_by(block_size) {
414 let i_end = (i_start + block_size).min(rows);
415
416 for j_start in (0..cols).step_by(block_size) {
417 let j_end = (j_start + block_size).min(cols);
418
419 for i in i_start..i_end {
420 for j in j_start..j_end {
421 result[[j, i]] = array[[i, j]];
422 }
423 }
424 }
425 }
426
427 Ok(result)
428 }
429}
430
431pub struct ParallelConvolution;
433
434impl ParallelConvolution {
435 pub fn conv1d<F: Float>(
437 input: &Array<F, IxDyn>,
438 kernel: &Array<F, IxDyn>,
439 config: &ParallelConfig,
440 ) -> Result<Array<F, IxDyn>, ThreadPoolError> {
441 if input.ndim() != 1 || kernel.ndim() != 1 {
442 return Err(ThreadPoolError::ExecutionFailed);
443 }
444
445 let input_len = input.len();
446 let kernel_len = kernel.len();
447 let output_len = input_len + kernel_len - 1;
448
449 if output_len < config.min_parallel_size {
450 return Self::conv1d_sequential(input, kernel);
452 }
453
454 let output = Array::zeros(IxDyn(&[output_len]));
455
456 (0..output_len).into_par_iter().for_each(|i| {
458 let mut sum = F::zero();
459
460 for j in 0..kernel_len {
461 if i >= j && (i - j) < input_len {
462 sum += input[i - j] * kernel[j];
463 }
464 }
465
466 });
469
470 Ok(output)
471 }
472
473 fn conv1d_sequential<F: Float>(
475 input: &Array<F, IxDyn>,
476 kernel: &Array<F, IxDyn>,
477 ) -> Result<Array<F, IxDyn>, ThreadPoolError> {
478 let input_len = input.len();
479 let kernel_len = kernel.len();
480 let output_len = input_len + kernel_len - 1;
481 let mut output = Array::zeros(IxDyn(&[output_len]));
482
483 for i in 0..output_len {
484 let mut sum = F::zero();
485
486 for j in 0..kernel_len {
487 if i >= j && (i - j) < input_len {
488 sum += input[i - j] * kernel[j];
489 }
490 }
491
492 output[i] = sum;
493 }
494
495 Ok(output)
496 }
497}
498
499pub struct ParallelSort;
501
502impl ParallelSort {
503 pub fn sort<F: Float>(
505 array: &Array<F, IxDyn>,
506 config: &ParallelConfig,
507 ) -> Result<Array<F, IxDyn>, ThreadPoolError> {
508 let mut data: Vec<F> = array.iter().cloned().collect();
509
510 if data.len() < config.min_parallel_size {
511 data.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
512 } else {
513 data.par_sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
514 }
515
516 Array::from_shape_vec(array.raw_dim(), data).map_err(|_| ThreadPoolError::ExecutionFailed)
517 }
518
519 pub fn argsort<F: Float>(
521 array: &Array<F, IxDyn>,
522 config: &ParallelConfig,
523 ) -> Result<Array<usize, IxDyn>, ThreadPoolError> {
524 let mut indices: Vec<(usize, F)> = array.iter().cloned().enumerate().collect();
525
526 if indices.len() < config.min_parallel_size {
527 indices.sort_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
528 } else {
529 indices.par_sort_by(|(_, a), (_, b)| {
530 a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
531 });
532 }
533
534 let sorted_indices: Vec<usize> = indices.into_iter().map(|(idx, _)| idx).collect();
535
536 Array::from_shape_vec(array.raw_dim(), sorted_indices)
537 .map_err(|_| ThreadPoolError::ExecutionFailed)
538 }
539}
540
541pub struct ParallelDispatcher {
543 config: ParallelConfig,
544}
545
546impl ParallelDispatcher {
547 pub fn new() -> Self {
549 let _ = init_thread_pool();
551
552 Self {
553 config: ParallelConfig::default(),
554 }
555 }
556
557 pub fn with_config(config: ParallelConfig) -> Self {
559 let _ = init_thread_pool();
560
561 Self { config }
562 }
563
564 pub fn dispatch_elementwise<F, Op>(
566 &self,
567 arrays: &[&Array<F, IxDyn>],
568 operation: Op,
569 ) -> Result<Array<F, IxDyn>, ThreadPoolError>
570 where
571 F: Float,
572 Op: Fn(&[F]) -> F + Sync + Send,
573 {
574 if arrays.is_empty() {
575 return Err(ThreadPoolError::ExecutionFailed);
576 }
577
578 let shape = arrays[0].raw_dim();
579 let size = arrays[0].len();
580
581 for array in arrays.iter().skip(1) {
583 if array.raw_dim() != shape {
584 return Err(ThreadPoolError::ExecutionFailed);
585 }
586 }
587
588 let mut result = Array::zeros(shape);
589
590 if size < self.config.min_parallel_size {
591 for (i, result_elem) in result.iter_mut().enumerate() {
593 let values: Vec<F> = arrays
594 .iter()
595 .map(|arr| arr.as_slice().expect("Operation failed")[i])
596 .collect();
597 *result_elem = operation(&values);
598 }
599 } else {
600 result.iter_mut().enumerate().for_each(|(i, result_elem)| {
602 let values: Vec<F> = arrays
603 .iter()
604 .map(|arr| arr.as_slice().expect("Operation failed")[i])
605 .collect();
606 *result_elem = operation(&values);
607 });
608 }
609
610 Ok(result)
611 }
612
613 pub fn get_config(&self) -> &ParallelConfig {
615 &self.config
616 }
617
618 pub fn set_config(&mut self, config: ParallelConfig) {
620 self.config = config;
621 }
622}
623
624impl Default for ParallelDispatcher {
625 fn default() -> Self {
626 Self::new()
627 }
628}
629
630#[cfg(test)]
631mod tests {
632 use super::*;
633 #[allow(unused_imports)]
634 use scirs2_core::ndarray::Array1;
635
636 #[test]
637 fn test_parallel_element_wise_add() {
638 let config = ParallelConfig::default();
639
640 let a =
641 Array::from_shape_vec(IxDyn(&[4]), vec![1.0, 2.0, 3.0, 4.0]).expect("Operation failed");
642 let b =
643 Array::from_shape_vec(IxDyn(&[4]), vec![5.0, 6.0, 7.0, 8.0]).expect("Operation failed");
644
645 let result = ParallelElementWise::add(&a, &b, &config).expect("Operation failed");
646 let expected = vec![6.0, 8.0, 10.0, 12.0];
647
648 assert_eq!(result.as_slice().expect("Operation failed"), &expected);
649 }
650
651 #[test]
652 fn test_parallel_reduction_sum() {
653 let config = ParallelConfig::default();
654
655 let a =
656 Array::from_shape_vec(IxDyn(&[4]), vec![1.0, 2.0, 3.0, 4.0]).expect("Operation failed");
657 let result = ParallelReduction::sum(&a, &config).expect("Operation failed");
658
659 assert_eq!(result, 10.0);
660 }
661
662 #[test]
663 fn test_parallel_matrix_multiplication() {
664 let config = ParallelConfig::default();
665
666 let a = Array::from_shape_vec(IxDyn(&[2, 2]), vec![1.0, 2.0, 3.0, 4.0])
667 .expect("Operation failed");
668 let b = Array::from_shape_vec(IxDyn(&[2, 2]), vec![5.0, 6.0, 7.0, 8.0])
669 .expect("Operation failed");
670
671 let result = ParallelMatrix::matmul(&a, &b, &config).expect("Operation failed");
672
673 assert_eq!(result[[0, 0]], 19.0);
675 assert_eq!(result[[0, 1]], 22.0);
676 assert_eq!(result[[1, 0]], 43.0);
677 assert_eq!(result[[1, 1]], 50.0);
678 }
679
680 #[test]
681 fn test_parallel_transpose() {
682 let config = ParallelConfig::default();
683
684 let a = Array::from_shape_vec(IxDyn(&[2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
685 .expect("Operation failed");
686 let result = ParallelMatrix::transpose(&a, &config).expect("Operation failed");
687
688 assert_eq!(result.shape(), &[3, 2]);
689 assert_eq!(result[[0, 0]], 1.0);
690 assert_eq!(result[[1, 0]], 2.0);
691 assert_eq!(result[[2, 0]], 3.0);
692 assert_eq!(result[[0, 1]], 4.0);
693 assert_eq!(result[[1, 1]], 5.0);
694 assert_eq!(result[[2, 1]], 6.0);
695 }
696
697 #[test]
698 fn test_parallel_sort() {
699 let config = ParallelConfig::default();
700
701 let a = Array::from_shape_vec((4,), vec![4.0, 1.0, 3.0, 2.0])
702 .expect("Test: failed to create array")
703 .into_dyn();
704 let result = ParallelSort::sort(&a, &config).expect("Operation failed");
705
706 assert_eq!(
707 result.as_slice().expect("Operation failed"),
708 &[1.0, 2.0, 3.0, 4.0]
709 );
710 }
711
712 #[test]
713 fn test_parallel_argsort() {
714 let config = ParallelConfig::default();
715
716 let a = Array::from_shape_vec((4,), vec![4.0, 1.0, 3.0, 2.0])
717 .expect("Test: failed to create array")
718 .into_dyn();
719 let result = ParallelSort::argsort(&a, &config).expect("Operation failed");
720
721 assert_eq!(result.as_slice().expect("Operation failed"), &[1, 3, 2, 0]);
722 }
723
724 #[test]
725 fn test_parallel_dispatcher() {
726 let dispatcher = ParallelDispatcher::new();
727
728 let a = Array::from_shape_vec((3,), vec![1.0, 2.0, 3.0])
729 .expect("Test: failed to create array")
730 .into_dyn();
731 let b = Array::from_shape_vec((3,), vec![4.0, 5.0, 6.0])
732 .expect("Test: failed to create array")
733 .into_dyn();
734
735 let result = dispatcher
736 .dispatch_elementwise(&[&a, &b], |values| values[0] + values[1])
737 .expect("Test: operation failed");
738
739 assert_eq!(
740 result.as_slice().expect("Operation failed"),
741 &[5.0, 7.0, 9.0]
742 );
743 }
744
745 #[test]
746 fn test_parallel_config() {
747 let config = ParallelConfig {
748 min_parallel_size: 500,
749 num_chunks: Some(8),
750 adaptive_chunking: false,
751 preferred_chunk_size: 1000,
752 };
753
754 assert_eq!(config.min_parallel_size, 500);
755 assert_eq!(config.num_chunks, Some(8));
756 assert!(!config.adaptive_chunking);
757 assert_eq!(config.preferred_chunk_size, 1000);
758 }
759}