1use scirs2_core::ndarray::{s, Array1, Array2, ArrayView1, ArrayView2};
7use sklears_core::{
8 error::{Result as SklResult, SklearsError},
9 traits::{Estimator, Fit, Untrained},
10 types::Float,
11};
12use std::collections::HashMap;
13use std::net::SocketAddr;
14use std::sync::{Arc, Mutex, RwLock};
15use std::thread::{self, JoinHandle};
16use std::time::{Duration, SystemTime};
17
18use crate::{PipelinePredictor, PipelineStep};
19
20pub type NodeId = String;
22
23pub type TaskId = String;
25
26#[derive(Debug, Clone)]
28pub struct ClusterNode {
29 pub id: NodeId,
31 pub address: SocketAddr,
33 pub status: NodeStatus,
35 pub resources: NodeResources,
37 pub load: NodeLoad,
39 pub last_heartbeat: SystemTime,
41 pub metadata: HashMap<String, String>,
43}
44
45#[derive(Debug, Clone, PartialEq)]
47pub enum NodeStatus {
48 Healthy,
50 Stressed,
52 Unavailable,
54 Failed,
56 ShuttingDown,
58}
59
60#[derive(Debug, Clone)]
62pub struct NodeResources {
63 pub cpu_cores: u32,
65 pub memory_mb: u64,
67 pub disk_mb: u64,
69 pub gpu_count: u32,
71 pub network_bandwidth: u32,
73}
74
75#[derive(Debug, Clone)]
77pub struct NodeLoad {
78 pub cpu_utilization: f64,
80 pub memory_utilization: f64,
82 pub disk_utilization: f64,
84 pub network_utilization: f64,
86 pub active_tasks: usize,
88}
89
90impl Default for NodeLoad {
91 fn default() -> Self {
92 Self {
93 cpu_utilization: 0.0,
94 memory_utilization: 0.0,
95 disk_utilization: 0.0,
96 network_utilization: 0.0,
97 active_tasks: 0,
98 }
99 }
100}
101
102#[derive(Debug)]
104pub struct DistributedTask {
105 pub id: TaskId,
107 pub name: String,
109 pub component: Box<dyn PipelineStep>,
111 pub input_shards: Vec<DataShard>,
113 pub dependencies: Vec<TaskId>,
115 pub resource_requirements: ResourceRequirements,
117 pub config: TaskConfig,
119 pub metadata: HashMap<String, String>,
121}
122
123#[derive(Debug, Clone)]
125pub struct DataShard {
126 pub id: String,
128 pub data: Array2<f64>,
130 pub targets: Option<Array1<f64>>,
132 pub metadata: HashMap<String, String>,
134 pub source_node: Option<NodeId>,
136}
137
138#[derive(Debug, Clone)]
140pub struct ResourceRequirements {
141 pub cpu_cores: u32,
143 pub memory_mb: u64,
145 pub disk_mb: u64,
147 pub gpu_required: bool,
149 pub estimated_duration: Duration,
151 pub priority: TaskPriority,
153}
154
155#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
157pub enum TaskPriority {
158 Low,
160 Normal,
162 High,
164 Critical,
166}
167
168#[derive(Debug, Clone)]
170pub struct TaskConfig {
171 pub max_retries: usize,
173 pub timeout: Duration,
175 pub failure_tolerance: FailureTolerance,
177 pub checkpoint_interval: Option<Duration>,
179 pub persist_results: bool,
181}
182
183#[derive(Debug, Clone)]
185pub enum FailureTolerance {
186 FailFast,
188 RetryOnNode { max_retries: usize },
190 MigrateNode,
192 SkipFailed,
194 Fallback {
196 fallback_fn: fn(&DataShard) -> SklResult<Array2<f64>>,
197 },
198}
199
200#[derive(Debug, Clone)]
202pub struct TaskResult {
203 pub task_id: TaskId,
205 pub status: TaskStatus,
207 pub result: Option<Array2<f64>>,
209 pub error: Option<SklearsError>,
211 pub metrics: ExecutionMetrics,
213 pub node_id: NodeId,
215}
216
217#[derive(Debug, Clone, PartialEq)]
219pub enum TaskStatus {
220 Pending,
222 Running,
224 Completed,
226 Failed,
228 Retrying,
230 Cancelled,
232}
233
234#[derive(Debug, Clone)]
236pub struct ExecutionMetrics {
237 pub start_time: SystemTime,
239 pub end_time: Option<SystemTime>,
241 pub duration: Option<Duration>,
243 pub resource_usage: NodeLoad,
245 pub data_transfer: DataTransferMetrics,
247}
248
249#[derive(Debug, Clone)]
251pub struct DataTransferMetrics {
252 pub bytes_sent: u64,
254 pub bytes_received: u64,
256 pub transfer_time: Duration,
258 pub network_errors: usize,
260}
261
262#[derive(Debug)]
264pub struct ClusterManager {
265 nodes: Arc<RwLock<HashMap<NodeId, ClusterNode>>>,
267 active_tasks: Arc<Mutex<HashMap<TaskId, DistributedTask>>>,
269 task_results: Arc<Mutex<HashMap<TaskId, TaskResult>>>,
271 load_balancer: LoadBalancer,
273 fault_detector: FaultDetector,
275 config: ClusterConfig,
277}
278
279#[derive(Debug, Clone)]
281pub struct ClusterConfig {
282 pub heartbeat_interval: Duration,
284 pub failure_timeout: Duration,
286 pub max_tasks_per_node: usize,
288 pub replication_factor: usize,
290 pub load_balancing: LoadBalancingStrategy,
292}
293
294impl Default for ClusterConfig {
295 fn default() -> Self {
296 Self {
297 heartbeat_interval: Duration::from_secs(10),
298 failure_timeout: Duration::from_secs(30),
299 max_tasks_per_node: 10,
300 replication_factor: 2,
301 load_balancing: LoadBalancingStrategy::RoundRobin,
302 }
303 }
304}
305
306#[derive(Debug, Clone)]
308pub enum LoadBalancingStrategy {
309 RoundRobin,
311 LeastLoaded,
313 Random,
315 LocalityAware,
317 Custom {
319 balance_fn: fn(&[ClusterNode], &ResourceRequirements) -> Option<NodeId>,
320 },
321}
322
323#[derive(Debug)]
325pub struct LoadBalancer {
326 strategy: LoadBalancingStrategy,
327 round_robin_index: Mutex<usize>,
328 node_assignments: Arc<Mutex<HashMap<TaskId, NodeId>>>,
329}
330
331impl LoadBalancer {
332 #[must_use]
334 pub fn new(strategy: LoadBalancingStrategy) -> Self {
335 Self {
336 strategy,
337 round_robin_index: Mutex::new(0),
338 node_assignments: Arc::new(Mutex::new(HashMap::new())),
339 }
340 }
341
342 pub fn select_node(
344 &self,
345 nodes: &[ClusterNode],
346 requirements: &ResourceRequirements,
347 ) -> SklResult<NodeId> {
348 let available_nodes: Vec<_> = nodes
349 .iter()
350 .filter(|node| {
351 node.status == NodeStatus::Healthy
352 && self.can_satisfy_requirements(node, requirements)
353 })
354 .collect();
355
356 if available_nodes.is_empty() {
357 return Err(SklearsError::InvalidInput(
358 "No available nodes satisfy requirements".to_string(),
359 ));
360 }
361
362 match &self.strategy {
363 LoadBalancingStrategy::RoundRobin => {
364 let mut index = self.round_robin_index.lock().unwrap();
365 let selected = &available_nodes[*index % available_nodes.len()];
366 *index = (*index + 1) % available_nodes.len();
367 Ok(selected.id.clone())
368 }
369 LoadBalancingStrategy::LeastLoaded => {
370 let least_loaded = available_nodes
371 .iter()
372 .min_by_key(|node| {
373 (node.load.cpu_utilization * 100.0) as u32 + node.load.active_tasks as u32
374 })
375 .unwrap();
376 Ok(least_loaded.id.clone())
377 }
378 LoadBalancingStrategy::Random => {
379 use scirs2_core::random::thread_rng;
380 let mut rng = thread_rng();
381 let selected = &available_nodes[rng.gen_range(0..available_nodes.len())];
382 Ok(selected.id.clone())
383 }
384 LoadBalancingStrategy::LocalityAware => {
385 Ok(available_nodes[0].id.clone())
387 }
388 LoadBalancingStrategy::Custom { balance_fn } => {
389 let nodes_vec: Vec<ClusterNode> = available_nodes.into_iter().cloned().collect();
390 balance_fn(&nodes_vec, requirements).ok_or_else(|| {
391 SklearsError::InvalidInput("Custom balancer failed to select node".to_string())
392 })
393 }
394 }
395 }
396
397 fn can_satisfy_requirements(
399 &self,
400 node: &ClusterNode,
401 requirements: &ResourceRequirements,
402 ) -> bool {
403 node.resources.cpu_cores >= requirements.cpu_cores
404 && node.resources.memory_mb >= requirements.memory_mb
405 && node.resources.disk_mb >= requirements.disk_mb
406 && (!requirements.gpu_required || node.resources.gpu_count > 0)
407 && node.load.active_tasks < 10 }
409}
410
411#[derive(Debug)]
413pub struct FaultDetector {
414 failure_history: Arc<Mutex<HashMap<NodeId, Vec<SystemTime>>>>,
416 recovery_strategies: HashMap<String, RecoveryStrategy>,
418}
419
420#[derive(Debug)]
422pub enum RecoveryStrategy {
423 RestartSameNode,
425 MigrateTask,
427 ReplicateTask { replicas: usize },
429 UseCachedResults,
431 SkipTask,
433}
434
435impl Default for FaultDetector {
436 fn default() -> Self {
437 Self::new()
438 }
439}
440
441impl FaultDetector {
442 #[must_use]
444 pub fn new() -> Self {
445 let mut recovery_strategies = HashMap::new();
446 recovery_strategies.insert("node_failure".to_string(), RecoveryStrategy::MigrateTask);
447 recovery_strategies.insert(
448 "task_failure".to_string(),
449 RecoveryStrategy::RestartSameNode,
450 );
451 recovery_strategies.insert(
452 "network_partition".to_string(),
453 RecoveryStrategy::ReplicateTask { replicas: 2 },
454 );
455
456 Self {
457 failure_history: Arc::new(Mutex::new(HashMap::new())),
458 recovery_strategies,
459 }
460 }
461
462 #[must_use]
464 pub fn detect_node_failure(&self, node: &ClusterNode, timeout: Duration) -> bool {
465 node.last_heartbeat.elapsed().unwrap_or(Duration::MAX) > timeout
466 }
467
468 pub fn record_failure(&self, node_id: &NodeId) {
470 let mut history = self.failure_history.lock().unwrap();
471 history
472 .entry(node_id.clone())
473 .or_default()
474 .push(SystemTime::now());
475 }
476
477 #[must_use]
479 pub fn get_recovery_strategy(&self, failure_type: &str) -> Option<&RecoveryStrategy> {
480 self.recovery_strategies.get(failure_type)
481 }
482}
483
484impl ClusterManager {
485 #[must_use]
487 pub fn new(config: ClusterConfig) -> Self {
488 Self {
489 nodes: Arc::new(RwLock::new(HashMap::new())),
490 active_tasks: Arc::new(Mutex::new(HashMap::new())),
491 task_results: Arc::new(Mutex::new(HashMap::new())),
492 load_balancer: LoadBalancer::new(config.load_balancing.clone()),
493 fault_detector: FaultDetector::new(),
494 config,
495 }
496 }
497
498 pub fn add_node(&self, node: ClusterNode) -> SklResult<()> {
500 let mut nodes = self.nodes.write().unwrap();
501 nodes.insert(node.id.clone(), node);
502 Ok(())
503 }
504
505 pub fn remove_node(&self, node_id: &NodeId) -> SklResult<()> {
507 let mut nodes = self.nodes.write().unwrap();
508 nodes.remove(node_id);
509 Ok(())
510 }
511
512 pub fn submit_task(&self, task: DistributedTask) -> SklResult<TaskId> {
514 let task_id = task.id.clone();
515
516 let nodes = self.nodes.read().unwrap();
518 let available_nodes: Vec<ClusterNode> = nodes.values().cloned().collect();
519 drop(nodes);
520
521 let selected_node = self
522 .load_balancer
523 .select_node(&available_nodes, &task.resource_requirements)?;
524
525 let mut active_tasks = self.active_tasks.lock().unwrap();
527 active_tasks.insert(task_id.clone(), task);
528 drop(active_tasks);
529
530 self.execute_task_on_node(&task_id, &selected_node)?;
532
533 Ok(task_id)
534 }
535
536 fn execute_task_on_node(&self, task_id: &TaskId, node_id: &NodeId) -> SklResult<()> {
538 let active_tasks = self.active_tasks.lock().unwrap();
539 let task = active_tasks
540 .get(task_id)
541 .ok_or_else(|| SklearsError::InvalidInput(format!("Task {task_id} not found")))?;
542
543 let start_time = SystemTime::now();
544 let mut metrics = ExecutionMetrics {
545 start_time,
546 end_time: None,
547 duration: None,
548 resource_usage: NodeLoad::default(),
549 data_transfer: DataTransferMetrics {
550 bytes_sent: 0,
551 bytes_received: 0,
552 transfer_time: Duration::ZERO,
553 network_errors: 0,
554 },
555 };
556
557 let result = self.execute_pipeline_component(&task.component, &task.input_shards);
559
560 let end_time = SystemTime::now();
561 metrics.end_time = Some(end_time);
562 metrics.duration = start_time.elapsed().ok();
563
564 let (result_data, error_info) = match result {
566 Ok(data) => (Some(data), None),
567 Err(e) => (None, Some(e)),
568 };
569
570 let task_result = TaskResult {
571 task_id: task_id.clone(),
572 status: if result_data.is_some() {
573 TaskStatus::Completed
574 } else {
575 TaskStatus::Failed
576 },
577 result: result_data,
578 error: error_info,
579 metrics,
580 node_id: node_id.clone(),
581 };
582
583 let mut results = self.task_results.lock().unwrap();
584 results.insert(task_id.clone(), task_result);
585
586 Ok(())
587 }
588
589 fn execute_pipeline_component(
591 &self,
592 component: &Box<dyn PipelineStep>,
593 shards: &[DataShard],
594 ) -> SklResult<Array2<f64>> {
595 let mut all_results = Vec::new();
596
597 for shard in shards {
598 let mapped_data = shard.data.view().mapv(|v| v as Float);
599 let result = component.transform(&mapped_data.view())?;
600 all_results.push(result);
601 }
602
603 if all_results.is_empty() {
605 return Ok(Array2::zeros((0, 0)));
606 }
607
608 let total_rows: usize = all_results
609 .iter()
610 .map(scirs2_core::ndarray::ArrayBase::nrows)
611 .sum();
612 let n_cols = all_results[0].ncols();
613
614 let mut concatenated = Array2::zeros((total_rows, n_cols));
615 let mut row_idx = 0;
616
617 for result in all_results {
618 let end_idx = row_idx + result.nrows();
619 concatenated
620 .slice_mut(s![row_idx..end_idx, ..])
621 .assign(&result);
622 row_idx = end_idx;
623 }
624
625 Ok(concatenated)
626 }
627
628 pub fn get_task_result(&self, task_id: &TaskId) -> Option<TaskResult> {
630 let results = self.task_results.lock().unwrap();
631 results.get(task_id).cloned()
632 }
633
634 pub fn cluster_status(&self) -> ClusterStatus {
636 let nodes = self.nodes.read().unwrap();
637 let active_tasks = self.active_tasks.lock().unwrap();
638 let task_results = self.task_results.lock().unwrap();
639
640 let healthy_nodes = nodes
641 .values()
642 .filter(|n| n.status == NodeStatus::Healthy)
643 .count();
644 let total_nodes = nodes.len();
645 let pending_tasks = active_tasks.len();
646 let completed_tasks = task_results
647 .values()
648 .filter(|r| r.status == TaskStatus::Completed)
649 .count();
650 let failed_tasks = task_results
651 .values()
652 .filter(|r| r.status == TaskStatus::Failed)
653 .count();
654
655 ClusterStatus {
657 total_nodes,
658 healthy_nodes,
659 pending_tasks,
660 completed_tasks,
661 failed_tasks,
662 cluster_load: self.calculate_cluster_load(&nodes),
663 }
664 }
665
666 fn calculate_cluster_load(&self, nodes: &HashMap<NodeId, ClusterNode>) -> f64 {
668 if nodes.is_empty() {
669 return 0.0;
670 }
671
672 let total_load: f64 = nodes.values().map(|node| node.load.cpu_utilization).sum();
673
674 total_load / nodes.len() as f64
675 }
676
677 pub fn start_health_monitoring(&self) -> JoinHandle<()> {
679 let nodes = Arc::clone(&self.nodes);
680 let fault_detector = FaultDetector::new();
681 let heartbeat_interval = self.config.heartbeat_interval;
682 let failure_timeout = self.config.failure_timeout;
683
684 thread::spawn(move || {
685 loop {
686 thread::sleep(heartbeat_interval);
687
688 let mut nodes_guard = nodes.write().unwrap();
689 let mut failed_nodes = Vec::new();
690
691 for (node_id, node) in nodes_guard.iter_mut() {
692 if fault_detector.detect_node_failure(node, failure_timeout) {
693 node.status = NodeStatus::Failed;
694 failed_nodes.push(node_id.clone());
695 fault_detector.record_failure(node_id);
696 }
697 }
698
699 drop(nodes_guard);
700
701 for failed_node in failed_nodes {
703 println!("Node {failed_node} has failed");
704 }
705 }
706 })
707 }
708}
709
710#[derive(Debug, Clone)]
712pub struct ClusterStatus {
713 pub total_nodes: usize,
715 pub healthy_nodes: usize,
717 pub pending_tasks: usize,
719 pub completed_tasks: usize,
721 pub failed_tasks: usize,
723 pub cluster_load: f64,
725}
726
727#[derive(Debug)]
729pub struct MapReducePipeline<S = Untrained> {
730 state: S,
731 mapper: Option<Box<dyn PipelineStep>>,
732 reducer: Option<Box<dyn PipelineStep>>,
733 cluster_manager: Arc<ClusterManager>,
734 partitioning_strategy: PartitioningStrategy,
735 map_tasks: Vec<TaskId>,
736 reduce_tasks: Vec<TaskId>,
737}
738
739#[derive(Debug)]
741pub enum PartitioningStrategy {
742 EqualSize { partition_size: usize },
744 HashBased { num_partitions: usize },
746 RangeBased { ranges: Vec<(f64, f64)> },
748 Custom {
750 partition_fn: fn(&Array2<f64>) -> Vec<DataShard>,
751 },
752}
753
754#[derive(Debug)]
756pub struct MapReducePipelineTrained {
757 fitted_mapper: Box<dyn PipelineStep>,
758 fitted_reducer: Box<dyn PipelineStep>,
759 cluster_manager: Arc<ClusterManager>,
760 partitioning_strategy: PartitioningStrategy,
761 n_features_in: usize,
762 feature_names_in: Option<Vec<String>>,
763}
764
765impl MapReducePipeline<Untrained> {
766 pub fn new(
768 mapper: Box<dyn PipelineStep>,
769 reducer: Box<dyn PipelineStep>,
770 cluster_manager: Arc<ClusterManager>,
771 ) -> Self {
772 Self {
773 state: Untrained,
774 mapper: Some(mapper),
775 reducer: Some(reducer),
776 cluster_manager,
777 partitioning_strategy: PartitioningStrategy::EqualSize {
778 partition_size: 1000,
779 },
780 map_tasks: Vec::new(),
781 reduce_tasks: Vec::new(),
782 }
783 }
784
785 #[must_use]
787 pub fn partitioning_strategy(mut self, strategy: PartitioningStrategy) -> Self {
788 self.partitioning_strategy = strategy;
789 self
790 }
791}
792
793impl Estimator for MapReducePipeline<Untrained> {
794 type Config = ();
795 type Error = SklearsError;
796 type Float = Float;
797
798 fn config(&self) -> &Self::Config {
799 &()
800 }
801}
802
803impl Fit<ArrayView2<'_, Float>, Option<&ArrayView1<'_, Float>>> for MapReducePipeline<Untrained> {
804 type Fitted = MapReducePipeline<MapReducePipelineTrained>;
805
806 fn fit(
807 self,
808 x: &ArrayView2<'_, Float>,
809 y: &Option<&ArrayView1<'_, Float>>,
810 ) -> SklResult<Self::Fitted> {
811 let mut mapper = self
812 .mapper
813 .ok_or_else(|| SklearsError::InvalidInput("No mapper provided".to_string()))?;
814
815 let mut reducer = self
816 .reducer
817 .ok_or_else(|| SklearsError::InvalidInput("No reducer provided".to_string()))?;
818
819 mapper.fit(x, y.as_ref().copied())?;
821 reducer.fit(x, y.as_ref().copied())?;
822
823 Ok(MapReducePipeline {
824 state: MapReducePipelineTrained {
825 fitted_mapper: mapper,
826 fitted_reducer: reducer,
827 cluster_manager: self.cluster_manager,
828 partitioning_strategy: self.partitioning_strategy,
829 n_features_in: x.ncols(),
830 feature_names_in: None,
831 },
832 mapper: None,
833 reducer: None,
834 cluster_manager: Arc::new(ClusterManager::new(ClusterConfig::default())),
835 partitioning_strategy: PartitioningStrategy::EqualSize {
836 partition_size: 1000,
837 },
838 map_tasks: Vec::new(),
839 reduce_tasks: Vec::new(),
840 })
841 }
842}
843
844impl MapReducePipeline<MapReducePipelineTrained> {
845 pub fn map_reduce(&mut self, x: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
847 let partitions = self.partition_data(x)?;
849
850 let mut map_task_ids = Vec::new();
852 for (i, partition) in partitions.into_iter().enumerate() {
853 let map_task = DistributedTask {
854 id: format!("map_task_{i}"),
855 name: format!("Map Task {i}"),
856 component: self.state.fitted_mapper.clone_step(),
857 input_shards: vec![partition],
858 dependencies: Vec::new(),
859 resource_requirements: ResourceRequirements {
860 cpu_cores: 1,
861 memory_mb: 512,
862 disk_mb: 100,
863 gpu_required: false,
864 estimated_duration: Duration::from_secs(60),
865 priority: TaskPriority::Normal,
866 },
867 config: TaskConfig {
868 max_retries: 3,
869 timeout: Duration::from_secs(300),
870 failure_tolerance: FailureTolerance::RetryOnNode { max_retries: 2 },
871 checkpoint_interval: None,
872 persist_results: true,
873 },
874 metadata: HashMap::new(),
875 };
876
877 let task_id = self.state.cluster_manager.submit_task(map_task)?;
878 map_task_ids.push(task_id);
879 }
880
881 let map_results = self.wait_for_tasks(&map_task_ids)?;
883
884 let reduce_shard = DataShard {
886 id: "reduce_input".to_string(),
887 data: self.combine_map_results(map_results)?,
888 targets: None,
889 metadata: HashMap::new(),
890 source_node: None,
891 };
892
893 let reduce_task = DistributedTask {
894 id: "reduce_task".to_string(),
895 name: "Reduce Task".to_string(),
896 component: self.state.fitted_reducer.clone_step(),
897 input_shards: vec![reduce_shard],
898 dependencies: map_task_ids,
899 resource_requirements: ResourceRequirements {
900 cpu_cores: 2,
901 memory_mb: 1024,
902 disk_mb: 200,
903 gpu_required: false,
904 estimated_duration: Duration::from_secs(120),
905 priority: TaskPriority::High,
906 },
907 config: TaskConfig {
908 max_retries: 3,
909 timeout: Duration::from_secs(600),
910 failure_tolerance: FailureTolerance::RetryOnNode { max_retries: 2 },
911 checkpoint_interval: None,
912 persist_results: true,
913 },
914 metadata: HashMap::new(),
915 };
916
917 let reduce_task_id = self.state.cluster_manager.submit_task(reduce_task)?;
918
919 let reduce_results = self.wait_for_tasks(&[reduce_task_id])?;
921
922 if let Some(result) = reduce_results.into_iter().next() {
923 Ok(result)
924 } else {
925 Err(SklearsError::InvalidData {
926 reason: "Reduce task produced no result".to_string(),
927 })
928 }
929 }
930
931 fn partition_data(&self, x: &ArrayView2<'_, Float>) -> SklResult<Vec<DataShard>> {
933 match &self.state.partitioning_strategy {
934 PartitioningStrategy::EqualSize { partition_size } => {
935 let mut partitions = Vec::new();
936 let n_rows = x.nrows();
937
938 for (i, chunk_start) in (0..n_rows).step_by(*partition_size).enumerate() {
939 let chunk_end = std::cmp::min(chunk_start + partition_size, n_rows);
940 let chunk = x.slice(s![chunk_start..chunk_end, ..]).to_owned();
941
942 let shard = DataShard {
943 id: format!("partition_{i}"),
944 data: chunk.mapv(|v| v),
945 targets: None,
946 metadata: HashMap::new(),
947 source_node: None,
948 };
949
950 partitions.push(shard);
951 }
952
953 Ok(partitions)
954 }
955 PartitioningStrategy::HashBased { num_partitions } => {
956 let mut partitions: Vec<Vec<usize>> = vec![Vec::new(); *num_partitions];
958
959 for i in 0..x.nrows() {
960 let hash = i % num_partitions; partitions[hash].push(i);
962 }
963
964 let mut shards = Vec::new();
965 for (partition_idx, indices) in partitions.into_iter().enumerate() {
966 if !indices.is_empty() {
967 let mut partition_data = Array2::zeros((indices.len(), x.ncols()));
968 for (row_idx, &original_idx) in indices.iter().enumerate() {
969 partition_data
970 .row_mut(row_idx)
971 .assign(&x.row(original_idx).mapv(|v| v));
972 }
973
974 let shard = DataShard {
975 id: format!("hash_partition_{partition_idx}"),
976 data: partition_data,
977 targets: None,
978 metadata: HashMap::new(),
979 source_node: None,
980 };
981
982 shards.push(shard);
983 }
984 }
985
986 Ok(shards)
987 }
988 PartitioningStrategy::RangeBased { ranges } => {
989 let mut shards = Vec::new();
991
992 for (range_idx, (min_val, max_val)) in ranges.iter().enumerate() {
993 let mut selected_rows = Vec::new();
994
995 for i in 0..x.nrows() {
996 let feature_val = x[[i, 0]];
997 if feature_val >= *min_val && feature_val < *max_val {
998 selected_rows.push(i);
999 }
1000 }
1001
1002 if !selected_rows.is_empty() {
1003 let mut partition_data = Array2::zeros((selected_rows.len(), x.ncols()));
1004 for (row_idx, &original_idx) in selected_rows.iter().enumerate() {
1005 partition_data
1006 .row_mut(row_idx)
1007 .assign(&x.row(original_idx).mapv(|v| v));
1008 }
1009
1010 let shard = DataShard {
1011 id: format!("range_partition_{range_idx}"),
1012 data: partition_data,
1013 targets: None,
1014 metadata: HashMap::new(),
1015 source_node: None,
1016 };
1017
1018 shards.push(shard);
1019 }
1020 }
1021
1022 Ok(shards)
1023 }
1024 PartitioningStrategy::Custom { partition_fn } => Ok(partition_fn(&x.mapv(|v| v))),
1025 }
1026 }
1027
1028 fn wait_for_tasks(&self, task_ids: &[TaskId]) -> SklResult<Vec<Array2<f64>>> {
1030 let mut results = Vec::new();
1031
1032 for task_id in task_ids {
1033 let mut attempts = 0;
1035 const MAX_ATTEMPTS: usize = 100;
1036
1037 loop {
1038 if let Some(task_result) = self.state.cluster_manager.get_task_result(task_id) {
1039 match task_result.status {
1040 TaskStatus::Completed => {
1041 if let Some(result) = task_result.result {
1042 results.push(result);
1043 }
1044 break;
1045 }
1046 TaskStatus::Failed => {
1047 return Err(task_result.error.unwrap_or_else(|| {
1048 SklearsError::InvalidData {
1049 reason: format!("Task {task_id} failed"),
1050 }
1051 }));
1052 }
1053 _ => {
1054 }
1056 }
1057 }
1058
1059 attempts += 1;
1060 if attempts >= MAX_ATTEMPTS {
1061 return Err(SklearsError::InvalidData {
1062 reason: format!("Task {task_id} timed out"),
1063 });
1064 }
1065
1066 thread::sleep(Duration::from_millis(100));
1067 }
1068 }
1069
1070 Ok(results)
1071 }
1072
1073 fn combine_map_results(&self, results: Vec<Array2<f64>>) -> SklResult<Array2<f64>> {
1075 if results.is_empty() {
1076 return Ok(Array2::zeros((0, 0)));
1077 }
1078
1079 let total_rows: usize = results
1080 .iter()
1081 .map(scirs2_core::ndarray::ArrayBase::nrows)
1082 .sum();
1083 let n_cols = results[0].ncols();
1084
1085 let mut combined = Array2::zeros((total_rows, n_cols));
1086 let mut row_idx = 0;
1087
1088 for result in results {
1089 let end_idx = row_idx + result.nrows();
1090 combined.slice_mut(s![row_idx..end_idx, ..]).assign(&result);
1091 row_idx = end_idx;
1092 }
1093
1094 Ok(combined)
1095 }
1096
1097 #[must_use]
1099 pub fn cluster_manager(&self) -> &Arc<ClusterManager> {
1100 &self.state.cluster_manager
1101 }
1102}
1103
1104#[allow(non_snake_case)]
1105#[cfg(test)]
1106mod tests {
1107 use super::*;
1108 use crate::MockTransformer;
1109 use scirs2_core::ndarray::array;
1110 use std::net::{IpAddr, Ipv4Addr};
1111
1112 #[test]
1113 fn test_cluster_node_creation() {
1114 let node = ClusterNode {
1115 id: "node1".to_string(),
1116 address: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080),
1117 status: NodeStatus::Healthy,
1118 resources: NodeResources {
1119 cpu_cores: 4,
1120 memory_mb: 8192,
1121 disk_mb: 100000,
1122 gpu_count: 1,
1123 network_bandwidth: 1000,
1124 },
1125 load: NodeLoad::default(),
1126 last_heartbeat: SystemTime::now(),
1127 metadata: HashMap::new(),
1128 };
1129
1130 assert_eq!(node.id, "node1");
1131 assert_eq!(node.status, NodeStatus::Healthy);
1132 assert_eq!(node.resources.cpu_cores, 4);
1133 }
1134
1135 #[test]
1136 fn test_load_balancer_round_robin() {
1137 let balancer = LoadBalancer::new(LoadBalancingStrategy::RoundRobin);
1138
1139 let nodes = vec![
1140 ClusterNode {
1142 id: "node1".to_string(),
1143 address: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080),
1144 status: NodeStatus::Healthy,
1145 resources: NodeResources {
1146 cpu_cores: 4,
1147 memory_mb: 8192,
1148 disk_mb: 100000,
1149 gpu_count: 0,
1150 network_bandwidth: 1000,
1151 },
1152 load: NodeLoad::default(),
1153 last_heartbeat: SystemTime::now(),
1154 metadata: HashMap::new(),
1155 },
1156 ClusterNode {
1158 id: "node2".to_string(),
1159 address: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8081),
1160 status: NodeStatus::Healthy,
1161 resources: NodeResources {
1162 cpu_cores: 4,
1163 memory_mb: 8192,
1164 disk_mb: 100000,
1165 gpu_count: 0,
1166 network_bandwidth: 1000,
1167 },
1168 load: NodeLoad::default(),
1169 last_heartbeat: SystemTime::now(),
1170 metadata: HashMap::new(),
1171 },
1172 ];
1173
1174 let requirements = ResourceRequirements {
1175 cpu_cores: 1,
1176 memory_mb: 1024,
1177 disk_mb: 1000,
1178 gpu_required: false,
1179 estimated_duration: Duration::from_secs(60),
1180 priority: TaskPriority::Normal,
1181 };
1182
1183 let selected1 = balancer.select_node(&nodes, &requirements).unwrap();
1184 let selected2 = balancer.select_node(&nodes, &requirements).unwrap();
1185
1186 assert_ne!(selected1, selected2); }
1188
1189 #[test]
1190 fn test_cluster_manager() {
1191 let config = ClusterConfig::default();
1192 let manager = ClusterManager::new(config);
1193
1194 let node = ClusterNode {
1195 id: "test_node".to_string(),
1196 address: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080),
1197 status: NodeStatus::Healthy,
1198 resources: NodeResources {
1199 cpu_cores: 4,
1200 memory_mb: 8192,
1201 disk_mb: 100000,
1202 gpu_count: 0,
1203 network_bandwidth: 1000,
1204 },
1205 load: NodeLoad::default(),
1206 last_heartbeat: SystemTime::now(),
1207 metadata: HashMap::new(),
1208 };
1209
1210 manager.add_node(node).unwrap();
1211
1212 let status = manager.cluster_status();
1213 assert_eq!(status.total_nodes, 1);
1214 assert_eq!(status.healthy_nodes, 1);
1215 }
1216
1217 #[test]
1218 fn test_data_shard_creation() {
1219 let data = array![[1.0, 2.0], [3.0, 4.0]];
1220 let targets = array![1.0, 0.0];
1221
1222 let shard = DataShard {
1223 id: "test_shard".to_string(),
1224 data: data.clone(),
1225 targets: Some(targets.clone()),
1226 metadata: HashMap::new(),
1227 source_node: None,
1228 };
1229
1230 assert_eq!(shard.id, "test_shard");
1231 assert_eq!(shard.data, data);
1232 assert_eq!(shard.targets, Some(targets));
1233 }
1234
1235 #[test]
1236 fn test_mapreduce_pipeline_creation() {
1237 let mapper = Box::new(MockTransformer::new());
1238 let reducer = Box::new(MockTransformer::new());
1239 let cluster_manager = Arc::new(ClusterManager::new(ClusterConfig::default()));
1240
1241 let pipeline = MapReducePipeline::new(mapper, reducer, cluster_manager);
1242
1243 assert!(matches!(
1244 pipeline.partitioning_strategy,
1245 PartitioningStrategy::EqualSize {
1246 partition_size: 1000
1247 }
1248 ));
1249 }
1250
1251 #[test]
1252 fn test_fault_detector() {
1253 let detector = FaultDetector::new();
1254
1255 let node = ClusterNode {
1256 id: "test_node".to_string(),
1257 address: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080),
1258 status: NodeStatus::Healthy,
1259 resources: NodeResources {
1260 cpu_cores: 4,
1261 memory_mb: 8192,
1262 disk_mb: 100000,
1263 gpu_count: 0,
1264 network_bandwidth: 1000,
1265 },
1266 load: NodeLoad::default(),
1267 last_heartbeat: SystemTime::now() - Duration::from_secs(60),
1268 metadata: HashMap::new(),
1269 };
1270
1271 let is_failed = detector.detect_node_failure(&node, Duration::from_secs(30));
1272 assert!(is_failed);
1273
1274 detector.record_failure(&node.id);
1275
1276 let strategy = detector.get_recovery_strategy("node_failure");
1277 assert!(matches!(strategy, Some(RecoveryStrategy::MigrateTask)));
1278 }
1279}