1use crate::distributed::NodeId;
7use crate::error::{Result, SklearsError};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::sync::{Arc, RwLock};
11
12#[derive(Debug)]
17pub struct DistributedLinearRegression {
18 pub config: DistributedConfig,
20 pub parameter_server: Arc<RwLock<ParameterServer>>,
22 pub workers: Vec<WorkerNode>,
24 pub parameters: Arc<RwLock<ModelParameters>>,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct DistributedConfig {
31 pub num_workers: usize,
33 pub sync_strategy: SyncStrategy,
35 pub fault_tolerance: bool,
37 pub max_iterations: usize,
39 pub tolerance: f64,
41 pub learning_rate: f64,
43}
44
45impl Default for DistributedConfig {
46 fn default() -> Self {
47 Self {
48 num_workers: 4,
49 sync_strategy: SyncStrategy::Synchronous,
50 fault_tolerance: true,
51 max_iterations: 100,
52 tolerance: 1e-6,
53 learning_rate: 0.01,
54 }
55 }
56}
57
58#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
60pub enum SyncStrategy {
61 Synchronous,
63 Asynchronous,
65 BoundedAsync { staleness_bound: usize },
67}
68
69#[derive(Debug, Clone)]
71pub struct ParameterServer {
72 pub parameters: Vec<f64>,
74 pub version: usize,
76 pub gradient_accumulator: Vec<f64>,
78 pub num_workers: usize,
80 pub updates_received: usize,
82}
83
84impl ParameterServer {
85 pub fn new(num_parameters: usize, num_workers: usize) -> Self {
87 Self {
88 parameters: vec![0.0; num_parameters],
89 version: 0,
90 gradient_accumulator: vec![0.0; num_parameters],
91 num_workers,
92 updates_received: 0,
93 }
94 }
95
96 pub fn receive_gradient(&mut self, gradient: Vec<f64>) -> Result<()> {
98 if gradient.len() != self.parameters.len() {
99 return Err(SklearsError::DimensionMismatch {
100 expected: self.parameters.len(),
101 actual: gradient.len(),
102 });
103 }
104
105 for (acc, grad) in self.gradient_accumulator.iter_mut().zip(gradient.iter()) {
107 *acc += grad;
108 }
109
110 self.updates_received += 1;
111
112 if self.updates_received == self.num_workers {
114 self.apply_accumulated_gradients();
115 }
116
117 Ok(())
118 }
119
120 fn apply_accumulated_gradients(&mut self) {
122 let scale = 1.0 / self.num_workers as f64;
123
124 for (param, grad) in self
125 .parameters
126 .iter_mut()
127 .zip(self.gradient_accumulator.iter())
128 {
129 *param -= grad * scale;
130 }
131
132 self.gradient_accumulator.iter_mut().for_each(|g| *g = 0.0);
134 self.updates_received = 0;
135 self.version += 1;
136 }
137
138 pub fn get_parameters(&self) -> Vec<f64> {
140 self.parameters.clone()
141 }
142
143 pub fn get_version(&self) -> usize {
145 self.version
146 }
147}
148
149#[derive(Debug, Clone)]
151pub struct WorkerNode {
152 pub id: NodeId,
154 pub data_partition: DataPartition,
156 pub local_parameters: Vec<f64>,
158 pub parameter_version: usize,
160 pub stats: WorkerStats,
162}
163
164#[derive(Debug, Clone, Serialize, Deserialize)]
166pub struct DataPartition {
167 pub features: Vec<Vec<f64>>,
169 pub targets: Vec<f64>,
171 pub partition_id: usize,
173}
174
175#[derive(Debug, Clone, Default, Serialize, Deserialize)]
177pub struct WorkerStats {
178 pub samples_processed: usize,
180 pub gradient_computations: usize,
182 pub total_compute_time_ms: u64,
184 pub communication_rounds: usize,
186}
187
188impl WorkerNode {
189 pub fn new(id: NodeId, data_partition: DataPartition) -> Self {
191 Self {
192 id,
193 data_partition,
194 local_parameters: Vec::new(),
195 parameter_version: 0,
196 stats: WorkerStats::default(),
197 }
198 }
199
200 pub fn compute_local_gradient(&mut self, parameters: &[f64]) -> Result<Vec<f64>> {
202 let start_time = std::time::Instant::now();
203
204 self.local_parameters = parameters.to_vec();
206
207 let n_samples = self.data_partition.features.len();
208 let n_features = parameters.len();
209 let mut gradient = vec![0.0; n_features];
210
211 for (features, target) in self
213 .data_partition
214 .features
215 .iter()
216 .zip(self.data_partition.targets.iter())
217 {
218 let prediction: f64 = features
220 .iter()
221 .zip(parameters.iter())
222 .map(|(x, w)| x * w)
223 .sum();
224
225 let error = prediction - target;
227
228 for (i, x) in features.iter().enumerate() {
230 gradient[i] += 2.0 * error * x;
231 }
232 }
233
234 for g in gradient.iter_mut() {
236 *g /= n_samples as f64;
237 }
238
239 self.stats.samples_processed += n_samples;
241 self.stats.gradient_computations += 1;
242 self.stats.total_compute_time_ms += start_time.elapsed().as_millis() as u64;
243
244 Ok(gradient)
245 }
246
247 pub fn get_stats(&self) -> &WorkerStats {
249 &self.stats
250 }
251}
252
253#[derive(Debug, Clone, Serialize, Deserialize)]
255pub struct ModelParameters {
256 pub weights: Vec<f64>,
258 pub bias: f64,
260 pub metadata: ParameterMetadata,
262}
263
264#[derive(Debug, Clone, Serialize, Deserialize)]
266pub struct ParameterMetadata {
267 pub iterations_completed: usize,
269 pub current_loss: f64,
271 pub converged: bool,
273 pub last_updated: std::time::SystemTime,
275}
276
277impl DistributedLinearRegression {
278 pub fn new(config: DistributedConfig, num_features: usize) -> Self {
280 let parameter_server = Arc::new(RwLock::new(ParameterServer::new(
281 num_features,
282 config.num_workers,
283 )));
284
285 Self {
286 config,
287 parameter_server,
288 workers: Vec::new(),
289 parameters: Arc::new(RwLock::new(ModelParameters {
290 weights: vec![0.0; num_features],
291 bias: 0.0,
292 metadata: ParameterMetadata {
293 iterations_completed: 0,
294 current_loss: f64::INFINITY,
295 converged: false,
296 last_updated: std::time::SystemTime::now(),
297 },
298 })),
299 }
300 }
301
302 pub fn partition_data(&mut self, features: Vec<Vec<f64>>, targets: Vec<f64>) -> Result<()> {
304 let n_samples = features.len();
305 let samples_per_worker = n_samples.div_ceil(self.config.num_workers);
306
307 for worker_idx in 0..self.config.num_workers {
308 let start_idx = worker_idx * samples_per_worker;
309 let end_idx = ((worker_idx + 1) * samples_per_worker).min(n_samples);
310
311 if start_idx < n_samples {
312 let partition = DataPartition {
313 features: features[start_idx..end_idx].to_vec(),
314 targets: targets[start_idx..end_idx].to_vec(),
315 partition_id: worker_idx,
316 };
317
318 let worker =
319 WorkerNode::new(NodeId::new(format!("worker_{}", worker_idx)), partition);
320
321 self.workers.push(worker);
322 }
323 }
324
325 Ok(())
326 }
327
328 pub fn fit(&mut self) -> Result<()> {
330 for iteration in 0..self.config.max_iterations {
331 let params = {
333 let ps = self
334 .parameter_server
335 .read()
336 .unwrap_or_else(|e| e.into_inner());
337 ps.get_parameters()
338 };
339
340 let mut all_gradients = Vec::new();
342 for worker in &mut self.workers {
343 let gradient = worker.compute_local_gradient(¶ms)?;
344 all_gradients.push(gradient);
345 }
346
347 {
349 let mut ps = self
350 .parameter_server
351 .write()
352 .unwrap_or_else(|e| e.into_inner());
353 for gradient in all_gradients {
354 ps.receive_gradient(gradient)?;
355 }
356 }
357
358 if iteration % 10 == 0 {
360 let loss = self.compute_global_loss(¶ms)?;
361 let mut model_params = self.parameters.write().unwrap_or_else(|e| e.into_inner());
362
363 if (model_params.metadata.current_loss - loss).abs() < self.config.tolerance {
364 model_params.metadata.converged = true;
365 model_params.metadata.iterations_completed = iteration + 1;
366 break;
367 }
368
369 model_params.metadata.current_loss = loss;
370 model_params.metadata.iterations_completed = iteration + 1;
371 model_params.metadata.last_updated = std::time::SystemTime::now();
372 }
373 }
374
375 let final_params = {
377 let ps = self
378 .parameter_server
379 .read()
380 .unwrap_or_else(|e| e.into_inner());
381 ps.get_parameters()
382 };
383
384 let mut model_params = self.parameters.write().unwrap_or_else(|e| e.into_inner());
385 model_params.weights = final_params;
386
387 Ok(())
388 }
389
390 fn compute_global_loss(&self, parameters: &[f64]) -> Result<f64> {
392 let mut total_loss = 0.0;
393 let mut total_samples = 0;
394
395 for worker in &self.workers {
396 for (features, target) in worker
397 .data_partition
398 .features
399 .iter()
400 .zip(worker.data_partition.targets.iter())
401 {
402 let prediction: f64 = features
403 .iter()
404 .zip(parameters.iter())
405 .map(|(x, w)| x * w)
406 .sum();
407 let error = prediction - target;
408 total_loss += error * error;
409 total_samples += 1;
410 }
411 }
412
413 Ok(total_loss / total_samples as f64)
414 }
415
416 pub fn get_training_stats(&self) -> DistributedTrainingStats {
418 let mut total_samples = 0;
419 let mut total_compute_time = 0;
420 let mut total_gradient_computations = 0;
421
422 for worker in &self.workers {
423 let stats = worker.get_stats();
424 total_samples += stats.samples_processed;
425 total_compute_time += stats.total_compute_time_ms;
426 total_gradient_computations += stats.gradient_computations;
427 }
428
429 DistributedTrainingStats {
430 num_workers: self.workers.len(),
431 total_samples_processed: total_samples,
432 total_compute_time_ms: total_compute_time,
433 total_gradient_computations,
434 parameter_server_version: self
435 .parameter_server
436 .read()
437 .unwrap_or_else(|e| e.into_inner())
438 .get_version(),
439 }
440 }
441
442 pub fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<f64>> {
444 let params = self.parameters.read().unwrap_or_else(|e| e.into_inner());
445 let mut predictions = Vec::new();
446
447 for feature_row in features {
448 let pred: f64 = feature_row
449 .iter()
450 .zip(params.weights.iter())
451 .map(|(x, w)| x * w)
452 .sum::<f64>()
453 + params.bias;
454
455 predictions.push(pred);
456 }
457
458 Ok(predictions)
459 }
460}
461
462#[derive(Debug, Clone, Serialize, Deserialize)]
464pub struct DistributedTrainingStats {
465 pub num_workers: usize,
467 pub total_samples_processed: usize,
469 pub total_compute_time_ms: u64,
471 pub total_gradient_computations: usize,
473 pub parameter_server_version: usize,
475}
476
477#[derive(Debug)]
486pub struct FederatedLearning {
487 pub config: FederatedConfig,
489 pub clients: Vec<FederatedClient>,
491 pub global_model: Arc<RwLock<ModelParameters>>,
493 pub privacy_mechanism: PrivacyMechanism,
495}
496
497#[derive(Debug, Clone, Serialize, Deserialize)]
499pub struct FederatedConfig {
500 pub num_clients: usize,
502 pub client_fraction: f64,
504 pub local_epochs: usize,
506 pub local_learning_rate: f64,
508 pub secure_aggregation: bool,
510 pub dp_epsilon: Option<f64>,
512 pub dp_delta: Option<f64>,
514}
515
516impl Default for FederatedConfig {
517 fn default() -> Self {
518 Self {
519 num_clients: 10,
520 client_fraction: 0.3,
521 local_epochs: 5,
522 local_learning_rate: 0.01,
523 secure_aggregation: true,
524 dp_epsilon: Some(1.0),
525 dp_delta: Some(1e-5),
526 }
527 }
528}
529
530#[derive(Debug, Clone)]
532pub struct FederatedClient {
533 pub id: String,
535 pub dataset_size: usize,
537 pub local_parameters: Vec<f64>,
539 pub stats: ClientStats,
541}
542
543#[derive(Debug, Clone, Default, Serialize, Deserialize)]
545pub struct ClientStats {
546 pub rounds_participated: usize,
548 pub total_samples: usize,
550 pub avg_local_loss: f64,
552}
553
554impl FederatedLearning {
555 pub fn new(config: FederatedConfig, num_features: usize) -> Self {
557 Self {
558 config,
559 clients: Vec::new(),
560 global_model: Arc::new(RwLock::new(ModelParameters {
561 weights: vec![0.0; num_features],
562 bias: 0.0,
563 metadata: ParameterMetadata {
564 iterations_completed: 0,
565 current_loss: f64::INFINITY,
566 converged: false,
567 last_updated: std::time::SystemTime::now(),
568 },
569 })),
570 privacy_mechanism: PrivacyMechanism::new(),
571 }
572 }
573
574 pub fn add_client(&mut self, client_id: String, dataset_size: usize) {
576 let client = FederatedClient {
577 id: client_id,
578 dataset_size,
579 local_parameters: Vec::new(),
580 stats: ClientStats::default(),
581 };
582 self.clients.push(client);
583 }
584
585 pub fn select_clients(&self) -> Vec<usize> {
587 use scirs2_core::random::thread_rng;
588
589 let num_selected =
590 (self.clients.len() as f64 * self.config.client_fraction).ceil() as usize;
591 let mut selected = Vec::new();
592 let mut rng = thread_rng();
593
594 let mut indices: Vec<usize> = (0..self.clients.len()).collect();
595
596 for i in (1..indices.len()).rev() {
598 let j = rng.gen_range(0..=i);
599 indices.swap(i, j);
600 }
601
602 selected.extend_from_slice(&indices[..num_selected]);
603 selected
604 }
605
606 pub fn federated_average(&self, client_updates: &[(usize, Vec<f64>)]) -> Vec<f64> {
608 if client_updates.is_empty() {
609 return vec![];
610 }
611
612 let num_features = client_updates[0].1.len();
613 let mut averaged = vec![0.0; num_features];
614 let mut total_weight = 0.0;
615
616 for (client_idx, update) in client_updates {
617 let weight = self.clients[*client_idx].dataset_size as f64;
618 total_weight += weight;
619
620 for (i, val) in update.iter().enumerate() {
621 averaged[i] += val * weight;
622 }
623 }
624
625 for val in averaged.iter_mut() {
627 *val /= total_weight;
628 }
629
630 if let Some(dp_epsilon) = self.config.dp_epsilon {
632 self.privacy_mechanism
633 .apply_noise(&mut averaged, dp_epsilon);
634 }
635
636 averaged
637 }
638
639 pub fn get_global_model(&self) -> ModelParameters {
641 self.global_model
642 .read()
643 .unwrap_or_else(|e| e.into_inner())
644 .clone()
645 }
646}
647
648#[derive(Debug, Clone)]
650pub struct PrivacyMechanism {
651 pub noise_scale: f64,
653}
654
655impl PrivacyMechanism {
656 pub fn new() -> Self {
658 Self { noise_scale: 1.0 }
659 }
660
661 pub fn apply_noise(&self, gradients: &mut [f64], epsilon: f64) {
663 use scirs2_core::random::essentials::Normal;
664 use scirs2_core::random::thread_rng;
665
666 let mut rng = thread_rng();
667 let noise_std = self.noise_scale / epsilon;
668 let normal = Normal::new(0.0, noise_std)
669 .unwrap_or_else(|_| Normal::new(0.0, 1.0).expect("default normal distribution"));
670
671 for grad in gradients.iter_mut() {
672 *grad += rng.sample(normal);
673 }
674 }
675
676 pub fn clip_gradients(&self, gradients: &mut [f64], clip_norm: f64) {
678 let norm: f64 = gradients.iter().map(|g| g * g).sum::<f64>().sqrt();
679
680 if norm > clip_norm {
681 let scale = clip_norm / norm;
682 for grad in gradients.iter_mut() {
683 *grad *= scale;
684 }
685 }
686 }
687}
688
689impl Default for PrivacyMechanism {
690 fn default() -> Self {
691 Self::new()
692 }
693}
694
695#[derive(Debug)]
700pub struct ByzantineFaultTolerant {
701 pub config: BFTConfig,
703 pub aggregation_method: AggregationMethod,
705}
706
707#[derive(Debug, Clone, Serialize, Deserialize)]
709pub struct BFTConfig {
710 pub max_byzantine_fraction: f64,
712 pub detection_threshold: f64,
714 pub enable_reputation: bool,
716}
717
718impl Default for BFTConfig {
719 fn default() -> Self {
720 Self {
721 max_byzantine_fraction: 0.3,
722 detection_threshold: 2.0,
723 enable_reputation: true,
724 }
725 }
726}
727
728#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
730pub enum AggregationMethod {
731 Median,
733 TrimmedMean { trim_fraction: usize },
735 Krum,
737 Bulyan,
739}
740
741impl ByzantineFaultTolerant {
742 pub fn new(config: BFTConfig, method: AggregationMethod) -> Self {
744 Self {
745 config,
746 aggregation_method: method,
747 }
748 }
749
750 pub fn aggregate(&self, gradients: &[Vec<f64>]) -> Result<Vec<f64>> {
752 if gradients.is_empty() {
753 return Err(SklearsError::InvalidInput(
754 "Cannot aggregate empty gradient set".to_string(),
755 ));
756 }
757
758 match self.aggregation_method {
759 AggregationMethod::Median => self.coordinate_wise_median(gradients),
760 AggregationMethod::TrimmedMean { trim_fraction } => {
761 self.trimmed_mean(gradients, trim_fraction)
762 }
763 AggregationMethod::Krum => self.krum(gradients),
764 AggregationMethod::Bulyan => self.bulyan(gradients),
765 }
766 }
767
768 fn coordinate_wise_median(&self, gradients: &[Vec<f64>]) -> Result<Vec<f64>> {
770 let num_features = gradients[0].len();
771 let mut result = vec![0.0; num_features];
772
773 for i in 0..num_features {
774 let mut values: Vec<f64> = gradients.iter().map(|g| g[i]).collect();
775 values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
776 result[i] = values[values.len() / 2];
777 }
778
779 Ok(result)
780 }
781
782 fn trimmed_mean(&self, gradients: &[Vec<f64>], trim_fraction: usize) -> Result<Vec<f64>> {
784 let num_features = gradients[0].len();
785 let mut result = vec![0.0; num_features];
786 let trim_count = (gradients.len() * trim_fraction) / 100;
787
788 for i in 0..num_features {
789 let mut values: Vec<f64> = gradients.iter().map(|g| g[i]).collect();
790 values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
791
792 let trimmed = &values[trim_count..values.len() - trim_count];
794 result[i] = trimmed.iter().sum::<f64>() / trimmed.len() as f64;
795 }
796
797 Ok(result)
798 }
799
800 fn krum(&self, gradients: &[Vec<f64>]) -> Result<Vec<f64>> {
802 let n = gradients.len();
803 let f = (n as f64 * self.config.max_byzantine_fraction).floor() as usize;
804 let m = n - f - 2;
805
806 let mut scores = vec![0.0; n];
807
808 for i in 0..n {
810 let mut distances: Vec<(usize, f64)> = Vec::new();
811
812 for j in 0..n {
813 if i != j {
814 let dist = self.euclidean_distance(&gradients[i], &gradients[j]);
815 distances.push((j, dist));
816 }
817 }
818
819 distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
821 scores[i] = distances.iter().take(m).map(|(_, d)| d).sum();
822 }
823
824 let best_idx = scores
826 .iter()
827 .enumerate()
828 .min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
829 .map(|(idx, _)| idx)
830 .expect("expected valid value");
831
832 Ok(gradients[best_idx].clone())
833 }
834
835 fn bulyan(&self, gradients: &[Vec<f64>]) -> Result<Vec<f64>> {
837 let n = gradients.len();
839 let f = (n as f64 * self.config.max_byzantine_fraction).floor() as usize;
840 let theta = n - 2 * f;
841
842 if theta < 1 {
843 return Err(SklearsError::InvalidInput(
844 "Too many Byzantine workers for Bulyan".to_string(),
845 ));
846 }
847
848 self.coordinate_wise_median(gradients)
850 }
851
852 fn euclidean_distance(&self, a: &[f64], b: &[f64]) -> f64 {
854 a.iter()
855 .zip(b.iter())
856 .map(|(x, y)| (x - y).powi(2))
857 .sum::<f64>()
858 .sqrt()
859 }
860}
861
862#[derive(Debug)]
867pub struct LoadBalancer {
868 pub strategy: LoadBalancingStrategy,
870 pub worker_loads: HashMap<String, WorkerLoad>,
872}
873
874#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
876pub enum LoadBalancingStrategy {
877 RoundRobin,
879 LeastLoaded,
881 WeightedRandom,
883 PowerOfTwo,
885}
886
887#[derive(Debug, Clone, Serialize, Deserialize)]
889pub struct WorkerLoad {
890 pub active_tasks: usize,
892 pub capacity: usize,
894 pub avg_completion_time_ms: u64,
896 pub load_factor: f64,
898}
899
900impl LoadBalancer {
901 pub fn new(strategy: LoadBalancingStrategy) -> Self {
903 Self {
904 strategy,
905 worker_loads: HashMap::new(),
906 }
907 }
908
909 pub fn register_worker(&mut self, worker_id: String, capacity: usize) {
911 self.worker_loads.insert(
912 worker_id,
913 WorkerLoad {
914 active_tasks: 0,
915 capacity,
916 avg_completion_time_ms: 0,
917 load_factor: 0.0,
918 },
919 );
920 }
921
922 pub fn select_worker(&mut self) -> Option<String> {
924 match self.strategy {
925 LoadBalancingStrategy::RoundRobin => self.round_robin_select(),
926 LoadBalancingStrategy::LeastLoaded => self.least_loaded_select(),
927 LoadBalancingStrategy::WeightedRandom => self.weighted_random_select(),
928 LoadBalancingStrategy::PowerOfTwo => self.power_of_two_select(),
929 }
930 }
931
932 fn round_robin_select(&self) -> Option<String> {
934 self.worker_loads.keys().next().cloned()
935 }
936
937 fn least_loaded_select(&self) -> Option<String> {
939 self.worker_loads
940 .iter()
941 .min_by(|(_, a), (_, b)| {
942 a.load_factor
943 .partial_cmp(&b.load_factor)
944 .unwrap_or(std::cmp::Ordering::Equal)
945 })
946 .map(|(id, _)| id.clone())
947 }
948
949 fn weighted_random_select(&self) -> Option<String> {
951 use scirs2_core::random::thread_rng;
952
953 if self.worker_loads.is_empty() {
954 return None;
955 }
956
957 let mut rng = thread_rng();
958 let total_capacity: f64 = self
959 .worker_loads
960 .values()
961 .map(|load| (load.capacity - load.active_tasks) as f64)
962 .sum();
963
964 let mut rand_val = rng.gen_range(0.0..total_capacity);
965
966 for (id, load) in &self.worker_loads {
967 let available = (load.capacity - load.active_tasks) as f64;
968 if rand_val < available {
969 return Some(id.clone());
970 }
971 rand_val -= available;
972 }
973
974 self.worker_loads.keys().next().cloned()
975 }
976
977 fn power_of_two_select(&self) -> Option<String> {
979 use scirs2_core::random::thread_rng;
980
981 if self.worker_loads.is_empty() {
982 return None;
983 }
984
985 let mut rng = thread_rng();
986 let workers: Vec<_> = self.worker_loads.keys().collect();
987
988 if workers.len() == 1 {
989 return Some(workers[0].clone());
990 }
991
992 let idx1 = rng.gen_range(0..workers.len());
993 let mut idx2 = rng.gen_range(0..workers.len());
994 while idx2 == idx1 {
995 idx2 = rng.gen_range(0..workers.len());
996 }
997
998 let load1 = &self.worker_loads[workers[idx1]];
999 let load2 = &self.worker_loads[workers[idx2]];
1000
1001 if load1.load_factor < load2.load_factor {
1002 Some(workers[idx1].clone())
1003 } else {
1004 Some(workers[idx2].clone())
1005 }
1006 }
1007
1008 pub fn update_load(&mut self, worker_id: &str, task_assigned: bool) {
1010 if let Some(load) = self.worker_loads.get_mut(worker_id) {
1011 if task_assigned {
1012 load.active_tasks += 1;
1013 } else if load.active_tasks > 0 {
1014 load.active_tasks -= 1;
1015 }
1016 load.load_factor = load.active_tasks as f64 / load.capacity as f64;
1017 }
1018 }
1019}
1020
1021#[cfg(test)]
1022mod tests {
1023 use super::*;
1024
1025 #[test]
1026 fn test_parameter_server_creation() {
1027 let ps = ParameterServer::new(5, 3);
1028 assert_eq!(ps.parameters.len(), 5);
1029 assert_eq!(ps.num_workers, 3);
1030 assert_eq!(ps.version, 0);
1031 }
1032
1033 #[test]
1034 fn test_gradient_accumulation() {
1035 let mut ps = ParameterServer::new(3, 2);
1036
1037 let grad1 = vec![1.0, 2.0, 3.0];
1038 let grad2 = vec![2.0, 3.0, 4.0];
1039
1040 ps.receive_gradient(grad1)
1041 .expect("receive_gradient should succeed");
1042 ps.receive_gradient(grad2)
1043 .expect("receive_gradient should succeed");
1044
1045 let params = ps.get_parameters();
1047 assert_eq!(ps.version, 1);
1048 assert!(params.iter().all(|&p| p != 0.0));
1049 }
1050
1051 #[test]
1052 fn test_worker_node_creation() {
1053 let partition = DataPartition {
1054 features: vec![vec![1.0, 2.0], vec![3.0, 4.0]],
1055 targets: vec![1.0, 2.0],
1056 partition_id: 0,
1057 };
1058
1059 let worker = WorkerNode::new(NodeId::new("worker_0"), partition);
1060 assert_eq!(worker.id.0, "worker_0");
1061 assert_eq!(worker.stats.samples_processed, 0);
1062 }
1063
1064 #[test]
1065 fn test_local_gradient_computation() {
1066 let partition = DataPartition {
1067 features: vec![vec![1.0, 2.0], vec![2.0, 3.0]],
1068 targets: vec![3.0, 5.0],
1069 partition_id: 0,
1070 };
1071
1072 let mut worker = WorkerNode::new(NodeId::new("worker_0"), partition);
1073 let params = vec![1.0, 1.0];
1074
1075 let gradient = worker
1076 .compute_local_gradient(¶ms)
1077 .expect("compute_local_gradient should succeed");
1078 assert_eq!(gradient.len(), 2);
1079 assert!(worker.stats.gradient_computations > 0);
1080 }
1081
1082 #[test]
1083 fn test_distributed_regression_creation() {
1084 let config = DistributedConfig::default();
1085 let model = DistributedLinearRegression::new(config, 5);
1086
1087 assert_eq!(model.workers.len(), 0);
1088 assert!(
1089 model
1090 .parameters
1091 .read()
1092 .unwrap_or_else(|e| e.into_inner())
1093 .weights
1094 .len()
1095 == 5
1096 );
1097 }
1098
1099 #[test]
1100 fn test_data_partitioning() {
1101 let config = DistributedConfig {
1102 num_workers: 2,
1103 ..Default::default()
1104 };
1105
1106 let mut model = DistributedLinearRegression::new(config, 2);
1107
1108 let features = vec![
1109 vec![1.0, 2.0],
1110 vec![3.0, 4.0],
1111 vec![5.0, 6.0],
1112 vec![7.0, 8.0],
1113 ];
1114 let targets = vec![3.0, 7.0, 11.0, 15.0];
1115
1116 model
1117 .partition_data(features, targets)
1118 .expect("partition_data should succeed");
1119
1120 assert_eq!(model.workers.len(), 2);
1121 assert!(!model.workers[0].data_partition.features.is_empty());
1122 }
1123
1124 #[test]
1125 fn test_distributed_training() {
1126 let config = DistributedConfig {
1127 num_workers: 2,
1128 max_iterations: 10,
1129 tolerance: 1e-3,
1130 learning_rate: 0.01,
1131 ..Default::default()
1132 };
1133
1134 let mut model = DistributedLinearRegression::new(config, 2);
1135
1136 let features = vec![
1138 vec![1.0, 1.0],
1139 vec![2.0, 2.0],
1140 vec![3.0, 3.0],
1141 vec![4.0, 4.0],
1142 ];
1143 let targets = vec![5.0, 10.0, 15.0, 20.0];
1144
1145 model
1146 .partition_data(features, targets)
1147 .expect("partition_data should succeed");
1148 model.fit().expect("model fitting should succeed");
1149
1150 let stats = model.get_training_stats();
1151 assert!(stats.total_samples_processed > 0);
1152 assert!(stats.parameter_server_version > 0);
1153 }
1154
1155 #[test]
1156 fn test_prediction() {
1157 let config = DistributedConfig::default();
1158 let model = DistributedLinearRegression::new(config, 2);
1159
1160 let test_features = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
1161 let predictions = model
1162 .predict(&test_features)
1163 .expect("prediction should succeed");
1164
1165 assert_eq!(predictions.len(), 2);
1166 }
1167
1168 #[test]
1169 fn test_training_stats() {
1170 let config = DistributedConfig {
1171 num_workers: 3,
1172 ..Default::default()
1173 };
1174
1175 let model = DistributedLinearRegression::new(config, 2);
1176 let stats = model.get_training_stats();
1177
1178 assert_eq!(stats.num_workers, 0); }
1180
1181 #[test]
1186 fn test_federated_learning_creation() {
1187 let config = FederatedConfig::default();
1188 let fed_learning = FederatedLearning::new(config, 5);
1189
1190 assert_eq!(fed_learning.clients.len(), 0);
1191 assert_eq!(fed_learning.config.num_clients, 10);
1192 }
1193
1194 #[test]
1195 fn test_federated_add_client() {
1196 let config = FederatedConfig::default();
1197 let mut fed_learning = FederatedLearning::new(config, 5);
1198
1199 fed_learning.add_client("client_1".to_string(), 100);
1200 fed_learning.add_client("client_2".to_string(), 150);
1201
1202 assert_eq!(fed_learning.clients.len(), 2);
1203 assert_eq!(fed_learning.clients[0].dataset_size, 100);
1204 assert_eq!(fed_learning.clients[1].dataset_size, 150);
1205 }
1206
1207 #[test]
1208 fn test_federated_client_selection() {
1209 let config = FederatedConfig {
1210 client_fraction: 0.5,
1211 ..Default::default()
1212 };
1213 let mut fed_learning = FederatedLearning::new(config, 5);
1214
1215 for i in 0..10 {
1216 fed_learning.add_client(format!("client_{}", i), 100);
1217 }
1218
1219 let selected = fed_learning.select_clients();
1220 assert!(selected.len() >= 4 && selected.len() <= 6); }
1222
1223 #[test]
1224 fn test_federated_averaging() {
1225 let config = FederatedConfig {
1226 dp_epsilon: None, ..Default::default()
1228 };
1229 let mut fed_learning = FederatedLearning::new(config, 3);
1230
1231 fed_learning.add_client("client_1".to_string(), 100);
1232 fed_learning.add_client("client_2".to_string(), 100);
1233
1234 let updates = vec![(0, vec![1.0, 2.0, 3.0]), (1, vec![2.0, 4.0, 6.0])];
1235
1236 let averaged = fed_learning.federated_average(&updates);
1237 assert_eq!(averaged.len(), 3);
1238 assert!((averaged[0] - 1.5).abs() < 1e-6);
1240 assert!((averaged[1] - 3.0).abs() < 1e-6);
1241 assert!((averaged[2] - 4.5).abs() < 1e-6);
1242 }
1243
1244 #[test]
1245 fn test_privacy_mechanism_noise() {
1246 let privacy = PrivacyMechanism::new();
1247 let mut gradients = vec![1.0, 2.0, 3.0];
1248 let original = gradients.clone();
1249
1250 privacy.apply_noise(&mut gradients, 1.0);
1251
1252 assert_ne!(gradients, original);
1254 }
1255
1256 #[test]
1257 fn test_privacy_mechanism_clipping() {
1258 let privacy = PrivacyMechanism::new();
1259 let mut gradients = vec![3.0, 4.0]; privacy.clip_gradients(&mut gradients, 1.0);
1262
1263 let norm: f64 = gradients.iter().map(|g| g * g).sum::<f64>().sqrt();
1264 assert!((norm - 1.0).abs() < 1e-6);
1265 }
1266
1267 #[test]
1268 fn test_byzantine_fault_tolerant_creation() {
1269 let config = BFTConfig::default();
1270 let bft = ByzantineFaultTolerant::new(config, AggregationMethod::Median);
1271
1272 assert_eq!(bft.aggregation_method, AggregationMethod::Median);
1273 }
1274
1275 #[test]
1276 fn test_byzantine_median_aggregation() {
1277 let config = BFTConfig::default();
1278 let bft = ByzantineFaultTolerant::new(config, AggregationMethod::Median);
1279
1280 let gradients = vec![
1281 vec![1.0, 2.0, 3.0],
1282 vec![2.0, 3.0, 4.0],
1283 vec![3.0, 4.0, 5.0],
1284 vec![100.0, 100.0, 100.0], ];
1286
1287 let result = bft.aggregate(&gradients).expect("aggregate should succeed");
1288 assert_eq!(result.len(), 3);
1289 assert!(result[0] < 50.0);
1291 assert!(result[1] < 50.0);
1292 assert!(result[2] < 50.0);
1293 }
1294
1295 #[test]
1296 fn test_byzantine_trimmed_mean() {
1297 let config = BFTConfig::default();
1298 let bft = ByzantineFaultTolerant::new(
1299 config,
1300 AggregationMethod::TrimmedMean { trim_fraction: 25 },
1301 );
1302
1303 let gradients = vec![
1304 vec![1.0, 2.0, 3.0],
1305 vec![2.0, 3.0, 4.0],
1306 vec![3.0, 4.0, 5.0],
1307 vec![4.0, 5.0, 6.0],
1308 ];
1309
1310 let result = bft.aggregate(&gradients).expect("aggregate should succeed");
1311 assert_eq!(result.len(), 3);
1312 assert!(result[0] > 1.0 && result[0] < 4.0);
1314 }
1315
1316 #[test]
1317 fn test_byzantine_krum() {
1318 let config = BFTConfig::default();
1319 let bft = ByzantineFaultTolerant::new(config, AggregationMethod::Krum);
1320
1321 let gradients = vec![
1322 vec![1.0, 2.0, 3.0],
1323 vec![1.1, 2.1, 3.1],
1324 vec![1.2, 2.2, 3.2],
1325 vec![100.0, 100.0, 100.0], ];
1327
1328 let result = bft.aggregate(&gradients).expect("aggregate should succeed");
1329 assert_eq!(result.len(), 3);
1330 assert!(result[0] < 10.0);
1332 }
1333
1334 #[test]
1335 fn test_byzantine_empty_gradients() {
1336 let config = BFTConfig::default();
1337 let bft = ByzantineFaultTolerant::new(config, AggregationMethod::Median);
1338
1339 let gradients: Vec<Vec<f64>> = vec![];
1340 let result = bft.aggregate(&gradients);
1341
1342 assert!(result.is_err());
1343 }
1344
1345 #[test]
1346 fn test_load_balancer_creation() {
1347 let lb = LoadBalancer::new(LoadBalancingStrategy::RoundRobin);
1348 assert_eq!(lb.strategy, LoadBalancingStrategy::RoundRobin);
1349 assert_eq!(lb.worker_loads.len(), 0);
1350 }
1351
1352 #[test]
1353 fn test_load_balancer_register_worker() {
1354 let mut lb = LoadBalancer::new(LoadBalancingStrategy::LeastLoaded);
1355
1356 lb.register_worker("worker_1".to_string(), 10);
1357 lb.register_worker("worker_2".to_string(), 20);
1358
1359 assert_eq!(lb.worker_loads.len(), 2);
1360 assert_eq!(
1361 lb.worker_loads
1362 .get("worker_1")
1363 .expect("key should exist")
1364 .capacity,
1365 10
1366 );
1367 assert_eq!(
1368 lb.worker_loads
1369 .get("worker_2")
1370 .expect("key should exist")
1371 .capacity,
1372 20
1373 );
1374 }
1375
1376 #[test]
1377 fn test_load_balancer_least_loaded() {
1378 let mut lb = LoadBalancer::new(LoadBalancingStrategy::LeastLoaded);
1379
1380 lb.register_worker("worker_1".to_string(), 10);
1381 lb.register_worker("worker_2".to_string(), 10);
1382
1383 let selected = lb.select_worker();
1385 assert!(selected.is_some());
1386
1387 lb.update_load("worker_1", true);
1389 lb.update_load("worker_1", true);
1390
1391 let selected = lb.select_worker();
1393 assert!(selected.is_some());
1394 }
1395
1396 #[test]
1397 fn test_load_balancer_update_load() {
1398 let mut lb = LoadBalancer::new(LoadBalancingStrategy::LeastLoaded);
1399
1400 lb.register_worker("worker_1".to_string(), 10);
1401
1402 lb.update_load("worker_1", true);
1403 assert_eq!(
1404 lb.worker_loads
1405 .get("worker_1")
1406 .expect("key should exist")
1407 .active_tasks,
1408 1
1409 );
1410 assert!(
1411 (lb.worker_loads
1412 .get("worker_1")
1413 .expect("key should exist")
1414 .load_factor
1415 - 0.1)
1416 .abs()
1417 < 1e-6
1418 );
1419
1420 lb.update_load("worker_1", false);
1421 assert_eq!(
1422 lb.worker_loads
1423 .get("worker_1")
1424 .expect("key should exist")
1425 .active_tasks,
1426 0
1427 );
1428 assert!(
1429 (lb.worker_loads
1430 .get("worker_1")
1431 .expect("key should exist")
1432 .load_factor)
1433 .abs()
1434 < 1e-6
1435 );
1436 }
1437
1438 #[test]
1439 fn test_load_balancer_power_of_two() {
1440 let mut lb = LoadBalancer::new(LoadBalancingStrategy::PowerOfTwo);
1441
1442 lb.register_worker("worker_1".to_string(), 10);
1443 lb.register_worker("worker_2".to_string(), 10);
1444 lb.register_worker("worker_3".to_string(), 10);
1445
1446 let selected = lb.select_worker();
1447 assert!(selected.is_some());
1448 }
1449
1450 #[test]
1451 fn test_federated_config_default() {
1452 let config = FederatedConfig::default();
1453 assert_eq!(config.num_clients, 10);
1454 assert!((config.client_fraction - 0.3).abs() < 1e-6);
1455 assert_eq!(config.local_epochs, 5);
1456 assert!(config.secure_aggregation);
1457 }
1458
1459 #[test]
1460 fn test_bft_config_default() {
1461 let config = BFTConfig::default();
1462 assert!((config.max_byzantine_fraction - 0.3).abs() < 1e-6);
1463 assert!((config.detection_threshold - 2.0).abs() < 1e-6);
1464 assert!(config.enable_reputation);
1465 }
1466
1467 #[test]
1468 fn test_aggregation_method_equality() {
1469 assert_eq!(AggregationMethod::Median, AggregationMethod::Median);
1470 assert_eq!(
1471 AggregationMethod::TrimmedMean { trim_fraction: 25 },
1472 AggregationMethod::TrimmedMean { trim_fraction: 25 }
1473 );
1474 assert_ne!(AggregationMethod::Median, AggregationMethod::Krum);
1475 }
1476
1477 #[test]
1478 fn test_load_balancing_strategy_equality() {
1479 assert_eq!(
1480 LoadBalancingStrategy::RoundRobin,
1481 LoadBalancingStrategy::RoundRobin
1482 );
1483 assert_eq!(
1484 LoadBalancingStrategy::LeastLoaded,
1485 LoadBalancingStrategy::LeastLoaded
1486 );
1487 assert_ne!(
1488 LoadBalancingStrategy::RoundRobin,
1489 LoadBalancingStrategy::LeastLoaded
1490 );
1491 }
1492}