1use crate::error::{Result, SklearsError};
59use futures_core::future::BoxFuture;
60use serde::{Deserialize, Serialize};
61use std::collections::HashMap;
62use std::sync::{Arc, RwLock};
63use std::time::{Duration, SystemTime};
64
65#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
71pub struct NodeId(pub String);
72
73impl NodeId {
74 pub fn new(id: impl Into<String>) -> Self {
76 Self(id.into())
77 }
78
79 pub fn as_str(&self) -> &str {
81 &self.0
82 }
83}
84
85impl std::fmt::Display for NodeId {
86 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87 write!(f, "{}", self.0)
88 }
89}
90
91#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct DistributedMessage {
94 pub id: String,
96 pub sender: NodeId,
98 pub receiver: NodeId,
100 pub message_type: MessageType,
102 pub payload: Vec<u8>,
104 pub timestamp: SystemTime,
106 pub priority: MessagePriority,
108 pub retry_count: u32,
110}
111
112#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
114pub enum MessageType {
115 DataTransfer,
117 ParameterSync,
119 GradientAggregation,
121 Coordination,
123 HealthCheck,
125 FaultRecovery,
127 LoadBalance,
129 Custom(String),
131}
132
133#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
135pub enum MessagePriority {
136 Low = 0,
138 Normal = 1,
140 High = 2,
142 Critical = 3,
144}
145
146pub trait MessagePassing: Send + Sync {
148 fn send_message(
150 &self,
151 target: NodeId,
152 message: DistributedMessage,
153 ) -> BoxFuture<'_, Result<()>>;
154
155 fn receive_message(&self) -> BoxFuture<'_, Result<DistributedMessage>>;
157
158 fn broadcast_message(&self, message: DistributedMessage) -> BoxFuture<'_, Result<()>>;
160
161 fn send_and_receive(
163 &self,
164 target: NodeId,
165 message: DistributedMessage,
166 ) -> BoxFuture<'_, Result<DistributedMessage>>;
167
168 fn has_pending_messages(&self) -> BoxFuture<'_, Result<bool>>;
170
171 fn pending_message_count(&self) -> BoxFuture<'_, Result<usize>>;
173
174 fn flush_outgoing(&self) -> BoxFuture<'_, Result<()>>;
176}
177
178pub trait ClusterNode: MessagePassing + Send + Sync {
180 fn node_id(&self) -> &NodeId;
182
183 fn cluster_nodes(&self) -> BoxFuture<'_, Result<Vec<NodeId>>>;
185
186 fn is_coordinator(&self) -> bool;
188
189 fn health_status(&self) -> BoxFuture<'_, Result<NodeHealth>>;
191
192 fn resources(&self) -> BoxFuture<'_, Result<NodeResources>>;
194
195 fn join_cluster(&mut self, coordinator: NodeId) -> BoxFuture<'_, Result<()>>;
197
198 fn leave_cluster(&mut self) -> BoxFuture<'_, Result<()>>;
200
201 fn handle_node_failure(&mut self, failed_node: NodeId) -> BoxFuture<'_, Result<()>>;
203}
204
205#[derive(Debug, Clone, Serialize, Deserialize)]
207pub struct NodeHealth {
208 pub health_score: f64,
210 pub cpu_usage: f64,
212 pub memory_usage: f64,
214 pub network_latency: Duration,
216 pub last_heartbeat: SystemTime,
218 pub recent_errors: u32,
220 pub uptime: Duration,
222}
223
224#[derive(Debug, Clone, Serialize, Deserialize)]
226pub struct NodeResources {
227 pub cpu_cores: u32,
229 pub total_memory: u64,
231 pub available_memory: u64,
233 pub gpu_devices: Vec<GpuDevice>,
235 pub network_bandwidth: u64,
237 pub storage_capacity: u64,
239 pub tags: HashMap<String, String>,
241}
242
243#[derive(Debug, Clone, Serialize, Deserialize)]
245pub struct GpuDevice {
246 pub device_id: u32,
248 pub name: String,
250 pub total_memory: u64,
252 pub available_memory: u64,
254 pub compute_capability: String,
256}
257
258pub trait DistributedEstimator: Send + Sync {
264 type TrainingData;
266
267 type PredictionInput;
269
270 type PredictionOutput;
272
273 type Parameters: Serialize + for<'de> Deserialize<'de>;
275
276 fn fit_distributed<'a>(
278 &'a mut self,
279 cluster: &'a dyn DistributedCluster,
280 training_data: &Self::TrainingData,
281 ) -> BoxFuture<'a, Result<()>>;
282
283 fn predict_distributed<'a>(
285 &'a self,
286 cluster: &dyn DistributedCluster,
287 input: &'a Self::PredictionInput,
288 ) -> BoxFuture<'a, Result<Self::PredictionOutput>>;
289
290 fn get_parameters(&self) -> Result<Self::Parameters>;
292
293 fn set_parameters(&mut self, params: Self::Parameters) -> Result<()>;
295
296 fn sync_parameters(&mut self, cluster: &dyn DistributedCluster) -> BoxFuture<'_, Result<()>>;
298
299 fn training_progress(&self) -> DistributedTrainingProgress;
301}
302
303#[derive(Debug, Clone, Serialize, Deserialize)]
305pub struct DistributedTrainingProgress {
306 pub epoch: u32,
308 pub total_epochs: u32,
310 pub training_loss: f64,
312 pub validation_loss: Option<f64>,
314 pub samples_processed: u64,
316 pub start_time: SystemTime,
318 pub estimated_completion: Option<SystemTime>,
320 pub active_nodes: Vec<NodeId>,
322 pub node_statistics: HashMap<NodeId, NodeTrainingStats>,
324}
325
326#[derive(Debug, Clone, Serialize, Deserialize)]
328pub struct NodeTrainingStats {
329 pub samples_processed: u64,
331 pub processing_rate: f64,
333 pub current_loss: f64,
335 pub memory_usage: u64,
337 pub cpu_utilization: f64,
339}
340
341pub trait DistributedCluster: Send + Sync {
343 fn active_nodes(&self) -> BoxFuture<'_, Result<Vec<NodeId>>>;
345
346 fn coordinator(&self) -> &NodeId;
348
349 fn configuration(&self) -> &ClusterConfiguration;
351
352 fn add_node(&mut self, node: NodeId) -> BoxFuture<'_, Result<()>>;
354
355 fn remove_node(&mut self, node: NodeId) -> BoxFuture<'_, Result<()>>;
357
358 fn rebalance_load(&mut self) -> BoxFuture<'_, Result<()>>;
360
361 fn cluster_health(&self) -> BoxFuture<'_, Result<ClusterHealth>>;
363
364 fn create_checkpoint(&self) -> BoxFuture<'_, Result<ClusterCheckpoint>>;
366
367 fn restore_checkpoint(&mut self, checkpoint: ClusterCheckpoint) -> BoxFuture<'_, Result<()>>;
369}
370
371#[derive(Debug, Clone, Serialize, Deserialize)]
373pub struct ClusterConfiguration {
374 pub max_nodes: u32,
376 pub heartbeat_interval: Duration,
378 pub failure_timeout: Duration,
380 pub max_retries: u32,
382 pub load_balancing: LoadBalancingStrategy,
384 pub fault_tolerance: FaultToleranceMode,
386 pub consistency_level: ConsistencyLevel,
388}
389
390#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
392pub enum LoadBalancingStrategy {
393 RoundRobin,
395 ResourceBased,
397 LoadBased,
399 LocalityAware,
401 Custom(String),
403}
404
405#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
407pub enum FaultToleranceMode {
408 None,
410 BasicRetry,
412 CheckpointRecovery,
414 RedundantComputation,
416 Byzantine,
418}
419
420#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
422pub enum ConsistencyLevel {
423 None,
425 Eventual,
427 Strong,
429 Causal,
431 Sequential,
433}
434
435#[derive(Debug, Clone, Serialize, Deserialize)]
437pub struct ClusterHealth {
438 pub overall_health: f64,
440 pub healthy_nodes: u32,
442 pub failed_nodes: u32,
444 pub average_response_time: Duration,
446 pub total_throughput: f64,
448 pub resource_utilization: ClusterResourceUtilization,
450}
451
452#[derive(Debug, Clone, Serialize, Deserialize)]
454pub struct ClusterResourceUtilization {
455 pub cpu_utilization: f64,
457 pub memory_utilization: f64,
459 pub network_utilization: f64,
461 pub storage_utilization: f64,
463}
464
465#[derive(Debug, Clone, Serialize, Deserialize)]
467pub struct ClusterCheckpoint {
468 pub checkpoint_id: String,
470 pub timestamp: SystemTime,
472 pub configuration: ClusterConfiguration,
474 pub node_states: HashMap<NodeId, NodeCheckpoint>,
476 pub cluster_state: Vec<u8>,
478}
479
480#[derive(Debug, Clone, Serialize, Deserialize)]
482pub struct NodeCheckpoint {
483 pub node_id: NodeId,
485 pub state_data: Vec<u8>,
487 pub health: NodeHealth,
489 pub resources: NodeResources,
491}
492
493pub trait DistributedDataset: Send + Sync {
499 type Item;
501
502 type PartitionStrategy;
504
505 fn size(&self) -> u64;
507
508 fn partition_count(&self) -> u32;
510
511 fn partition<'a>(
513 &'a mut self,
514 cluster: &'a dyn DistributedCluster,
515 strategy: Self::PartitionStrategy,
516 ) -> BoxFuture<'a, Result<Vec<DistributedPartition<Self::Item>>>>;
517
518 fn get_partition(
520 &self,
521 partition_id: u32,
522 ) -> BoxFuture<'_, Result<DistributedPartition<Self::Item>>>;
523
524 fn repartition<'a>(
526 &'a mut self,
527 cluster: &'a dyn DistributedCluster,
528 new_strategy: Self::PartitionStrategy,
529 ) -> BoxFuture<'a, Result<()>>;
530
531 fn collect(&self, cluster: &dyn DistributedCluster) -> BoxFuture<'_, Result<Vec<Self::Item>>>;
533
534 fn partition_assignment(&self) -> HashMap<NodeId, Vec<u32>>;
536}
537
538#[derive(Debug, Clone)]
540pub struct DistributedPartition<T> {
541 pub partition_id: u32,
543 pub node_id: NodeId,
545 pub data: Vec<T>,
547 pub metadata: PartitionMetadata,
549}
550
551#[derive(Debug, Clone, Serialize, Deserialize)]
553pub struct PartitionMetadata {
554 pub item_count: u64,
556 pub size_bytes: u64,
558 pub schema: Option<String>,
560 pub created_at: SystemTime,
562 pub modified_at: SystemTime,
564 pub checksum: String,
566}
567
568#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
570pub enum PartitioningStrategy {
571 EvenSplit,
573 HashBased(u32),
575 RangeBased,
577 Random,
579 Stratified,
581 Custom(String),
583}
584
585pub trait ParameterServer: Send + Sync {
591 type Parameters: Serialize + for<'de> Deserialize<'de>;
593
594 fn initialize(&mut self, initial_params: Self::Parameters) -> BoxFuture<'_, Result<()>>;
596
597 fn get_parameters(&self) -> BoxFuture<'_, Result<Self::Parameters>>;
599
600 fn update_parameters(&mut self, gradients: Vec<Self::Parameters>) -> BoxFuture<'_, Result<()>>;
602
603 fn push_parameters(&self, cluster: &dyn DistributedCluster) -> BoxFuture<'_, Result<()>>;
605
606 fn pull_parameters(&mut self, cluster: &dyn DistributedCluster) -> BoxFuture<'_, Result<()>>;
608
609 fn aggregate_gradients(
611 &mut self,
612 gradients: Vec<Self::Parameters>,
613 ) -> BoxFuture<'_, Result<Self::Parameters>>;
614
615 fn apply_optimization(
617 &mut self,
618 aggregated_gradients: Self::Parameters,
619 ) -> BoxFuture<'_, Result<()>>;
620}
621
622#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
624pub enum GradientAggregation {
625 Average,
627 WeightedAverage,
629 FederatedAveraging,
631 ByzantineRobust,
633 Compressed,
635}
636
637pub trait FaultTolerance: Send + Sync {
643 fn detect_failure(
645 &self,
646 cluster: &dyn DistributedCluster,
647 ) -> BoxFuture<'_, Result<Vec<NodeId>>>;
648
649 fn recover_from_failure(
651 &mut self,
652 cluster: &mut dyn DistributedCluster,
653 failed_nodes: Vec<NodeId>,
654 ) -> BoxFuture<'_, Result<()>>;
655
656 fn create_checkpoint(
658 &self,
659 cluster: &dyn DistributedCluster,
660 ) -> BoxFuture<'_, Result<FaultToleranceCheckpoint>>;
661
662 fn restore_checkpoint(
664 &mut self,
665 cluster: &mut dyn DistributedCluster,
666 checkpoint: FaultToleranceCheckpoint,
667 ) -> BoxFuture<'_, Result<()>>;
668
669 fn replicate_data(
671 &self,
672 cluster: &dyn DistributedCluster,
673 data: Vec<u8>,
674 ) -> BoxFuture<'_, Result<()>>;
675
676 fn validate_integrity(
678 &self,
679 cluster: &dyn DistributedCluster,
680 ) -> BoxFuture<'_, Result<IntegrityReport>>;
681}
682
683#[derive(Debug, Clone, Serialize, Deserialize)]
685pub struct FaultToleranceCheckpoint {
686 pub id: String,
688 pub timestamp: SystemTime,
690 pub training_state: Vec<u8>,
692 pub model_parameters: Vec<u8>,
694 pub node_assignments: HashMap<NodeId, Vec<u32>>,
696 pub replication_map: HashMap<String, Vec<NodeId>>,
698}
699
700#[derive(Debug, Clone, Serialize, Deserialize)]
702pub struct IntegrityReport {
703 pub integrity_score: f64,
705 pub data_consistency: bool,
707 pub parameter_sync: bool,
709 pub replication_health: f64,
711 pub inconsistencies: Vec<String>,
713 pub recommendations: Vec<String>,
715}
716
717pub struct DefaultDistributedCluster {
723 configuration: ClusterConfiguration,
725 coordinator: NodeId,
727 nodes: Arc<RwLock<HashMap<NodeId, Arc<dyn ClusterNode>>>>,
729 health_monitor: Arc<RwLock<ClusterHealth>>,
731}
732
733impl std::fmt::Debug for DefaultDistributedCluster {
734 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
735 f.debug_struct("DefaultDistributedCluster")
736 .field("configuration", &self.configuration)
737 .field("coordinator", &self.coordinator)
738 .field("nodes", &"<HashMap<NodeId, Arc<dyn ClusterNode>>>")
739 .field("health_monitor", &self.health_monitor)
740 .finish()
741 }
742}
743
744impl DefaultDistributedCluster {
745 pub fn new(coordinator: NodeId, configuration: ClusterConfiguration) -> Self {
747 Self {
748 configuration,
749 coordinator,
750 nodes: Arc::new(RwLock::new(HashMap::new())),
751 health_monitor: Arc::new(RwLock::new(ClusterHealth {
752 overall_health: 1.0,
753 healthy_nodes: 0,
754 failed_nodes: 0,
755 average_response_time: Duration::from_millis(10),
756 total_throughput: 0.0,
757 resource_utilization: ClusterResourceUtilization {
758 cpu_utilization: 0.0,
759 memory_utilization: 0.0,
760 network_utilization: 0.0,
761 storage_utilization: 0.0,
762 },
763 })),
764 }
765 }
766}
767
768impl DistributedCluster for DefaultDistributedCluster {
769 fn active_nodes(&self) -> BoxFuture<'_, Result<Vec<NodeId>>> {
770 Box::pin(async move {
771 let nodes = self.nodes.read().map_err(|_| {
772 SklearsError::InvalidOperation("Failed to acquire read lock on nodes".to_string())
773 })?;
774 Ok(nodes.keys().cloned().collect())
775 })
776 }
777
778 fn coordinator(&self) -> &NodeId {
779 &self.coordinator
780 }
781
782 fn configuration(&self) -> &ClusterConfiguration {
783 &self.configuration
784 }
785
786 fn add_node(&mut self, _node_id: NodeId) -> BoxFuture<'_, Result<()>> {
787 Box::pin(async move {
788 Ok(())
791 })
792 }
793
794 fn remove_node(&mut self, node_id: NodeId) -> BoxFuture<'_, Result<()>> {
795 Box::pin(async move {
796 let mut nodes = self.nodes.write().map_err(|_| {
797 SklearsError::InvalidOperation("Failed to acquire write lock on nodes".to_string())
798 })?;
799 nodes.remove(&node_id);
800 Ok(())
801 })
802 }
803
804 fn rebalance_load(&mut self) -> BoxFuture<'_, Result<()>> {
805 Box::pin(async move {
806 Ok(())
808 })
809 }
810
811 fn cluster_health(&self) -> BoxFuture<'_, Result<ClusterHealth>> {
812 Box::pin(async move {
813 let health = self.health_monitor.read().map_err(|_| {
814 SklearsError::InvalidOperation(
815 "Failed to acquire read lock on health monitor".to_string(),
816 )
817 })?;
818 Ok(health.clone())
819 })
820 }
821
822 fn create_checkpoint(&self) -> BoxFuture<'_, Result<ClusterCheckpoint>> {
823 Box::pin(async move {
824 let checkpoint = ClusterCheckpoint {
825 checkpoint_id: format!("checkpoint_{}", chrono::Utc::now().timestamp()),
826 timestamp: SystemTime::now(),
827 configuration: self.configuration.clone(),
828 node_states: HashMap::new(), cluster_state: Vec::new(), };
831 Ok(checkpoint)
832 })
833 }
834
835 fn restore_checkpoint(&mut self, _checkpoint: ClusterCheckpoint) -> BoxFuture<'_, Result<()>> {
836 Box::pin(async move {
837 Ok(())
839 })
840 }
841}
842
843impl Default for ClusterConfiguration {
844 fn default() -> Self {
845 Self {
846 max_nodes: 64,
847 heartbeat_interval: Duration::from_secs(30),
848 failure_timeout: Duration::from_secs(120),
849 max_retries: 3,
850 load_balancing: LoadBalancingStrategy::ResourceBased,
851 fault_tolerance: FaultToleranceMode::CheckpointRecovery,
852 consistency_level: ConsistencyLevel::Eventual,
853 }
854 }
855}
856
857#[derive(Debug)]
863pub struct DistributedLinearRegression {
864 parameters: Option<Vec<f64>>,
866 config: DistributedTrainingConfig,
868 progress: DistributedTrainingProgress,
870}
871
872#[derive(Debug, Clone)]
874pub struct DistributedTrainingConfig {
875 pub learning_rate: f64,
877 pub epochs: u32,
879 pub batch_size: u32,
881 pub aggregation: GradientAggregation,
883 pub checkpoint_frequency: u32,
885}
886
887impl Default for DistributedLinearRegression {
888 fn default() -> Self {
889 Self::new()
890 }
891}
892
893impl DistributedLinearRegression {
894 pub fn new() -> Self {
896 Self {
897 parameters: None,
898 config: DistributedTrainingConfig::default(),
899 progress: DistributedTrainingProgress {
900 epoch: 0,
901 total_epochs: 0,
902 training_loss: 0.0,
903 validation_loss: None,
904 samples_processed: 0,
905 start_time: SystemTime::now(),
906 estimated_completion: None,
907 active_nodes: Vec::new(),
908 node_statistics: HashMap::new(),
909 },
910 }
911 }
912
913 pub fn with_config(mut self, config: DistributedTrainingConfig) -> Self {
915 self.config = config;
916 self
917 }
918}
919
920impl Default for DistributedTrainingConfig {
921 fn default() -> Self {
922 Self {
923 learning_rate: 0.01,
924 epochs: 100,
925 batch_size: 32,
926 aggregation: GradientAggregation::Average,
927 checkpoint_frequency: 10,
928 }
929 }
930}
931
932impl DistributedEstimator for DistributedLinearRegression {
933 type TrainingData = (Vec<Vec<f64>>, Vec<f64>); type PredictionInput = Vec<Vec<f64>>;
935 type PredictionOutput = Vec<f64>;
936 type Parameters = Vec<f64>;
937
938 fn fit_distributed<'a>(
939 &'a mut self,
940 _cluster: &'a dyn DistributedCluster,
941 training_data: &Self::TrainingData,
942 ) -> BoxFuture<'a, Result<()>> {
943 let training_data = training_data.clone();
944 Box::pin(async move {
945 let (x, _y) = &training_data;
946
947 if self.parameters.is_none() {
949 let feature_count = x.first().map(|row| row.len()).unwrap_or(0);
950 self.parameters = Some(vec![0.0; feature_count + 1]); }
952
953 self.progress.total_epochs = self.config.epochs;
955 self.progress.start_time = SystemTime::now();
956 self.progress.active_nodes = vec![]; for epoch in 0..self.config.epochs {
960 self.progress.epoch = epoch;
961
962 if let Some(ref mut params) = self.parameters {
971 for param in params.iter_mut() {
973 *param += self.config.learning_rate * 0.1; }
975 }
976
977 self.progress.samples_processed += x.len() as u64;
979 self.progress.training_loss = (epoch as f64 * 0.1).exp().recip(); if epoch % self.config.checkpoint_frequency == 0 {
983 }
986 }
987
988 Ok(())
989 })
990 }
991
992 fn predict_distributed<'a>(
993 &'a self,
994 _cluster: &dyn DistributedCluster,
995 input: &'a Self::PredictionInput,
996 ) -> BoxFuture<'a, Result<Self::PredictionOutput>> {
997 Box::pin(async move {
998 let Some(ref params) = self.parameters else {
999 return Err(SklearsError::InvalidOperation(
1000 "Model not trained. Call fit_distributed first.".to_string(),
1001 ));
1002 };
1003
1004 let predictions = input
1006 .iter()
1007 .map(|features| {
1008 let mut prediction = *params.last().unwrap_or(&0.0); for (feature, weight) in features.iter().zip(params.iter()) {
1010 prediction += feature * weight;
1011 }
1012 prediction
1013 })
1014 .collect();
1015
1016 Ok(predictions)
1017 })
1018 }
1019
1020 fn get_parameters(&self) -> Result<Self::Parameters> {
1021 self.parameters
1022 .clone()
1023 .ok_or_else(|| SklearsError::InvalidOperation("Model not trained".to_string()))
1024 }
1025
1026 fn set_parameters(&mut self, params: Self::Parameters) -> Result<()> {
1027 self.parameters = Some(params);
1028 Ok(())
1029 }
1030
1031 fn sync_parameters(&mut self, _cluster: &dyn DistributedCluster) -> BoxFuture<'_, Result<()>> {
1032 Box::pin(async move {
1033 Ok(())
1035 })
1036 }
1037
1038 fn training_progress(&self) -> DistributedTrainingProgress {
1039 self.progress.clone()
1040 }
1041}
1042
1043#[derive(Debug)]
1049pub struct DistributedNumericalDataset {
1050 data: Vec<Vec<f64>>,
1052 partitions: Vec<DistributedPartition<Vec<f64>>>,
1054 assignment: HashMap<NodeId, Vec<u32>>,
1056}
1057
1058impl DistributedNumericalDataset {
1059 pub fn new(data: Vec<Vec<f64>>) -> Self {
1061 Self {
1062 data,
1063 partitions: Vec::new(),
1064 assignment: HashMap::new(),
1065 }
1066 }
1067}
1068
1069impl DistributedDataset for DistributedNumericalDataset {
1070 type Item = Vec<f64>;
1071 type PartitionStrategy = PartitioningStrategy;
1072
1073 fn size(&self) -> u64 {
1074 self.data.len() as u64
1075 }
1076
1077 fn partition_count(&self) -> u32 {
1078 self.partitions.len() as u32
1079 }
1080
1081 fn partition<'a>(
1082 &'a mut self,
1083 cluster: &'a dyn DistributedCluster,
1084 strategy: Self::PartitionStrategy,
1085 ) -> BoxFuture<'a, Result<Vec<DistributedPartition<Self::Item>>>> {
1086 Box::pin(async move {
1087 let nodes = cluster.active_nodes().await?;
1088 let num_nodes = nodes.len();
1089
1090 if num_nodes == 0 {
1091 return Err(SklearsError::InvalidOperation(
1092 "No active nodes in cluster".to_string(),
1093 ));
1094 }
1095
1096 self.partitions.clear();
1097 self.assignment.clear();
1098
1099 match strategy {
1100 PartitioningStrategy::EvenSplit => {
1101 let chunk_size = (self.data.len() + num_nodes - 1) / num_nodes;
1102
1103 for (i, node_id) in nodes.iter().enumerate() {
1104 let start = i * chunk_size;
1105 let end = std::cmp::min(start + chunk_size, self.data.len());
1106
1107 if start < self.data.len() {
1108 let partition_data = self.data[start..end].to_vec();
1109 let partition = DistributedPartition {
1110 partition_id: i as u32,
1111 node_id: node_id.clone(),
1112 data: partition_data.clone(),
1113 metadata: PartitionMetadata {
1114 item_count: partition_data.len() as u64,
1115 size_bytes: partition_data.len() as u64
1116 * std::mem::size_of::<f64>() as u64,
1117 schema: Some("numerical_array".to_string()),
1118 created_at: SystemTime::now(),
1119 modified_at: SystemTime::now(),
1120 checksum: format!("checksum_{}", i),
1121 },
1122 };
1123
1124 self.partitions.push(partition);
1125 self.assignment
1126 .entry(node_id.clone())
1127 .or_default()
1128 .push(i as u32);
1129 }
1130 }
1131 }
1132 _ => {
1133 return Err(SklearsError::InvalidOperation(
1135 "Partitioning strategy not yet implemented".to_string(),
1136 ));
1137 }
1138 }
1139
1140 Ok(self.partitions.clone())
1141 })
1142 }
1143
1144 fn get_partition(
1145 &self,
1146 partition_id: u32,
1147 ) -> BoxFuture<'_, Result<DistributedPartition<Self::Item>>> {
1148 Box::pin(async move {
1149 self.partitions
1150 .get(partition_id as usize)
1151 .cloned()
1152 .ok_or_else(|| {
1153 SklearsError::InvalidOperation(format!("Partition {} not found", partition_id))
1154 })
1155 })
1156 }
1157
1158 fn repartition<'a>(
1159 &'a mut self,
1160 cluster: &'a dyn DistributedCluster,
1161 new_strategy: Self::PartitionStrategy,
1162 ) -> BoxFuture<'a, Result<()>> {
1163 Box::pin(async move {
1164 let collected_data = self.collect(cluster).await?;
1166 self.data = collected_data;
1167
1168 self.partition(cluster, new_strategy).await?;
1170
1171 Ok(())
1172 })
1173 }
1174
1175 fn collect(&self, _cluster: &dyn DistributedCluster) -> BoxFuture<'_, Result<Vec<Self::Item>>> {
1176 Box::pin(async move {
1177 let mut collected = Vec::new();
1178 for partition in &self.partitions {
1179 collected.extend(partition.data.clone());
1180 }
1181 Ok(collected)
1182 })
1183 }
1184
1185 fn partition_assignment(&self) -> HashMap<NodeId, Vec<u32>> {
1186 self.assignment.clone()
1187 }
1188}
1189
1190#[allow(non_snake_case)]
1191#[cfg(test)]
1192mod tests {
1193 use super::*;
1194
1195 #[test]
1196 fn test_node_id_creation() {
1197 let node_id = NodeId::new("worker-01");
1198 assert_eq!(node_id.as_str(), "worker-01");
1199 assert_eq!(node_id.to_string(), "worker-01");
1200 }
1201
1202 #[test]
1203 fn test_message_priority_ordering() {
1204 assert!(MessagePriority::Critical > MessagePriority::High);
1205 assert!(MessagePriority::High > MessagePriority::Normal);
1206 assert!(MessagePriority::Normal > MessagePriority::Low);
1207 }
1208
1209 #[test]
1210 fn test_cluster_configuration_default() {
1211 let config = ClusterConfiguration::default();
1212 assert_eq!(config.max_nodes, 64);
1213 assert_eq!(config.load_balancing, LoadBalancingStrategy::ResourceBased);
1214 assert_eq!(
1215 config.fault_tolerance,
1216 FaultToleranceMode::CheckpointRecovery
1217 );
1218 }
1219
1220 #[test]
1221 fn test_distributed_linear_regression_creation() {
1222 let model = DistributedLinearRegression::new();
1223 assert!(model.parameters.is_none());
1224 assert_eq!(model.progress.epoch, 0);
1225 }
1226
1227 #[test]
1228 fn test_distributed_dataset_size() {
1229 let data = vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]];
1230 let dataset = DistributedNumericalDataset::new(data);
1231 assert_eq!(dataset.size(), 3);
1232 assert_eq!(dataset.partition_count(), 0); }
1234
1235 #[test]
1236 fn test_message_type_serialization() {
1237 let msg_type = MessageType::ParameterSync;
1238 let serialized = serde_json::to_string(&msg_type).unwrap();
1239 let deserialized: MessageType = serde_json::from_str(&serialized).unwrap();
1240 assert_eq!(msg_type, deserialized);
1241 }
1242
1243 #[test]
1244 fn test_partitioning_strategy_variants() {
1245 let strategies = vec![
1246 PartitioningStrategy::EvenSplit,
1247 PartitioningStrategy::HashBased(4),
1248 PartitioningStrategy::RangeBased,
1249 PartitioningStrategy::Random,
1250 PartitioningStrategy::Stratified,
1251 PartitioningStrategy::Custom("custom_strategy".to_string()),
1252 ];
1253
1254 for strategy in strategies {
1255 let serialized = serde_json::to_string(&strategy).unwrap();
1256 let _deserialized: PartitioningStrategy = serde_json::from_str(&serialized).unwrap();
1257 }
1258 }
1259
1260 #[test]
1261 fn test_distributed_training_config() {
1262 let config = DistributedTrainingConfig::default();
1263 assert_eq!(config.learning_rate, 0.01);
1264 assert_eq!(config.epochs, 100);
1265 assert_eq!(config.batch_size, 32);
1266 }
1267
1268 #[cfg(feature = "async_support")]
1269 #[tokio::test]
1270 async fn test_default_cluster_operations() {
1271 let coordinator = NodeId::new("coordinator");
1272 let config = ClusterConfiguration::default();
1273 let cluster = DefaultDistributedCluster::new(coordinator.clone(), config);
1274
1275 assert_eq!(cluster.coordinator(), &coordinator);
1276
1277 let nodes = cluster.active_nodes().await.unwrap();
1278 assert!(nodes.is_empty()); let health = cluster.cluster_health().await.unwrap();
1281 assert_eq!(health.overall_health, 1.0);
1282 }
1283}