1use crate::error::{NeuralError, Result};
43use scirs2_core::ndarray::{Array, ArrayD, IxDyn, ScalarOperand};
44use scirs2_core::numeric::{Float, FromPrimitive, NumAssign, ToPrimitive};
45use scirs2_core::random::rngs::SmallRng;
46use scirs2_core::random::{Rng, RngExt, SeedableRng};
47use std::fmt::{self, Debug, Display};
48
49#[derive(Debug, Clone, Copy, PartialEq, Eq)]
55pub enum ClientSelectionStrategy {
56 Random,
58 ImportanceBased,
60 All,
62}
63
64impl Display for ClientSelectionStrategy {
65 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66 match self {
67 Self::Random => write!(f, "Random"),
68 Self::ImportanceBased => write!(f, "ImportanceBased"),
69 Self::All => write!(f, "All"),
70 }
71 }
72}
73
74#[derive(Debug, Clone, Copy, PartialEq, Eq)]
76pub enum AggregationMethod {
77 FedAvg,
79 SimpleMean,
81 Median,
83}
84
85impl Display for AggregationMethod {
86 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
87 match self {
88 Self::FedAvg => write!(f, "FedAvg"),
89 Self::SimpleMean => write!(f, "SimpleMean"),
90 Self::Median => write!(f, "Median"),
91 }
92 }
93}
94
95#[derive(Debug, Clone)]
101pub struct DifferentialPrivacyConfig {
102 pub enabled: bool,
104 pub noise_multiplier: f64,
106 pub max_grad_norm: f64,
108 pub delta: f64,
110}
111
112impl Default for DifferentialPrivacyConfig {
113 fn default() -> Self {
114 Self {
115 enabled: false,
116 noise_multiplier: 1.0,
117 max_grad_norm: 1.0,
118 delta: 1e-5,
119 }
120 }
121}
122
123#[derive(Debug, Clone)]
125pub struct GradientCompressionConfig {
126 pub enabled: bool,
128 pub top_k_fraction: f64,
130}
131
132impl Default for GradientCompressionConfig {
133 fn default() -> Self {
134 Self {
135 enabled: false,
136 top_k_fraction: 0.1,
137 }
138 }
139}
140
141#[derive(Debug, Clone)]
147pub struct FederatedConfig {
148 pub num_rounds: usize,
150 pub clients_per_round: usize,
152 pub client_selection: ClientSelectionStrategy,
154 pub aggregation: AggregationMethod,
156 pub dp_config: DifferentialPrivacyConfig,
158 pub compression: GradientCompressionConfig,
160 pub local_epochs: usize,
162 pub local_lr: f64,
164 pub seed: Option<u64>,
166}
167
168impl Default for FederatedConfig {
169 fn default() -> Self {
170 Self {
171 num_rounds: 100,
172 clients_per_round: 10,
173 client_selection: ClientSelectionStrategy::Random,
174 aggregation: AggregationMethod::FedAvg,
175 dp_config: DifferentialPrivacyConfig::default(),
176 compression: GradientCompressionConfig::default(),
177 local_epochs: 1,
178 local_lr: 0.01,
179 seed: None,
180 }
181 }
182}
183
184impl FederatedConfig {
185 pub fn builder() -> FederatedConfigBuilder {
187 FederatedConfigBuilder::default()
188 }
189
190 pub fn validate(&self) -> Result<()> {
192 if self.num_rounds == 0 {
193 return Err(NeuralError::InvalidArgument(
194 "num_rounds must be > 0".into(),
195 ));
196 }
197 if self.clients_per_round == 0 {
198 return Err(NeuralError::InvalidArgument(
199 "clients_per_round must be > 0".into(),
200 ));
201 }
202 if self.local_epochs == 0 {
203 return Err(NeuralError::InvalidArgument(
204 "local_epochs must be > 0".into(),
205 ));
206 }
207 if self.local_lr <= 0.0 {
208 return Err(NeuralError::InvalidArgument(
209 "local_lr must be positive".into(),
210 ));
211 }
212 if self.dp_config.enabled && self.dp_config.noise_multiplier <= 0.0 {
213 return Err(NeuralError::InvalidArgument(
214 "noise_multiplier must be positive when DP is enabled".into(),
215 ));
216 }
217 if self.dp_config.enabled && self.dp_config.max_grad_norm <= 0.0 {
218 return Err(NeuralError::InvalidArgument(
219 "max_grad_norm must be positive when DP is enabled".into(),
220 ));
221 }
222 if self.compression.enabled && !(0.0..=1.0).contains(&self.compression.top_k_fraction) {
223 return Err(NeuralError::InvalidArgument(
224 "top_k_fraction must be in [0.0, 1.0]".into(),
225 ));
226 }
227 Ok(())
228 }
229}
230
231#[derive(Debug, Clone, Default)]
237pub struct FederatedConfigBuilder {
238 config: FederatedConfig,
239}
240
241impl FederatedConfigBuilder {
242 pub fn num_rounds(mut self, n: usize) -> Self {
244 self.config.num_rounds = n;
245 self
246 }
247
248 pub fn clients_per_round(mut self, n: usize) -> Self {
250 self.config.clients_per_round = n;
251 self
252 }
253
254 pub fn client_selection(mut self, s: ClientSelectionStrategy) -> Self {
256 self.config.client_selection = s;
257 self
258 }
259
260 pub fn aggregation(mut self, a: AggregationMethod) -> Self {
262 self.config.aggregation = a;
263 self
264 }
265
266 pub fn differential_privacy(mut self, noise_multiplier: f64, max_grad_norm: f64) -> Self {
268 self.config.dp_config.enabled = true;
269 self.config.dp_config.noise_multiplier = noise_multiplier;
270 self.config.dp_config.max_grad_norm = max_grad_norm;
271 self
272 }
273
274 pub fn dp_delta(mut self, delta: f64) -> Self {
276 self.config.dp_config.delta = delta;
277 self
278 }
279
280 pub fn gradient_compression(mut self, top_k_fraction: f64) -> Self {
282 self.config.compression.enabled = true;
283 self.config.compression.top_k_fraction = top_k_fraction;
284 self
285 }
286
287 pub fn local_epochs(mut self, n: usize) -> Self {
289 self.config.local_epochs = n;
290 self
291 }
292
293 pub fn local_lr(mut self, lr: f64) -> Self {
295 self.config.local_lr = lr;
296 self
297 }
298
299 pub fn seed(mut self, s: u64) -> Self {
301 self.config.seed = Some(s);
302 self
303 }
304
305 pub fn build(self) -> Result<FederatedConfig> {
307 self.config.validate()?;
308 Ok(self.config)
309 }
310}
311
312#[derive(Debug, Clone)]
318pub struct ClientUpdate {
319 pub client_id: usize,
321 pub parameters: Vec<ArrayD<f64>>,
323 pub num_samples: usize,
325 pub local_loss: Option<f64>,
327 pub metrics: std::collections::HashMap<String, f64>,
329}
330
331impl ClientUpdate {
332 pub fn new(client_id: usize, parameters: Vec<ArrayD<f64>>, num_samples: usize) -> Self {
334 Self {
335 client_id,
336 parameters,
337 num_samples,
338 local_loss: None,
339 metrics: std::collections::HashMap::new(),
340 }
341 }
342
343 pub fn with_loss(mut self, loss: f64) -> Self {
345 self.local_loss = Some(loss);
346 self
347 }
348
349 pub fn with_metric(mut self, name: &str, value: f64) -> Self {
351 self.metrics.insert(name.to_string(), value);
352 self
353 }
354}
355
356#[derive(Debug, Clone)]
362pub struct RoundStats {
363 pub round: usize,
365 pub num_clients: usize,
367 pub total_samples: usize,
369 pub avg_loss: Option<f64>,
371 pub client_ids: Vec<usize>,
373}
374
375#[derive(Debug, Clone)]
384pub struct FederatedServer {
385 config: FederatedConfig,
387 global_params: Vec<ArrayD<f64>>,
389 current_round: usize,
391 round_history: Vec<RoundStats>,
393 rng: SmallRng,
395}
396
397impl FederatedServer {
398 pub fn new(config: FederatedConfig, global_params: Vec<ArrayD<f64>>) -> Self {
400 let rng = match config.seed {
401 Some(s) => SmallRng::seed_from_u64(s),
402 None => SmallRng::seed_from_u64(42),
403 };
404 Self {
405 config,
406 global_params,
407 current_round: 0,
408 round_history: Vec::new(),
409 rng,
410 }
411 }
412
413 pub fn global_params(&self) -> &[ArrayD<f64>] {
415 &self.global_params
416 }
417
418 pub fn current_round(&self) -> usize {
420 self.current_round
421 }
422
423 pub fn round_history(&self) -> &[RoundStats] {
425 &self.round_history
426 }
427
428 pub fn is_complete(&self) -> bool {
430 self.current_round >= self.config.num_rounds
431 }
432
433 pub fn select_clients(&mut self, available_clients: &[(usize, usize)]) -> Vec<usize> {
440 if available_clients.is_empty() {
441 return Vec::new();
442 }
443
444 let k = self.config.clients_per_round.min(available_clients.len());
445
446 match self.config.client_selection {
447 ClientSelectionStrategy::All => available_clients.iter().map(|&(id, _)| id).collect(),
448 ClientSelectionStrategy::Random => {
449 let mut indices: Vec<usize> = (0..available_clients.len()).collect();
451 for i in 0..k {
452 let j = i + self.rng.random_range(0..indices.len() - i);
453 indices.swap(i, j);
454 }
455 indices[..k]
456 .iter()
457 .map(|&i| available_clients[i].0)
458 .collect()
459 }
460 ClientSelectionStrategy::ImportanceBased => {
461 let total: usize = available_clients.iter().map(|&(_, n)| n).sum();
463 if total == 0 {
464 return available_clients
465 .iter()
466 .take(k)
467 .map(|&(id, _)| id)
468 .collect();
469 }
470
471 let mut selected = Vec::with_capacity(k);
472 let mut used = vec![false; available_clients.len()];
473
474 for _ in 0..k {
475 let threshold = self.rng.random_range(0..total);
476 let mut cumulative = 0usize;
477 for (idx, &(client_id, n)) in available_clients.iter().enumerate() {
478 if used[idx] {
479 continue;
480 }
481 cumulative += n;
482 if cumulative > threshold {
483 selected.push(client_id);
484 used[idx] = true;
485 break;
486 }
487 }
488 if selected.len() < selected.capacity()
490 && selected.len() < k
491 && selected.len() == selected.len()
492 {
493 }
495 }
496
497 if selected.len() < k {
499 for (idx, &(client_id, _)) in available_clients.iter().enumerate() {
500 if selected.len() >= k {
501 break;
502 }
503 if !used[idx] {
504 selected.push(client_id);
505 used[idx] = true;
506 }
507 }
508 }
509
510 selected
511 }
512 }
513 }
514
515 pub fn aggregate_round(&mut self, updates: &[ClientUpdate]) -> Result<()> {
517 if updates.is_empty() {
518 return Err(NeuralError::InvalidArgument(
519 "No client updates to aggregate".into(),
520 ));
521 }
522
523 for update in updates {
525 if update.parameters.len() != self.global_params.len() {
526 return Err(NeuralError::ShapeMismatch(format!(
527 "Client {} has {} parameter tensors, expected {}",
528 update.client_id,
529 update.parameters.len(),
530 self.global_params.len()
531 )));
532 }
533 for (i, param) in update.parameters.iter().enumerate() {
534 if param.shape() != self.global_params[i].shape() {
535 return Err(NeuralError::ShapeMismatch(format!(
536 "Client {} param[{}] shape {:?} != global {:?}",
537 update.client_id,
538 i,
539 param.shape(),
540 self.global_params[i].shape()
541 )));
542 }
543 }
544 }
545
546 let processed_updates = if self.config.compression.enabled {
548 updates
549 .iter()
550 .map(|u| {
551 let compressed = compress_gradients(
552 &u.parameters,
553 &self.global_params,
554 self.config.compression.top_k_fraction,
555 );
556 ClientUpdate {
557 client_id: u.client_id,
558 parameters: apply_compressed_delta(&self.global_params, &compressed),
559 num_samples: u.num_samples,
560 local_loss: u.local_loss,
561 metrics: u.metrics.clone(),
562 }
563 })
564 .collect::<Vec<_>>()
565 } else {
566 updates.to_vec()
567 };
568
569 match self.config.aggregation {
571 AggregationMethod::FedAvg => self.fedavg_aggregate(&processed_updates),
572 AggregationMethod::SimpleMean => self.simple_mean_aggregate(&processed_updates),
573 AggregationMethod::Median => self.median_aggregate(&processed_updates),
574 }?;
575
576 if self.config.dp_config.enabled {
578 self.apply_dp_noise()?;
579 }
580
581 let total_samples: usize = updates.iter().map(|u| u.num_samples).sum();
583 let avg_loss = {
584 let losses: Vec<f64> = updates.iter().filter_map(|u| u.local_loss).collect();
585 if losses.is_empty() {
586 None
587 } else {
588 Some(losses.iter().sum::<f64>() / losses.len() as f64)
589 }
590 };
591
592 self.round_history.push(RoundStats {
593 round: self.current_round,
594 num_clients: updates.len(),
595 total_samples,
596 avg_loss,
597 client_ids: updates.iter().map(|u| u.client_id).collect(),
598 });
599
600 self.current_round += 1;
601 Ok(())
602 }
603
604 fn fedavg_aggregate(&mut self, updates: &[ClientUpdate]) -> Result<()> {
606 let total_samples: f64 = updates.iter().map(|u| u.num_samples as f64).sum();
607 if total_samples < f64::EPSILON {
608 return Err(NeuralError::ComputationError(
609 "Total samples is zero".into(),
610 ));
611 }
612
613 for p_idx in 0..self.global_params.len() {
614 let mut aggregated = ArrayD::<f64>::zeros(self.global_params[p_idx].raw_dim());
615 for update in updates {
616 let weight = update.num_samples as f64 / total_samples;
617 aggregated = aggregated + &update.parameters[p_idx] * weight;
618 }
619 self.global_params[p_idx] = aggregated;
620 }
621 Ok(())
622 }
623
624 fn simple_mean_aggregate(&mut self, updates: &[ClientUpdate]) -> Result<()> {
626 let n = updates.len() as f64;
627 for p_idx in 0..self.global_params.len() {
628 let mut aggregated = ArrayD::<f64>::zeros(self.global_params[p_idx].raw_dim());
629 for update in updates {
630 aggregated += &update.parameters[p_idx];
631 }
632 self.global_params[p_idx] = aggregated / n;
633 }
634 Ok(())
635 }
636
637 fn median_aggregate(&mut self, updates: &[ClientUpdate]) -> Result<()> {
639 for p_idx in 0..self.global_params.len() {
640 let shape = self.global_params[p_idx].raw_dim();
641 let flat_len = self.global_params[p_idx].len();
642 let mut result = ArrayD::<f64>::zeros(shape);
643
644 for elem_idx in 0..flat_len {
645 let mut values: Vec<f64> = updates
646 .iter()
647 .map(|u| {
648 u.parameters[p_idx]
649 .as_slice()
650 .map(|s| s[elem_idx])
651 .unwrap_or(0.0)
652 })
653 .collect();
654 values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
655
656 let median = if values.len().is_multiple_of(2) && values.len() >= 2 {
657 (values[values.len() / 2 - 1] + values[values.len() / 2]) / 2.0
658 } else {
659 values[values.len() / 2]
660 };
661
662 if let Some(slice) = result.as_slice_mut() {
663 slice[elem_idx] = median;
664 }
665 }
666 self.global_params[p_idx] = result;
667 }
668 Ok(())
669 }
670
671 fn apply_dp_noise(&mut self) -> Result<()> {
673 let sigma = self.config.dp_config.noise_multiplier * self.config.dp_config.max_grad_norm;
674
675 for param in &mut self.global_params {
676 let noise = generate_gaussian_noise(param.len(), 0.0, sigma, &mut self.rng);
677 let noise_arr = ArrayD::from_shape_vec(param.raw_dim(), noise).map_err(|e| {
678 NeuralError::ComputationError(format!("Failed to create noise array: {e}"))
679 })?;
680 *param = &*param + &noise_arr;
681 }
682 Ok(())
683 }
684
685 pub fn summary(&self) -> String {
687 let mut out = String::new();
688 out.push_str("=== Federated Learning Summary ===\n");
689 out.push_str(&format!("Aggregation: {}\n", self.config.aggregation));
690 out.push_str(&format!("Selection: {}\n", self.config.client_selection));
691 out.push_str(&format!(
692 "Rounds: {} / {}\n",
693 self.current_round, self.config.num_rounds
694 ));
695 out.push_str(&format!("DP enabled: {}\n", self.config.dp_config.enabled));
696 out.push_str(&format!(
697 "Compression enabled: {}\n",
698 self.config.compression.enabled
699 ));
700
701 if let Some(last) = self.round_history.last() {
702 out.push_str(&format!(
703 "Last round: {} clients, {} samples",
704 last.num_clients, last.total_samples
705 ));
706 if let Some(loss) = last.avg_loss {
707 out.push_str(&format!(", avg_loss={loss:.6}"));
708 }
709 out.push('\n');
710 }
711 out
712 }
713}
714
715fn compress_gradients(
724 client_params: &[ArrayD<f64>],
725 global_params: &[ArrayD<f64>],
726 top_k_fraction: f64,
727) -> Vec<Vec<(usize, f64)>> {
728 let mut compressed = Vec::with_capacity(client_params.len());
729
730 for (cp, gp) in client_params.iter().zip(global_params.iter()) {
731 let delta = cp - gp;
732 let flat: Vec<f64> = delta
733 .as_slice()
734 .map(|s| s.to_vec())
735 .unwrap_or_else(|| delta.iter().copied().collect());
736
737 let k = ((flat.len() as f64 * top_k_fraction).ceil() as usize)
738 .max(1)
739 .min(flat.len());
740
741 let mut indexed: Vec<(usize, f64)> = flat.into_iter().enumerate().collect();
743 indexed.sort_by(|a, b| {
744 b.1.abs()
745 .partial_cmp(&a.1.abs())
746 .unwrap_or(std::cmp::Ordering::Equal)
747 });
748 indexed.truncate(k);
749 compressed.push(indexed);
750 }
751
752 compressed
753}
754
755fn apply_compressed_delta(
757 global_params: &[ArrayD<f64>],
758 compressed: &[Vec<(usize, f64)>],
759) -> Vec<ArrayD<f64>> {
760 let mut result = global_params.to_vec();
761 for (p_idx, deltas) in compressed.iter().enumerate() {
762 if let Some(slice) = result[p_idx].as_slice_mut() {
763 for &(idx, val) in deltas {
764 if idx < slice.len() {
765 slice[idx] += val;
766 }
767 }
768 }
769 }
770 result
771}
772
773pub fn clip_l2_norm(params: &mut [ArrayD<f64>], max_norm: f64) {
775 let norm_sq: f64 = params
776 .iter()
777 .map(|p| p.iter().map(|&x| x * x).sum::<f64>())
778 .sum();
779 let norm = norm_sq.sqrt();
780 if norm > max_norm && norm > f64::EPSILON {
781 let scale = max_norm / norm;
782 for p in params.iter_mut() {
783 p.mapv_inplace(|x| x * scale);
784 }
785 }
786}
787
788fn generate_gaussian_noise(len: usize, mean: f64, std_dev: f64, rng: &mut SmallRng) -> Vec<f64> {
790 let mut result = Vec::with_capacity(len);
792 let mut i = 0;
793 while i < len {
794 let u1: f64 = rng.random_range(f64::EPSILON..1.0);
795 let u2: f64 = rng.random_range(0.0..std::f64::consts::TAU);
796 let r = (-2.0 * u1.ln()).sqrt();
797 let z0 = r * u2.cos() * std_dev + mean;
798 let z1 = r * u2.sin() * std_dev + mean;
799 result.push(z0);
800 i += 1;
801 if i < len {
802 result.push(z1);
803 i += 1;
804 }
805 }
806 result
807}
808
809#[cfg(test)]
814mod tests {
815 use super::*;
816 use scirs2_core::ndarray::array;
817
818 #[test]
819 fn test_config_defaults() {
820 let config = FederatedConfig::default();
821 assert_eq!(config.num_rounds, 100);
822 assert_eq!(config.clients_per_round, 10);
823 assert_eq!(config.aggregation, AggregationMethod::FedAvg);
824 }
825
826 #[test]
827 fn test_config_builder() {
828 let config = FederatedConfig::builder()
829 .num_rounds(50)
830 .clients_per_round(5)
831 .aggregation(AggregationMethod::SimpleMean)
832 .local_epochs(3)
833 .local_lr(0.1)
834 .seed(123)
835 .build()
836 .expect("valid config");
837
838 assert_eq!(config.num_rounds, 50);
839 assert_eq!(config.clients_per_round, 5);
840 assert_eq!(config.local_epochs, 3);
841 }
842
843 #[test]
844 fn test_config_validation_errors() {
845 assert!(FederatedConfig::builder().num_rounds(0).build().is_err());
846 assert!(FederatedConfig::builder()
847 .clients_per_round(0)
848 .build()
849 .is_err());
850 assert!(FederatedConfig::builder().local_epochs(0).build().is_err());
851 assert!(FederatedConfig::builder().local_lr(-1.0).build().is_err());
852 assert!(FederatedConfig::builder()
853 .differential_privacy(0.0, 1.0)
854 .build()
855 .is_err());
856 assert!(FederatedConfig::builder()
857 .gradient_compression(-0.1)
858 .build()
859 .is_err());
860 }
861
862 #[test]
863 fn test_fedavg_aggregation() {
864 let config = FederatedConfig::builder()
865 .num_rounds(10)
866 .clients_per_round(2)
867 .aggregation(AggregationMethod::FedAvg)
868 .build()
869 .expect("valid");
870
871 let global = vec![array![0.0_f64, 0.0, 0.0].into_dyn()];
872 let mut server = FederatedServer::new(config, global);
873
874 let updates = vec![
878 ClientUpdate::new(0, vec![array![1.0, 2.0, 3.0].into_dyn()], 100),
879 ClientUpdate::new(1, vec![array![3.0, 2.0, 1.0].into_dyn()], 300),
880 ];
881
882 server.aggregate_round(&updates).expect("ok");
883
884 let result = &server.global_params()[0];
885 let slice = result.as_slice().expect("contiguous");
886 assert!((slice[0] - 2.5).abs() < 1e-10);
887 assert!((slice[1] - 2.0).abs() < 1e-10);
888 assert!((slice[2] - 1.5).abs() < 1e-10);
889 assert_eq!(server.current_round(), 1);
890 }
891
892 #[test]
893 fn test_simple_mean_aggregation() {
894 let config = FederatedConfig::builder()
895 .num_rounds(10)
896 .clients_per_round(3)
897 .aggregation(AggregationMethod::SimpleMean)
898 .build()
899 .expect("valid");
900
901 let global = vec![array![0.0_f64, 0.0].into_dyn()];
902 let mut server = FederatedServer::new(config, global);
903
904 let updates = vec![
905 ClientUpdate::new(0, vec![array![1.0, 4.0].into_dyn()], 10),
906 ClientUpdate::new(1, vec![array![2.0, 5.0].into_dyn()], 10),
907 ClientUpdate::new(2, vec![array![3.0, 6.0].into_dyn()], 10),
908 ];
909
910 server.aggregate_round(&updates).expect("ok");
911
912 let result = &server.global_params()[0];
913 let slice = result.as_slice().expect("contiguous");
914 assert!((slice[0] - 2.0).abs() < 1e-10);
915 assert!((slice[1] - 5.0).abs() < 1e-10);
916 }
917
918 #[test]
919 fn test_median_aggregation() {
920 let config = FederatedConfig::builder()
921 .num_rounds(10)
922 .clients_per_round(3)
923 .aggregation(AggregationMethod::Median)
924 .build()
925 .expect("valid");
926
927 let global = vec![array![0.0_f64, 0.0].into_dyn()];
928 let mut server = FederatedServer::new(config, global);
929
930 let updates = vec![
931 ClientUpdate::new(0, vec![array![1.0, 100.0].into_dyn()], 10),
932 ClientUpdate::new(1, vec![array![2.0, 5.0].into_dyn()], 10),
933 ClientUpdate::new(2, vec![array![3.0, 6.0].into_dyn()], 10),
934 ];
935
936 server.aggregate_round(&updates).expect("ok");
937
938 let result = &server.global_params()[0];
939 let slice = result.as_slice().expect("contiguous");
940 assert!((slice[0] - 2.0).abs() < 1e-10);
942 assert!((slice[1] - 6.0).abs() < 1e-10);
943 }
944
945 #[test]
946 fn test_empty_updates_error() {
947 let config = FederatedConfig::builder()
948 .num_rounds(10)
949 .clients_per_round(2)
950 .build()
951 .expect("valid");
952
953 let global = vec![array![1.0_f64, 2.0].into_dyn()];
954 let mut server = FederatedServer::new(config, global);
955
956 assert!(server.aggregate_round(&[]).is_err());
957 }
958
959 #[test]
960 fn test_shape_mismatch_error() {
961 let config = FederatedConfig::builder()
962 .num_rounds(10)
963 .clients_per_round(1)
964 .build()
965 .expect("valid");
966
967 let global = vec![array![1.0_f64, 2.0].into_dyn()];
968 let mut server = FederatedServer::new(config, global);
969
970 let updates = vec![ClientUpdate::new(
972 0,
973 vec![array![1.0, 2.0].into_dyn(), array![3.0].into_dyn()],
974 10,
975 )];
976 assert!(server.aggregate_round(&updates).is_err());
977 }
978
979 #[test]
980 fn test_client_selection_random() {
981 let config = FederatedConfig::builder()
982 .num_rounds(10)
983 .clients_per_round(3)
984 .client_selection(ClientSelectionStrategy::Random)
985 .seed(42)
986 .build()
987 .expect("valid");
988
989 let global = vec![array![0.0_f64].into_dyn()];
990 let mut server = FederatedServer::new(config, global);
991
992 let clients = vec![(0, 100), (1, 200), (2, 300), (3, 400), (4, 500)];
993 let selected = server.select_clients(&clients);
994
995 assert_eq!(selected.len(), 3);
996 for id in &selected {
998 assert!(*id <= 4);
999 }
1000 }
1001
1002 #[test]
1003 fn test_client_selection_all() {
1004 let config = FederatedConfig::builder()
1005 .num_rounds(10)
1006 .clients_per_round(2)
1007 .client_selection(ClientSelectionStrategy::All)
1008 .build()
1009 .expect("valid");
1010
1011 let global = vec![array![0.0_f64].into_dyn()];
1012 let mut server = FederatedServer::new(config, global);
1013
1014 let clients = vec![(0, 100), (1, 200), (2, 300)];
1015 let selected = server.select_clients(&clients);
1016
1017 assert_eq!(selected.len(), 3); }
1019
1020 #[test]
1021 fn test_client_selection_importance() {
1022 let config = FederatedConfig::builder()
1023 .num_rounds(10)
1024 .clients_per_round(2)
1025 .client_selection(ClientSelectionStrategy::ImportanceBased)
1026 .seed(42)
1027 .build()
1028 .expect("valid");
1029
1030 let global = vec![array![0.0_f64].into_dyn()];
1031 let mut server = FederatedServer::new(config, global);
1032
1033 let clients = vec![(0, 1), (1, 1000), (2, 1)];
1034 let selected = server.select_clients(&clients);
1035
1036 assert_eq!(selected.len(), 2);
1037 }
1038
1039 #[test]
1040 fn test_dp_noise_application() {
1041 let config = FederatedConfig::builder()
1042 .num_rounds(10)
1043 .clients_per_round(1)
1044 .differential_privacy(1.0, 1.0)
1045 .seed(42)
1046 .build()
1047 .expect("valid");
1048
1049 let global = vec![array![0.0_f64, 0.0, 0.0].into_dyn()];
1050 let mut server = FederatedServer::new(config, global);
1051
1052 let updates = vec![ClientUpdate::new(
1053 0,
1054 vec![array![1.0, 2.0, 3.0].into_dyn()],
1055 100,
1056 )];
1057
1058 server.aggregate_round(&updates).expect("ok");
1059
1060 let result = &server.global_params()[0];
1062 let slice = result.as_slice().expect("contiguous");
1063 let any_noisy = slice[0] != 1.0 || slice[1] != 2.0 || slice[2] != 3.0;
1064 assert!(any_noisy, "DP noise should perturb the result");
1065 }
1066
1067 #[test]
1068 fn test_gradient_compression() {
1069 let config = FederatedConfig::builder()
1070 .num_rounds(10)
1071 .clients_per_round(1)
1072 .gradient_compression(0.5)
1073 .build()
1074 .expect("valid");
1075
1076 let global = vec![array![1.0_f64, 2.0, 3.0, 4.0].into_dyn()];
1077 let mut server = FederatedServer::new(config, global);
1078
1079 let updates = vec![ClientUpdate::new(
1081 0,
1082 vec![array![10.0, 2.1, 3.1, 14.0].into_dyn()],
1083 100,
1084 )];
1085
1086 server.aggregate_round(&updates).expect("ok");
1087
1088 let result = &server.global_params()[0];
1091 let slice = result.as_slice().expect("contiguous");
1092 assert!((slice[0] - 10.0).abs() < 1e-10);
1094 assert!((slice[1] - 2.0).abs() < 1e-10); assert!((slice[2] - 3.0).abs() < 1e-10); assert!((slice[3] - 14.0).abs() < 1e-10);
1097 }
1098
1099 #[test]
1100 fn test_clip_l2_norm() {
1101 let mut params = vec![array![3.0_f64, 4.0].into_dyn()];
1102 clip_l2_norm(&mut params, 1.0);
1104 let slice = params[0].as_slice().expect("contiguous");
1105 let norm = (slice[0] * slice[0] + slice[1] * slice[1]).sqrt();
1106 assert!((norm - 1.0).abs() < 1e-10);
1107 }
1108
1109 #[test]
1110 fn test_clip_l2_norm_no_clip_needed() {
1111 let mut params = vec![array![0.3_f64, 0.4].into_dyn()];
1112 clip_l2_norm(&mut params, 1.0);
1114 let slice = params[0].as_slice().expect("contiguous");
1115 assert!((slice[0] - 0.3).abs() < 1e-10);
1116 assert!((slice[1] - 0.4).abs() < 1e-10);
1117 }
1118
1119 #[test]
1120 fn test_multiple_rounds() {
1121 let config = FederatedConfig::builder()
1122 .num_rounds(3)
1123 .clients_per_round(2)
1124 .aggregation(AggregationMethod::SimpleMean)
1125 .build()
1126 .expect("valid");
1127
1128 let global = vec![array![0.0_f64, 0.0].into_dyn()];
1129 let mut server = FederatedServer::new(config, global);
1130
1131 for round in 0..3 {
1132 let v = (round + 1) as f64;
1133 let updates = vec![
1134 ClientUpdate::new(0, vec![array![v, v * 2.0].into_dyn()], 10),
1135 ClientUpdate::new(1, vec![array![v * 3.0, v * 4.0].into_dyn()], 10),
1136 ];
1137 server.aggregate_round(&updates).expect("ok");
1138 }
1139
1140 assert_eq!(server.current_round(), 3);
1141 assert!(server.is_complete());
1142 assert_eq!(server.round_history().len(), 3);
1143 }
1144
1145 #[test]
1146 fn test_client_update_with_metrics() {
1147 let update = ClientUpdate::new(0, vec![array![1.0_f64].into_dyn()], 100)
1148 .with_loss(0.5)
1149 .with_metric("accuracy", 0.95);
1150
1151 assert_eq!(update.local_loss, Some(0.5));
1152 assert!((update.metrics["accuracy"] - 0.95).abs() < 1e-10);
1153 }
1154
1155 #[test]
1156 fn test_round_stats_avg_loss() {
1157 let config = FederatedConfig::builder()
1158 .num_rounds(10)
1159 .clients_per_round(2)
1160 .build()
1161 .expect("valid");
1162
1163 let global = vec![array![0.0_f64].into_dyn()];
1164 let mut server = FederatedServer::new(config, global);
1165
1166 let updates = vec![
1167 ClientUpdate::new(0, vec![array![1.0].into_dyn()], 10).with_loss(0.3),
1168 ClientUpdate::new(1, vec![array![2.0].into_dyn()], 10).with_loss(0.7),
1169 ];
1170
1171 server.aggregate_round(&updates).expect("ok");
1172
1173 let stats = &server.round_history()[0];
1174 assert_eq!(stats.num_clients, 2);
1175 assert_eq!(stats.total_samples, 20);
1176 assert!((stats.avg_loss.expect("has loss") - 0.5).abs() < 1e-10);
1177 }
1178
1179 #[test]
1180 fn test_summary_generation() {
1181 let config = FederatedConfig::builder()
1182 .num_rounds(10)
1183 .clients_per_round(2)
1184 .build()
1185 .expect("valid");
1186
1187 let global = vec![array![0.0_f64].into_dyn()];
1188 let mut server = FederatedServer::new(config, global);
1189
1190 let updates = vec![ClientUpdate::new(0, vec![array![1.0].into_dyn()], 10)];
1191 server.aggregate_round(&updates).expect("ok");
1192
1193 let summary = server.summary();
1194 assert!(summary.contains("Federated Learning Summary"));
1195 assert!(summary.contains("FedAvg"));
1196 }
1197
1198 #[test]
1199 fn test_display_types() {
1200 assert_eq!(format!("{}", ClientSelectionStrategy::Random), "Random");
1201 assert_eq!(
1202 format!("{}", ClientSelectionStrategy::ImportanceBased),
1203 "ImportanceBased"
1204 );
1205 assert_eq!(format!("{}", ClientSelectionStrategy::All), "All");
1206 assert_eq!(format!("{}", AggregationMethod::FedAvg), "FedAvg");
1207 assert_eq!(format!("{}", AggregationMethod::SimpleMean), "SimpleMean");
1208 assert_eq!(format!("{}", AggregationMethod::Median), "Median");
1209 }
1210
1211 #[test]
1212 fn test_gaussian_noise_generation() {
1213 let mut rng = SmallRng::seed_from_u64(42);
1214 let noise = generate_gaussian_noise(1000, 0.0, 1.0, &mut rng);
1215 assert_eq!(noise.len(), 1000);
1216
1217 let mean = noise.iter().sum::<f64>() / noise.len() as f64;
1219 assert!(mean.abs() < 0.2, "mean={mean}, expected ~0");
1220
1221 let var = noise.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / noise.len() as f64;
1223 let std = var.sqrt();
1224 assert!((std - 1.0).abs() < 0.2, "std={std}, expected ~1");
1225 }
1226
1227 #[test]
1228 fn test_multi_param_tensors() {
1229 let config = FederatedConfig::builder()
1230 .num_rounds(10)
1231 .clients_per_round(2)
1232 .aggregation(AggregationMethod::SimpleMean)
1233 .build()
1234 .expect("valid");
1235
1236 let global = vec![
1237 array![1.0_f64, 2.0].into_dyn(),
1238 array![3.0_f64, 4.0, 5.0].into_dyn(),
1239 ];
1240 let mut server = FederatedServer::new(config, global);
1241
1242 let updates = vec![
1243 ClientUpdate::new(
1244 0,
1245 vec![
1246 array![2.0, 4.0].into_dyn(),
1247 array![6.0, 8.0, 10.0].into_dyn(),
1248 ],
1249 10,
1250 ),
1251 ClientUpdate::new(
1252 1,
1253 vec![
1254 array![4.0, 6.0].into_dyn(),
1255 array![9.0, 12.0, 15.0].into_dyn(),
1256 ],
1257 10,
1258 ),
1259 ];
1260
1261 server.aggregate_round(&updates).expect("ok");
1262
1263 let p0 = server.global_params()[0].as_slice().expect("contiguous");
1264 assert!((p0[0] - 3.0).abs() < 1e-10);
1265 assert!((p0[1] - 5.0).abs() < 1e-10);
1266
1267 let p1 = server.global_params()[1].as_slice().expect("contiguous");
1268 assert!((p1[0] - 7.5).abs() < 1e-10);
1269 assert!((p1[1] - 10.0).abs() < 1e-10);
1270 assert!((p1[2] - 12.5).abs() < 1e-10);
1271 }
1272}