1use core::fmt;
43
44#[derive(Debug, Clone, PartialEq, Eq, Hash)]
46pub struct ClientId(String);
47
48impl ClientId {
49 pub fn new(id: impl Into<String>) -> Self {
51 Self(id.into())
52 }
53
54 pub fn as_str(&self) -> &str {
56 &self.0
57 }
58}
59
60impl fmt::Display for ClientId {
61 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
62 write!(f, "{}", self.0)
63 }
64}
65
66#[derive(Debug, Clone)]
71pub struct FederatedClient {
72 id: ClientId,
74 num_samples: usize,
76 availability: f64,
78 compute_capacity: f64,
80 bandwidth_mbps: f64,
82 data_distribution: DataDistribution,
84 privacy_budget: Option<f64>,
86}
87
88impl FederatedClient {
89 pub fn new(id: impl Into<String>, num_samples: usize, availability: f64) -> Self {
91 Self {
92 id: ClientId::new(id),
93 num_samples,
94 availability: availability.max(0.0).min(1.0),
95 compute_capacity: 1.0,
96 bandwidth_mbps: 10.0,
97 data_distribution: DataDistribution::Unknown,
98 privacy_budget: None,
99 }
100 }
101
102 pub fn with_compute_capacity(mut self, capacity: f64) -> Self {
104 self.compute_capacity = capacity;
105 self
106 }
107
108 pub fn with_bandwidth(mut self, bandwidth_mbps: f64) -> Self {
110 self.bandwidth_mbps = bandwidth_mbps;
111 self
112 }
113
114 pub fn with_data_distribution(mut self, distribution: DataDistribution) -> Self {
116 self.data_distribution = distribution;
117 self
118 }
119
120 pub fn with_privacy_budget(mut self, epsilon: f64) -> Self {
122 self.privacy_budget = Some(epsilon);
123 self
124 }
125
126 pub fn id(&self) -> &ClientId {
128 &self.id
129 }
130
131 pub fn num_samples(&self) -> usize {
133 self.num_samples
134 }
135
136 pub fn availability(&self) -> f64 {
138 self.availability
139 }
140
141 pub fn compute_capacity(&self) -> f64 {
143 self.compute_capacity
144 }
145
146 pub fn bandwidth_mbps(&self) -> f64 {
148 self.bandwidth_mbps
149 }
150
151 pub fn data_distribution(&self) -> &DataDistribution {
153 &self.data_distribution
154 }
155
156 pub fn privacy_budget(&self) -> Option<f64> {
158 self.privacy_budget
159 }
160
161 pub fn weight(&self) -> f64 {
163 self.num_samples as f64
164 }
165}
166
167#[derive(Debug, Clone, Copy, PartialEq)]
169pub enum DataDistribution {
170 IID,
172 LabelSkew { skew_factor: f64 },
174 FeatureSkew { skew_factor: f64 },
176 QuantitySkew,
178 Unknown,
180}
181
182#[derive(Debug, Clone, Copy, PartialEq, Eq)]
184pub enum AggregationStrategy {
185 FedAvg,
187 FedProx,
189 FedAdaptive,
191 SecureAggregation,
193 WeightedBySize,
195 WeightedByPerformance,
197}
198
199impl fmt::Display for AggregationStrategy {
200 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
201 match self {
202 AggregationStrategy::FedAvg => write!(f, "FedAvg"),
203 AggregationStrategy::FedProx => write!(f, "FedProx"),
204 AggregationStrategy::FedAdaptive => write!(f, "FedAdaptive"),
205 AggregationStrategy::SecureAggregation => write!(f, "SecureAggregation"),
206 AggregationStrategy::WeightedBySize => write!(f, "WeightedBySize"),
207 AggregationStrategy::WeightedByPerformance => write!(f, "WeightedByPerformance"),
208 }
209 }
210}
211
212#[derive(Debug, Clone)]
214pub struct ClientUpdate {
215 client_id: ClientId,
217 round: u64,
219 num_steps: usize,
221 loss: f64,
223 accuracy: Option<f64>,
225 metadata: Vec<(String, String)>,
227}
228
229impl ClientUpdate {
230 pub fn new(client_id: ClientId, round: u64, num_steps: usize, loss: f64) -> Self {
232 Self {
233 client_id,
234 round,
235 num_steps,
236 loss,
237 accuracy: None,
238 metadata: Vec::new(),
239 }
240 }
241
242 pub fn with_accuracy(mut self, accuracy: f64) -> Self {
244 self.accuracy = Some(accuracy);
245 self
246 }
247
248 pub fn add_metadata(&mut self, key: impl Into<String>, value: impl Into<String>) {
250 self.metadata.push((key.into(), value.into()));
251 }
252
253 pub fn client_id(&self) -> &ClientId {
255 &self.client_id
256 }
257
258 pub fn round(&self) -> u64 {
260 self.round
261 }
262
263 pub fn num_steps(&self) -> usize {
265 self.num_steps
266 }
267
268 pub fn loss(&self) -> f64 {
270 self.loss
271 }
272
273 pub fn accuracy(&self) -> Option<f64> {
275 self.accuracy
276 }
277
278 pub fn metadata(&self) -> &[(String, String)] {
280 &self.metadata
281 }
282}
283
284#[derive(Debug, Clone, Copy, PartialEq, Eq)]
286pub enum ClientSelectionStrategy {
287 Random,
289 ByAvailability,
291 ByDataSize,
293 ByComputeCapacity,
295 PowerOfChoice { choices: usize },
297 All,
299}
300
301#[derive(Debug, Clone)]
303pub struct ClientSelector {
304 strategy: ClientSelectionStrategy,
305}
306
307impl ClientSelector {
308 pub fn new(strategy: ClientSelectionStrategy) -> Self {
310 Self { strategy }
311 }
312
313 pub fn select(&self, clients: &[FederatedClient], num_select: usize) -> Vec<ClientId> {
315 match self.strategy {
316 ClientSelectionStrategy::Random => {
317 clients
319 .iter()
320 .take(num_select.min(clients.len()))
321 .map(|c| c.id().clone())
322 .collect()
323 }
324 ClientSelectionStrategy::ByAvailability => {
325 let mut sorted: Vec<_> = clients.iter().collect();
326 sorted.sort_by(|a, b| {
327 b.availability()
328 .partial_cmp(&a.availability())
329 .unwrap_or(core::cmp::Ordering::Equal)
330 });
331 sorted
332 .iter()
333 .take(num_select.min(clients.len()))
334 .map(|c| c.id().clone())
335 .collect()
336 }
337 ClientSelectionStrategy::ByDataSize => {
338 let mut sorted: Vec<_> = clients.iter().collect();
339 sorted.sort_by_key(|c| core::cmp::Reverse(c.num_samples()));
340 sorted
341 .iter()
342 .take(num_select.min(clients.len()))
343 .map(|c| c.id().clone())
344 .collect()
345 }
346 ClientSelectionStrategy::ByComputeCapacity => {
347 let mut sorted: Vec<_> = clients.iter().collect();
348 sorted.sort_by(|a, b| {
349 b.compute_capacity()
350 .partial_cmp(&a.compute_capacity())
351 .unwrap_or(core::cmp::Ordering::Equal)
352 });
353 sorted
354 .iter()
355 .take(num_select.min(clients.len()))
356 .map(|c| c.id().clone())
357 .collect()
358 }
359 ClientSelectionStrategy::PowerOfChoice { choices: _ } => {
360 clients
362 .iter()
363 .take(num_select.min(clients.len()))
364 .map(|c| c.id().clone())
365 .collect()
366 }
367 ClientSelectionStrategy::All => clients.iter().map(|c| c.id().clone()).collect(),
368 }
369 }
370
371 pub fn strategy(&self) -> ClientSelectionStrategy {
373 self.strategy
374 }
375}
376
377#[derive(Debug, Clone, Copy)]
379pub struct PrivacyParameters {
380 epsilon: f64,
382 delta: f64,
384 clip_norm: f64,
386 noise_multiplier: f64,
388}
389
390impl PrivacyParameters {
391 pub fn new(epsilon: f64, delta: f64) -> Self {
393 Self {
394 epsilon,
395 delta,
396 clip_norm: 1.0,
397 noise_multiplier: 1.0,
398 }
399 }
400
401 pub fn with_clip_norm(mut self, clip_norm: f64) -> Self {
403 self.clip_norm = clip_norm;
404 self
405 }
406
407 pub fn with_noise_multiplier(mut self, noise_multiplier: f64) -> Self {
409 self.noise_multiplier = noise_multiplier;
410 self
411 }
412
413 pub fn epsilon(&self) -> f64 {
415 self.epsilon
416 }
417
418 pub fn delta(&self) -> f64 {
420 self.delta
421 }
422
423 pub fn clip_norm(&self) -> f64 {
425 self.clip_norm
426 }
427
428 pub fn noise_multiplier(&self) -> f64 {
430 self.noise_multiplier
431 }
432
433 pub fn is_exhausted(&self) -> bool {
435 self.epsilon <= 0.0
436 }
437}
438
439#[derive(Debug, Clone, Copy, PartialEq, Eq)]
441pub enum CompressionTechnique {
442 None,
444 Quantization { bits: u8 },
446 Sparsification { k: usize },
448 Sketching,
450 LowRank { rank: usize },
452}
453
454#[derive(Debug, Clone)]
456pub struct TrainingRound {
457 round: u64,
459 num_clients: usize,
461 num_completed: usize,
463 avg_loss: f64,
465 avg_accuracy: Option<f64>,
467 communication_cost: usize,
469 duration_secs: f64,
471}
472
473impl TrainingRound {
474 pub fn new(round: u64, num_clients: usize) -> Self {
476 Self {
477 round,
478 num_clients,
479 num_completed: 0,
480 avg_loss: 0.0,
481 avg_accuracy: None,
482 communication_cost: 0,
483 duration_secs: 0.0,
484 }
485 }
486
487 pub fn set_completed(&mut self, num_completed: usize) {
489 self.num_completed = num_completed;
490 }
491
492 pub fn set_avg_loss(&mut self, avg_loss: f64) {
494 self.avg_loss = avg_loss;
495 }
496
497 pub fn set_avg_accuracy(&mut self, avg_accuracy: f64) {
499 self.avg_accuracy = Some(avg_accuracy);
500 }
501
502 pub fn set_communication_cost(&mut self, cost: usize) {
504 self.communication_cost = cost;
505 }
506
507 pub fn set_duration(&mut self, duration_secs: f64) {
509 self.duration_secs = duration_secs;
510 }
511
512 pub fn round(&self) -> u64 {
514 self.round
515 }
516
517 pub fn num_clients(&self) -> usize {
519 self.num_clients
520 }
521
522 pub fn num_completed(&self) -> usize {
524 self.num_completed
525 }
526
527 pub fn avg_loss(&self) -> f64 {
529 self.avg_loss
530 }
531
532 pub fn avg_accuracy(&self) -> Option<f64> {
534 self.avg_accuracy
535 }
536
537 pub fn communication_cost(&self) -> usize {
539 self.communication_cost
540 }
541
542 pub fn duration_secs(&self) -> f64 {
544 self.duration_secs
545 }
546
547 pub fn completion_rate(&self) -> f64 {
549 if self.num_clients == 0 {
550 0.0
551 } else {
552 self.num_completed as f64 / self.num_clients as f64
553 }
554 }
555}
556
557#[derive(Debug, Clone)]
559pub struct FairnessMetrics {
560 accuracy_variance: f64,
562 min_accuracy: f64,
564 max_accuracy: f64,
566 jains_index: f64,
568}
569
570impl FairnessMetrics {
571 pub fn new(
573 accuracy_variance: f64,
574 min_accuracy: f64,
575 max_accuracy: f64,
576 jains_index: f64,
577 ) -> Self {
578 Self {
579 accuracy_variance,
580 min_accuracy,
581 max_accuracy,
582 jains_index,
583 }
584 }
585
586 pub fn accuracy_variance(&self) -> f64 {
588 self.accuracy_variance
589 }
590
591 pub fn min_accuracy(&self) -> f64 {
593 self.min_accuracy
594 }
595
596 pub fn max_accuracy(&self) -> f64 {
598 self.max_accuracy
599 }
600
601 pub fn jains_index(&self) -> f64 {
603 self.jains_index
604 }
605
606 pub fn is_fair(&self) -> bool {
608 self.jains_index > 0.8
609 }
610}
611
612#[derive(Debug, Clone)]
614pub struct FederatedCoordinator {
615 current_round: u64,
617 strategy: AggregationStrategy,
619 privacy: Option<PrivacyParameters>,
621 compression: CompressionTechnique,
623 rounds: Vec<TrainingRound>,
625}
626
627impl FederatedCoordinator {
628 pub fn new(strategy: AggregationStrategy) -> Self {
630 Self {
631 current_round: 0,
632 strategy,
633 privacy: None,
634 compression: CompressionTechnique::None,
635 rounds: Vec::new(),
636 }
637 }
638
639 pub fn with_privacy(mut self, privacy: PrivacyParameters) -> Self {
641 self.privacy = Some(privacy);
642 self
643 }
644
645 pub fn with_compression(mut self, compression: CompressionTechnique) -> Self {
647 self.compression = compression;
648 self
649 }
650
651 pub fn start_round(&mut self, num_clients: usize) -> u64 {
653 self.current_round += 1;
654 self.rounds
655 .push(TrainingRound::new(self.current_round, num_clients));
656 self.current_round
657 }
658
659 pub fn complete_round(&mut self, avg_loss: f64, num_completed: usize) {
661 if let Some(round) = self.rounds.last_mut() {
662 round.set_avg_loss(avg_loss);
663 round.set_completed(num_completed);
664 }
665 }
666
667 pub fn current_round(&self) -> u64 {
669 self.current_round
670 }
671
672 pub fn strategy(&self) -> AggregationStrategy {
674 self.strategy
675 }
676
677 pub fn privacy(&self) -> Option<&PrivacyParameters> {
679 self.privacy.as_ref()
680 }
681
682 pub fn compression(&self) -> CompressionTechnique {
684 self.compression
685 }
686
687 pub fn rounds(&self) -> &[TrainingRound] {
689 &self.rounds
690 }
691
692 pub fn statistics(&self) -> CoordinatorStatistics {
694 let total_rounds = self.rounds.len();
695 let avg_completion_rate = if total_rounds > 0 {
696 self.rounds.iter().map(|r| r.completion_rate()).sum::<f64>() / total_rounds as f64
697 } else {
698 0.0
699 };
700 let total_communication = self.rounds.iter().map(|r| r.communication_cost()).sum();
701
702 CoordinatorStatistics {
703 total_rounds,
704 avg_completion_rate,
705 total_communication,
706 }
707 }
708}
709
710#[derive(Debug, Clone)]
712pub struct CoordinatorStatistics {
713 pub total_rounds: usize,
715 pub avg_completion_rate: f64,
717 pub total_communication: usize,
719}
720
721#[cfg(test)]
722mod tests {
723 use super::*;
724
725 #[test]
726 fn test_client_id() {
727 let id = ClientId::new("client_1");
728 assert_eq!(id.as_str(), "client_1");
729 assert_eq!(format!("{}", id), "client_1");
730 }
731
732 #[test]
733 fn test_federated_client_creation() {
734 let client = FederatedClient::new("client_1", 1000, 0.8);
735 assert_eq!(client.id().as_str(), "client_1");
736 assert_eq!(client.num_samples(), 1000);
737 assert_eq!(client.availability(), 0.8);
738 assert_eq!(client.weight(), 1000.0);
739 }
740
741 #[test]
742 fn test_client_with_builder() {
743 let client = FederatedClient::new("client_1", 1000, 0.8)
744 .with_compute_capacity(2.0)
745 .with_bandwidth(50.0)
746 .with_privacy_budget(1.0);
747
748 assert_eq!(client.compute_capacity(), 2.0);
749 assert_eq!(client.bandwidth_mbps(), 50.0);
750 assert_eq!(client.privacy_budget(), Some(1.0));
751 }
752
753 #[test]
754 fn test_data_distribution() {
755 let iid = DataDistribution::IID;
756 let label_skew = DataDistribution::LabelSkew { skew_factor: 0.5 };
757 let _feature_skew = DataDistribution::FeatureSkew { skew_factor: 0.3 };
758
759 assert_eq!(iid, DataDistribution::IID);
760 assert_ne!(iid, label_skew);
761 }
762
763 #[test]
764 fn test_aggregation_strategy_display() {
765 assert_eq!(format!("{}", AggregationStrategy::FedAvg), "FedAvg");
766 assert_eq!(format!("{}", AggregationStrategy::FedProx), "FedProx");
767 }
768
769 #[test]
770 fn test_client_update() {
771 let id = ClientId::new("client_1");
772 let mut update = ClientUpdate::new(id.clone(), 5, 100, 0.5).with_accuracy(0.85);
773
774 update.add_metadata("dataset", "mnist");
775 assert_eq!(update.client_id(), &id);
776 assert_eq!(update.round(), 5);
777 assert_eq!(update.num_steps(), 100);
778 assert_eq!(update.loss(), 0.5);
779 assert_eq!(update.accuracy(), Some(0.85));
780 assert_eq!(update.metadata().len(), 1);
781 }
782
783 #[test]
784 fn test_client_selector_random() {
785 let clients = vec![
786 FederatedClient::new("client_1", 1000, 0.8),
787 FederatedClient::new("client_2", 500, 0.6),
788 FederatedClient::new("client_3", 800, 0.9),
789 ];
790
791 let selector = ClientSelector::new(ClientSelectionStrategy::Random);
792 let selected = selector.select(&clients, 2);
793 assert_eq!(selected.len(), 2);
794 }
795
796 #[test]
797 fn test_client_selector_by_data_size() {
798 let clients = vec![
799 FederatedClient::new("client_1", 1000, 0.8),
800 FederatedClient::new("client_2", 500, 0.6),
801 FederatedClient::new("client_3", 800, 0.9),
802 ];
803
804 let selector = ClientSelector::new(ClientSelectionStrategy::ByDataSize);
805 let selected = selector.select(&clients, 2);
806 assert_eq!(selected.len(), 2);
807 assert_eq!(selected[0].as_str(), "client_1"); }
809
810 #[test]
811 fn test_client_selector_all() {
812 let clients = vec![
813 FederatedClient::new("client_1", 1000, 0.8),
814 FederatedClient::new("client_2", 500, 0.6),
815 ];
816
817 let selector = ClientSelector::new(ClientSelectionStrategy::All);
818 let selected = selector.select(&clients, 10);
819 assert_eq!(selected.len(), 2); }
821
822 #[test]
823 fn test_privacy_parameters() {
824 let privacy = PrivacyParameters::new(1.0, 1e-5)
825 .with_clip_norm(2.0)
826 .with_noise_multiplier(0.5);
827
828 assert_eq!(privacy.epsilon(), 1.0);
829 assert_eq!(privacy.delta(), 1e-5);
830 assert_eq!(privacy.clip_norm(), 2.0);
831 assert_eq!(privacy.noise_multiplier(), 0.5);
832 assert!(!privacy.is_exhausted());
833 }
834
835 #[test]
836 fn test_compression_techniques() {
837 let _none = CompressionTechnique::None;
838 let _quant = CompressionTechnique::Quantization { bits: 8 };
839 let _sparse = CompressionTechnique::Sparsification { k: 100 };
840 let _sketch = CompressionTechnique::Sketching;
841 let _low_rank = CompressionTechnique::LowRank { rank: 10 };
842 }
843
844 #[test]
845 fn test_training_round() {
846 let mut round = TrainingRound::new(1, 10);
847 round.set_completed(8);
848 round.set_avg_loss(0.5);
849 round.set_avg_accuracy(0.85);
850 round.set_communication_cost(1024);
851 round.set_duration(60.0);
852
853 assert_eq!(round.round(), 1);
854 assert_eq!(round.num_clients(), 10);
855 assert_eq!(round.num_completed(), 8);
856 assert_eq!(round.avg_loss(), 0.5);
857 assert_eq!(round.avg_accuracy(), Some(0.85));
858 assert_eq!(round.communication_cost(), 1024);
859 assert_eq!(round.duration_secs(), 60.0);
860 assert_eq!(round.completion_rate(), 0.8);
861 }
862
863 #[test]
864 fn test_fairness_metrics() {
865 let metrics = FairnessMetrics::new(0.01, 0.80, 0.90, 0.85);
866 assert_eq!(metrics.accuracy_variance(), 0.01);
867 assert_eq!(metrics.min_accuracy(), 0.80);
868 assert_eq!(metrics.max_accuracy(), 0.90);
869 assert_eq!(metrics.jains_index(), 0.85);
870 assert!(metrics.is_fair());
871 }
872
873 #[test]
874 fn test_federated_coordinator() {
875 let mut coordinator = FederatedCoordinator::new(AggregationStrategy::FedAvg)
876 .with_privacy(PrivacyParameters::new(1.0, 1e-5))
877 .with_compression(CompressionTechnique::Quantization { bits: 8 });
878
879 assert_eq!(coordinator.current_round(), 0);
880
881 let round1 = coordinator.start_round(10);
882 assert_eq!(round1, 1);
883
884 coordinator.complete_round(0.5, 8);
885
886 let stats = coordinator.statistics();
887 assert_eq!(stats.total_rounds, 1);
888 }
889
890 #[test]
891 fn test_coordinator_multiple_rounds() {
892 let mut coordinator = FederatedCoordinator::new(AggregationStrategy::FedAvg);
893
894 coordinator.start_round(10);
895 coordinator.complete_round(0.6, 9);
896
897 coordinator.start_round(10);
898 coordinator.complete_round(0.4, 8);
899
900 coordinator.start_round(10);
901 coordinator.complete_round(0.3, 10);
902
903 assert_eq!(coordinator.current_round(), 3);
904 assert_eq!(coordinator.rounds().len(), 3);
905
906 let stats = coordinator.statistics();
907 assert_eq!(stats.total_rounds, 3);
908 assert!(stats.avg_completion_rate > 0.8);
909 }
910
911 #[test]
912 fn test_client_selection_strategies() {
913 let _random = ClientSelectionStrategy::Random;
914 let _avail = ClientSelectionStrategy::ByAvailability;
915 let _size = ClientSelectionStrategy::ByDataSize;
916 let _compute = ClientSelectionStrategy::ByComputeCapacity;
917 let _power = ClientSelectionStrategy::PowerOfChoice { choices: 3 };
918 let _all = ClientSelectionStrategy::All;
919 }
920
921 #[test]
922 fn test_availability_clamping() {
923 let client1 = FederatedClient::new("c1", 100, 1.5); let client2 = FederatedClient::new("c2", 100, -0.1); assert_eq!(client1.availability(), 1.0);
927 assert_eq!(client2.availability(), 0.0);
928 }
929}