1use scirs2_core::ndarray::{s, Array1, Array2, ArrayView1, ArrayView2};
7use sklears_core::{
8 error::{Result as SklResult, SklearsError},
9 traits::{Estimator, Fit, Untrained},
10 types::Float,
11};
12use std::collections::HashMap;
13use std::net::SocketAddr;
14use std::sync::{Arc, Mutex, RwLock};
15use std::thread::{self, JoinHandle};
16use std::time::{Duration, SystemTime};
17
18use crate::{PipelinePredictor, PipelineStep};
19
20pub type NodeId = String;
22
23pub type TaskId = String;
25
26#[derive(Debug, Clone)]
28pub struct ClusterNode {
29 pub id: NodeId,
31 pub address: SocketAddr,
33 pub status: NodeStatus,
35 pub resources: NodeResources,
37 pub load: NodeLoad,
39 pub last_heartbeat: SystemTime,
41 pub metadata: HashMap<String, String>,
43}
44
45#[derive(Debug, Clone, PartialEq)]
47pub enum NodeStatus {
48 Healthy,
50 Stressed,
52 Unavailable,
54 Failed,
56 ShuttingDown,
58}
59
60#[derive(Debug, Clone)]
62pub struct NodeResources {
63 pub cpu_cores: u32,
65 pub memory_mb: u64,
67 pub disk_mb: u64,
69 pub gpu_count: u32,
71 pub network_bandwidth: u32,
73}
74
75#[derive(Debug, Clone)]
77pub struct NodeLoad {
78 pub cpu_utilization: f64,
80 pub memory_utilization: f64,
82 pub disk_utilization: f64,
84 pub network_utilization: f64,
86 pub active_tasks: usize,
88}
89
90impl Default for NodeLoad {
91 fn default() -> Self {
92 Self {
93 cpu_utilization: 0.0,
94 memory_utilization: 0.0,
95 disk_utilization: 0.0,
96 network_utilization: 0.0,
97 active_tasks: 0,
98 }
99 }
100}
101
102#[derive(Debug)]
104pub struct DistributedTask {
105 pub id: TaskId,
107 pub name: String,
109 pub component: Box<dyn PipelineStep>,
111 pub input_shards: Vec<DataShard>,
113 pub dependencies: Vec<TaskId>,
115 pub resource_requirements: ResourceRequirements,
117 pub config: TaskConfig,
119 pub metadata: HashMap<String, String>,
121}
122
123#[derive(Debug, Clone)]
125pub struct DataShard {
126 pub id: String,
128 pub data: Array2<f64>,
130 pub targets: Option<Array1<f64>>,
132 pub metadata: HashMap<String, String>,
134 pub source_node: Option<NodeId>,
136}
137
138#[derive(Debug, Clone)]
140pub struct ResourceRequirements {
141 pub cpu_cores: u32,
143 pub memory_mb: u64,
145 pub disk_mb: u64,
147 pub gpu_required: bool,
149 pub estimated_duration: Duration,
151 pub priority: TaskPriority,
153}
154
155#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
157pub enum TaskPriority {
158 Low,
160 Normal,
162 High,
164 Critical,
166}
167
168#[derive(Debug, Clone)]
170pub struct TaskConfig {
171 pub max_retries: usize,
173 pub timeout: Duration,
175 pub failure_tolerance: FailureTolerance,
177 pub checkpoint_interval: Option<Duration>,
179 pub persist_results: bool,
181}
182
183#[derive(Debug, Clone)]
185pub enum FailureTolerance {
186 FailFast,
188 RetryOnNode { max_retries: usize },
190 MigrateNode,
192 SkipFailed,
194 Fallback {
196 fallback_fn: fn(&DataShard) -> SklResult<Array2<f64>>,
197 },
198}
199
200#[derive(Debug, Clone)]
202pub struct TaskResult {
203 pub task_id: TaskId,
205 pub status: TaskStatus,
207 pub result: Option<Array2<f64>>,
209 pub error: Option<SklearsError>,
211 pub metrics: ExecutionMetrics,
213 pub node_id: NodeId,
215}
216
217#[derive(Debug, Clone, PartialEq)]
219pub enum TaskStatus {
220 Pending,
222 Running,
224 Completed,
226 Failed,
228 Retrying,
230 Cancelled,
232}
233
234#[derive(Debug, Clone)]
236pub struct ExecutionMetrics {
237 pub start_time: SystemTime,
239 pub end_time: Option<SystemTime>,
241 pub duration: Option<Duration>,
243 pub resource_usage: NodeLoad,
245 pub data_transfer: DataTransferMetrics,
247}
248
249#[derive(Debug, Clone)]
251pub struct DataTransferMetrics {
252 pub bytes_sent: u64,
254 pub bytes_received: u64,
256 pub transfer_time: Duration,
258 pub network_errors: usize,
260}
261
262#[derive(Debug)]
264pub struct ClusterManager {
265 nodes: Arc<RwLock<HashMap<NodeId, ClusterNode>>>,
267 active_tasks: Arc<Mutex<HashMap<TaskId, DistributedTask>>>,
269 task_results: Arc<Mutex<HashMap<TaskId, TaskResult>>>,
271 load_balancer: LoadBalancer,
273 fault_detector: FaultDetector,
275 config: ClusterConfig,
277}
278
279#[derive(Debug, Clone)]
281pub struct ClusterConfig {
282 pub heartbeat_interval: Duration,
284 pub failure_timeout: Duration,
286 pub max_tasks_per_node: usize,
288 pub replication_factor: usize,
290 pub load_balancing: LoadBalancingStrategy,
292}
293
294impl Default for ClusterConfig {
295 fn default() -> Self {
296 Self {
297 heartbeat_interval: Duration::from_secs(10),
298 failure_timeout: Duration::from_secs(30),
299 max_tasks_per_node: 10,
300 replication_factor: 2,
301 load_balancing: LoadBalancingStrategy::RoundRobin,
302 }
303 }
304}
305
306#[derive(Debug, Clone)]
308pub enum LoadBalancingStrategy {
309 RoundRobin,
311 LeastLoaded,
313 Random,
315 LocalityAware,
317 Custom {
319 balance_fn: fn(&[ClusterNode], &ResourceRequirements) -> Option<NodeId>,
320 },
321}
322
323#[derive(Debug)]
325pub struct LoadBalancer {
326 strategy: LoadBalancingStrategy,
327 round_robin_index: Mutex<usize>,
328 node_assignments: Arc<Mutex<HashMap<TaskId, NodeId>>>,
329}
330
331impl LoadBalancer {
332 #[must_use]
334 pub fn new(strategy: LoadBalancingStrategy) -> Self {
335 Self {
336 strategy,
337 round_robin_index: Mutex::new(0),
338 node_assignments: Arc::new(Mutex::new(HashMap::new())),
339 }
340 }
341
342 pub fn select_node(
344 &self,
345 nodes: &[ClusterNode],
346 requirements: &ResourceRequirements,
347 ) -> SklResult<NodeId> {
348 let available_nodes: Vec<_> = nodes
349 .iter()
350 .filter(|node| {
351 node.status == NodeStatus::Healthy
352 && self.can_satisfy_requirements(node, requirements)
353 })
354 .collect();
355
356 if available_nodes.is_empty() {
357 return Err(SklearsError::InvalidInput(
358 "No available nodes satisfy requirements".to_string(),
359 ));
360 }
361
362 match &self.strategy {
363 LoadBalancingStrategy::RoundRobin => {
364 let mut index = self
365 .round_robin_index
366 .lock()
367 .unwrap_or_else(|e| e.into_inner());
368 let selected = &available_nodes[*index % available_nodes.len()];
369 *index = (*index + 1) % available_nodes.len();
370 Ok(selected.id.clone())
371 }
372 LoadBalancingStrategy::LeastLoaded => {
373 let least_loaded = available_nodes
374 .iter()
375 .min_by_key(|node| {
376 (node.load.cpu_utilization * 100.0) as u32 + node.load.active_tasks as u32
377 })
378 .expect("should have available nodes");
379 Ok(least_loaded.id.clone())
380 }
381 LoadBalancingStrategy::Random => {
382 use scirs2_core::random::thread_rng;
383 let mut rng = thread_rng();
384 let selected = &available_nodes[rng.gen_range(0..available_nodes.len())];
385 Ok(selected.id.clone())
386 }
387 LoadBalancingStrategy::LocalityAware => {
388 Ok(available_nodes[0].id.clone())
390 }
391 LoadBalancingStrategy::Custom { balance_fn } => {
392 let nodes_vec: Vec<ClusterNode> = available_nodes.into_iter().cloned().collect();
393 balance_fn(&nodes_vec, requirements).ok_or_else(|| {
394 SklearsError::InvalidInput("Custom balancer failed to select node".to_string())
395 })
396 }
397 }
398 }
399
400 fn can_satisfy_requirements(
402 &self,
403 node: &ClusterNode,
404 requirements: &ResourceRequirements,
405 ) -> bool {
406 node.resources.cpu_cores >= requirements.cpu_cores
407 && node.resources.memory_mb >= requirements.memory_mb
408 && node.resources.disk_mb >= requirements.disk_mb
409 && (!requirements.gpu_required || node.resources.gpu_count > 0)
410 && node.load.active_tasks < 10 }
412}
413
414#[derive(Debug)]
416pub struct FaultDetector {
417 failure_history: Arc<Mutex<HashMap<NodeId, Vec<SystemTime>>>>,
419 recovery_strategies: HashMap<String, RecoveryStrategy>,
421}
422
423#[derive(Debug)]
425pub enum RecoveryStrategy {
426 RestartSameNode,
428 MigrateTask,
430 ReplicateTask { replicas: usize },
432 UseCachedResults,
434 SkipTask,
436}
437
438impl Default for FaultDetector {
439 fn default() -> Self {
440 Self::new()
441 }
442}
443
444impl FaultDetector {
445 #[must_use]
447 pub fn new() -> Self {
448 let mut recovery_strategies = HashMap::new();
449 recovery_strategies.insert("node_failure".to_string(), RecoveryStrategy::MigrateTask);
450 recovery_strategies.insert(
451 "task_failure".to_string(),
452 RecoveryStrategy::RestartSameNode,
453 );
454 recovery_strategies.insert(
455 "network_partition".to_string(),
456 RecoveryStrategy::ReplicateTask { replicas: 2 },
457 );
458
459 Self {
460 failure_history: Arc::new(Mutex::new(HashMap::new())),
461 recovery_strategies,
462 }
463 }
464
465 #[must_use]
467 pub fn detect_node_failure(&self, node: &ClusterNode, timeout: Duration) -> bool {
468 node.last_heartbeat.elapsed().unwrap_or(Duration::MAX) > timeout
469 }
470
471 pub fn record_failure(&self, node_id: &NodeId) {
473 let mut history = self
474 .failure_history
475 .lock()
476 .unwrap_or_else(|e| e.into_inner());
477 history
478 .entry(node_id.clone())
479 .or_default()
480 .push(SystemTime::now());
481 }
482
483 #[must_use]
485 pub fn get_recovery_strategy(&self, failure_type: &str) -> Option<&RecoveryStrategy> {
486 self.recovery_strategies.get(failure_type)
487 }
488}
489
490impl ClusterManager {
491 #[must_use]
493 pub fn new(config: ClusterConfig) -> Self {
494 Self {
495 nodes: Arc::new(RwLock::new(HashMap::new())),
496 active_tasks: Arc::new(Mutex::new(HashMap::new())),
497 task_results: Arc::new(Mutex::new(HashMap::new())),
498 load_balancer: LoadBalancer::new(config.load_balancing.clone()),
499 fault_detector: FaultDetector::new(),
500 config,
501 }
502 }
503
504 pub fn add_node(&self, node: ClusterNode) -> SklResult<()> {
506 let mut nodes = self.nodes.write().unwrap_or_else(|e| e.into_inner());
507 nodes.insert(node.id.clone(), node);
508 Ok(())
509 }
510
511 pub fn remove_node(&self, node_id: &NodeId) -> SklResult<()> {
513 let mut nodes = self.nodes.write().unwrap_or_else(|e| e.into_inner());
514 nodes.remove(node_id);
515 Ok(())
516 }
517
518 pub fn submit_task(&self, task: DistributedTask) -> SklResult<TaskId> {
520 let task_id = task.id.clone();
521
522 let nodes = self.nodes.read().unwrap_or_else(|e| e.into_inner());
524 let available_nodes: Vec<ClusterNode> = nodes.values().cloned().collect();
525 drop(nodes);
526
527 let selected_node = self
528 .load_balancer
529 .select_node(&available_nodes, &task.resource_requirements)?;
530
531 let mut active_tasks = self.active_tasks.lock().unwrap_or_else(|e| e.into_inner());
533 active_tasks.insert(task_id.clone(), task);
534 drop(active_tasks);
535
536 self.execute_task_on_node(&task_id, &selected_node)?;
538
539 Ok(task_id)
540 }
541
542 fn execute_task_on_node(&self, task_id: &TaskId, node_id: &NodeId) -> SklResult<()> {
544 let active_tasks = self.active_tasks.lock().unwrap_or_else(|e| e.into_inner());
545 let task = active_tasks
546 .get(task_id)
547 .ok_or_else(|| SklearsError::InvalidInput(format!("Task {task_id} not found")))?;
548
549 let start_time = SystemTime::now();
550 let mut metrics = ExecutionMetrics {
551 start_time,
552 end_time: None,
553 duration: None,
554 resource_usage: NodeLoad::default(),
555 data_transfer: DataTransferMetrics {
556 bytes_sent: 0,
557 bytes_received: 0,
558 transfer_time: Duration::ZERO,
559 network_errors: 0,
560 },
561 };
562
563 let result = self.execute_pipeline_component(&task.component, &task.input_shards);
565
566 let end_time = SystemTime::now();
567 metrics.end_time = Some(end_time);
568 metrics.duration = start_time.elapsed().ok();
569
570 let (result_data, error_info) = match result {
572 Ok(data) => (Some(data), None),
573 Err(e) => (None, Some(e)),
574 };
575
576 let task_result = TaskResult {
577 task_id: task_id.clone(),
578 status: if result_data.is_some() {
579 TaskStatus::Completed
580 } else {
581 TaskStatus::Failed
582 },
583 result: result_data,
584 error: error_info,
585 metrics,
586 node_id: node_id.clone(),
587 };
588
589 let mut results = self.task_results.lock().unwrap_or_else(|e| e.into_inner());
590 results.insert(task_id.clone(), task_result);
591
592 Ok(())
593 }
594
595 fn execute_pipeline_component(
597 &self,
598 component: &Box<dyn PipelineStep>,
599 shards: &[DataShard],
600 ) -> SklResult<Array2<f64>> {
601 let mut all_results = Vec::new();
602
603 for shard in shards {
604 let mapped_data = shard.data.view().mapv(|v| v as Float);
605 let result = component.transform(&mapped_data.view())?;
606 all_results.push(result);
607 }
608
609 if all_results.is_empty() {
611 return Ok(Array2::zeros((0, 0)));
612 }
613
614 let total_rows: usize = all_results
615 .iter()
616 .map(scirs2_core::ndarray::ArrayBase::nrows)
617 .sum();
618 let n_cols = all_results[0].ncols();
619
620 let mut concatenated = Array2::zeros((total_rows, n_cols));
621 let mut row_idx = 0;
622
623 for result in all_results {
624 let end_idx = row_idx + result.nrows();
625 concatenated
626 .slice_mut(s![row_idx..end_idx, ..])
627 .assign(&result);
628 row_idx = end_idx;
629 }
630
631 Ok(concatenated)
632 }
633
634 pub fn get_task_result(&self, task_id: &TaskId) -> Option<TaskResult> {
636 let results = self.task_results.lock().unwrap_or_else(|e| e.into_inner());
637 results.get(task_id).cloned()
638 }
639
640 pub fn cluster_status(&self) -> ClusterStatus {
642 let nodes = self.nodes.read().unwrap_or_else(|e| e.into_inner());
643 let active_tasks = self.active_tasks.lock().unwrap_or_else(|e| e.into_inner());
644 let task_results = self.task_results.lock().unwrap_or_else(|e| e.into_inner());
645
646 let healthy_nodes = nodes
647 .values()
648 .filter(|n| n.status == NodeStatus::Healthy)
649 .count();
650 let total_nodes = nodes.len();
651 let pending_tasks = active_tasks.len();
652 let completed_tasks = task_results
653 .values()
654 .filter(|r| r.status == TaskStatus::Completed)
655 .count();
656 let failed_tasks = task_results
657 .values()
658 .filter(|r| r.status == TaskStatus::Failed)
659 .count();
660
661 ClusterStatus {
663 total_nodes,
664 healthy_nodes,
665 pending_tasks,
666 completed_tasks,
667 failed_tasks,
668 cluster_load: self.calculate_cluster_load(&nodes),
669 }
670 }
671
672 fn calculate_cluster_load(&self, nodes: &HashMap<NodeId, ClusterNode>) -> f64 {
674 if nodes.is_empty() {
675 return 0.0;
676 }
677
678 let total_load: f64 = nodes.values().map(|node| node.load.cpu_utilization).sum();
679
680 total_load / nodes.len() as f64
681 }
682
683 pub fn start_health_monitoring(&self) -> JoinHandle<()> {
685 let nodes = Arc::clone(&self.nodes);
686 let fault_detector = FaultDetector::new();
687 let heartbeat_interval = self.config.heartbeat_interval;
688 let failure_timeout = self.config.failure_timeout;
689
690 thread::spawn(move || {
691 loop {
692 thread::sleep(heartbeat_interval);
693
694 let mut nodes_guard = nodes.write().unwrap_or_else(|e| e.into_inner());
695 let mut failed_nodes = Vec::new();
696
697 for (node_id, node) in nodes_guard.iter_mut() {
698 if fault_detector.detect_node_failure(node, failure_timeout) {
699 node.status = NodeStatus::Failed;
700 failed_nodes.push(node_id.clone());
701 fault_detector.record_failure(node_id);
702 }
703 }
704
705 drop(nodes_guard);
706
707 for failed_node in failed_nodes {
709 println!("Node {failed_node} has failed");
710 }
711 }
712 })
713 }
714}
715
716#[derive(Debug, Clone)]
718pub struct ClusterStatus {
719 pub total_nodes: usize,
721 pub healthy_nodes: usize,
723 pub pending_tasks: usize,
725 pub completed_tasks: usize,
727 pub failed_tasks: usize,
729 pub cluster_load: f64,
731}
732
733#[derive(Debug)]
735pub struct MapReducePipeline<S = Untrained> {
736 state: S,
737 mapper: Option<Box<dyn PipelineStep>>,
738 reducer: Option<Box<dyn PipelineStep>>,
739 cluster_manager: Arc<ClusterManager>,
740 partitioning_strategy: PartitioningStrategy,
741 map_tasks: Vec<TaskId>,
742 reduce_tasks: Vec<TaskId>,
743}
744
745#[derive(Debug)]
747pub enum PartitioningStrategy {
748 EqualSize { partition_size: usize },
750 HashBased { num_partitions: usize },
752 RangeBased { ranges: Vec<(f64, f64)> },
754 Custom {
756 partition_fn: fn(&Array2<f64>) -> Vec<DataShard>,
757 },
758}
759
760#[derive(Debug)]
762pub struct MapReducePipelineTrained {
763 fitted_mapper: Box<dyn PipelineStep>,
764 fitted_reducer: Box<dyn PipelineStep>,
765 cluster_manager: Arc<ClusterManager>,
766 partitioning_strategy: PartitioningStrategy,
767 n_features_in: usize,
768 feature_names_in: Option<Vec<String>>,
769}
770
771impl MapReducePipeline<Untrained> {
772 pub fn new(
774 mapper: Box<dyn PipelineStep>,
775 reducer: Box<dyn PipelineStep>,
776 cluster_manager: Arc<ClusterManager>,
777 ) -> Self {
778 Self {
779 state: Untrained,
780 mapper: Some(mapper),
781 reducer: Some(reducer),
782 cluster_manager,
783 partitioning_strategy: PartitioningStrategy::EqualSize {
784 partition_size: 1000,
785 },
786 map_tasks: Vec::new(),
787 reduce_tasks: Vec::new(),
788 }
789 }
790
791 #[must_use]
793 pub fn partitioning_strategy(mut self, strategy: PartitioningStrategy) -> Self {
794 self.partitioning_strategy = strategy;
795 self
796 }
797}
798
799impl Estimator for MapReducePipeline<Untrained> {
800 type Config = ();
801 type Error = SklearsError;
802 type Float = Float;
803
804 fn config(&self) -> &Self::Config {
805 &()
806 }
807}
808
809impl Fit<ArrayView2<'_, Float>, Option<&ArrayView1<'_, Float>>> for MapReducePipeline<Untrained> {
810 type Fitted = MapReducePipeline<MapReducePipelineTrained>;
811
812 fn fit(
813 self,
814 x: &ArrayView2<'_, Float>,
815 y: &Option<&ArrayView1<'_, Float>>,
816 ) -> SklResult<Self::Fitted> {
817 let mut mapper = self
818 .mapper
819 .ok_or_else(|| SklearsError::InvalidInput("No mapper provided".to_string()))?;
820
821 let mut reducer = self
822 .reducer
823 .ok_or_else(|| SklearsError::InvalidInput("No reducer provided".to_string()))?;
824
825 mapper.fit(x, y.as_ref().copied())?;
827 reducer.fit(x, y.as_ref().copied())?;
828
829 Ok(MapReducePipeline {
830 state: MapReducePipelineTrained {
831 fitted_mapper: mapper,
832 fitted_reducer: reducer,
833 cluster_manager: self.cluster_manager,
834 partitioning_strategy: self.partitioning_strategy,
835 n_features_in: x.ncols(),
836 feature_names_in: None,
837 },
838 mapper: None,
839 reducer: None,
840 cluster_manager: Arc::new(ClusterManager::new(ClusterConfig::default())),
841 partitioning_strategy: PartitioningStrategy::EqualSize {
842 partition_size: 1000,
843 },
844 map_tasks: Vec::new(),
845 reduce_tasks: Vec::new(),
846 })
847 }
848}
849
850impl MapReducePipeline<MapReducePipelineTrained> {
851 pub fn map_reduce(&mut self, x: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
853 let partitions = self.partition_data(x)?;
855
856 let mut map_task_ids = Vec::new();
858 for (i, partition) in partitions.into_iter().enumerate() {
859 let map_task = DistributedTask {
860 id: format!("map_task_{i}"),
861 name: format!("Map Task {i}"),
862 component: self.state.fitted_mapper.clone_step(),
863 input_shards: vec![partition],
864 dependencies: Vec::new(),
865 resource_requirements: ResourceRequirements {
866 cpu_cores: 1,
867 memory_mb: 512,
868 disk_mb: 100,
869 gpu_required: false,
870 estimated_duration: Duration::from_secs(60),
871 priority: TaskPriority::Normal,
872 },
873 config: TaskConfig {
874 max_retries: 3,
875 timeout: Duration::from_secs(300),
876 failure_tolerance: FailureTolerance::RetryOnNode { max_retries: 2 },
877 checkpoint_interval: None,
878 persist_results: true,
879 },
880 metadata: HashMap::new(),
881 };
882
883 let task_id = self.state.cluster_manager.submit_task(map_task)?;
884 map_task_ids.push(task_id);
885 }
886
887 let map_results = self.wait_for_tasks(&map_task_ids)?;
889
890 let reduce_shard = DataShard {
892 id: "reduce_input".to_string(),
893 data: self.combine_map_results(map_results)?,
894 targets: None,
895 metadata: HashMap::new(),
896 source_node: None,
897 };
898
899 let reduce_task = DistributedTask {
900 id: "reduce_task".to_string(),
901 name: "Reduce Task".to_string(),
902 component: self.state.fitted_reducer.clone_step(),
903 input_shards: vec![reduce_shard],
904 dependencies: map_task_ids,
905 resource_requirements: ResourceRequirements {
906 cpu_cores: 2,
907 memory_mb: 1024,
908 disk_mb: 200,
909 gpu_required: false,
910 estimated_duration: Duration::from_secs(120),
911 priority: TaskPriority::High,
912 },
913 config: TaskConfig {
914 max_retries: 3,
915 timeout: Duration::from_secs(600),
916 failure_tolerance: FailureTolerance::RetryOnNode { max_retries: 2 },
917 checkpoint_interval: None,
918 persist_results: true,
919 },
920 metadata: HashMap::new(),
921 };
922
923 let reduce_task_id = self.state.cluster_manager.submit_task(reduce_task)?;
924
925 let reduce_results = self.wait_for_tasks(&[reduce_task_id])?;
927
928 if let Some(result) = reduce_results.into_iter().next() {
929 Ok(result)
930 } else {
931 Err(SklearsError::InvalidData {
932 reason: "Reduce task produced no result".to_string(),
933 })
934 }
935 }
936
937 fn partition_data(&self, x: &ArrayView2<'_, Float>) -> SklResult<Vec<DataShard>> {
939 match &self.state.partitioning_strategy {
940 PartitioningStrategy::EqualSize { partition_size } => {
941 let mut partitions = Vec::new();
942 let n_rows = x.nrows();
943
944 for (i, chunk_start) in (0..n_rows).step_by(*partition_size).enumerate() {
945 let chunk_end = std::cmp::min(chunk_start + partition_size, n_rows);
946 let chunk = x.slice(s![chunk_start..chunk_end, ..]).to_owned();
947
948 let shard = DataShard {
949 id: format!("partition_{i}"),
950 data: chunk.mapv(|v| v),
951 targets: None,
952 metadata: HashMap::new(),
953 source_node: None,
954 };
955
956 partitions.push(shard);
957 }
958
959 Ok(partitions)
960 }
961 PartitioningStrategy::HashBased { num_partitions } => {
962 let mut partitions: Vec<Vec<usize>> = vec![Vec::new(); *num_partitions];
964
965 for i in 0..x.nrows() {
966 let hash = i % num_partitions; partitions[hash].push(i);
968 }
969
970 let mut shards = Vec::new();
971 for (partition_idx, indices) in partitions.into_iter().enumerate() {
972 if !indices.is_empty() {
973 let mut partition_data = Array2::zeros((indices.len(), x.ncols()));
974 for (row_idx, &original_idx) in indices.iter().enumerate() {
975 partition_data
976 .row_mut(row_idx)
977 .assign(&x.row(original_idx).mapv(|v| v));
978 }
979
980 let shard = DataShard {
981 id: format!("hash_partition_{partition_idx}"),
982 data: partition_data,
983 targets: None,
984 metadata: HashMap::new(),
985 source_node: None,
986 };
987
988 shards.push(shard);
989 }
990 }
991
992 Ok(shards)
993 }
994 PartitioningStrategy::RangeBased { ranges } => {
995 let mut shards = Vec::new();
997
998 for (range_idx, (min_val, max_val)) in ranges.iter().enumerate() {
999 let mut selected_rows = Vec::new();
1000
1001 for i in 0..x.nrows() {
1002 let feature_val = x[[i, 0]];
1003 if feature_val >= *min_val && feature_val < *max_val {
1004 selected_rows.push(i);
1005 }
1006 }
1007
1008 if !selected_rows.is_empty() {
1009 let mut partition_data = Array2::zeros((selected_rows.len(), x.ncols()));
1010 for (row_idx, &original_idx) in selected_rows.iter().enumerate() {
1011 partition_data
1012 .row_mut(row_idx)
1013 .assign(&x.row(original_idx).mapv(|v| v));
1014 }
1015
1016 let shard = DataShard {
1017 id: format!("range_partition_{range_idx}"),
1018 data: partition_data,
1019 targets: None,
1020 metadata: HashMap::new(),
1021 source_node: None,
1022 };
1023
1024 shards.push(shard);
1025 }
1026 }
1027
1028 Ok(shards)
1029 }
1030 PartitioningStrategy::Custom { partition_fn } => Ok(partition_fn(&x.mapv(|v| v))),
1031 }
1032 }
1033
1034 fn wait_for_tasks(&self, task_ids: &[TaskId]) -> SklResult<Vec<Array2<f64>>> {
1036 let mut results = Vec::new();
1037
1038 for task_id in task_ids {
1039 let mut attempts = 0;
1041 const MAX_ATTEMPTS: usize = 100;
1042
1043 loop {
1044 if let Some(task_result) = self.state.cluster_manager.get_task_result(task_id) {
1045 match task_result.status {
1046 TaskStatus::Completed => {
1047 if let Some(result) = task_result.result {
1048 results.push(result);
1049 }
1050 break;
1051 }
1052 TaskStatus::Failed => {
1053 return Err(task_result.error.unwrap_or_else(|| {
1054 SklearsError::InvalidData {
1055 reason: format!("Task {task_id} failed"),
1056 }
1057 }));
1058 }
1059 _ => {
1060 }
1062 }
1063 }
1064
1065 attempts += 1;
1066 if attempts >= MAX_ATTEMPTS {
1067 return Err(SklearsError::InvalidData {
1068 reason: format!("Task {task_id} timed out"),
1069 });
1070 }
1071
1072 thread::sleep(Duration::from_millis(100));
1073 }
1074 }
1075
1076 Ok(results)
1077 }
1078
1079 fn combine_map_results(&self, results: Vec<Array2<f64>>) -> SklResult<Array2<f64>> {
1081 if results.is_empty() {
1082 return Ok(Array2::zeros((0, 0)));
1083 }
1084
1085 let total_rows: usize = results
1086 .iter()
1087 .map(scirs2_core::ndarray::ArrayBase::nrows)
1088 .sum();
1089 let n_cols = results[0].ncols();
1090
1091 let mut combined = Array2::zeros((total_rows, n_cols));
1092 let mut row_idx = 0;
1093
1094 for result in results {
1095 let end_idx = row_idx + result.nrows();
1096 combined.slice_mut(s![row_idx..end_idx, ..]).assign(&result);
1097 row_idx = end_idx;
1098 }
1099
1100 Ok(combined)
1101 }
1102
1103 #[must_use]
1105 pub fn cluster_manager(&self) -> &Arc<ClusterManager> {
1106 &self.state.cluster_manager
1107 }
1108}
1109
1110#[allow(non_snake_case)]
1111#[cfg(test)]
1112mod tests {
1113 use super::*;
1114 use crate::MockTransformer;
1115 use scirs2_core::ndarray::array;
1116 use std::net::{IpAddr, Ipv4Addr};
1117
1118 #[test]
1119 fn test_cluster_node_creation() {
1120 let node = ClusterNode {
1121 id: "node1".to_string(),
1122 address: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080),
1123 status: NodeStatus::Healthy,
1124 resources: NodeResources {
1125 cpu_cores: 4,
1126 memory_mb: 8192,
1127 disk_mb: 100000,
1128 gpu_count: 1,
1129 network_bandwidth: 1000,
1130 },
1131 load: NodeLoad::default(),
1132 last_heartbeat: SystemTime::now(),
1133 metadata: HashMap::new(),
1134 };
1135
1136 assert_eq!(node.id, "node1");
1137 assert_eq!(node.status, NodeStatus::Healthy);
1138 assert_eq!(node.resources.cpu_cores, 4);
1139 }
1140
1141 #[test]
1142 fn test_load_balancer_round_robin() {
1143 let balancer = LoadBalancer::new(LoadBalancingStrategy::RoundRobin);
1144
1145 let nodes = vec![
1146 ClusterNode {
1148 id: "node1".to_string(),
1149 address: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080),
1150 status: NodeStatus::Healthy,
1151 resources: NodeResources {
1152 cpu_cores: 4,
1153 memory_mb: 8192,
1154 disk_mb: 100000,
1155 gpu_count: 0,
1156 network_bandwidth: 1000,
1157 },
1158 load: NodeLoad::default(),
1159 last_heartbeat: SystemTime::now(),
1160 metadata: HashMap::new(),
1161 },
1162 ClusterNode {
1164 id: "node2".to_string(),
1165 address: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8081),
1166 status: NodeStatus::Healthy,
1167 resources: NodeResources {
1168 cpu_cores: 4,
1169 memory_mb: 8192,
1170 disk_mb: 100000,
1171 gpu_count: 0,
1172 network_bandwidth: 1000,
1173 },
1174 load: NodeLoad::default(),
1175 last_heartbeat: SystemTime::now(),
1176 metadata: HashMap::new(),
1177 },
1178 ];
1179
1180 let requirements = ResourceRequirements {
1181 cpu_cores: 1,
1182 memory_mb: 1024,
1183 disk_mb: 1000,
1184 gpu_required: false,
1185 estimated_duration: Duration::from_secs(60),
1186 priority: TaskPriority::Normal,
1187 };
1188
1189 let selected1 = balancer
1190 .select_node(&nodes, &requirements)
1191 .unwrap_or_default();
1192 let selected2 = balancer
1193 .select_node(&nodes, &requirements)
1194 .unwrap_or_default();
1195
1196 assert_ne!(selected1, selected2); }
1198
1199 #[test]
1200 fn test_cluster_manager() {
1201 let config = ClusterConfig::default();
1202 let manager = ClusterManager::new(config);
1203
1204 let node = ClusterNode {
1205 id: "test_node".to_string(),
1206 address: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080),
1207 status: NodeStatus::Healthy,
1208 resources: NodeResources {
1209 cpu_cores: 4,
1210 memory_mb: 8192,
1211 disk_mb: 100000,
1212 gpu_count: 0,
1213 network_bandwidth: 1000,
1214 },
1215 load: NodeLoad::default(),
1216 last_heartbeat: SystemTime::now(),
1217 metadata: HashMap::new(),
1218 };
1219
1220 manager.add_node(node).unwrap_or_default();
1221
1222 let status = manager.cluster_status();
1223 assert_eq!(status.total_nodes, 1);
1224 assert_eq!(status.healthy_nodes, 1);
1225 }
1226
1227 #[test]
1228 fn test_data_shard_creation() {
1229 let data = array![[1.0, 2.0], [3.0, 4.0]];
1230 let targets = array![1.0, 0.0];
1231
1232 let shard = DataShard {
1233 id: "test_shard".to_string(),
1234 data: data.clone(),
1235 targets: Some(targets.clone()),
1236 metadata: HashMap::new(),
1237 source_node: None,
1238 };
1239
1240 assert_eq!(shard.id, "test_shard");
1241 assert_eq!(shard.data, data);
1242 assert_eq!(shard.targets, Some(targets));
1243 }
1244
1245 #[test]
1246 fn test_mapreduce_pipeline_creation() {
1247 let mapper = Box::new(MockTransformer::new());
1248 let reducer = Box::new(MockTransformer::new());
1249 let cluster_manager = Arc::new(ClusterManager::new(ClusterConfig::default()));
1250
1251 let pipeline = MapReducePipeline::new(mapper, reducer, cluster_manager);
1252
1253 assert!(matches!(
1254 pipeline.partitioning_strategy,
1255 PartitioningStrategy::EqualSize {
1256 partition_size: 1000
1257 }
1258 ));
1259 }
1260
1261 #[test]
1262 fn test_fault_detector() {
1263 let detector = FaultDetector::new();
1264
1265 let node = ClusterNode {
1266 id: "test_node".to_string(),
1267 address: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080),
1268 status: NodeStatus::Healthy,
1269 resources: NodeResources {
1270 cpu_cores: 4,
1271 memory_mb: 8192,
1272 disk_mb: 100000,
1273 gpu_count: 0,
1274 network_bandwidth: 1000,
1275 },
1276 load: NodeLoad::default(),
1277 last_heartbeat: SystemTime::now() - Duration::from_secs(60),
1278 metadata: HashMap::new(),
1279 };
1280
1281 let is_failed = detector.detect_node_failure(&node, Duration::from_secs(30));
1282 assert!(is_failed);
1283
1284 detector.record_failure(&node.id);
1285
1286 let strategy = detector.get_recovery_strategy("node_failure");
1287 assert!(matches!(strategy, Some(RecoveryStrategy::MigrateTask)));
1288 }
1289}