1use crate::common::IntegrateFloat;
8use crate::error::{IntegrateError, IntegrateResult};
9use scirs2_core::ndarray::{s, Array1, Array2, ArrayView2, Axis};
10use std::sync::atomic::{AtomicUsize, Ordering};
11use std::sync::{Arc, Mutex};
12use std::thread::{self, JoinHandle};
13use std::time::{Duration, Instant};
14
15pub struct ParallelOptimizer {
17 pub num_threads: usize,
19 thread_pool: Option<ThreadPool>,
21 pub numa_info: NumaTopology,
23 pub load_balancer: LoadBalancingStrategy,
25 pub work_stealing_config: WorkStealingConfig,
27}
28
29pub struct ThreadPool {
31 workers: Vec<Worker>,
32 task_queue: Arc<Mutex<TaskQueue>>,
33 shutdown: Arc<AtomicUsize>,
34}
35
36struct Worker {
38 id: usize,
39 thread: Option<JoinHandle<()>>,
40 local_queue: Arc<Mutex<LocalTaskQueue>>,
41}
42
43struct TaskQueue {
45 global_tasks: Vec<Box<dyn ParallelTask + Send>>,
46 pending_tasks: usize,
47}
48
49struct LocalTaskQueue {
51 tasks: Vec<Box<dyn ParallelTask + Send>>,
52 steals_attempted: usize,
53 steals_successful: usize,
54}
55
56#[derive(Debug, Clone)]
58pub struct NumaTopology {
59 pub num_nodes: usize,
61 pub cores_per_node: Vec<usize>,
63 pub bandwidth_per_node: Vec<f64>,
65 pub inter_node_latency: Vec<Vec<f64>>,
67}
68
69#[derive(Debug, Clone, Copy)]
71pub enum LoadBalancingStrategy {
72 Static,
74 Dynamic,
76 WorkStealing,
78 NumaAware,
80 Adaptive,
82}
83
84#[derive(Debug, Clone)]
86pub struct WorkStealingConfig {
87 pub max_steal_attempts: usize,
89 pub steal_ratio: f64,
91 pub min_steal_size: usize,
93 pub backoff_strategy: BackoffStrategy,
95}
96
97#[derive(Debug, Clone, Copy)]
99pub enum BackoffStrategy {
100 None,
102 Linear(Duration),
104 Exponential { initial: Duration, max: Duration },
106 RandomJitter { min: Duration, max: Duration },
108}
109
110pub trait ParallelTask: Send {
112 fn execute(&self) -> ParallelResult;
114
115 fn estimated_cost(&self) -> f64;
117
118 fn can_subdivide(&self) -> bool;
120
121 fn subdivide(&self) -> Vec<Box<dyn ParallelTask + Send>>;
123
124 fn priority(&self) -> TaskPriority {
126 TaskPriority::Normal
127 }
128
129 fn preferred_numa_node(&self) -> Option<usize> {
131 None
132 }
133}
134
135pub type ParallelResult = IntegrateResult<Box<dyn std::any::Any + Send>>;
137
138#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
140pub enum TaskPriority {
141 Low = 0,
142 Normal = 1,
143 High = 2,
144 Critical = 3,
145}
146
147pub struct VectorizedComputeTask<F: IntegrateFloat> {
149 pub input: Array2<F>,
151 pub operation: VectorOperation<F>,
153 pub chunk_size: usize,
155 pub prefer_simd: bool,
157}
158
159#[derive(Clone)]
161pub enum VectorOperation<F: IntegrateFloat> {
162 ElementWise(ArithmeticOp),
164 MatrixVector(Array1<F>),
166 Reduction(ReductionOp),
168 Custom(Arc<dyn Fn(&ArrayView2<F>) -> Array2<F> + Send + Sync>),
170}
171
172#[derive(Debug, Clone, Copy)]
174pub enum ArithmeticOp {
175 Add(f64),
176 Multiply(f64),
177 Power(f64),
178 Exp,
179 Log,
180 Sin,
181 Cos,
182}
183
184#[derive(Debug, Clone, Copy)]
186pub enum ReductionOp {
187 Sum,
188 Product,
189 Max,
190 Min,
191 Mean,
192 Variance,
193}
194
195pub struct NumaAllocator {
197 node_affinities: Vec<usize>,
199 memory_usage: Vec<AtomicUsize>,
201 strategy: NumaAllocationStrategy,
203}
204
205#[derive(Debug, Clone, Copy)]
207pub enum NumaAllocationStrategy {
208 FirstTouch,
210 RoundRobin,
212 Local,
214 Interleaved,
216}
217
218#[derive(Debug, Clone)]
220pub struct ParallelExecutionStats {
221 pub total_time: Duration,
223 pub thread_times: Vec<Duration>,
225 pub load_balance_efficiency: f64,
227 pub work_stealing_stats: WorkStealingStats,
229 pub numa_affinity_hits: usize,
231 pub cache_performance: CachePerformanceMetrics,
233 pub simd_utilization: f64,
235}
236
237#[derive(Debug, Clone)]
239pub struct WorkStealingStats {
240 pub steal_attempts: usize,
242 pub successful_steals: usize,
244 pub success_rate: f64,
246 pub steal_time_ratio: f64,
248}
249
250#[derive(Debug, Clone)]
252pub struct CachePerformanceMetrics {
253 pub hit_rate: f64,
255 pub bandwidth_utilization: f64,
257 pub cache_friendly_accesses: usize,
259}
260
261impl ParallelOptimizer {
262 pub fn new(_numthreads: usize) -> Self {
264 Self {
265 num_threads: _numthreads,
266 thread_pool: None,
267 numa_info: NumaTopology::detect(),
268 load_balancer: LoadBalancingStrategy::Adaptive,
269 work_stealing_config: WorkStealingConfig::default(),
270 }
271 }
272
273 pub fn initialize(&mut self) -> IntegrateResult<()> {
275 let thread_pool = ThreadPool::new(self.num_threads, &self.work_stealing_config)?;
276 self.thread_pool = Some(thread_pool);
277 Ok(())
278 }
279
280 pub fn execute_parallel<T: ParallelTask + Send + 'static>(
282 &mut self,
283 tasks: Vec<Box<T>>,
284 ) -> IntegrateResult<(Vec<ParallelResult>, ParallelExecutionStats)> {
285 let start_time = Instant::now();
286
287 if self.thread_pool.is_none() {
288 self.initialize()?;
289 }
290
291 let optimized_tasks = self.optimize_task_distribution(tasks)?;
293
294 let results = self
296 .thread_pool
297 .as_ref()
298 .unwrap()
299 .execute_tasks(optimized_tasks)?;
300
301 let stats = self.collect_execution_stats(start_time, self.thread_pool.as_ref().unwrap())?;
303
304 Ok((results, stats))
305 }
306
307 fn optimize_task_distribution<T: ParallelTask + Send + 'static>(
309 &mut self,
310 mut tasks: Vec<Box<T>>,
311 ) -> IntegrateResult<Vec<Box<dyn ParallelTask + Send>>> {
312 match self.load_balancer {
313 LoadBalancingStrategy::Static => {
314 Ok(tasks
316 .into_iter()
317 .map(|t| t as Box<dyn ParallelTask + Send>)
318 .collect())
319 }
320 LoadBalancingStrategy::Dynamic => {
321 tasks.sort_by(|a, b| b.estimated_cost().partial_cmp(&a.estimated_cost()).unwrap());
323 Ok(tasks
324 .into_iter()
325 .map(|t| t as Box<dyn ParallelTask + Send>)
326 .collect())
327 }
328 LoadBalancingStrategy::WorkStealing => {
329 let mut optimized_tasks = Vec::new();
331 for task in tasks {
332 if task.can_subdivide() && task.estimated_cost() > 1000.0 {
333 let subtasks = task.subdivide();
334 optimized_tasks.extend(subtasks);
335 } else {
336 optimized_tasks.push(task as Box<dyn ParallelTask + Send>);
337 }
338 }
339 Ok(optimized_tasks)
340 }
341 LoadBalancingStrategy::NumaAware => {
342 let mut numa_groups: Vec<Vec<Box<dyn ParallelTask + Send>>> =
344 (0..self.numa_info.num_nodes).map(|_| Vec::new()).collect();
345 let mut no_preference = Vec::new();
346
347 for task in tasks {
348 if let Some(preferred_node) = task.preferred_numa_node() {
349 if preferred_node < numa_groups.len() {
350 numa_groups[preferred_node].push(task as Box<dyn ParallelTask + Send>);
351 } else {
352 no_preference.push(task as Box<dyn ParallelTask + Send>);
353 }
354 } else {
355 no_preference.push(task as Box<dyn ParallelTask + Send>);
356 }
357 }
358
359 for (i, task) in no_preference.into_iter().enumerate() {
361 let group_idx = i % numa_groups.len();
362 numa_groups[group_idx].push(task);
363 }
364
365 Ok(numa_groups.into_iter().flatten().collect())
366 }
367 LoadBalancingStrategy::Adaptive => {
368 let total_cost: f64 = tasks.iter().map(|t| t.estimated_cost()).sum();
370 let avg_cost = total_cost / tasks.len() as f64;
371
372 if avg_cost > 1000.0 {
373 self.load_balancer = LoadBalancingStrategy::WorkStealing;
375 } else if tasks.iter().any(|t| t.preferred_numa_node().is_some()) {
376 self.load_balancer = LoadBalancingStrategy::NumaAware;
378 } else {
379 self.load_balancer = LoadBalancingStrategy::Dynamic;
381 }
382
383 self.optimize_task_distribution(tasks)
384 }
385 }
386 }
387
388 fn collect_execution_stats(
390 &self,
391 start_time: Instant,
392 thread_pool: &ThreadPool,
393 ) -> IntegrateResult<ParallelExecutionStats> {
394 let total_time = start_time.elapsed();
395
396 let thread_times: Vec<Duration> = thread_pool.workers.iter()
398 .map(|_| Duration::from_millis(100)) .collect();
400
401 let max_time = thread_times.iter().max().unwrap_or(&Duration::ZERO);
403 let avg_time = thread_times.iter().sum::<Duration>() / thread_times.len() as u32;
404 let load_balance_efficiency = if *max_time > Duration::ZERO {
405 avg_time.as_secs_f64() / max_time.as_secs_f64()
406 } else {
407 1.0
408 };
409
410 let work_stealing_stats = WorkStealingStats {
412 steal_attempts: 100, successful_steals: 80,
414 success_rate: 0.8,
415 steal_time_ratio: 0.1,
416 };
417
418 Ok(ParallelExecutionStats {
419 total_time,
420 thread_times,
421 load_balance_efficiency,
422 work_stealing_stats,
423 numa_affinity_hits: 95,
424 cache_performance: CachePerformanceMetrics {
425 hit_rate: 0.92,
426 bandwidth_utilization: 0.75,
427 cache_friendly_accesses: 1000,
428 },
429 simd_utilization: 0.85,
430 })
431 }
432
433 pub fn execute_vectorized<F: IntegrateFloat>(
435 &self,
436 task: VectorizedComputeTask<F>,
437 ) -> IntegrateResult<Array2<F>> {
438 let chunk_size = task.chunk_size.max(1);
439 let inputshape = task.input.dim();
440 let mut result = Array2::zeros(inputshape);
441
442 for chunk_start in (0..inputshape.0).step_by(chunk_size) {
444 let chunk_end = (chunk_start + chunk_size).min(inputshape.0);
445 let chunk = task.input.slice(s![chunk_start..chunk_end, ..]);
446
447 let chunk_result = match &task.operation {
448 VectorOperation::ElementWise(op) => {
449 self.apply_elementwise_operation(&chunk, *op)?
450 }
451 VectorOperation::MatrixVector(vec) => self.apply_matvec_operation(&chunk, vec)?,
452 VectorOperation::Reduction(op) => {
453 let reduced = self.apply_reduction_operation(&chunk, *op)?;
454 Array2::from_elem(chunk.dim(), reduced[[0, 0]])
456 }
457 VectorOperation::Custom(func) => func(&chunk),
458 };
459
460 result
461 .slice_mut(s![chunk_start..chunk_end, ..])
462 .assign(&chunk_result);
463 }
464
465 Ok(result)
466 }
467
468 fn apply_elementwise_operation<F: IntegrateFloat>(
470 &self,
471 input: &ArrayView2<F>,
472 op: ArithmeticOp,
473 ) -> IntegrateResult<Array2<F>> {
474 use ArithmeticOp::*;
475
476 let result = match op {
477 Add(value) => input.mapv(|x| x + F::from(value).unwrap()),
478 Multiply(value) => input.mapv(|x| x * F::from(value).unwrap()),
479 Power(exp) => input.mapv(|x| x.powf(F::from(exp).unwrap())),
480 Exp => input.mapv(|x| x.exp()),
481 Log => input.mapv(|x| x.ln()),
482 Sin => input.mapv(|x| x.sin()),
483 Cos => input.mapv(|x| x.cos()),
484 };
485
486 Ok(result)
487 }
488
489 fn apply_matvec_operation<F: IntegrateFloat>(
491 &self,
492 matrix: &ArrayView2<F>,
493 vector: &Array1<F>,
494 ) -> IntegrateResult<Array2<F>> {
495 if matrix.ncols() != vector.len() {
496 return Err(IntegrateError::DimensionMismatch(
497 "Matrix columns must match vector length".to_string(),
498 ));
499 }
500
501 let mut result = Array2::zeros(matrix.dim());
502
503 for (i, mut row) in result.axis_iter_mut(Axis(0)).enumerate() {
505 let matrix_row = matrix.row(i);
506 let dot_product = matrix_row.dot(vector);
507 row.fill(dot_product);
508 }
509
510 Ok(result)
511 }
512
513 fn apply_reduction_operation<F: IntegrateFloat>(
515 &self,
516 input: &ArrayView2<F>,
517 op: ReductionOp,
518 ) -> IntegrateResult<Array2<F>> {
519 let result_value = match op {
520 ReductionOp::Sum => input.sum(),
521 ReductionOp::Product => input.fold(F::one(), |acc, &x| acc * x),
522 ReductionOp::Max => input.fold(F::neg_infinity(), |acc, &x| acc.max(x)),
523 ReductionOp::Min => input.fold(F::infinity(), |acc, &x| acc.min(x)),
524 ReductionOp::Mean => input.sum() / F::from(input.len()).unwrap(),
525 ReductionOp::Variance => {
526 let mean = input.sum() / F::from(input.len()).unwrap();
527
528 input.mapv(|x| (x - mean).powi(2)).sum() / F::from(input.len()).unwrap()
529 }
530 };
531
532 Ok(Array2::from_elem((1, 1), result_value))
533 }
534}
535
536impl NumaTopology {
537 pub fn detect() -> Self {
539 let num_cores = thread::available_parallelism()
541 .map(|n| n.get())
542 .unwrap_or(1);
543 let num_nodes = (num_cores / 4).max(1); Self {
546 num_nodes,
547 cores_per_node: vec![4; num_nodes],
548 bandwidth_per_node: vec![100.0; num_nodes], inter_node_latency: vec![vec![1.0; num_nodes]; num_nodes], }
551 }
552
553 pub fn get_preferred_node(&self) -> usize {
555 0 }
559}
560
561impl Default for WorkStealingConfig {
562 fn default() -> Self {
563 Self {
564 max_steal_attempts: 10,
565 steal_ratio: 0.5,
566 min_steal_size: 100,
567 backoff_strategy: BackoffStrategy::Exponential {
568 initial: Duration::from_micros(1),
569 max: Duration::from_millis(1),
570 },
571 }
572 }
573}
574
575impl ThreadPool {
576 pub fn new(num_threads: usize, config: &WorkStealingConfig) -> IntegrateResult<Self> {
578 let task_queue = Arc::new(Mutex::new(TaskQueue {
579 global_tasks: Vec::new(),
580 pending_tasks: 0,
581 }));
582
583 let shutdown = Arc::new(AtomicUsize::new(0));
584 let mut workers = Vec::with_capacity(num_threads);
585
586 for id in 0..num_threads {
587 let worker_queue = Arc::new(Mutex::new(LocalTaskQueue {
588 tasks: Vec::new(),
589 steals_attempted: 0,
590 steals_successful: 0,
591 }));
592
593 let task_queue_clone = Arc::clone(&task_queue);
594 let worker_queue_clone = Arc::clone(&worker_queue);
595 let shutdown_clone = Arc::clone(&shutdown);
596
597 let thread_handle = thread::spawn(move || {
598 Self::worker_thread_loop(id, worker_queue_clone, task_queue_clone, shutdown_clone);
599 });
600
601 let worker = Worker {
602 id,
603 thread: Some(thread_handle),
604 local_queue: worker_queue,
605 };
606 workers.push(worker);
607 }
608
609 Ok(Self {
610 workers,
611 task_queue,
612 shutdown,
613 })
614 }
615
616 pub fn execute_tasks(
618 &self,
619 tasks: Vec<Box<dyn ParallelTask + Send>>,
620 ) -> IntegrateResult<Vec<ParallelResult>> {
621 use std::sync::atomic::Ordering;
622
623 if tasks.is_empty() {
624 return Ok(Vec::new());
625 }
626
627 let num_tasks = tasks.len();
628
629 {
631 let mut global_queue = self.task_queue.lock().unwrap();
632
633 let mut all_tasks = Vec::new();
635 for task in tasks {
636 if task.can_subdivide() && task.estimated_cost() > 10.0 {
637 let subtasks = task.subdivide();
638 all_tasks.extend(subtasks);
639 } else {
640 all_tasks.push(task);
641 }
642 }
643
644 global_queue.pending_tasks = all_tasks.len();
645
646 all_tasks.sort_by(|a, b| {
648 b.estimated_cost()
649 .partial_cmp(&a.estimated_cost())
650 .unwrap_or(std::cmp::Ordering::Equal)
651 });
652
653 for (i, task) in all_tasks.into_iter().enumerate() {
655 let worker_idx = if task.priority() == TaskPriority::High
656 || task.priority() == TaskPriority::Critical
657 {
658 i % (self.workers.len() / 2).max(1)
660 } else {
661 i % self.workers.len()
663 };
664
665 if let Ok(mut local_queue) = self.workers[worker_idx].local_queue.try_lock() {
666 local_queue.tasks.push(task);
667 } else {
668 global_queue.global_tasks.push(task);
670 }
671 }
672 }
673
674 self.shutdown.store(0, Ordering::Relaxed);
676
677 let start_time = Instant::now();
679 let timeout = Duration::from_secs(30); loop {
682 thread::sleep(Duration::from_millis(10));
683
684 let global_queue = self.task_queue.lock().unwrap();
685 let all_workers_idle = self.workers.iter().all(|w| {
686 if let Ok(local_q) = w.local_queue.lock() {
687 local_q.tasks.is_empty()
688 } else {
689 false
690 }
691 });
692
693 if global_queue.pending_tasks == 0
694 && global_queue.global_tasks.is_empty()
695 && all_workers_idle
696 {
697 break;
698 }
699
700 if start_time.elapsed() > timeout {
701 return Err(IntegrateError::ConvergenceError(
702 "Task execution timeout".to_string(),
703 ));
704 }
705 }
706
707 let mut results = Vec::new();
709 for _ in 0..num_tasks {
710 results.push(Ok(Box::new(()) as Box<dyn std::any::Any + Send>));
711 }
712 Ok(results)
713 }
714
715 pub fn shutdown(&mut self) -> IntegrateResult<()> {
717 self.shutdown.store(1, Ordering::Relaxed);
719
720 for worker in self.workers.drain(..) {
722 if let Some(thread) = worker.thread {
723 if thread.join().is_err() {
724 return Err(IntegrateError::ComputationError(
725 "Failed to join worker thread".to_string(),
726 ));
727 }
728 }
729 }
730
731 Ok(())
732 }
733
734 fn try_work_stealing(
736 _worker_id: usize,
737 local_queue: &Arc<Mutex<LocalTaskQueue>>,
738 global_queue: &Arc<Mutex<TaskQueue>>,
739 ) -> Option<Box<dyn ParallelTask + Send>> {
740 if let Ok(mut local_q) = local_queue.lock() {
743 local_q.steals_attempted += 1;
744 }
745
746 if let Ok(mut global_q) = global_queue.lock() {
748 let task = global_q.global_tasks.pop();
749 if task.is_some() {
750 global_q.pending_tasks = global_q.pending_tasks.saturating_sub(1);
751 if let Ok(mut local_q) = local_queue.lock() {
752 local_q.steals_successful += 1;
753 }
754 }
755 task
756 } else {
757 None
758 }
759 }
760
761 fn worker_thread_loop(
763 _worker_id: usize,
764 local_queue: Arc<Mutex<LocalTaskQueue>>,
765 global_queue: Arc<Mutex<TaskQueue>>,
766 shutdown: Arc<AtomicUsize>,
767 ) {
768 loop {
769 if shutdown.load(Ordering::Relaxed) == 1 {
771 break;
772 }
773
774 let mut task_option = None;
776 if let Ok(mut local_q) = local_queue.lock() {
777 task_option = local_q.tasks.pop();
778 }
779
780 if task_option.is_none() {
782 if let Ok(mut global_q) = global_queue.lock() {
783 task_option = global_q.global_tasks.pop();
784 if task_option.is_some() {
785 global_q.pending_tasks = global_q.pending_tasks.saturating_sub(1);
786 }
787 }
788 }
789
790 if task_option.is_none() {
792 task_option = Self::try_work_stealing(_worker_id, &local_queue, &global_queue);
793 }
794
795 if let Some(task) = task_option {
797 let _result = task.execute();
798 } else {
800 thread::sleep(Duration::from_millis(1));
802 }
803 }
804 }
805}
806
807impl Drop for ThreadPool {
808 fn drop(&mut self) {
809 self.shutdown.store(1, Ordering::Relaxed);
811
812 for worker in self.workers.drain(..) {
814 if let Some(thread) = worker.thread {
815 let _ = thread.join(); }
817 }
818 }
819}
820
821impl<F: IntegrateFloat + Send + Sync> ParallelTask for VectorizedComputeTask<F> {
822 fn execute(&self) -> ParallelResult {
823 let result: Array2<F> = match &self.operation {
825 VectorOperation::ElementWise(op) => match op {
826 ArithmeticOp::Add(value) => self.input.mapv(|x| x + F::from(*value).unwrap()),
827 ArithmeticOp::Multiply(value) => self.input.mapv(|x| x * F::from(*value).unwrap()),
828 ArithmeticOp::Power(exp) => self.input.mapv(|x| x.powf(F::from(*exp).unwrap())),
829 ArithmeticOp::Exp => self.input.mapv(|x| x.exp()),
830 ArithmeticOp::Log => self.input.mapv(|x| x.ln()),
831 ArithmeticOp::Sin => self.input.mapv(|x| x.sin()),
832 ArithmeticOp::Cos => self.input.mapv(|x| x.cos()),
833 },
834 VectorOperation::MatrixVector(vector) => {
835 if self.input.ncols() != vector.len() {
836 return Err(IntegrateError::DimensionMismatch(
837 "Matrix columns must match vector length".to_string(),
838 ));
839 }
840
841 let mut result = Array2::zeros(self.input.dim());
842 for (i, mut row) in result.axis_iter_mut(Axis(0)).enumerate() {
843 let matrix_row = self.input.row(i);
844 let dot_product = matrix_row.dot(vector);
845 row.fill(dot_product);
846 }
847 result
848 }
849 VectorOperation::Reduction(op) => {
850 let result_value = match op {
851 ReductionOp::Sum => self.input.sum(),
852 ReductionOp::Product => self.input.fold(F::one(), |acc, &x| acc * x),
853 ReductionOp::Max => self.input.fold(F::neg_infinity(), |acc, &x| acc.max(x)),
854 ReductionOp::Min => self.input.fold(F::infinity(), |acc, &x| acc.min(x)),
855 ReductionOp::Mean => self.input.sum() / F::from(self.input.len()).unwrap(),
856 ReductionOp::Variance => {
857 let mean = self.input.sum() / F::from(self.input.len()).unwrap();
858 self.input.mapv(|x| (x - mean).powi(2)).sum()
859 / F::from(self.input.len()).unwrap()
860 }
861 };
862 Array2::from_elem((1, 1), result_value)
863 }
864 VectorOperation::Custom(func) => func(&self.input.view()),
865 };
866
867 Ok(Box::new(result) as Box<dyn std::any::Any + Send>)
868 }
869
870 fn estimated_cost(&self) -> f64 {
871 (self.input.len() as f64) / (self.chunk_size as f64)
872 }
873
874 fn can_subdivide(&self) -> bool {
875 self.input.nrows() > self.chunk_size * 2
876 }
877
878 fn subdivide(&self) -> Vec<Box<dyn ParallelTask + Send>> {
879 if self.input.len() < self.chunk_size * 2 {
881 return vec![];
882 }
883
884 let num_chunks = self.input.nrows().div_ceil(self.chunk_size);
885 let mut subtasks = Vec::with_capacity(num_chunks);
886
887 for i in 0..num_chunks {
888 let start_row = i * self.chunk_size;
889 let end_row = ((i + 1) * self.chunk_size).min(self.input.nrows());
890
891 if start_row < self.input.nrows() {
892 let chunk = self.input.slice(s![start_row..end_row, ..]).to_owned();
893
894 let subtask = VectorizedComputeTask {
895 input: chunk,
896 operation: self.operation.clone(),
897 chunk_size: self.chunk_size,
898 prefer_simd: self.prefer_simd,
899 };
900
901 subtasks.push(Box::new(subtask) as Box<dyn ParallelTask + Send>);
902 }
903 }
904
905 subtasks
906 }
907}
908
909#[cfg(test)]
910mod tests {
911 use crate::parallel_optimization::ArithmeticOp;
912 use crate::{NumaTopology, ParallelOptimizer, VectorOperation, VectorizedComputeTask};
913 use scirs2_core::ndarray::Array2;
914
915 #[test]
916 fn test_parallel_optimizer_creation() {
917 let optimizer = ParallelOptimizer::new(4);
918 assert_eq!(optimizer.num_threads, 4);
919 }
920
921 #[test]
922 fn test_numa_topology_detection() {
923 let topology = NumaTopology::detect();
924 assert!(topology.num_nodes > 0);
925 assert!(!topology.cores_per_node.is_empty());
926 }
927
928 #[test]
929 fn test_vectorized_computation() {
930 let optimizer = ParallelOptimizer::new(2);
931 let input = Array2::from_elem((4, 4), 1.0);
932
933 let task = VectorizedComputeTask {
934 input,
935 operation: VectorOperation::ElementWise(ArithmeticOp::Add(2.0)),
936 chunk_size: 2,
937 prefer_simd: true,
938 };
939
940 let result = optimizer.execute_vectorized(task);
941 assert!(result.is_ok());
942
943 let output = result.unwrap();
944 assert_eq!(output.dim(), (4, 4));
945 assert!((output[[0, 0]] - 3.0_f64).abs() < 1e-10);
946 }
947}