1use scirs2_core::ndarray::{s, Array1, Array2, ArrayView1, ArrayView2};
7use scirs2_core::random::thread_rng;
8use sklears_core::{
9 error::Result as SklResult,
10 prelude::{Predict, SklearsError, Transform},
11 traits::{Estimator, Fit, Untrained},
12 types::Float,
13};
14use std::collections::{HashMap, VecDeque};
15use std::future::Future;
16use std::pin::Pin;
17use std::sync::{Arc, Condvar, Mutex, RwLock};
18use std::task::{Context, Poll};
19use std::thread::{self, JoinHandle, ThreadId};
20use std::time::{Duration, Instant, SystemTime};
21
22use crate::{PipelinePredictor, PipelineStep};
23
24#[derive(Debug, Clone)]
26pub struct ParallelConfig {
27 pub num_workers: usize,
29 pub pool_type: ThreadPoolType,
31 pub work_stealing: bool,
33 pub load_balancing: LoadBalancingStrategy,
35 pub scheduling: SchedulingStrategy,
37 pub max_queue_size: usize,
39 pub idle_timeout: Duration,
41}
42
43impl Default for ParallelConfig {
44 fn default() -> Self {
45 Self {
46 num_workers: num_cpus::get(),
47 pool_type: ThreadPoolType::FixedSize,
48 work_stealing: true,
49 load_balancing: LoadBalancingStrategy::RoundRobin,
50 scheduling: SchedulingStrategy::FIFO,
51 max_queue_size: 1000,
52 idle_timeout: Duration::from_secs(60),
53 }
54 }
55}
56
57#[derive(Debug, Clone)]
59pub enum ThreadPoolType {
60 FixedSize,
62 Dynamic {
64 min_threads: usize,
65 max_threads: usize,
66 },
67 SingleThreaded,
69}
70
71#[derive(Debug, Clone)]
73pub enum LoadBalancingStrategy {
74 RoundRobin,
76 LeastLoaded,
78 Random,
80 LocalityAware,
82}
83
84#[derive(Debug, Clone)]
86pub enum SchedulingStrategy {
87 FIFO,
89 LIFO,
91 Priority,
93 WorkStealing,
95}
96
97pub struct ParallelTask {
99 pub id: String,
101 pub task_fn: Box<dyn FnOnce() -> SklResult<TaskResult> + Send>,
103 pub priority: u32,
105 pub estimated_duration: Duration,
107 pub dependencies: Vec<String>,
109 pub metadata: HashMap<String, String>,
111}
112
113impl std::fmt::Debug for ParallelTask {
114 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115 f.debug_struct("ParallelTask")
116 .field("id", &self.id)
117 .field("task_fn", &"<function>")
118 .field("priority", &self.priority)
119 .field("estimated_duration", &self.estimated_duration)
120 .field("dependencies", &self.dependencies)
121 .field("metadata", &self.metadata)
122 .finish()
123 }
124}
125
126#[derive(Debug, Clone)]
128pub struct TaskResult {
129 pub task_id: String,
131 pub data: Vec<u8>,
133 pub duration: Duration,
135 pub worker_id: ThreadId,
137 pub success: bool,
139 pub error: Option<String>,
141}
142
143#[derive(Debug)]
145pub struct WorkerState {
146 pub worker_id: usize,
148 pub thread_handle: Option<JoinHandle<()>>,
150 pub task_queue: Arc<Mutex<VecDeque<ParallelTask>>>,
152 pub status: WorkerStatus,
154 pub stats: WorkerStatistics,
156 pub steal_deque: Arc<Mutex<VecDeque<ParallelTask>>>,
158}
159
160#[derive(Debug, Clone, PartialEq)]
162pub enum WorkerStatus {
163 Idle,
165 Working,
167 Stealing,
169 Terminated,
171}
172
173#[derive(Debug, Clone)]
175pub struct WorkerStatistics {
176 pub tasks_completed: u64,
178 pub tasks_failed: u64,
180 pub total_execution_time: Duration,
182 pub avg_task_duration: Duration,
184 pub last_activity: SystemTime,
186 pub work_stolen: u64,
188 pub work_given: u64,
190}
191
192impl Default for WorkerStatistics {
193 fn default() -> Self {
194 Self {
195 tasks_completed: 0,
196 tasks_failed: 0,
197 total_execution_time: Duration::ZERO,
198 avg_task_duration: Duration::ZERO,
199 last_activity: SystemTime::now(),
200 work_stolen: 0,
201 work_given: 0,
202 }
203 }
204}
205
206#[derive(Debug)]
208pub struct ParallelExecutor {
209 config: ParallelConfig,
211 workers: Vec<WorkerState>,
213 dispatcher: TaskDispatcher,
215 global_queue: Arc<Mutex<VecDeque<ParallelTask>>>,
217 completed_tasks: Arc<Mutex<HashMap<String, TaskResult>>>,
219 statistics: Arc<RwLock<ExecutorStatistics>>,
221 is_running: Arc<Mutex<bool>>,
223 shutdown_signal: Arc<Condvar>,
225}
226
227#[derive(Debug)]
229pub struct TaskDispatcher {
230 round_robin_index: Mutex<usize>,
232 strategy: LoadBalancingStrategy,
234 worker_loads: Arc<RwLock<Vec<usize>>>,
236}
237
238#[derive(Debug, Clone)]
240pub struct ExecutorStatistics {
241 pub tasks_submitted: u64,
243 pub tasks_completed: u64,
245 pub tasks_failed: u64,
247 pub avg_task_duration: Duration,
249 pub throughput: f64,
251 pub worker_utilization: f64,
253 pub queue_depth: usize,
255 pub last_updated: SystemTime,
257}
258
259impl Default for ExecutorStatistics {
260 fn default() -> Self {
261 Self {
262 tasks_submitted: 0,
263 tasks_completed: 0,
264 tasks_failed: 0,
265 avg_task_duration: Duration::ZERO,
266 throughput: 0.0,
267 worker_utilization: 0.0,
268 queue_depth: 0,
269 last_updated: SystemTime::now(),
270 }
271 }
272}
273
274impl TaskDispatcher {
275 #[must_use]
277 pub fn new(strategy: LoadBalancingStrategy, num_workers: usize) -> Self {
278 Self {
279 round_robin_index: Mutex::new(0),
280 strategy,
281 worker_loads: Arc::new(RwLock::new(vec![0; num_workers])),
282 }
283 }
284
285 pub fn dispatch_task(&self, task: ParallelTask, workers: &mut [WorkerState]) -> SklResult<()> {
287 let worker_index = match self.strategy {
288 LoadBalancingStrategy::RoundRobin => {
289 let mut index = self.round_robin_index.lock().unwrap();
290 let selected = *index;
291 *index = (*index + 1) % workers.len();
292 selected
293 }
294 LoadBalancingStrategy::LeastLoaded => self.find_least_loaded_worker(workers),
295 LoadBalancingStrategy::Random => {
296 let mut rng = thread_rng();
297 rng.gen_range(0..workers.len())
298 }
299 LoadBalancingStrategy::LocalityAware => {
300 self.find_best_locality_worker(workers, &task)
302 }
303 };
304
305 let mut queue = workers[worker_index].task_queue.lock().unwrap();
307 queue.push_back(task);
308
309 let mut loads = self.worker_loads.write().unwrap();
311 loads[worker_index] += 1;
312
313 Ok(())
314 }
315
316 fn find_least_loaded_worker(&self, workers: &[WorkerState]) -> usize {
318 let loads = self.worker_loads.read().unwrap();
319 loads
320 .iter()
321 .enumerate()
322 .min_by_key(|(_, &load)| load)
323 .map_or(0, |(index, _)| index)
324 }
325
326 fn find_best_locality_worker(&self, workers: &[WorkerState], _task: &ParallelTask) -> usize {
328 workers
330 .iter()
331 .position(|worker| worker.status == WorkerStatus::Idle)
332 .unwrap_or(0)
333 }
334
335 pub fn update_worker_load(&self, worker_index: usize, delta: i32) {
337 let mut loads = self.worker_loads.write().unwrap();
338 if delta < 0 {
339 loads[worker_index] = loads[worker_index].saturating_sub((-delta) as usize);
340 } else {
341 loads[worker_index] += delta as usize;
342 }
343 }
344}
345
346impl ParallelExecutor {
347 #[must_use]
349 pub fn new(config: ParallelConfig) -> Self {
350 let num_workers = config.num_workers;
351 let dispatcher = TaskDispatcher::new(config.load_balancing.clone(), num_workers);
352
353 Self {
354 config,
355 workers: Vec::with_capacity(num_workers),
356 dispatcher,
357 global_queue: Arc::new(Mutex::new(VecDeque::new())),
358 completed_tasks: Arc::new(Mutex::new(HashMap::new())),
359 statistics: Arc::new(RwLock::new(ExecutorStatistics::default())),
360 is_running: Arc::new(Mutex::new(false)),
361 shutdown_signal: Arc::new(Condvar::new()),
362 }
363 }
364
365 pub fn start(&mut self) -> SklResult<()> {
367 {
368 let mut running = self.is_running.lock().unwrap();
369 if *running {
370 return Ok(());
371 }
372 *running = true;
373 }
374
375 for worker_id in 0..self.config.num_workers {
377 let worker = self.create_worker(worker_id)?;
378 self.workers.push(worker);
379 }
380
381 for i in 0..self.workers.len() {
383 self.start_worker_by_index(i)?;
384 }
385
386 Ok(())
387 }
388
389 pub fn stop(&mut self) -> SklResult<()> {
391 {
392 let mut running = self.is_running.lock().unwrap();
393 *running = false;
394 }
395
396 self.shutdown_signal.notify_all();
398
399 for worker in &mut self.workers {
401 if let Some(handle) = worker.thread_handle.take() {
402 handle.join().map_err(|_| SklearsError::InvalidData {
403 reason: "Failed to join worker thread".to_string(),
404 })?;
405 }
406 }
407
408 Ok(())
409 }
410
411 fn create_worker(&self, worker_id: usize) -> SklResult<WorkerState> {
413 Ok(WorkerState {
414 worker_id,
415 thread_handle: None,
416 task_queue: Arc::new(Mutex::new(VecDeque::new())),
417 status: WorkerStatus::Idle,
418 stats: WorkerStatistics::default(),
419 steal_deque: Arc::new(Mutex::new(VecDeque::new())),
420 })
421 }
422
423 fn start_worker_by_index(&mut self, worker_index: usize) -> SklResult<()> {
425 let worker_id = self.workers[worker_index].worker_id;
427 let task_queue = Arc::clone(&self.workers[worker_index].task_queue);
428 let steal_deque = Arc::clone(&self.workers[worker_index].steal_deque);
429 let completed_tasks = Arc::clone(&self.completed_tasks);
430 let is_running = Arc::clone(&self.is_running);
431 let shutdown_signal = Arc::clone(&self.shutdown_signal);
432 let statistics = Arc::clone(&self.statistics);
433 let config = self.config.clone();
434
435 let other_workers: Vec<Arc<Mutex<VecDeque<ParallelTask>>>> = self
437 .workers
438 .iter()
439 .enumerate()
440 .filter(|(i, _)| *i != worker_id)
441 .map(|(_, w)| Arc::clone(&w.task_queue))
442 .collect();
443
444 let handle = thread::spawn(move || {
445 Self::worker_loop(
446 worker_id,
447 task_queue,
448 steal_deque,
449 other_workers,
450 completed_tasks,
451 is_running,
452 shutdown_signal,
453 statistics,
454 config,
455 );
456 });
457
458 let worker = &mut self.workers[worker_index];
460 worker.thread_handle = Some(handle);
461 Ok(())
462 }
463
464 fn worker_loop(
466 worker_id: usize,
467 task_queue: Arc<Mutex<VecDeque<ParallelTask>>>,
468 steal_deque: Arc<Mutex<VecDeque<ParallelTask>>>,
469 other_workers: Vec<Arc<Mutex<VecDeque<ParallelTask>>>>,
470 completed_tasks: Arc<Mutex<HashMap<String, TaskResult>>>,
471 is_running: Arc<Mutex<bool>>,
472 shutdown_signal: Arc<Condvar>,
473 statistics: Arc<RwLock<ExecutorStatistics>>,
474 config: ParallelConfig,
475 ) {
476 let mut local_stats = WorkerStatistics::default();
477
478 while *is_running.lock().unwrap() {
479 let task = {
481 let mut queue = task_queue.lock().unwrap();
482 queue.pop_front()
483 };
484
485 let task = if let Some(task) = task {
486 Some(task)
487 } else if config.work_stealing {
488 Self::steal_work(&other_workers, worker_id, &mut local_stats)
490 } else {
491 let queue = task_queue.lock().unwrap();
493 let _guard = shutdown_signal
494 .wait_timeout(queue, config.idle_timeout)
495 .unwrap();
496 continue;
497 };
498
499 if let Some(task) = task {
500 let start_time = Instant::now();
501 let task_id = task.id.clone();
502
503 let result = match (task.task_fn)() {
505 Ok(mut result) => {
506 result.task_id = task_id.clone();
507 result.worker_id = thread::current().id();
508 result.duration = start_time.elapsed();
509 result.success = true;
510 local_stats.tasks_completed += 1;
511 result
512 }
513 Err(e) => {
514 local_stats.tasks_failed += 1;
515 TaskResult {
517 task_id: task_id.clone(),
518 data: Vec::new(),
519 duration: start_time.elapsed(),
520 worker_id: thread::current().id(),
521 success: false,
522 error: Some(format!("{e:?}")),
523 }
524 }
525 };
526
527 let execution_time = start_time.elapsed();
529 local_stats.total_execution_time += execution_time;
530 local_stats.avg_task_duration = local_stats.total_execution_time
531 / (local_stats.tasks_completed + local_stats.tasks_failed) as u32;
532 local_stats.last_activity = SystemTime::now();
533
534 {
536 let mut completed = completed_tasks.lock().unwrap();
537 completed.insert(task_id, result);
538 }
539
540 {
542 let mut stats = statistics.write().unwrap();
543 stats.tasks_completed += 1;
544 stats.last_updated = SystemTime::now();
545 }
546 }
547 }
548 }
549
550 fn steal_work(
552 other_workers: &[Arc<Mutex<VecDeque<ParallelTask>>>],
553 _worker_id: usize,
554 stats: &mut WorkerStatistics,
555 ) -> Option<ParallelTask> {
556 for other_queue in other_workers {
557 if let Ok(mut queue) = other_queue.try_lock() {
558 if let Some(task) = queue.pop_back() {
559 stats.work_stolen += 1;
560 return Some(task);
561 }
562 }
563 }
564 None
565 }
566
567 pub fn submit_task(&mut self, task: ParallelTask) -> SklResult<()> {
569 {
570 let mut stats = self.statistics.write().unwrap();
571 stats.tasks_submitted += 1;
572 }
573
574 self.dispatcher.dispatch_task(task, &mut self.workers)?;
575 Ok(())
576 }
577
578 pub fn get_task_result(&self, task_id: &str) -> Option<TaskResult> {
580 let completed = self.completed_tasks.lock().unwrap();
581 completed.get(task_id).cloned()
582 }
583
584 pub fn statistics(&self) -> ExecutorStatistics {
586 let stats = self.statistics.read().unwrap();
587 stats.clone()
588 }
589
590 pub fn wait_for_completion(&self, timeout: Option<Duration>) -> SklResult<()> {
592 let start_time = Instant::now();
593
594 loop {
595 let stats = self.statistics();
596 if stats.tasks_submitted == stats.tasks_completed + stats.tasks_failed {
597 break;
598 }
599
600 if let Some(timeout) = timeout {
601 if start_time.elapsed() > timeout {
602 return Err(SklearsError::InvalidData {
603 reason: "Timeout waiting for task completion".to_string(),
604 });
605 }
606 }
607
608 thread::sleep(Duration::from_millis(10));
609 }
610
611 Ok(())
612 }
613}
614
615#[derive(Debug)]
617pub struct ParallelPipeline<S = Untrained> {
618 state: S,
619 steps: Vec<(String, Box<dyn PipelineStep>)>,
620 final_estimator: Option<Box<dyn PipelinePredictor>>,
621 executor: Option<ParallelExecutor>,
622 parallel_config: ParallelConfig,
623 execution_strategy: ParallelExecutionStrategy,
624}
625
626#[derive(Debug, Clone)]
628pub enum ParallelExecutionStrategy {
629 FullParallel,
631 BatchParallel { batch_size: usize },
633 PipelineParallel,
635 DataParallel { chunk_size: usize },
637}
638
639#[derive(Debug)]
641pub struct ParallelPipelineTrained {
642 fitted_steps: Vec<(String, Box<dyn PipelineStep>)>,
643 fitted_estimator: Option<Box<dyn PipelinePredictor>>,
644 parallel_config: ParallelConfig,
645 execution_strategy: ParallelExecutionStrategy,
646 n_features_in: usize,
647 feature_names_in: Option<Vec<String>>,
648}
649
650impl ParallelPipeline<Untrained> {
651 #[must_use]
653 pub fn new(parallel_config: ParallelConfig) -> Self {
654 Self {
655 state: Untrained,
656 steps: Vec::new(),
657 final_estimator: None,
658 executor: None,
659 parallel_config,
660 execution_strategy: ParallelExecutionStrategy::FullParallel,
661 }
662 }
663
664 pub fn add_step(&mut self, name: String, step: Box<dyn PipelineStep>) {
666 self.steps.push((name, step));
667 }
668
669 pub fn set_estimator(&mut self, estimator: Box<dyn PipelinePredictor>) {
671 self.final_estimator = Some(estimator);
672 }
673
674 pub fn execution_strategy(mut self, strategy: ParallelExecutionStrategy) -> Self {
676 self.execution_strategy = strategy;
677 self
678 }
679}
680
681impl Estimator for ParallelPipeline<Untrained> {
682 type Config = ();
683 type Error = SklearsError;
684 type Float = Float;
685
686 fn config(&self) -> &Self::Config {
687 &()
688 }
689}
690
691impl Fit<ArrayView2<'_, Float>, Option<&ArrayView1<'_, Float>>> for ParallelPipeline<Untrained> {
692 type Fitted = ParallelPipeline<ParallelPipelineTrained>;
693
694 fn fit(
695 mut self,
696 x: &ArrayView2<'_, Float>,
697 y: &Option<&ArrayView1<'_, Float>>,
698 ) -> SklResult<Self::Fitted> {
699 let mut executor = ParallelExecutor::new(self.parallel_config.clone());
701 executor.start()?;
702
703 let fitted_steps = match self.execution_strategy {
705 ParallelExecutionStrategy::FullParallel => {
706 self.fit_steps_parallel(x, y, &mut executor)?
707 }
708 ParallelExecutionStrategy::BatchParallel { batch_size } => {
709 self.fit_steps_batch_parallel(x, y, &mut executor, batch_size)?
710 }
711 ParallelExecutionStrategy::PipelineParallel => {
712 self.fit_steps_pipeline_parallel(x, y, &mut executor)?
713 }
714 ParallelExecutionStrategy::DataParallel { chunk_size } => {
715 self.fit_steps_data_parallel(x, y, &mut executor, chunk_size)?
716 }
717 };
718
719 let fitted_estimator = if let Some(mut estimator) = self.final_estimator {
721 let mut current_x = x.to_owned();
723 for (_, step) in &fitted_steps {
724 current_x = step.transform(¤t_x.view())?;
725 }
726
727 if let Some(y_values) = y.as_ref() {
728 let mapped_x = current_x.view().mapv(|v| v as Float);
729 estimator.fit(&mapped_x.view(), y_values)?;
730 Some(estimator)
731 } else {
732 None
733 }
734 } else {
735 None
736 };
737
738 executor.stop()?;
740
741 Ok(ParallelPipeline {
742 state: ParallelPipelineTrained {
743 fitted_steps,
744 fitted_estimator,
745 parallel_config: self.parallel_config,
746 execution_strategy: self.execution_strategy,
747 n_features_in: x.ncols(),
748 feature_names_in: None,
749 },
750 steps: Vec::new(),
751 final_estimator: None,
752 executor: None,
753 parallel_config: ParallelConfig::default(),
754 execution_strategy: ParallelExecutionStrategy::FullParallel,
755 })
756 }
757}
758
759impl ParallelPipeline<Untrained> {
760 fn fit_steps_parallel(
762 &mut self,
763 x: &ArrayView2<'_, Float>,
764 y: &Option<&ArrayView1<'_, Float>>,
765 executor: &mut ParallelExecutor,
766 ) -> SklResult<Vec<(String, Box<dyn PipelineStep>)>> {
767 let mut fitted_steps = Vec::new();
768
769 let mut current_x = x.to_owned();
772 for (name, mut step) in self.steps.drain(..) {
773 step.fit(¤t_x.view(), y.as_ref().copied())?;
774 current_x = step.transform(¤t_x.view())?;
775 fitted_steps.push((name, step));
776 }
777
778 Ok(fitted_steps)
779 }
780
781 fn fit_steps_batch_parallel(
783 &mut self,
784 x: &ArrayView2<'_, Float>,
785 y: &Option<&ArrayView1<'_, Float>>,
786 executor: &mut ParallelExecutor,
787 batch_size: usize,
788 ) -> SklResult<Vec<(String, Box<dyn PipelineStep>)>> {
789 let mut fitted_steps = Vec::new();
790 let mut steps = self.steps.drain(..).collect::<Vec<_>>();
791
792 while !steps.is_empty() {
793 let batch_size = batch_size.min(steps.len());
794 let batch: Vec<_> = steps.drain(0..batch_size).collect();
795 let mut batch_fitted = Vec::new();
796
797 for (name, mut step) in batch {
798 step.fit(x, y.as_ref().copied())?;
799 batch_fitted.push((name, step));
800 }
801
802 fitted_steps.extend(batch_fitted);
803 }
804
805 Ok(fitted_steps)
806 }
807
808 fn fit_steps_pipeline_parallel(
810 &mut self,
811 x: &ArrayView2<'_, Float>,
812 y: &Option<&ArrayView1<'_, Float>>,
813 executor: &mut ParallelExecutor,
814 ) -> SklResult<Vec<(String, Box<dyn PipelineStep>)>> {
815 self.fit_steps_parallel(x, y, executor)
817 }
818
819 fn fit_steps_data_parallel(
821 &mut self,
822 x: &ArrayView2<'_, Float>,
823 y: &Option<&ArrayView1<'_, Float>>,
824 executor: &mut ParallelExecutor,
825 chunk_size: usize,
826 ) -> SklResult<Vec<(String, Box<dyn PipelineStep>)>> {
827 let mut fitted_steps = Vec::new();
828
829 let mut current_x = x.to_owned();
831 for (name, mut step) in self.steps.drain(..) {
832 step.fit(¤t_x.view(), y.as_ref().copied())?;
835 current_x = step.transform(¤t_x.view())?;
836 fitted_steps.push((name, step));
837 }
838
839 Ok(fitted_steps)
840 }
841}
842
843impl ParallelPipeline<ParallelPipelineTrained> {
844 pub fn transform(&self, x: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
846 if let ParallelExecutionStrategy::DataParallel { chunk_size } =
847 self.state.execution_strategy
848 {
849 self.transform_data_parallel(x, chunk_size)
850 } else {
851 let mut current_x = x.to_owned();
853 for (_, step) in &self.state.fitted_steps {
854 current_x = step.transform(¤t_x.view())?;
855 }
856 Ok(current_x)
857 }
858 }
859
860 fn transform_data_parallel(
862 &self,
863 x: &ArrayView2<'_, Float>,
864 chunk_size: usize,
865 ) -> SklResult<Array2<f64>> {
866 let n_rows = x.nrows();
867 let n_chunks = (n_rows + chunk_size - 1) / chunk_size;
868 let mut results = Vec::with_capacity(n_chunks);
869
870 for chunk_start in (0..n_rows).step_by(chunk_size) {
872 let chunk_end = std::cmp::min(chunk_start + chunk_size, n_rows);
873 let chunk = x.slice(s![chunk_start..chunk_end, ..]);
874
875 let mut current_chunk = chunk.to_owned();
876 for (_, step) in &self.state.fitted_steps {
877 current_chunk = step.transform(¤t_chunk.view())?;
878 }
879
880 results.push(current_chunk);
881 }
882
883 if results.is_empty() {
885 return Ok(Array2::zeros((0, 0)));
886 }
887
888 let total_rows: usize = results
889 .iter()
890 .map(scirs2_core::ndarray::ArrayBase::nrows)
891 .sum();
892 let n_cols = results[0].ncols();
893 let mut combined = Array2::zeros((total_rows, n_cols));
894
895 let mut row_offset = 0;
896 for result in results {
897 let end_offset = row_offset + result.nrows();
898 combined
899 .slice_mut(s![row_offset..end_offset, ..])
900 .assign(&result);
901 row_offset = end_offset;
902 }
903
904 Ok(combined)
905 }
906
907 pub fn predict(&self, x: &ArrayView2<'_, Float>) -> SklResult<Array1<f64>> {
909 let transformed = self.transform(x)?;
910
911 if let Some(estimator) = &self.state.fitted_estimator {
912 let mapped_data = transformed.view().mapv(|v| v as Float);
913 estimator.predict(&mapped_data.view())
914 } else {
915 Err(SklearsError::NotFitted {
916 operation: "predict".to_string(),
917 })
918 }
919 }
920}
921
922pub struct AsyncTask {
924 future: Pin<Box<dyn Future<Output = SklResult<TaskResult>> + Send>>,
925}
926
927impl AsyncTask {
928 pub fn new<F>(future: F) -> Self
930 where
931 F: Future<Output = SklResult<TaskResult>> + Send + 'static,
932 {
933 Self {
934 future: Box::pin(future),
935 }
936 }
937}
938
939impl Future for AsyncTask {
940 type Output = SklResult<TaskResult>;
941
942 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
943 self.future.as_mut().poll(cx)
944 }
945}
946
947#[allow(non_snake_case)]
948#[cfg(test)]
949mod tests {
950 use super::*;
951 use crate::MockTransformer;
952
953 #[test]
954 fn test_parallel_config() {
955 let config = ParallelConfig::default();
956 assert!(config.num_workers > 0);
957 assert!(matches!(config.pool_type, ThreadPoolType::FixedSize));
958 assert!(config.work_stealing);
959 }
960
961 #[test]
962 fn test_task_dispatcher() {
963 let dispatcher = TaskDispatcher::new(LoadBalancingStrategy::RoundRobin, 4);
964
965 let mut workers = vec![
967 WorkerState {
969 worker_id: 0,
970 thread_handle: None,
971 task_queue: Arc::new(Mutex::new(VecDeque::new())),
972 status: WorkerStatus::Idle,
973 stats: WorkerStatistics::default(),
974 steal_deque: Arc::new(Mutex::new(VecDeque::new())),
975 },
976 WorkerState {
978 worker_id: 1,
979 thread_handle: None,
980 task_queue: Arc::new(Mutex::new(VecDeque::new())),
981 status: WorkerStatus::Idle,
982 stats: WorkerStatistics::default(),
983 steal_deque: Arc::new(Mutex::new(VecDeque::new())),
984 },
985 ];
986
987 let task = ParallelTask {
988 id: "test_task".to_string(),
989 task_fn: Box::new(|| {
990 Ok(TaskResult {
991 task_id: "test_task".to_string(),
992 data: vec![1, 2, 3],
993 duration: Duration::from_millis(10),
994 worker_id: thread::current().id(),
995 success: true,
996 error: None,
997 })
998 }),
999 priority: 1,
1000 estimated_duration: Duration::from_millis(100),
1001 dependencies: Vec::new(),
1002 metadata: HashMap::new(),
1003 };
1004
1005 assert!(dispatcher.dispatch_task(task, &mut workers).is_ok());
1006 }
1007
1008 #[test]
1009 fn test_worker_statistics() {
1010 let mut stats = WorkerStatistics::default();
1011 assert_eq!(stats.tasks_completed, 0);
1012 assert_eq!(stats.tasks_failed, 0);
1013 assert_eq!(stats.work_stolen, 0);
1014 }
1015
1016 #[test]
1017 fn test_parallel_pipeline_creation() {
1018 let config = ParallelConfig::default();
1019 let mut pipeline = ParallelPipeline::new(config);
1020
1021 pipeline.add_step("step1".to_string(), Box::new(MockTransformer::new()));
1022 pipeline.set_estimator(Box::new(crate::MockPredictor::new()));
1023
1024 assert_eq!(pipeline.steps.len(), 1);
1025 assert!(pipeline.final_estimator.is_some());
1026 }
1027
1028 #[test]
1029 fn test_execution_strategies() {
1030 let strategies = vec![
1031 ParallelExecutionStrategy::FullParallel,
1032 ParallelExecutionStrategy::BatchParallel { batch_size: 2 },
1033 ParallelExecutionStrategy::PipelineParallel,
1034 ParallelExecutionStrategy::DataParallel { chunk_size: 100 },
1035 ];
1036
1037 for strategy in strategies {
1038 let config = ParallelConfig::default();
1039 let pipeline = ParallelPipeline::new(config).execution_strategy(strategy);
1040 assert!(pipeline.steps.is_empty());
1042 }
1043 }
1044
1045 #[test]
1046 fn test_task_result() {
1047 let result = TaskResult {
1048 task_id: "test".to_string(),
1049 data: vec![1, 2, 3, 4],
1050 duration: Duration::from_millis(50),
1051 worker_id: thread::current().id(),
1052 success: true,
1053 error: None,
1054 };
1055
1056 assert_eq!(result.task_id, "test");
1057 assert_eq!(result.data.len(), 4);
1058 assert!(result.success);
1059 assert!(result.error.is_none());
1060 }
1061}