1use scirs2_core::ndarray::{s, Array1, Array2, ArrayView1, ArrayView2};
7use scirs2_core::random::thread_rng;
8use sklears_core::{
9 error::Result as SklResult,
10 prelude::{Predict, SklearsError, Transform},
11 traits::{Estimator, Fit, Untrained},
12 types::Float,
13};
14use std::collections::{HashMap, VecDeque};
15use std::future::Future;
16use std::pin::Pin;
17use std::sync::{Arc, Condvar, Mutex, RwLock};
18use std::task::{Context, Poll};
19use std::thread::{self, JoinHandle, ThreadId};
20use std::time::{Duration, Instant, SystemTime};
21
22use crate::{PipelinePredictor, PipelineStep};
23
24#[derive(Debug, Clone)]
26pub struct ParallelConfig {
27 pub num_workers: usize,
29 pub pool_type: ThreadPoolType,
31 pub work_stealing: bool,
33 pub load_balancing: LoadBalancingStrategy,
35 pub scheduling: SchedulingStrategy,
37 pub max_queue_size: usize,
39 pub idle_timeout: Duration,
41}
42
43impl Default for ParallelConfig {
44 fn default() -> Self {
45 Self {
46 num_workers: num_cpus::get(),
47 pool_type: ThreadPoolType::FixedSize,
48 work_stealing: true,
49 load_balancing: LoadBalancingStrategy::RoundRobin,
50 scheduling: SchedulingStrategy::FIFO,
51 max_queue_size: 1000,
52 idle_timeout: Duration::from_secs(60),
53 }
54 }
55}
56
57#[derive(Debug, Clone)]
59pub enum ThreadPoolType {
60 FixedSize,
62 Dynamic {
64 min_threads: usize,
65 max_threads: usize,
66 },
67 SingleThreaded,
69}
70
71#[derive(Debug, Clone)]
73pub enum LoadBalancingStrategy {
74 RoundRobin,
76 LeastLoaded,
78 Random,
80 LocalityAware,
82}
83
84#[derive(Debug, Clone)]
86pub enum SchedulingStrategy {
87 FIFO,
89 LIFO,
91 Priority,
93 WorkStealing,
95}
96
97pub struct ParallelTask {
99 pub id: String,
101 pub task_fn: Box<dyn FnOnce() -> SklResult<TaskResult> + Send>,
103 pub priority: u32,
105 pub estimated_duration: Duration,
107 pub dependencies: Vec<String>,
109 pub metadata: HashMap<String, String>,
111}
112
113impl std::fmt::Debug for ParallelTask {
114 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115 f.debug_struct("ParallelTask")
116 .field("id", &self.id)
117 .field("task_fn", &"<function>")
118 .field("priority", &self.priority)
119 .field("estimated_duration", &self.estimated_duration)
120 .field("dependencies", &self.dependencies)
121 .field("metadata", &self.metadata)
122 .finish()
123 }
124}
125
126#[derive(Debug, Clone)]
128pub struct TaskResult {
129 pub task_id: String,
131 pub data: Vec<u8>,
133 pub duration: Duration,
135 pub worker_id: ThreadId,
137 pub success: bool,
139 pub error: Option<String>,
141}
142
143#[derive(Debug)]
145pub struct WorkerState {
146 pub worker_id: usize,
148 pub thread_handle: Option<JoinHandle<()>>,
150 pub task_queue: Arc<Mutex<VecDeque<ParallelTask>>>,
152 pub status: WorkerStatus,
154 pub stats: WorkerStatistics,
156 pub steal_deque: Arc<Mutex<VecDeque<ParallelTask>>>,
158}
159
160#[derive(Debug, Clone, PartialEq)]
162pub enum WorkerStatus {
163 Idle,
165 Working,
167 Stealing,
169 Terminated,
171}
172
173#[derive(Debug, Clone)]
175pub struct WorkerStatistics {
176 pub tasks_completed: u64,
178 pub tasks_failed: u64,
180 pub total_execution_time: Duration,
182 pub avg_task_duration: Duration,
184 pub last_activity: SystemTime,
186 pub work_stolen: u64,
188 pub work_given: u64,
190}
191
192impl Default for WorkerStatistics {
193 fn default() -> Self {
194 Self {
195 tasks_completed: 0,
196 tasks_failed: 0,
197 total_execution_time: Duration::ZERO,
198 avg_task_duration: Duration::ZERO,
199 last_activity: SystemTime::now(),
200 work_stolen: 0,
201 work_given: 0,
202 }
203 }
204}
205
206#[derive(Debug)]
208pub struct ParallelExecutor {
209 config: ParallelConfig,
211 workers: Vec<WorkerState>,
213 dispatcher: TaskDispatcher,
215 global_queue: Arc<Mutex<VecDeque<ParallelTask>>>,
217 completed_tasks: Arc<Mutex<HashMap<String, TaskResult>>>,
219 statistics: Arc<RwLock<ExecutorStatistics>>,
221 is_running: Arc<Mutex<bool>>,
223 shutdown_signal: Arc<Condvar>,
225}
226
227#[derive(Debug)]
229pub struct TaskDispatcher {
230 round_robin_index: Mutex<usize>,
232 strategy: LoadBalancingStrategy,
234 worker_loads: Arc<RwLock<Vec<usize>>>,
236}
237
238#[derive(Debug, Clone)]
240pub struct ExecutorStatistics {
241 pub tasks_submitted: u64,
243 pub tasks_completed: u64,
245 pub tasks_failed: u64,
247 pub avg_task_duration: Duration,
249 pub throughput: f64,
251 pub worker_utilization: f64,
253 pub queue_depth: usize,
255 pub last_updated: SystemTime,
257}
258
259impl Default for ExecutorStatistics {
260 fn default() -> Self {
261 Self {
262 tasks_submitted: 0,
263 tasks_completed: 0,
264 tasks_failed: 0,
265 avg_task_duration: Duration::ZERO,
266 throughput: 0.0,
267 worker_utilization: 0.0,
268 queue_depth: 0,
269 last_updated: SystemTime::now(),
270 }
271 }
272}
273
274impl TaskDispatcher {
275 #[must_use]
277 pub fn new(strategy: LoadBalancingStrategy, num_workers: usize) -> Self {
278 Self {
279 round_robin_index: Mutex::new(0),
280 strategy,
281 worker_loads: Arc::new(RwLock::new(vec![0; num_workers])),
282 }
283 }
284
285 pub fn dispatch_task(&self, task: ParallelTask, workers: &mut [WorkerState]) -> SklResult<()> {
287 let worker_index = match self.strategy {
288 LoadBalancingStrategy::RoundRobin => {
289 let mut index = self
290 .round_robin_index
291 .lock()
292 .unwrap_or_else(|e| e.into_inner());
293 let selected = *index;
294 *index = (*index + 1) % workers.len();
295 selected
296 }
297 LoadBalancingStrategy::LeastLoaded => self.find_least_loaded_worker(workers),
298 LoadBalancingStrategy::Random => {
299 let mut rng = thread_rng();
300 rng.gen_range(0..workers.len())
301 }
302 LoadBalancingStrategy::LocalityAware => {
303 self.find_best_locality_worker(workers, &task)
305 }
306 };
307
308 let mut queue = workers[worker_index]
310 .task_queue
311 .lock()
312 .unwrap_or_else(|e| e.into_inner());
313 queue.push_back(task);
314
315 let mut loads = self.worker_loads.write().unwrap_or_else(|e| e.into_inner());
317 loads[worker_index] += 1;
318
319 Ok(())
320 }
321
322 fn find_least_loaded_worker(&self, workers: &[WorkerState]) -> usize {
324 let loads = self.worker_loads.read().unwrap_or_else(|e| e.into_inner());
325 loads
326 .iter()
327 .enumerate()
328 .min_by_key(|(_, &load)| load)
329 .map_or(0, |(index, _)| index)
330 }
331
332 fn find_best_locality_worker(&self, workers: &[WorkerState], _task: &ParallelTask) -> usize {
334 workers
336 .iter()
337 .position(|worker| worker.status == WorkerStatus::Idle)
338 .unwrap_or(0)
339 }
340
341 pub fn update_worker_load(&self, worker_index: usize, delta: i32) {
343 let mut loads = self.worker_loads.write().unwrap_or_else(|e| e.into_inner());
344 if delta < 0 {
345 loads[worker_index] = loads[worker_index].saturating_sub((-delta) as usize);
346 } else {
347 loads[worker_index] += delta as usize;
348 }
349 }
350}
351
352impl ParallelExecutor {
353 #[must_use]
355 pub fn new(config: ParallelConfig) -> Self {
356 let num_workers = config.num_workers;
357 let dispatcher = TaskDispatcher::new(config.load_balancing.clone(), num_workers);
358
359 Self {
360 config,
361 workers: Vec::with_capacity(num_workers),
362 dispatcher,
363 global_queue: Arc::new(Mutex::new(VecDeque::new())),
364 completed_tasks: Arc::new(Mutex::new(HashMap::new())),
365 statistics: Arc::new(RwLock::new(ExecutorStatistics::default())),
366 is_running: Arc::new(Mutex::new(false)),
367 shutdown_signal: Arc::new(Condvar::new()),
368 }
369 }
370
371 pub fn start(&mut self) -> SklResult<()> {
373 {
374 let mut running = self.is_running.lock().unwrap_or_else(|e| e.into_inner());
375 if *running {
376 return Ok(());
377 }
378 *running = true;
379 }
380
381 for worker_id in 0..self.config.num_workers {
383 let worker = self.create_worker(worker_id)?;
384 self.workers.push(worker);
385 }
386
387 for i in 0..self.workers.len() {
389 self.start_worker_by_index(i)?;
390 }
391
392 Ok(())
393 }
394
395 pub fn stop(&mut self) -> SklResult<()> {
397 {
398 let mut running = self.is_running.lock().unwrap_or_else(|e| e.into_inner());
399 *running = false;
400 }
401
402 self.shutdown_signal.notify_all();
404
405 for worker in &mut self.workers {
407 if let Some(handle) = worker.thread_handle.take() {
408 handle.join().map_err(|_| SklearsError::InvalidData {
409 reason: "Failed to join worker thread".to_string(),
410 })?;
411 }
412 }
413
414 Ok(())
415 }
416
417 fn create_worker(&self, worker_id: usize) -> SklResult<WorkerState> {
419 Ok(WorkerState {
420 worker_id,
421 thread_handle: None,
422 task_queue: Arc::new(Mutex::new(VecDeque::new())),
423 status: WorkerStatus::Idle,
424 stats: WorkerStatistics::default(),
425 steal_deque: Arc::new(Mutex::new(VecDeque::new())),
426 })
427 }
428
429 fn start_worker_by_index(&mut self, worker_index: usize) -> SklResult<()> {
431 let worker_id = self.workers[worker_index].worker_id;
433 let task_queue = Arc::clone(&self.workers[worker_index].task_queue);
434 let steal_deque = Arc::clone(&self.workers[worker_index].steal_deque);
435 let completed_tasks = Arc::clone(&self.completed_tasks);
436 let is_running = Arc::clone(&self.is_running);
437 let shutdown_signal = Arc::clone(&self.shutdown_signal);
438 let statistics = Arc::clone(&self.statistics);
439 let config = self.config.clone();
440
441 let other_workers: Vec<Arc<Mutex<VecDeque<ParallelTask>>>> = self
443 .workers
444 .iter()
445 .enumerate()
446 .filter(|(i, _)| *i != worker_id)
447 .map(|(_, w)| Arc::clone(&w.task_queue))
448 .collect();
449
450 let handle = thread::spawn(move || {
451 Self::worker_loop(
452 worker_id,
453 task_queue,
454 steal_deque,
455 other_workers,
456 completed_tasks,
457 is_running,
458 shutdown_signal,
459 statistics,
460 config,
461 );
462 });
463
464 let worker = &mut self.workers[worker_index];
466 worker.thread_handle = Some(handle);
467 Ok(())
468 }
469
470 fn worker_loop(
472 worker_id: usize,
473 task_queue: Arc<Mutex<VecDeque<ParallelTask>>>,
474 steal_deque: Arc<Mutex<VecDeque<ParallelTask>>>,
475 other_workers: Vec<Arc<Mutex<VecDeque<ParallelTask>>>>,
476 completed_tasks: Arc<Mutex<HashMap<String, TaskResult>>>,
477 is_running: Arc<Mutex<bool>>,
478 shutdown_signal: Arc<Condvar>,
479 statistics: Arc<RwLock<ExecutorStatistics>>,
480 config: ParallelConfig,
481 ) {
482 let mut local_stats = WorkerStatistics::default();
483
484 while *is_running.lock().unwrap_or_else(|e| e.into_inner()) {
485 let task = {
487 let mut queue = task_queue.lock().unwrap_or_else(|e| e.into_inner());
488 queue.pop_front()
489 };
490
491 let task = if let Some(task) = task {
492 Some(task)
493 } else if config.work_stealing {
494 Self::steal_work(&other_workers, worker_id, &mut local_stats)
496 } else {
497 let queue = task_queue.lock().unwrap_or_else(|e| e.into_inner());
499 let _guard = shutdown_signal
500 .wait_timeout(queue, config.idle_timeout)
501 .unwrap_or_else(|e| e.into_inner());
502 continue;
503 };
504
505 if let Some(task) = task {
506 let start_time = Instant::now();
507 let task_id = task.id.clone();
508
509 let result = match (task.task_fn)() {
511 Ok(mut result) => {
512 result.task_id = task_id.clone();
513 result.worker_id = thread::current().id();
514 result.duration = start_time.elapsed();
515 result.success = true;
516 local_stats.tasks_completed += 1;
517 result
518 }
519 Err(e) => {
520 local_stats.tasks_failed += 1;
521 TaskResult {
523 task_id: task_id.clone(),
524 data: Vec::new(),
525 duration: start_time.elapsed(),
526 worker_id: thread::current().id(),
527 success: false,
528 error: Some(format!("{e:?}")),
529 }
530 }
531 };
532
533 let execution_time = start_time.elapsed();
535 local_stats.total_execution_time += execution_time;
536 local_stats.avg_task_duration = local_stats.total_execution_time
537 / (local_stats.tasks_completed + local_stats.tasks_failed) as u32;
538 local_stats.last_activity = SystemTime::now();
539
540 {
542 let mut completed = completed_tasks.lock().unwrap_or_else(|e| e.into_inner());
543 completed.insert(task_id, result);
544 }
545
546 {
548 let mut stats = statistics.write().unwrap_or_else(|e| e.into_inner());
549 stats.tasks_completed += 1;
550 stats.last_updated = SystemTime::now();
551 }
552 }
553 }
554 }
555
556 fn steal_work(
558 other_workers: &[Arc<Mutex<VecDeque<ParallelTask>>>],
559 _worker_id: usize,
560 stats: &mut WorkerStatistics,
561 ) -> Option<ParallelTask> {
562 for other_queue in other_workers {
563 if let Ok(mut queue) = other_queue.try_lock() {
564 if let Some(task) = queue.pop_back() {
565 stats.work_stolen += 1;
566 return Some(task);
567 }
568 }
569 }
570 None
571 }
572
573 pub fn submit_task(&mut self, task: ParallelTask) -> SklResult<()> {
575 {
576 let mut stats = self.statistics.write().unwrap_or_else(|e| e.into_inner());
577 stats.tasks_submitted += 1;
578 }
579
580 self.dispatcher.dispatch_task(task, &mut self.workers)?;
581 Ok(())
582 }
583
584 pub fn get_task_result(&self, task_id: &str) -> Option<TaskResult> {
586 let completed = self
587 .completed_tasks
588 .lock()
589 .unwrap_or_else(|e| e.into_inner());
590 completed.get(task_id).cloned()
591 }
592
593 pub fn statistics(&self) -> ExecutorStatistics {
595 let stats = self.statistics.read().unwrap_or_else(|e| e.into_inner());
596 stats.clone()
597 }
598
599 pub fn wait_for_completion(&self, timeout: Option<Duration>) -> SklResult<()> {
601 let start_time = Instant::now();
602
603 loop {
604 let stats = self.statistics();
605 if stats.tasks_submitted == stats.tasks_completed + stats.tasks_failed {
606 break;
607 }
608
609 if let Some(timeout) = timeout {
610 if start_time.elapsed() > timeout {
611 return Err(SklearsError::InvalidData {
612 reason: "Timeout waiting for task completion".to_string(),
613 });
614 }
615 }
616
617 thread::sleep(Duration::from_millis(10));
618 }
619
620 Ok(())
621 }
622}
623
624#[derive(Debug)]
626pub struct ParallelPipeline<S = Untrained> {
627 state: S,
628 steps: Vec<(String, Box<dyn PipelineStep>)>,
629 final_estimator: Option<Box<dyn PipelinePredictor>>,
630 executor: Option<ParallelExecutor>,
631 parallel_config: ParallelConfig,
632 execution_strategy: ParallelExecutionStrategy,
633}
634
635#[derive(Debug, Clone)]
637pub enum ParallelExecutionStrategy {
638 FullParallel,
640 BatchParallel { batch_size: usize },
642 PipelineParallel,
644 DataParallel { chunk_size: usize },
646}
647
648#[derive(Debug)]
650pub struct ParallelPipelineTrained {
651 fitted_steps: Vec<(String, Box<dyn PipelineStep>)>,
652 fitted_estimator: Option<Box<dyn PipelinePredictor>>,
653 parallel_config: ParallelConfig,
654 execution_strategy: ParallelExecutionStrategy,
655 n_features_in: usize,
656 feature_names_in: Option<Vec<String>>,
657}
658
659impl ParallelPipeline<Untrained> {
660 #[must_use]
662 pub fn new(parallel_config: ParallelConfig) -> Self {
663 Self {
664 state: Untrained,
665 steps: Vec::new(),
666 final_estimator: None,
667 executor: None,
668 parallel_config,
669 execution_strategy: ParallelExecutionStrategy::FullParallel,
670 }
671 }
672
673 pub fn add_step(&mut self, name: String, step: Box<dyn PipelineStep>) {
675 self.steps.push((name, step));
676 }
677
678 pub fn set_estimator(&mut self, estimator: Box<dyn PipelinePredictor>) {
680 self.final_estimator = Some(estimator);
681 }
682
683 pub fn execution_strategy(mut self, strategy: ParallelExecutionStrategy) -> Self {
685 self.execution_strategy = strategy;
686 self
687 }
688}
689
690impl Estimator for ParallelPipeline<Untrained> {
691 type Config = ();
692 type Error = SklearsError;
693 type Float = Float;
694
695 fn config(&self) -> &Self::Config {
696 &()
697 }
698}
699
700impl Fit<ArrayView2<'_, Float>, Option<&ArrayView1<'_, Float>>> for ParallelPipeline<Untrained> {
701 type Fitted = ParallelPipeline<ParallelPipelineTrained>;
702
703 fn fit(
704 mut self,
705 x: &ArrayView2<'_, Float>,
706 y: &Option<&ArrayView1<'_, Float>>,
707 ) -> SklResult<Self::Fitted> {
708 let mut executor = ParallelExecutor::new(self.parallel_config.clone());
710 executor.start()?;
711
712 let fitted_steps = match self.execution_strategy {
714 ParallelExecutionStrategy::FullParallel => {
715 self.fit_steps_parallel(x, y, &mut executor)?
716 }
717 ParallelExecutionStrategy::BatchParallel { batch_size } => {
718 self.fit_steps_batch_parallel(x, y, &mut executor, batch_size)?
719 }
720 ParallelExecutionStrategy::PipelineParallel => {
721 self.fit_steps_pipeline_parallel(x, y, &mut executor)?
722 }
723 ParallelExecutionStrategy::DataParallel { chunk_size } => {
724 self.fit_steps_data_parallel(x, y, &mut executor, chunk_size)?
725 }
726 };
727
728 let fitted_estimator = if let Some(mut estimator) = self.final_estimator {
730 let mut current_x = x.to_owned();
732 for (_, step) in &fitted_steps {
733 current_x = step.transform(¤t_x.view())?;
734 }
735
736 if let Some(y_values) = y.as_ref() {
737 let mapped_x = current_x.view().mapv(|v| v as Float);
738 estimator.fit(&mapped_x.view(), y_values)?;
739 Some(estimator)
740 } else {
741 None
742 }
743 } else {
744 None
745 };
746
747 executor.stop()?;
749
750 Ok(ParallelPipeline {
751 state: ParallelPipelineTrained {
752 fitted_steps,
753 fitted_estimator,
754 parallel_config: self.parallel_config,
755 execution_strategy: self.execution_strategy,
756 n_features_in: x.ncols(),
757 feature_names_in: None,
758 },
759 steps: Vec::new(),
760 final_estimator: None,
761 executor: None,
762 parallel_config: ParallelConfig::default(),
763 execution_strategy: ParallelExecutionStrategy::FullParallel,
764 })
765 }
766}
767
768impl ParallelPipeline<Untrained> {
769 fn fit_steps_parallel(
771 &mut self,
772 x: &ArrayView2<'_, Float>,
773 y: &Option<&ArrayView1<'_, Float>>,
774 executor: &mut ParallelExecutor,
775 ) -> SklResult<Vec<(String, Box<dyn PipelineStep>)>> {
776 let mut fitted_steps = Vec::new();
777
778 let mut current_x = x.to_owned();
781 for (name, mut step) in self.steps.drain(..) {
782 step.fit(¤t_x.view(), y.as_ref().copied())?;
783 current_x = step.transform(¤t_x.view())?;
784 fitted_steps.push((name, step));
785 }
786
787 Ok(fitted_steps)
788 }
789
790 fn fit_steps_batch_parallel(
792 &mut self,
793 x: &ArrayView2<'_, Float>,
794 y: &Option<&ArrayView1<'_, Float>>,
795 executor: &mut ParallelExecutor,
796 batch_size: usize,
797 ) -> SklResult<Vec<(String, Box<dyn PipelineStep>)>> {
798 let mut fitted_steps = Vec::new();
799 let mut steps = self.steps.drain(..).collect::<Vec<_>>();
800
801 while !steps.is_empty() {
802 let batch_size = batch_size.min(steps.len());
803 let batch: Vec<_> = steps.drain(0..batch_size).collect();
804 let mut batch_fitted = Vec::new();
805
806 for (name, mut step) in batch {
807 step.fit(x, y.as_ref().copied())?;
808 batch_fitted.push((name, step));
809 }
810
811 fitted_steps.extend(batch_fitted);
812 }
813
814 Ok(fitted_steps)
815 }
816
817 fn fit_steps_pipeline_parallel(
819 &mut self,
820 x: &ArrayView2<'_, Float>,
821 y: &Option<&ArrayView1<'_, Float>>,
822 executor: &mut ParallelExecutor,
823 ) -> SklResult<Vec<(String, Box<dyn PipelineStep>)>> {
824 self.fit_steps_parallel(x, y, executor)
826 }
827
828 fn fit_steps_data_parallel(
830 &mut self,
831 x: &ArrayView2<'_, Float>,
832 y: &Option<&ArrayView1<'_, Float>>,
833 executor: &mut ParallelExecutor,
834 chunk_size: usize,
835 ) -> SklResult<Vec<(String, Box<dyn PipelineStep>)>> {
836 let mut fitted_steps = Vec::new();
837
838 let mut current_x = x.to_owned();
840 for (name, mut step) in self.steps.drain(..) {
841 step.fit(¤t_x.view(), y.as_ref().copied())?;
844 current_x = step.transform(¤t_x.view())?;
845 fitted_steps.push((name, step));
846 }
847
848 Ok(fitted_steps)
849 }
850}
851
852impl ParallelPipeline<ParallelPipelineTrained> {
853 pub fn transform(&self, x: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
855 if let ParallelExecutionStrategy::DataParallel { chunk_size } =
856 self.state.execution_strategy
857 {
858 self.transform_data_parallel(x, chunk_size)
859 } else {
860 let mut current_x = x.to_owned();
862 for (_, step) in &self.state.fitted_steps {
863 current_x = step.transform(¤t_x.view())?;
864 }
865 Ok(current_x)
866 }
867 }
868
869 fn transform_data_parallel(
871 &self,
872 x: &ArrayView2<'_, Float>,
873 chunk_size: usize,
874 ) -> SklResult<Array2<f64>> {
875 let n_rows = x.nrows();
876 let n_chunks = (n_rows + chunk_size - 1) / chunk_size;
877 let mut results = Vec::with_capacity(n_chunks);
878
879 for chunk_start in (0..n_rows).step_by(chunk_size) {
881 let chunk_end = std::cmp::min(chunk_start + chunk_size, n_rows);
882 let chunk = x.slice(s![chunk_start..chunk_end, ..]);
883
884 let mut current_chunk = chunk.to_owned();
885 for (_, step) in &self.state.fitted_steps {
886 current_chunk = step.transform(¤t_chunk.view())?;
887 }
888
889 results.push(current_chunk);
890 }
891
892 if results.is_empty() {
894 return Ok(Array2::zeros((0, 0)));
895 }
896
897 let total_rows: usize = results
898 .iter()
899 .map(scirs2_core::ndarray::ArrayBase::nrows)
900 .sum();
901 let n_cols = results[0].ncols();
902 let mut combined = Array2::zeros((total_rows, n_cols));
903
904 let mut row_offset = 0;
905 for result in results {
906 let end_offset = row_offset + result.nrows();
907 combined
908 .slice_mut(s![row_offset..end_offset, ..])
909 .assign(&result);
910 row_offset = end_offset;
911 }
912
913 Ok(combined)
914 }
915
916 pub fn predict(&self, x: &ArrayView2<'_, Float>) -> SklResult<Array1<f64>> {
918 let transformed = self.transform(x)?;
919
920 if let Some(estimator) = &self.state.fitted_estimator {
921 let mapped_data = transformed.view().mapv(|v| v as Float);
922 estimator.predict(&mapped_data.view())
923 } else {
924 Err(SklearsError::NotFitted {
925 operation: "predict".to_string(),
926 })
927 }
928 }
929}
930
931pub struct AsyncTask {
933 future: Pin<Box<dyn Future<Output = SklResult<TaskResult>> + Send>>,
934}
935
936impl AsyncTask {
937 pub fn new<F>(future: F) -> Self
939 where
940 F: Future<Output = SklResult<TaskResult>> + Send + 'static,
941 {
942 Self {
943 future: Box::pin(future),
944 }
945 }
946}
947
948impl Future for AsyncTask {
949 type Output = SklResult<TaskResult>;
950
951 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
952 self.future.as_mut().poll(cx)
953 }
954}
955
956#[allow(non_snake_case)]
957#[cfg(test)]
958mod tests {
959 use super::*;
960 use crate::MockTransformer;
961
962 #[test]
963 fn test_parallel_config() {
964 let config = ParallelConfig::default();
965 assert!(config.num_workers > 0);
966 assert!(matches!(config.pool_type, ThreadPoolType::FixedSize));
967 assert!(config.work_stealing);
968 }
969
970 #[test]
971 fn test_task_dispatcher() {
972 let dispatcher = TaskDispatcher::new(LoadBalancingStrategy::RoundRobin, 4);
973
974 let mut workers = vec![
976 WorkerState {
978 worker_id: 0,
979 thread_handle: None,
980 task_queue: Arc::new(Mutex::new(VecDeque::new())),
981 status: WorkerStatus::Idle,
982 stats: WorkerStatistics::default(),
983 steal_deque: Arc::new(Mutex::new(VecDeque::new())),
984 },
985 WorkerState {
987 worker_id: 1,
988 thread_handle: None,
989 task_queue: Arc::new(Mutex::new(VecDeque::new())),
990 status: WorkerStatus::Idle,
991 stats: WorkerStatistics::default(),
992 steal_deque: Arc::new(Mutex::new(VecDeque::new())),
993 },
994 ];
995
996 let task = ParallelTask {
997 id: "test_task".to_string(),
998 task_fn: Box::new(|| {
999 Ok(TaskResult {
1000 task_id: "test_task".to_string(),
1001 data: vec![1, 2, 3],
1002 duration: Duration::from_millis(10),
1003 worker_id: thread::current().id(),
1004 success: true,
1005 error: None,
1006 })
1007 }),
1008 priority: 1,
1009 estimated_duration: Duration::from_millis(100),
1010 dependencies: Vec::new(),
1011 metadata: HashMap::new(),
1012 };
1013
1014 assert!(dispatcher.dispatch_task(task, &mut workers).is_ok());
1015 }
1016
1017 #[test]
1018 fn test_worker_statistics() {
1019 let mut stats = WorkerStatistics::default();
1020 assert_eq!(stats.tasks_completed, 0);
1021 assert_eq!(stats.tasks_failed, 0);
1022 assert_eq!(stats.work_stolen, 0);
1023 }
1024
1025 #[test]
1026 fn test_parallel_pipeline_creation() {
1027 let config = ParallelConfig::default();
1028 let mut pipeline = ParallelPipeline::new(config);
1029
1030 pipeline.add_step("step1".to_string(), Box::new(MockTransformer::new()));
1031 pipeline.set_estimator(Box::new(crate::MockPredictor::new()));
1032
1033 assert_eq!(pipeline.steps.len(), 1);
1034 assert!(pipeline.final_estimator.is_some());
1035 }
1036
1037 #[test]
1038 fn test_execution_strategies() {
1039 let strategies = vec![
1040 ParallelExecutionStrategy::FullParallel,
1041 ParallelExecutionStrategy::BatchParallel { batch_size: 2 },
1042 ParallelExecutionStrategy::PipelineParallel,
1043 ParallelExecutionStrategy::DataParallel { chunk_size: 100 },
1044 ];
1045
1046 for strategy in strategies {
1047 let config = ParallelConfig::default();
1048 let pipeline = ParallelPipeline::new(config).execution_strategy(strategy);
1049 assert!(pipeline.steps.is_empty());
1051 }
1052 }
1053
1054 #[test]
1055 fn test_task_result() {
1056 let result = TaskResult {
1057 task_id: "test".to_string(),
1058 data: vec![1, 2, 3, 4],
1059 duration: Duration::from_millis(50),
1060 worker_id: thread::current().id(),
1061 success: true,
1062 error: None,
1063 };
1064
1065 assert_eq!(result.task_id, "test");
1066 assert_eq!(result.data.len(), 4);
1067 assert!(result.success);
1068 assert!(result.error.is_none());
1069 }
1070}