1use crate::distributed::NodeId;
7use crate::error::{Result, SklearsError};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::sync::{Arc, RwLock};
11
12#[derive(Debug)]
17pub struct DistributedLinearRegression {
18 pub config: DistributedConfig,
20 pub parameter_server: Arc<RwLock<ParameterServer>>,
22 pub workers: Vec<WorkerNode>,
24 pub parameters: Arc<RwLock<ModelParameters>>,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct DistributedConfig {
31 pub num_workers: usize,
33 pub sync_strategy: SyncStrategy,
35 pub fault_tolerance: bool,
37 pub max_iterations: usize,
39 pub tolerance: f64,
41 pub learning_rate: f64,
43}
44
45impl Default for DistributedConfig {
46 fn default() -> Self {
47 Self {
48 num_workers: 4,
49 sync_strategy: SyncStrategy::Synchronous,
50 fault_tolerance: true,
51 max_iterations: 100,
52 tolerance: 1e-6,
53 learning_rate: 0.01,
54 }
55 }
56}
57
58#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
60pub enum SyncStrategy {
61 Synchronous,
63 Asynchronous,
65 BoundedAsync { staleness_bound: usize },
67}
68
69#[derive(Debug, Clone)]
71pub struct ParameterServer {
72 pub parameters: Vec<f64>,
74 pub version: usize,
76 pub gradient_accumulator: Vec<f64>,
78 pub num_workers: usize,
80 pub updates_received: usize,
82}
83
84impl ParameterServer {
85 pub fn new(num_parameters: usize, num_workers: usize) -> Self {
87 Self {
88 parameters: vec![0.0; num_parameters],
89 version: 0,
90 gradient_accumulator: vec![0.0; num_parameters],
91 num_workers,
92 updates_received: 0,
93 }
94 }
95
96 pub fn receive_gradient(&mut self, gradient: Vec<f64>) -> Result<()> {
98 if gradient.len() != self.parameters.len() {
99 return Err(SklearsError::DimensionMismatch {
100 expected: self.parameters.len(),
101 actual: gradient.len(),
102 });
103 }
104
105 for (acc, grad) in self.gradient_accumulator.iter_mut().zip(gradient.iter()) {
107 *acc += grad;
108 }
109
110 self.updates_received += 1;
111
112 if self.updates_received == self.num_workers {
114 self.apply_accumulated_gradients();
115 }
116
117 Ok(())
118 }
119
120 fn apply_accumulated_gradients(&mut self) {
122 let scale = 1.0 / self.num_workers as f64;
123
124 for (param, grad) in self
125 .parameters
126 .iter_mut()
127 .zip(self.gradient_accumulator.iter())
128 {
129 *param -= grad * scale;
130 }
131
132 self.gradient_accumulator.iter_mut().for_each(|g| *g = 0.0);
134 self.updates_received = 0;
135 self.version += 1;
136 }
137
138 pub fn get_parameters(&self) -> Vec<f64> {
140 self.parameters.clone()
141 }
142
143 pub fn get_version(&self) -> usize {
145 self.version
146 }
147}
148
149#[derive(Debug, Clone)]
151pub struct WorkerNode {
152 pub id: NodeId,
154 pub data_partition: DataPartition,
156 pub local_parameters: Vec<f64>,
158 pub parameter_version: usize,
160 pub stats: WorkerStats,
162}
163
164#[derive(Debug, Clone, Serialize, Deserialize)]
166pub struct DataPartition {
167 pub features: Vec<Vec<f64>>,
169 pub targets: Vec<f64>,
171 pub partition_id: usize,
173}
174
175#[derive(Debug, Clone, Default, Serialize, Deserialize)]
177pub struct WorkerStats {
178 pub samples_processed: usize,
180 pub gradient_computations: usize,
182 pub total_compute_time_ms: u64,
184 pub communication_rounds: usize,
186}
187
188impl WorkerNode {
189 pub fn new(id: NodeId, data_partition: DataPartition) -> Self {
191 Self {
192 id,
193 data_partition,
194 local_parameters: Vec::new(),
195 parameter_version: 0,
196 stats: WorkerStats::default(),
197 }
198 }
199
200 pub fn compute_local_gradient(&mut self, parameters: &[f64]) -> Result<Vec<f64>> {
202 let start_time = std::time::Instant::now();
203
204 self.local_parameters = parameters.to_vec();
206
207 let n_samples = self.data_partition.features.len();
208 let n_features = parameters.len();
209 let mut gradient = vec![0.0; n_features];
210
211 for (features, target) in self
213 .data_partition
214 .features
215 .iter()
216 .zip(self.data_partition.targets.iter())
217 {
218 let prediction: f64 = features
220 .iter()
221 .zip(parameters.iter())
222 .map(|(x, w)| x * w)
223 .sum();
224
225 let error = prediction - target;
227
228 for (i, x) in features.iter().enumerate() {
230 gradient[i] += 2.0 * error * x;
231 }
232 }
233
234 for g in gradient.iter_mut() {
236 *g /= n_samples as f64;
237 }
238
239 self.stats.samples_processed += n_samples;
241 self.stats.gradient_computations += 1;
242 self.stats.total_compute_time_ms += start_time.elapsed().as_millis() as u64;
243
244 Ok(gradient)
245 }
246
247 pub fn get_stats(&self) -> &WorkerStats {
249 &self.stats
250 }
251}
252
253#[derive(Debug, Clone, Serialize, Deserialize)]
255pub struct ModelParameters {
256 pub weights: Vec<f64>,
258 pub bias: f64,
260 pub metadata: ParameterMetadata,
262}
263
264#[derive(Debug, Clone, Serialize, Deserialize)]
266pub struct ParameterMetadata {
267 pub iterations_completed: usize,
269 pub current_loss: f64,
271 pub converged: bool,
273 pub last_updated: std::time::SystemTime,
275}
276
277impl DistributedLinearRegression {
278 pub fn new(config: DistributedConfig, num_features: usize) -> Self {
280 let parameter_server = Arc::new(RwLock::new(ParameterServer::new(
281 num_features,
282 config.num_workers,
283 )));
284
285 Self {
286 config,
287 parameter_server,
288 workers: Vec::new(),
289 parameters: Arc::new(RwLock::new(ModelParameters {
290 weights: vec![0.0; num_features],
291 bias: 0.0,
292 metadata: ParameterMetadata {
293 iterations_completed: 0,
294 current_loss: f64::INFINITY,
295 converged: false,
296 last_updated: std::time::SystemTime::now(),
297 },
298 })),
299 }
300 }
301
302 pub fn partition_data(&mut self, features: Vec<Vec<f64>>, targets: Vec<f64>) -> Result<()> {
304 let n_samples = features.len();
305 let samples_per_worker =
306 (n_samples + self.config.num_workers - 1) / self.config.num_workers;
307
308 for worker_idx in 0..self.config.num_workers {
309 let start_idx = worker_idx * samples_per_worker;
310 let end_idx = ((worker_idx + 1) * samples_per_worker).min(n_samples);
311
312 if start_idx < n_samples {
313 let partition = DataPartition {
314 features: features[start_idx..end_idx].to_vec(),
315 targets: targets[start_idx..end_idx].to_vec(),
316 partition_id: worker_idx,
317 };
318
319 let worker =
320 WorkerNode::new(NodeId::new(format!("worker_{}", worker_idx)), partition);
321
322 self.workers.push(worker);
323 }
324 }
325
326 Ok(())
327 }
328
329 pub fn fit(&mut self) -> Result<()> {
331 for iteration in 0..self.config.max_iterations {
332 let params = {
334 let ps = self.parameter_server.read().unwrap();
335 ps.get_parameters()
336 };
337
338 let mut all_gradients = Vec::new();
340 for worker in &mut self.workers {
341 let gradient = worker.compute_local_gradient(¶ms)?;
342 all_gradients.push(gradient);
343 }
344
345 {
347 let mut ps = self.parameter_server.write().unwrap();
348 for gradient in all_gradients {
349 ps.receive_gradient(gradient)?;
350 }
351 }
352
353 if iteration % 10 == 0 {
355 let loss = self.compute_global_loss(¶ms)?;
356 let mut model_params = self.parameters.write().unwrap();
357
358 if (model_params.metadata.current_loss - loss).abs() < self.config.tolerance {
359 model_params.metadata.converged = true;
360 model_params.metadata.iterations_completed = iteration + 1;
361 break;
362 }
363
364 model_params.metadata.current_loss = loss;
365 model_params.metadata.iterations_completed = iteration + 1;
366 model_params.metadata.last_updated = std::time::SystemTime::now();
367 }
368 }
369
370 let final_params = {
372 let ps = self.parameter_server.read().unwrap();
373 ps.get_parameters()
374 };
375
376 let mut model_params = self.parameters.write().unwrap();
377 model_params.weights = final_params;
378
379 Ok(())
380 }
381
382 fn compute_global_loss(&self, parameters: &[f64]) -> Result<f64> {
384 let mut total_loss = 0.0;
385 let mut total_samples = 0;
386
387 for worker in &self.workers {
388 for (features, target) in worker
389 .data_partition
390 .features
391 .iter()
392 .zip(worker.data_partition.targets.iter())
393 {
394 let prediction: f64 = features
395 .iter()
396 .zip(parameters.iter())
397 .map(|(x, w)| x * w)
398 .sum();
399 let error = prediction - target;
400 total_loss += error * error;
401 total_samples += 1;
402 }
403 }
404
405 Ok(total_loss / total_samples as f64)
406 }
407
408 pub fn get_training_stats(&self) -> DistributedTrainingStats {
410 let mut total_samples = 0;
411 let mut total_compute_time = 0;
412 let mut total_gradient_computations = 0;
413
414 for worker in &self.workers {
415 let stats = worker.get_stats();
416 total_samples += stats.samples_processed;
417 total_compute_time += stats.total_compute_time_ms;
418 total_gradient_computations += stats.gradient_computations;
419 }
420
421 DistributedTrainingStats {
422 num_workers: self.workers.len(),
423 total_samples_processed: total_samples,
424 total_compute_time_ms: total_compute_time,
425 total_gradient_computations,
426 parameter_server_version: self.parameter_server.read().unwrap().get_version(),
427 }
428 }
429
430 pub fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<f64>> {
432 let params = self.parameters.read().unwrap();
433 let mut predictions = Vec::new();
434
435 for feature_row in features {
436 let pred: f64 = feature_row
437 .iter()
438 .zip(params.weights.iter())
439 .map(|(x, w)| x * w)
440 .sum::<f64>()
441 + params.bias;
442
443 predictions.push(pred);
444 }
445
446 Ok(predictions)
447 }
448}
449
450#[derive(Debug, Clone, Serialize, Deserialize)]
452pub struct DistributedTrainingStats {
453 pub num_workers: usize,
455 pub total_samples_processed: usize,
457 pub total_compute_time_ms: u64,
459 pub total_gradient_computations: usize,
461 pub parameter_server_version: usize,
463}
464
465#[derive(Debug)]
474pub struct FederatedLearning {
475 pub config: FederatedConfig,
477 pub clients: Vec<FederatedClient>,
479 pub global_model: Arc<RwLock<ModelParameters>>,
481 pub privacy_mechanism: PrivacyMechanism,
483}
484
485#[derive(Debug, Clone, Serialize, Deserialize)]
487pub struct FederatedConfig {
488 pub num_clients: usize,
490 pub client_fraction: f64,
492 pub local_epochs: usize,
494 pub local_learning_rate: f64,
496 pub secure_aggregation: bool,
498 pub dp_epsilon: Option<f64>,
500 pub dp_delta: Option<f64>,
502}
503
504impl Default for FederatedConfig {
505 fn default() -> Self {
506 Self {
507 num_clients: 10,
508 client_fraction: 0.3,
509 local_epochs: 5,
510 local_learning_rate: 0.01,
511 secure_aggregation: true,
512 dp_epsilon: Some(1.0),
513 dp_delta: Some(1e-5),
514 }
515 }
516}
517
518#[derive(Debug, Clone)]
520pub struct FederatedClient {
521 pub id: String,
523 pub dataset_size: usize,
525 pub local_parameters: Vec<f64>,
527 pub stats: ClientStats,
529}
530
531#[derive(Debug, Clone, Default, Serialize, Deserialize)]
533pub struct ClientStats {
534 pub rounds_participated: usize,
536 pub total_samples: usize,
538 pub avg_local_loss: f64,
540}
541
542impl FederatedLearning {
543 pub fn new(config: FederatedConfig, num_features: usize) -> Self {
545 Self {
546 config,
547 clients: Vec::new(),
548 global_model: Arc::new(RwLock::new(ModelParameters {
549 weights: vec![0.0; num_features],
550 bias: 0.0,
551 metadata: ParameterMetadata {
552 iterations_completed: 0,
553 current_loss: f64::INFINITY,
554 converged: false,
555 last_updated: std::time::SystemTime::now(),
556 },
557 })),
558 privacy_mechanism: PrivacyMechanism::new(),
559 }
560 }
561
562 pub fn add_client(&mut self, client_id: String, dataset_size: usize) {
564 let client = FederatedClient {
565 id: client_id,
566 dataset_size,
567 local_parameters: Vec::new(),
568 stats: ClientStats::default(),
569 };
570 self.clients.push(client);
571 }
572
573 pub fn select_clients(&self) -> Vec<usize> {
575 use scirs2_core::random::thread_rng;
576
577 let num_selected =
578 (self.clients.len() as f64 * self.config.client_fraction).ceil() as usize;
579 let mut selected = Vec::new();
580 let mut rng = thread_rng();
581
582 let mut indices: Vec<usize> = (0..self.clients.len()).collect();
583
584 for i in (1..indices.len()).rev() {
586 let j = rng.gen_range(0..=i);
587 indices.swap(i, j);
588 }
589
590 selected.extend_from_slice(&indices[..num_selected]);
591 selected
592 }
593
594 pub fn federated_average(&self, client_updates: &[(usize, Vec<f64>)]) -> Vec<f64> {
596 if client_updates.is_empty() {
597 return vec![];
598 }
599
600 let num_features = client_updates[0].1.len();
601 let mut averaged = vec![0.0; num_features];
602 let mut total_weight = 0.0;
603
604 for (client_idx, update) in client_updates {
605 let weight = self.clients[*client_idx].dataset_size as f64;
606 total_weight += weight;
607
608 for (i, val) in update.iter().enumerate() {
609 averaged[i] += val * weight;
610 }
611 }
612
613 for val in averaged.iter_mut() {
615 *val /= total_weight;
616 }
617
618 if self.config.dp_epsilon.is_some() {
620 self.privacy_mechanism
621 .apply_noise(&mut averaged, self.config.dp_epsilon.unwrap());
622 }
623
624 averaged
625 }
626
627 pub fn get_global_model(&self) -> ModelParameters {
629 self.global_model.read().unwrap().clone()
630 }
631}
632
633#[derive(Debug, Clone)]
635pub struct PrivacyMechanism {
636 pub noise_scale: f64,
638}
639
640impl PrivacyMechanism {
641 pub fn new() -> Self {
643 Self { noise_scale: 1.0 }
644 }
645
646 pub fn apply_noise(&self, gradients: &mut [f64], epsilon: f64) {
648 use scirs2_core::random::essentials::Normal;
649 use scirs2_core::random::thread_rng;
650
651 let mut rng = thread_rng();
652 let noise_std = self.noise_scale / epsilon;
653 let normal = Normal::new(0.0, noise_std).unwrap();
654
655 for grad in gradients.iter_mut() {
656 *grad += rng.sample(normal);
657 }
658 }
659
660 pub fn clip_gradients(&self, gradients: &mut [f64], clip_norm: f64) {
662 let norm: f64 = gradients.iter().map(|g| g * g).sum::<f64>().sqrt();
663
664 if norm > clip_norm {
665 let scale = clip_norm / norm;
666 for grad in gradients.iter_mut() {
667 *grad *= scale;
668 }
669 }
670 }
671}
672
673impl Default for PrivacyMechanism {
674 fn default() -> Self {
675 Self::new()
676 }
677}
678
679#[derive(Debug)]
684pub struct ByzantineFaultTolerant {
685 pub config: BFTConfig,
687 pub aggregation_method: AggregationMethod,
689}
690
691#[derive(Debug, Clone, Serialize, Deserialize)]
693pub struct BFTConfig {
694 pub max_byzantine_fraction: f64,
696 pub detection_threshold: f64,
698 pub enable_reputation: bool,
700}
701
702impl Default for BFTConfig {
703 fn default() -> Self {
704 Self {
705 max_byzantine_fraction: 0.3,
706 detection_threshold: 2.0,
707 enable_reputation: true,
708 }
709 }
710}
711
712#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
714pub enum AggregationMethod {
715 Median,
717 TrimmedMean { trim_fraction: usize },
719 Krum,
721 Bulyan,
723}
724
725impl ByzantineFaultTolerant {
726 pub fn new(config: BFTConfig, method: AggregationMethod) -> Self {
728 Self {
729 config,
730 aggregation_method: method,
731 }
732 }
733
734 pub fn aggregate(&self, gradients: &[Vec<f64>]) -> Result<Vec<f64>> {
736 if gradients.is_empty() {
737 return Err(SklearsError::InvalidInput(
738 "Cannot aggregate empty gradient set".to_string(),
739 ));
740 }
741
742 match self.aggregation_method {
743 AggregationMethod::Median => self.coordinate_wise_median(gradients),
744 AggregationMethod::TrimmedMean { trim_fraction } => {
745 self.trimmed_mean(gradients, trim_fraction)
746 }
747 AggregationMethod::Krum => self.krum(gradients),
748 AggregationMethod::Bulyan => self.bulyan(gradients),
749 }
750 }
751
752 fn coordinate_wise_median(&self, gradients: &[Vec<f64>]) -> Result<Vec<f64>> {
754 let num_features = gradients[0].len();
755 let mut result = vec![0.0; num_features];
756
757 for i in 0..num_features {
758 let mut values: Vec<f64> = gradients.iter().map(|g| g[i]).collect();
759 values.sort_by(|a, b| a.partial_cmp(b).unwrap());
760 result[i] = values[values.len() / 2];
761 }
762
763 Ok(result)
764 }
765
766 fn trimmed_mean(&self, gradients: &[Vec<f64>], trim_fraction: usize) -> Result<Vec<f64>> {
768 let num_features = gradients[0].len();
769 let mut result = vec![0.0; num_features];
770 let trim_count = (gradients.len() * trim_fraction) / 100;
771
772 for i in 0..num_features {
773 let mut values: Vec<f64> = gradients.iter().map(|g| g[i]).collect();
774 values.sort_by(|a, b| a.partial_cmp(b).unwrap());
775
776 let trimmed = &values[trim_count..values.len() - trim_count];
778 result[i] = trimmed.iter().sum::<f64>() / trimmed.len() as f64;
779 }
780
781 Ok(result)
782 }
783
784 fn krum(&self, gradients: &[Vec<f64>]) -> Result<Vec<f64>> {
786 let n = gradients.len();
787 let f = (n as f64 * self.config.max_byzantine_fraction).floor() as usize;
788 let m = n - f - 2;
789
790 let mut scores = vec![0.0; n];
791
792 for i in 0..n {
794 let mut distances: Vec<(usize, f64)> = Vec::new();
795
796 for j in 0..n {
797 if i != j {
798 let dist = self.euclidean_distance(&gradients[i], &gradients[j]);
799 distances.push((j, dist));
800 }
801 }
802
803 distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
805 scores[i] = distances.iter().take(m).map(|(_, d)| d).sum();
806 }
807
808 let best_idx = scores
810 .iter()
811 .enumerate()
812 .min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
813 .map(|(idx, _)| idx)
814 .unwrap();
815
816 Ok(gradients[best_idx].clone())
817 }
818
819 fn bulyan(&self, gradients: &[Vec<f64>]) -> Result<Vec<f64>> {
821 let n = gradients.len();
823 let f = (n as f64 * self.config.max_byzantine_fraction).floor() as usize;
824 let theta = n - 2 * f;
825
826 if theta < 1 {
827 return Err(SklearsError::InvalidInput(
828 "Too many Byzantine workers for Bulyan".to_string(),
829 ));
830 }
831
832 self.coordinate_wise_median(gradients)
834 }
835
836 fn euclidean_distance(&self, a: &[f64], b: &[f64]) -> f64 {
838 a.iter()
839 .zip(b.iter())
840 .map(|(x, y)| (x - y).powi(2))
841 .sum::<f64>()
842 .sqrt()
843 }
844}
845
846#[derive(Debug)]
851pub struct LoadBalancer {
852 pub strategy: LoadBalancingStrategy,
854 pub worker_loads: HashMap<String, WorkerLoad>,
856}
857
858#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
860pub enum LoadBalancingStrategy {
861 RoundRobin,
863 LeastLoaded,
865 WeightedRandom,
867 PowerOfTwo,
869}
870
871#[derive(Debug, Clone, Serialize, Deserialize)]
873pub struct WorkerLoad {
874 pub active_tasks: usize,
876 pub capacity: usize,
878 pub avg_completion_time_ms: u64,
880 pub load_factor: f64,
882}
883
884impl LoadBalancer {
885 pub fn new(strategy: LoadBalancingStrategy) -> Self {
887 Self {
888 strategy,
889 worker_loads: HashMap::new(),
890 }
891 }
892
893 pub fn register_worker(&mut self, worker_id: String, capacity: usize) {
895 self.worker_loads.insert(
896 worker_id,
897 WorkerLoad {
898 active_tasks: 0,
899 capacity,
900 avg_completion_time_ms: 0,
901 load_factor: 0.0,
902 },
903 );
904 }
905
906 pub fn select_worker(&mut self) -> Option<String> {
908 match self.strategy {
909 LoadBalancingStrategy::RoundRobin => self.round_robin_select(),
910 LoadBalancingStrategy::LeastLoaded => self.least_loaded_select(),
911 LoadBalancingStrategy::WeightedRandom => self.weighted_random_select(),
912 LoadBalancingStrategy::PowerOfTwo => self.power_of_two_select(),
913 }
914 }
915
916 fn round_robin_select(&self) -> Option<String> {
918 self.worker_loads.keys().next().cloned()
919 }
920
921 fn least_loaded_select(&self) -> Option<String> {
923 self.worker_loads
924 .iter()
925 .min_by(|(_, a), (_, b)| a.load_factor.partial_cmp(&b.load_factor).unwrap())
926 .map(|(id, _)| id.clone())
927 }
928
929 fn weighted_random_select(&self) -> Option<String> {
931 use scirs2_core::random::thread_rng;
932
933 if self.worker_loads.is_empty() {
934 return None;
935 }
936
937 let mut rng = thread_rng();
938 let total_capacity: f64 = self
939 .worker_loads
940 .values()
941 .map(|load| (load.capacity - load.active_tasks) as f64)
942 .sum();
943
944 let mut rand_val = rng.gen_range(0.0..total_capacity);
945
946 for (id, load) in &self.worker_loads {
947 let available = (load.capacity - load.active_tasks) as f64;
948 if rand_val < available {
949 return Some(id.clone());
950 }
951 rand_val -= available;
952 }
953
954 self.worker_loads.keys().next().cloned()
955 }
956
957 fn power_of_two_select(&self) -> Option<String> {
959 use scirs2_core::random::thread_rng;
960
961 if self.worker_loads.is_empty() {
962 return None;
963 }
964
965 let mut rng = thread_rng();
966 let workers: Vec<_> = self.worker_loads.keys().collect();
967
968 if workers.len() == 1 {
969 return Some(workers[0].clone());
970 }
971
972 let idx1 = rng.gen_range(0..workers.len());
973 let mut idx2 = rng.gen_range(0..workers.len());
974 while idx2 == idx1 {
975 idx2 = rng.gen_range(0..workers.len());
976 }
977
978 let load1 = &self.worker_loads[workers[idx1]];
979 let load2 = &self.worker_loads[workers[idx2]];
980
981 if load1.load_factor < load2.load_factor {
982 Some(workers[idx1].clone())
983 } else {
984 Some(workers[idx2].clone())
985 }
986 }
987
988 pub fn update_load(&mut self, worker_id: &str, task_assigned: bool) {
990 if let Some(load) = self.worker_loads.get_mut(worker_id) {
991 if task_assigned {
992 load.active_tasks += 1;
993 } else if load.active_tasks > 0 {
994 load.active_tasks -= 1;
995 }
996 load.load_factor = load.active_tasks as f64 / load.capacity as f64;
997 }
998 }
999}
1000
1001#[cfg(test)]
1002mod tests {
1003 use super::*;
1004
1005 #[test]
1006 fn test_parameter_server_creation() {
1007 let ps = ParameterServer::new(5, 3);
1008 assert_eq!(ps.parameters.len(), 5);
1009 assert_eq!(ps.num_workers, 3);
1010 assert_eq!(ps.version, 0);
1011 }
1012
1013 #[test]
1014 fn test_gradient_accumulation() {
1015 let mut ps = ParameterServer::new(3, 2);
1016
1017 let grad1 = vec![1.0, 2.0, 3.0];
1018 let grad2 = vec![2.0, 3.0, 4.0];
1019
1020 ps.receive_gradient(grad1).unwrap();
1021 ps.receive_gradient(grad2).unwrap();
1022
1023 let params = ps.get_parameters();
1025 assert_eq!(ps.version, 1);
1026 assert!(params.iter().all(|&p| p != 0.0));
1027 }
1028
1029 #[test]
1030 fn test_worker_node_creation() {
1031 let partition = DataPartition {
1032 features: vec![vec![1.0, 2.0], vec![3.0, 4.0]],
1033 targets: vec![1.0, 2.0],
1034 partition_id: 0,
1035 };
1036
1037 let worker = WorkerNode::new(NodeId::new("worker_0"), partition);
1038 assert_eq!(worker.id.0, "worker_0");
1039 assert_eq!(worker.stats.samples_processed, 0);
1040 }
1041
1042 #[test]
1043 fn test_local_gradient_computation() {
1044 let partition = DataPartition {
1045 features: vec![vec![1.0, 2.0], vec![2.0, 3.0]],
1046 targets: vec![3.0, 5.0],
1047 partition_id: 0,
1048 };
1049
1050 let mut worker = WorkerNode::new(NodeId::new("worker_0"), partition);
1051 let params = vec![1.0, 1.0];
1052
1053 let gradient = worker.compute_local_gradient(¶ms).unwrap();
1054 assert_eq!(gradient.len(), 2);
1055 assert!(worker.stats.gradient_computations > 0);
1056 }
1057
1058 #[test]
1059 fn test_distributed_regression_creation() {
1060 let config = DistributedConfig::default();
1061 let model = DistributedLinearRegression::new(config, 5);
1062
1063 assert_eq!(model.workers.len(), 0);
1064 assert!(model.parameters.read().unwrap().weights.len() == 5);
1065 }
1066
1067 #[test]
1068 fn test_data_partitioning() {
1069 let config = DistributedConfig {
1070 num_workers: 2,
1071 ..Default::default()
1072 };
1073
1074 let mut model = DistributedLinearRegression::new(config, 2);
1075
1076 let features = vec![
1077 vec![1.0, 2.0],
1078 vec![3.0, 4.0],
1079 vec![5.0, 6.0],
1080 vec![7.0, 8.0],
1081 ];
1082 let targets = vec![3.0, 7.0, 11.0, 15.0];
1083
1084 model.partition_data(features, targets).unwrap();
1085
1086 assert_eq!(model.workers.len(), 2);
1087 assert!(model.workers[0].data_partition.features.len() > 0);
1088 }
1089
1090 #[test]
1091 fn test_distributed_training() {
1092 let config = DistributedConfig {
1093 num_workers: 2,
1094 max_iterations: 10,
1095 tolerance: 1e-3,
1096 learning_rate: 0.01,
1097 ..Default::default()
1098 };
1099
1100 let mut model = DistributedLinearRegression::new(config, 2);
1101
1102 let features = vec![
1104 vec![1.0, 1.0],
1105 vec![2.0, 2.0],
1106 vec![3.0, 3.0],
1107 vec![4.0, 4.0],
1108 ];
1109 let targets = vec![5.0, 10.0, 15.0, 20.0];
1110
1111 model.partition_data(features, targets).unwrap();
1112 model.fit().unwrap();
1113
1114 let stats = model.get_training_stats();
1115 assert!(stats.total_samples_processed > 0);
1116 assert!(stats.parameter_server_version > 0);
1117 }
1118
1119 #[test]
1120 fn test_prediction() {
1121 let config = DistributedConfig::default();
1122 let model = DistributedLinearRegression::new(config, 2);
1123
1124 let test_features = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
1125 let predictions = model.predict(&test_features).unwrap();
1126
1127 assert_eq!(predictions.len(), 2);
1128 }
1129
1130 #[test]
1131 fn test_training_stats() {
1132 let config = DistributedConfig {
1133 num_workers: 3,
1134 ..Default::default()
1135 };
1136
1137 let model = DistributedLinearRegression::new(config, 2);
1138 let stats = model.get_training_stats();
1139
1140 assert_eq!(stats.num_workers, 0); }
1142
1143 #[test]
1148 fn test_federated_learning_creation() {
1149 let config = FederatedConfig::default();
1150 let fed_learning = FederatedLearning::new(config, 5);
1151
1152 assert_eq!(fed_learning.clients.len(), 0);
1153 assert_eq!(fed_learning.config.num_clients, 10);
1154 }
1155
1156 #[test]
1157 fn test_federated_add_client() {
1158 let config = FederatedConfig::default();
1159 let mut fed_learning = FederatedLearning::new(config, 5);
1160
1161 fed_learning.add_client("client_1".to_string(), 100);
1162 fed_learning.add_client("client_2".to_string(), 150);
1163
1164 assert_eq!(fed_learning.clients.len(), 2);
1165 assert_eq!(fed_learning.clients[0].dataset_size, 100);
1166 assert_eq!(fed_learning.clients[1].dataset_size, 150);
1167 }
1168
1169 #[test]
1170 fn test_federated_client_selection() {
1171 let mut config = FederatedConfig::default();
1172 config.client_fraction = 0.5;
1173 let mut fed_learning = FederatedLearning::new(config, 5);
1174
1175 for i in 0..10 {
1176 fed_learning.add_client(format!("client_{}", i), 100);
1177 }
1178
1179 let selected = fed_learning.select_clients();
1180 assert!(selected.len() >= 4 && selected.len() <= 6); }
1182
1183 #[test]
1184 fn test_federated_averaging() {
1185 let config = FederatedConfig {
1186 dp_epsilon: None, ..Default::default()
1188 };
1189 let mut fed_learning = FederatedLearning::new(config, 3);
1190
1191 fed_learning.add_client("client_1".to_string(), 100);
1192 fed_learning.add_client("client_2".to_string(), 100);
1193
1194 let updates = vec![(0, vec![1.0, 2.0, 3.0]), (1, vec![2.0, 4.0, 6.0])];
1195
1196 let averaged = fed_learning.federated_average(&updates);
1197 assert_eq!(averaged.len(), 3);
1198 assert!((averaged[0] - 1.5).abs() < 1e-6);
1200 assert!((averaged[1] - 3.0).abs() < 1e-6);
1201 assert!((averaged[2] - 4.5).abs() < 1e-6);
1202 }
1203
1204 #[test]
1205 fn test_privacy_mechanism_noise() {
1206 let privacy = PrivacyMechanism::new();
1207 let mut gradients = vec![1.0, 2.0, 3.0];
1208 let original = gradients.clone();
1209
1210 privacy.apply_noise(&mut gradients, 1.0);
1211
1212 assert_ne!(gradients, original);
1214 }
1215
1216 #[test]
1217 fn test_privacy_mechanism_clipping() {
1218 let privacy = PrivacyMechanism::new();
1219 let mut gradients = vec![3.0, 4.0]; privacy.clip_gradients(&mut gradients, 1.0);
1222
1223 let norm: f64 = gradients.iter().map(|g| g * g).sum::<f64>().sqrt();
1224 assert!((norm - 1.0).abs() < 1e-6);
1225 }
1226
1227 #[test]
1228 fn test_byzantine_fault_tolerant_creation() {
1229 let config = BFTConfig::default();
1230 let bft = ByzantineFaultTolerant::new(config, AggregationMethod::Median);
1231
1232 assert_eq!(bft.aggregation_method, AggregationMethod::Median);
1233 }
1234
1235 #[test]
1236 fn test_byzantine_median_aggregation() {
1237 let config = BFTConfig::default();
1238 let bft = ByzantineFaultTolerant::new(config, AggregationMethod::Median);
1239
1240 let gradients = vec![
1241 vec![1.0, 2.0, 3.0],
1242 vec![2.0, 3.0, 4.0],
1243 vec![3.0, 4.0, 5.0],
1244 vec![100.0, 100.0, 100.0], ];
1246
1247 let result = bft.aggregate(&gradients).unwrap();
1248 assert_eq!(result.len(), 3);
1249 assert!(result[0] < 50.0);
1251 assert!(result[1] < 50.0);
1252 assert!(result[2] < 50.0);
1253 }
1254
1255 #[test]
1256 fn test_byzantine_trimmed_mean() {
1257 let config = BFTConfig::default();
1258 let bft = ByzantineFaultTolerant::new(
1259 config,
1260 AggregationMethod::TrimmedMean { trim_fraction: 25 },
1261 );
1262
1263 let gradients = vec![
1264 vec![1.0, 2.0, 3.0],
1265 vec![2.0, 3.0, 4.0],
1266 vec![3.0, 4.0, 5.0],
1267 vec![4.0, 5.0, 6.0],
1268 ];
1269
1270 let result = bft.aggregate(&gradients).unwrap();
1271 assert_eq!(result.len(), 3);
1272 assert!(result[0] > 1.0 && result[0] < 4.0);
1274 }
1275
1276 #[test]
1277 fn test_byzantine_krum() {
1278 let config = BFTConfig::default();
1279 let bft = ByzantineFaultTolerant::new(config, AggregationMethod::Krum);
1280
1281 let gradients = vec![
1282 vec![1.0, 2.0, 3.0],
1283 vec![1.1, 2.1, 3.1],
1284 vec![1.2, 2.2, 3.2],
1285 vec![100.0, 100.0, 100.0], ];
1287
1288 let result = bft.aggregate(&gradients).unwrap();
1289 assert_eq!(result.len(), 3);
1290 assert!(result[0] < 10.0);
1292 }
1293
1294 #[test]
1295 fn test_byzantine_empty_gradients() {
1296 let config = BFTConfig::default();
1297 let bft = ByzantineFaultTolerant::new(config, AggregationMethod::Median);
1298
1299 let gradients: Vec<Vec<f64>> = vec![];
1300 let result = bft.aggregate(&gradients);
1301
1302 assert!(result.is_err());
1303 }
1304
1305 #[test]
1306 fn test_load_balancer_creation() {
1307 let lb = LoadBalancer::new(LoadBalancingStrategy::RoundRobin);
1308 assert_eq!(lb.strategy, LoadBalancingStrategy::RoundRobin);
1309 assert_eq!(lb.worker_loads.len(), 0);
1310 }
1311
1312 #[test]
1313 fn test_load_balancer_register_worker() {
1314 let mut lb = LoadBalancer::new(LoadBalancingStrategy::LeastLoaded);
1315
1316 lb.register_worker("worker_1".to_string(), 10);
1317 lb.register_worker("worker_2".to_string(), 20);
1318
1319 assert_eq!(lb.worker_loads.len(), 2);
1320 assert_eq!(lb.worker_loads.get("worker_1").unwrap().capacity, 10);
1321 assert_eq!(lb.worker_loads.get("worker_2").unwrap().capacity, 20);
1322 }
1323
1324 #[test]
1325 fn test_load_balancer_least_loaded() {
1326 let mut lb = LoadBalancer::new(LoadBalancingStrategy::LeastLoaded);
1327
1328 lb.register_worker("worker_1".to_string(), 10);
1329 lb.register_worker("worker_2".to_string(), 10);
1330
1331 let selected = lb.select_worker();
1333 assert!(selected.is_some());
1334
1335 lb.update_load("worker_1", true);
1337 lb.update_load("worker_1", true);
1338
1339 let selected = lb.select_worker();
1341 assert!(selected.is_some());
1342 }
1343
1344 #[test]
1345 fn test_load_balancer_update_load() {
1346 let mut lb = LoadBalancer::new(LoadBalancingStrategy::LeastLoaded);
1347
1348 lb.register_worker("worker_1".to_string(), 10);
1349
1350 lb.update_load("worker_1", true);
1351 assert_eq!(lb.worker_loads.get("worker_1").unwrap().active_tasks, 1);
1352 assert!((lb.worker_loads.get("worker_1").unwrap().load_factor - 0.1).abs() < 1e-6);
1353
1354 lb.update_load("worker_1", false);
1355 assert_eq!(lb.worker_loads.get("worker_1").unwrap().active_tasks, 0);
1356 assert!((lb.worker_loads.get("worker_1").unwrap().load_factor).abs() < 1e-6);
1357 }
1358
1359 #[test]
1360 fn test_load_balancer_power_of_two() {
1361 let mut lb = LoadBalancer::new(LoadBalancingStrategy::PowerOfTwo);
1362
1363 lb.register_worker("worker_1".to_string(), 10);
1364 lb.register_worker("worker_2".to_string(), 10);
1365 lb.register_worker("worker_3".to_string(), 10);
1366
1367 let selected = lb.select_worker();
1368 assert!(selected.is_some());
1369 }
1370
1371 #[test]
1372 fn test_federated_config_default() {
1373 let config = FederatedConfig::default();
1374 assert_eq!(config.num_clients, 10);
1375 assert!((config.client_fraction - 0.3).abs() < 1e-6);
1376 assert_eq!(config.local_epochs, 5);
1377 assert!(config.secure_aggregation);
1378 }
1379
1380 #[test]
1381 fn test_bft_config_default() {
1382 let config = BFTConfig::default();
1383 assert!((config.max_byzantine_fraction - 0.3).abs() < 1e-6);
1384 assert!((config.detection_threshold - 2.0).abs() < 1e-6);
1385 assert!(config.enable_reputation);
1386 }
1387
1388 #[test]
1389 fn test_aggregation_method_equality() {
1390 assert_eq!(AggregationMethod::Median, AggregationMethod::Median);
1391 assert_eq!(
1392 AggregationMethod::TrimmedMean { trim_fraction: 25 },
1393 AggregationMethod::TrimmedMean { trim_fraction: 25 }
1394 );
1395 assert_ne!(AggregationMethod::Median, AggregationMethod::Krum);
1396 }
1397
1398 #[test]
1399 fn test_load_balancing_strategy_equality() {
1400 assert_eq!(
1401 LoadBalancingStrategy::RoundRobin,
1402 LoadBalancingStrategy::RoundRobin
1403 );
1404 assert_eq!(
1405 LoadBalancingStrategy::LeastLoaded,
1406 LoadBalancingStrategy::LeastLoaded
1407 );
1408 assert_ne!(
1409 LoadBalancingStrategy::RoundRobin,
1410 LoadBalancingStrategy::LeastLoaded
1411 );
1412 }
1413}