Skip to main content

torsh_core/
federated.rs

1// Copyright (c) 2025 ToRSh Contributors
2//
3// Federated Learning Metadata Management
4//
5// This module provides data structures and abstractions for federated learning,
6// enabling privacy-preserving distributed training across multiple clients without
7// centralizing sensitive data.
8//
9// # Key Features
10//
11// - **Client Management**: Track and manage federated learning clients
12// - **Model Aggregation**: FedAvg, FedProx, and other aggregation strategies
13// - **Privacy Mechanisms**: Differential privacy, secure aggregation
14// - **Client Selection**: Smart client sampling strategies
15// - **Communication Efficiency**: Gradient compression, quantization
16//
17// # Design Principles
18//
19// 1. **Privacy First**: Built-in differential privacy support
20// 2. **Heterogeneity**: Handle non-IID data distributions
21// 3. **Efficiency**: Minimize communication overhead
22// 4. **Fairness**: Ensure equitable contribution from all clients
23//
24// # Examples
25//
26// ```rust
27// use torsh_core::federated::{FederatedClient, AggregationStrategy, ClientSelector};
28//
29// // Create federated learning clients
30// let client1 = FederatedClient::new("client_1", 1000, 0.8);
31// let client2 = FederatedClient::new("client_2", 500, 0.6);
32//
33// // Select clients for training round
34// let selector = ClientSelector::new(ClientSelectionStrategy::Random);
35// let selected = selector.select(&clients, 10);
36//
37// // Aggregate client updates
38// let aggregator = FedAvgAggregator::new();
39// let global_update = aggregator.aggregate(&client_updates);
40// ```
41
42use core::fmt;
43
44/// Unique identifier for a federated learning client
45#[derive(Debug, Clone, PartialEq, Eq, Hash)]
46pub struct ClientId(String);
47
48impl ClientId {
49    /// Create a new client ID
50    pub fn new(id: impl Into<String>) -> Self {
51        Self(id.into())
52    }
53
54    /// Get the client ID as a string
55    pub fn as_str(&self) -> &str {
56        &self.0
57    }
58}
59
60impl fmt::Display for ClientId {
61    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
62        write!(f, "{}", self.0)
63    }
64}
65
66/// Federated learning client metadata
67///
68/// Represents a client participating in federated learning with
69/// information about their capabilities and data characteristics.
70#[derive(Debug, Clone)]
71pub struct FederatedClient {
72    /// Unique client identifier
73    id: ClientId,
74    /// Number of training samples
75    num_samples: usize,
76    /// Client availability (0.0 to 1.0)
77    availability: f64,
78    /// Computational capacity (relative score)
79    compute_capacity: f64,
80    /// Network bandwidth (MB/s)
81    bandwidth_mbps: f64,
82    /// Data distribution characteristics
83    data_distribution: DataDistribution,
84    /// Privacy budget (epsilon for differential privacy)
85    privacy_budget: Option<f64>,
86}
87
88impl FederatedClient {
89    /// Create a new federated client
90    pub fn new(id: impl Into<String>, num_samples: usize, availability: f64) -> Self {
91        Self {
92            id: ClientId::new(id),
93            num_samples,
94            availability: availability.max(0.0).min(1.0),
95            compute_capacity: 1.0,
96            bandwidth_mbps: 10.0,
97            data_distribution: DataDistribution::Unknown,
98            privacy_budget: None,
99        }
100    }
101
102    /// Set computational capacity
103    pub fn with_compute_capacity(mut self, capacity: f64) -> Self {
104        self.compute_capacity = capacity;
105        self
106    }
107
108    /// Set network bandwidth
109    pub fn with_bandwidth(mut self, bandwidth_mbps: f64) -> Self {
110        self.bandwidth_mbps = bandwidth_mbps;
111        self
112    }
113
114    /// Set data distribution
115    pub fn with_data_distribution(mut self, distribution: DataDistribution) -> Self {
116        self.data_distribution = distribution;
117        self
118    }
119
120    /// Set privacy budget
121    pub fn with_privacy_budget(mut self, epsilon: f64) -> Self {
122        self.privacy_budget = Some(epsilon);
123        self
124    }
125
126    /// Get client ID
127    pub fn id(&self) -> &ClientId {
128        &self.id
129    }
130
131    /// Get number of samples
132    pub fn num_samples(&self) -> usize {
133        self.num_samples
134    }
135
136    /// Get availability
137    pub fn availability(&self) -> f64 {
138        self.availability
139    }
140
141    /// Get compute capacity
142    pub fn compute_capacity(&self) -> f64 {
143        self.compute_capacity
144    }
145
146    /// Get bandwidth
147    pub fn bandwidth_mbps(&self) -> f64 {
148        self.bandwidth_mbps
149    }
150
151    /// Get data distribution
152    pub fn data_distribution(&self) -> &DataDistribution {
153        &self.data_distribution
154    }
155
156    /// Get privacy budget
157    pub fn privacy_budget(&self) -> Option<f64> {
158        self.privacy_budget
159    }
160
161    /// Calculate client weight for aggregation
162    pub fn weight(&self) -> f64 {
163        self.num_samples as f64
164    }
165}
166
167/// Data distribution characteristics for non-IID data
168#[derive(Debug, Clone, Copy, PartialEq)]
169pub enum DataDistribution {
170    /// IID (Independent and Identically Distributed)
171    IID,
172    /// Non-IID with label skew
173    LabelSkew { skew_factor: f64 },
174    /// Non-IID with feature skew
175    FeatureSkew { skew_factor: f64 },
176    /// Non-IID with quantity skew
177    QuantitySkew,
178    /// Unknown distribution
179    Unknown,
180}
181
182/// Aggregation strategies for federated learning
183#[derive(Debug, Clone, Copy, PartialEq, Eq)]
184pub enum AggregationStrategy {
185    /// Federated Averaging (FedAvg)
186    FedAvg,
187    /// Federated Proximal (FedProx)
188    FedProx,
189    /// Federated Adaptive (FedAdam, FedYogi, etc.)
190    FedAdaptive,
191    /// Secure Aggregation with encryption
192    SecureAggregation,
193    /// Weighted aggregation by data size
194    WeightedBySize,
195    /// Weighted aggregation by client performance
196    WeightedByPerformance,
197}
198
199impl fmt::Display for AggregationStrategy {
200    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
201        match self {
202            AggregationStrategy::FedAvg => write!(f, "FedAvg"),
203            AggregationStrategy::FedProx => write!(f, "FedProx"),
204            AggregationStrategy::FedAdaptive => write!(f, "FedAdaptive"),
205            AggregationStrategy::SecureAggregation => write!(f, "SecureAggregation"),
206            AggregationStrategy::WeightedBySize => write!(f, "WeightedBySize"),
207            AggregationStrategy::WeightedByPerformance => write!(f, "WeightedByPerformance"),
208        }
209    }
210}
211
212/// Client update from a training round
213#[derive(Debug, Clone)]
214pub struct ClientUpdate {
215    /// Client that produced this update
216    client_id: ClientId,
217    /// Training round number
218    round: u64,
219    /// Number of local training steps
220    num_steps: usize,
221    /// Local training loss
222    loss: f64,
223    /// Local training accuracy
224    accuracy: Option<f64>,
225    /// Metadata for the update
226    metadata: Vec<(String, String)>,
227}
228
229impl ClientUpdate {
230    /// Create a new client update
231    pub fn new(client_id: ClientId, round: u64, num_steps: usize, loss: f64) -> Self {
232        Self {
233            client_id,
234            round,
235            num_steps,
236            loss,
237            accuracy: None,
238            metadata: Vec::new(),
239        }
240    }
241
242    /// Set accuracy
243    pub fn with_accuracy(mut self, accuracy: f64) -> Self {
244        self.accuracy = Some(accuracy);
245        self
246    }
247
248    /// Add metadata
249    pub fn add_metadata(&mut self, key: impl Into<String>, value: impl Into<String>) {
250        self.metadata.push((key.into(), value.into()));
251    }
252
253    /// Get client ID
254    pub fn client_id(&self) -> &ClientId {
255        &self.client_id
256    }
257
258    /// Get round number
259    pub fn round(&self) -> u64 {
260        self.round
261    }
262
263    /// Get number of steps
264    pub fn num_steps(&self) -> usize {
265        self.num_steps
266    }
267
268    /// Get loss
269    pub fn loss(&self) -> f64 {
270        self.loss
271    }
272
273    /// Get accuracy
274    pub fn accuracy(&self) -> Option<f64> {
275        self.accuracy
276    }
277
278    /// Get metadata
279    pub fn metadata(&self) -> &[(String, String)] {
280        &self.metadata
281    }
282}
283
284/// Client selection strategies
285#[derive(Debug, Clone, Copy, PartialEq, Eq)]
286pub enum ClientSelectionStrategy {
287    /// Random selection
288    Random,
289    /// Select based on availability
290    ByAvailability,
291    /// Select based on data size
292    ByDataSize,
293    /// Select based on compute capacity
294    ByComputeCapacity,
295    /// Power-of-choice (select best from random sample)
296    PowerOfChoice { choices: usize },
297    /// All clients participate
298    All,
299}
300
301/// Client selector for federated learning rounds
302#[derive(Debug, Clone)]
303pub struct ClientSelector {
304    strategy: ClientSelectionStrategy,
305}
306
307impl ClientSelector {
308    /// Create a new client selector
309    pub fn new(strategy: ClientSelectionStrategy) -> Self {
310        Self { strategy }
311    }
312
313    /// Select clients for a training round
314    pub fn select(&self, clients: &[FederatedClient], num_select: usize) -> Vec<ClientId> {
315        match self.strategy {
316            ClientSelectionStrategy::Random => {
317                // Simple deterministic selection (in practice, would use RNG)
318                clients
319                    .iter()
320                    .take(num_select.min(clients.len()))
321                    .map(|c| c.id().clone())
322                    .collect()
323            }
324            ClientSelectionStrategy::ByAvailability => {
325                let mut sorted: Vec<_> = clients.iter().collect();
326                sorted.sort_by(|a, b| {
327                    b.availability()
328                        .partial_cmp(&a.availability())
329                        .unwrap_or(core::cmp::Ordering::Equal)
330                });
331                sorted
332                    .iter()
333                    .take(num_select.min(clients.len()))
334                    .map(|c| c.id().clone())
335                    .collect()
336            }
337            ClientSelectionStrategy::ByDataSize => {
338                let mut sorted: Vec<_> = clients.iter().collect();
339                sorted.sort_by_key(|c| core::cmp::Reverse(c.num_samples()));
340                sorted
341                    .iter()
342                    .take(num_select.min(clients.len()))
343                    .map(|c| c.id().clone())
344                    .collect()
345            }
346            ClientSelectionStrategy::ByComputeCapacity => {
347                let mut sorted: Vec<_> = clients.iter().collect();
348                sorted.sort_by(|a, b| {
349                    b.compute_capacity()
350                        .partial_cmp(&a.compute_capacity())
351                        .unwrap_or(core::cmp::Ordering::Equal)
352                });
353                sorted
354                    .iter()
355                    .take(num_select.min(clients.len()))
356                    .map(|c| c.id().clone())
357                    .collect()
358            }
359            ClientSelectionStrategy::PowerOfChoice { choices: _ } => {
360                // Simplified: select by availability from first subset
361                clients
362                    .iter()
363                    .take(num_select.min(clients.len()))
364                    .map(|c| c.id().clone())
365                    .collect()
366            }
367            ClientSelectionStrategy::All => clients.iter().map(|c| c.id().clone()).collect(),
368        }
369    }
370
371    /// Get selection strategy
372    pub fn strategy(&self) -> ClientSelectionStrategy {
373        self.strategy
374    }
375}
376
377/// Differential privacy parameters
378#[derive(Debug, Clone, Copy)]
379pub struct PrivacyParameters {
380    /// Privacy budget (epsilon)
381    epsilon: f64,
382    /// Privacy loss probability (delta)
383    delta: f64,
384    /// Clipping threshold for gradient norm
385    clip_norm: f64,
386    /// Noise multiplier
387    noise_multiplier: f64,
388}
389
390impl PrivacyParameters {
391    /// Create new privacy parameters
392    pub fn new(epsilon: f64, delta: f64) -> Self {
393        Self {
394            epsilon,
395            delta,
396            clip_norm: 1.0,
397            noise_multiplier: 1.0,
398        }
399    }
400
401    /// Set clipping norm
402    pub fn with_clip_norm(mut self, clip_norm: f64) -> Self {
403        self.clip_norm = clip_norm;
404        self
405    }
406
407    /// Set noise multiplier
408    pub fn with_noise_multiplier(mut self, noise_multiplier: f64) -> Self {
409        self.noise_multiplier = noise_multiplier;
410        self
411    }
412
413    /// Get epsilon
414    pub fn epsilon(&self) -> f64 {
415        self.epsilon
416    }
417
418    /// Get delta
419    pub fn delta(&self) -> f64 {
420        self.delta
421    }
422
423    /// Get clip norm
424    pub fn clip_norm(&self) -> f64 {
425        self.clip_norm
426    }
427
428    /// Get noise multiplier
429    pub fn noise_multiplier(&self) -> f64 {
430        self.noise_multiplier
431    }
432
433    /// Check if privacy budget is exhausted
434    pub fn is_exhausted(&self) -> bool {
435        self.epsilon <= 0.0
436    }
437}
438
439/// Communication efficiency techniques
440#[derive(Debug, Clone, Copy, PartialEq, Eq)]
441pub enum CompressionTechnique {
442    /// No compression
443    None,
444    /// Gradient quantization (reduce precision)
445    Quantization { bits: u8 },
446    /// Sparsification (top-k gradients)
447    Sparsification { k: usize },
448    /// Gradient sketching
449    Sketching,
450    /// Low-rank approximation
451    LowRank { rank: usize },
452}
453
454/// Federated learning round metadata
455#[derive(Debug, Clone)]
456pub struct TrainingRound {
457    /// Round number
458    round: u64,
459    /// Number of clients selected
460    num_clients: usize,
461    /// Number of clients that completed
462    num_completed: usize,
463    /// Average loss across clients
464    avg_loss: f64,
465    /// Average accuracy across clients
466    avg_accuracy: Option<f64>,
467    /// Total communication cost (bytes)
468    communication_cost: usize,
469    /// Round duration (seconds)
470    duration_secs: f64,
471}
472
473impl TrainingRound {
474    /// Create a new training round
475    pub fn new(round: u64, num_clients: usize) -> Self {
476        Self {
477            round,
478            num_clients,
479            num_completed: 0,
480            avg_loss: 0.0,
481            avg_accuracy: None,
482            communication_cost: 0,
483            duration_secs: 0.0,
484        }
485    }
486
487    /// Set number of completed clients
488    pub fn set_completed(&mut self, num_completed: usize) {
489        self.num_completed = num_completed;
490    }
491
492    /// Set average loss
493    pub fn set_avg_loss(&mut self, avg_loss: f64) {
494        self.avg_loss = avg_loss;
495    }
496
497    /// Set average accuracy
498    pub fn set_avg_accuracy(&mut self, avg_accuracy: f64) {
499        self.avg_accuracy = Some(avg_accuracy);
500    }
501
502    /// Set communication cost
503    pub fn set_communication_cost(&mut self, cost: usize) {
504        self.communication_cost = cost;
505    }
506
507    /// Set duration
508    pub fn set_duration(&mut self, duration_secs: f64) {
509        self.duration_secs = duration_secs;
510    }
511
512    /// Get round number
513    pub fn round(&self) -> u64 {
514        self.round
515    }
516
517    /// Get number of selected clients
518    pub fn num_clients(&self) -> usize {
519        self.num_clients
520    }
521
522    /// Get number of completed clients
523    pub fn num_completed(&self) -> usize {
524        self.num_completed
525    }
526
527    /// Get average loss
528    pub fn avg_loss(&self) -> f64 {
529        self.avg_loss
530    }
531
532    /// Get average accuracy
533    pub fn avg_accuracy(&self) -> Option<f64> {
534        self.avg_accuracy
535    }
536
537    /// Get communication cost
538    pub fn communication_cost(&self) -> usize {
539        self.communication_cost
540    }
541
542    /// Get duration
543    pub fn duration_secs(&self) -> f64 {
544        self.duration_secs
545    }
546
547    /// Calculate completion rate
548    pub fn completion_rate(&self) -> f64 {
549        if self.num_clients == 0 {
550            0.0
551        } else {
552            self.num_completed as f64 / self.num_clients as f64
553        }
554    }
555}
556
557/// Fairness metrics for federated learning
558#[derive(Debug, Clone)]
559pub struct FairnessMetrics {
560    /// Variance in client accuracy
561    accuracy_variance: f64,
562    /// Minimum client accuracy
563    min_accuracy: f64,
564    /// Maximum client accuracy
565    max_accuracy: f64,
566    /// Jain's fairness index
567    jains_index: f64,
568}
569
570impl FairnessMetrics {
571    /// Create new fairness metrics
572    pub fn new(
573        accuracy_variance: f64,
574        min_accuracy: f64,
575        max_accuracy: f64,
576        jains_index: f64,
577    ) -> Self {
578        Self {
579            accuracy_variance,
580            min_accuracy,
581            max_accuracy,
582            jains_index,
583        }
584    }
585
586    /// Get accuracy variance
587    pub fn accuracy_variance(&self) -> f64 {
588        self.accuracy_variance
589    }
590
591    /// Get minimum accuracy
592    pub fn min_accuracy(&self) -> f64 {
593        self.min_accuracy
594    }
595
596    /// Get maximum accuracy
597    pub fn max_accuracy(&self) -> f64 {
598        self.max_accuracy
599    }
600
601    /// Get Jain's fairness index
602    pub fn jains_index(&self) -> f64 {
603        self.jains_index
604    }
605
606    /// Check if learning is fair (Jain's index > 0.8)
607    pub fn is_fair(&self) -> bool {
608        self.jains_index > 0.8
609    }
610}
611
612/// Federated learning coordinator
613#[derive(Debug, Clone)]
614pub struct FederatedCoordinator {
615    /// Current round number
616    current_round: u64,
617    /// Aggregation strategy
618    strategy: AggregationStrategy,
619    /// Privacy parameters
620    privacy: Option<PrivacyParameters>,
621    /// Compression technique
622    compression: CompressionTechnique,
623    /// Training history
624    rounds: Vec<TrainingRound>,
625}
626
627impl FederatedCoordinator {
628    /// Create a new federated coordinator
629    pub fn new(strategy: AggregationStrategy) -> Self {
630        Self {
631            current_round: 0,
632            strategy,
633            privacy: None,
634            compression: CompressionTechnique::None,
635            rounds: Vec::new(),
636        }
637    }
638
639    /// Set privacy parameters
640    pub fn with_privacy(mut self, privacy: PrivacyParameters) -> Self {
641        self.privacy = Some(privacy);
642        self
643    }
644
645    /// Set compression technique
646    pub fn with_compression(mut self, compression: CompressionTechnique) -> Self {
647        self.compression = compression;
648        self
649    }
650
651    /// Start a new training round
652    pub fn start_round(&mut self, num_clients: usize) -> u64 {
653        self.current_round += 1;
654        self.rounds
655            .push(TrainingRound::new(self.current_round, num_clients));
656        self.current_round
657    }
658
659    /// Complete the current round
660    pub fn complete_round(&mut self, avg_loss: f64, num_completed: usize) {
661        if let Some(round) = self.rounds.last_mut() {
662            round.set_avg_loss(avg_loss);
663            round.set_completed(num_completed);
664        }
665    }
666
667    /// Get current round
668    pub fn current_round(&self) -> u64 {
669        self.current_round
670    }
671
672    /// Get strategy
673    pub fn strategy(&self) -> AggregationStrategy {
674        self.strategy
675    }
676
677    /// Get privacy parameters
678    pub fn privacy(&self) -> Option<&PrivacyParameters> {
679        self.privacy.as_ref()
680    }
681
682    /// Get compression technique
683    pub fn compression(&self) -> CompressionTechnique {
684        self.compression
685    }
686
687    /// Get training history
688    pub fn rounds(&self) -> &[TrainingRound] {
689        &self.rounds
690    }
691
692    /// Get statistics
693    pub fn statistics(&self) -> CoordinatorStatistics {
694        let total_rounds = self.rounds.len();
695        let avg_completion_rate = if total_rounds > 0 {
696            self.rounds.iter().map(|r| r.completion_rate()).sum::<f64>() / total_rounds as f64
697        } else {
698            0.0
699        };
700        let total_communication = self.rounds.iter().map(|r| r.communication_cost()).sum();
701
702        CoordinatorStatistics {
703            total_rounds,
704            avg_completion_rate,
705            total_communication,
706        }
707    }
708}
709
710/// Coordinator statistics
711#[derive(Debug, Clone)]
712pub struct CoordinatorStatistics {
713    /// Total number of rounds
714    pub total_rounds: usize,
715    /// Average completion rate
716    pub avg_completion_rate: f64,
717    /// Total communication cost (bytes)
718    pub total_communication: usize,
719}
720
721#[cfg(test)]
722mod tests {
723    use super::*;
724
725    #[test]
726    fn test_client_id() {
727        let id = ClientId::new("client_1");
728        assert_eq!(id.as_str(), "client_1");
729        assert_eq!(format!("{}", id), "client_1");
730    }
731
732    #[test]
733    fn test_federated_client_creation() {
734        let client = FederatedClient::new("client_1", 1000, 0.8);
735        assert_eq!(client.id().as_str(), "client_1");
736        assert_eq!(client.num_samples(), 1000);
737        assert_eq!(client.availability(), 0.8);
738        assert_eq!(client.weight(), 1000.0);
739    }
740
741    #[test]
742    fn test_client_with_builder() {
743        let client = FederatedClient::new("client_1", 1000, 0.8)
744            .with_compute_capacity(2.0)
745            .with_bandwidth(50.0)
746            .with_privacy_budget(1.0);
747
748        assert_eq!(client.compute_capacity(), 2.0);
749        assert_eq!(client.bandwidth_mbps(), 50.0);
750        assert_eq!(client.privacy_budget(), Some(1.0));
751    }
752
753    #[test]
754    fn test_data_distribution() {
755        let iid = DataDistribution::IID;
756        let label_skew = DataDistribution::LabelSkew { skew_factor: 0.5 };
757        let _feature_skew = DataDistribution::FeatureSkew { skew_factor: 0.3 };
758
759        assert_eq!(iid, DataDistribution::IID);
760        assert_ne!(iid, label_skew);
761    }
762
763    #[test]
764    fn test_aggregation_strategy_display() {
765        assert_eq!(format!("{}", AggregationStrategy::FedAvg), "FedAvg");
766        assert_eq!(format!("{}", AggregationStrategy::FedProx), "FedProx");
767    }
768
769    #[test]
770    fn test_client_update() {
771        let id = ClientId::new("client_1");
772        let mut update = ClientUpdate::new(id.clone(), 5, 100, 0.5).with_accuracy(0.85);
773
774        update.add_metadata("dataset", "mnist");
775        assert_eq!(update.client_id(), &id);
776        assert_eq!(update.round(), 5);
777        assert_eq!(update.num_steps(), 100);
778        assert_eq!(update.loss(), 0.5);
779        assert_eq!(update.accuracy(), Some(0.85));
780        assert_eq!(update.metadata().len(), 1);
781    }
782
783    #[test]
784    fn test_client_selector_random() {
785        let clients = vec![
786            FederatedClient::new("client_1", 1000, 0.8),
787            FederatedClient::new("client_2", 500, 0.6),
788            FederatedClient::new("client_3", 800, 0.9),
789        ];
790
791        let selector = ClientSelector::new(ClientSelectionStrategy::Random);
792        let selected = selector.select(&clients, 2);
793        assert_eq!(selected.len(), 2);
794    }
795
796    #[test]
797    fn test_client_selector_by_data_size() {
798        let clients = vec![
799            FederatedClient::new("client_1", 1000, 0.8),
800            FederatedClient::new("client_2", 500, 0.6),
801            FederatedClient::new("client_3", 800, 0.9),
802        ];
803
804        let selector = ClientSelector::new(ClientSelectionStrategy::ByDataSize);
805        let selected = selector.select(&clients, 2);
806        assert_eq!(selected.len(), 2);
807        assert_eq!(selected[0].as_str(), "client_1"); // Largest dataset
808    }
809
810    #[test]
811    fn test_client_selector_all() {
812        let clients = vec![
813            FederatedClient::new("client_1", 1000, 0.8),
814            FederatedClient::new("client_2", 500, 0.6),
815        ];
816
817        let selector = ClientSelector::new(ClientSelectionStrategy::All);
818        let selected = selector.select(&clients, 10);
819        assert_eq!(selected.len(), 2); // All clients
820    }
821
822    #[test]
823    fn test_privacy_parameters() {
824        let privacy = PrivacyParameters::new(1.0, 1e-5)
825            .with_clip_norm(2.0)
826            .with_noise_multiplier(0.5);
827
828        assert_eq!(privacy.epsilon(), 1.0);
829        assert_eq!(privacy.delta(), 1e-5);
830        assert_eq!(privacy.clip_norm(), 2.0);
831        assert_eq!(privacy.noise_multiplier(), 0.5);
832        assert!(!privacy.is_exhausted());
833    }
834
835    #[test]
836    fn test_compression_techniques() {
837        let _none = CompressionTechnique::None;
838        let _quant = CompressionTechnique::Quantization { bits: 8 };
839        let _sparse = CompressionTechnique::Sparsification { k: 100 };
840        let _sketch = CompressionTechnique::Sketching;
841        let _low_rank = CompressionTechnique::LowRank { rank: 10 };
842    }
843
844    #[test]
845    fn test_training_round() {
846        let mut round = TrainingRound::new(1, 10);
847        round.set_completed(8);
848        round.set_avg_loss(0.5);
849        round.set_avg_accuracy(0.85);
850        round.set_communication_cost(1024);
851        round.set_duration(60.0);
852
853        assert_eq!(round.round(), 1);
854        assert_eq!(round.num_clients(), 10);
855        assert_eq!(round.num_completed(), 8);
856        assert_eq!(round.avg_loss(), 0.5);
857        assert_eq!(round.avg_accuracy(), Some(0.85));
858        assert_eq!(round.communication_cost(), 1024);
859        assert_eq!(round.duration_secs(), 60.0);
860        assert_eq!(round.completion_rate(), 0.8);
861    }
862
863    #[test]
864    fn test_fairness_metrics() {
865        let metrics = FairnessMetrics::new(0.01, 0.80, 0.90, 0.85);
866        assert_eq!(metrics.accuracy_variance(), 0.01);
867        assert_eq!(metrics.min_accuracy(), 0.80);
868        assert_eq!(metrics.max_accuracy(), 0.90);
869        assert_eq!(metrics.jains_index(), 0.85);
870        assert!(metrics.is_fair());
871    }
872
873    #[test]
874    fn test_federated_coordinator() {
875        let mut coordinator = FederatedCoordinator::new(AggregationStrategy::FedAvg)
876            .with_privacy(PrivacyParameters::new(1.0, 1e-5))
877            .with_compression(CompressionTechnique::Quantization { bits: 8 });
878
879        assert_eq!(coordinator.current_round(), 0);
880
881        let round1 = coordinator.start_round(10);
882        assert_eq!(round1, 1);
883
884        coordinator.complete_round(0.5, 8);
885
886        let stats = coordinator.statistics();
887        assert_eq!(stats.total_rounds, 1);
888    }
889
890    #[test]
891    fn test_coordinator_multiple_rounds() {
892        let mut coordinator = FederatedCoordinator::new(AggregationStrategy::FedAvg);
893
894        coordinator.start_round(10);
895        coordinator.complete_round(0.6, 9);
896
897        coordinator.start_round(10);
898        coordinator.complete_round(0.4, 8);
899
900        coordinator.start_round(10);
901        coordinator.complete_round(0.3, 10);
902
903        assert_eq!(coordinator.current_round(), 3);
904        assert_eq!(coordinator.rounds().len(), 3);
905
906        let stats = coordinator.statistics();
907        assert_eq!(stats.total_rounds, 3);
908        assert!(stats.avg_completion_rate > 0.8);
909    }
910
911    #[test]
912    fn test_client_selection_strategies() {
913        let _random = ClientSelectionStrategy::Random;
914        let _avail = ClientSelectionStrategy::ByAvailability;
915        let _size = ClientSelectionStrategy::ByDataSize;
916        let _compute = ClientSelectionStrategy::ByComputeCapacity;
917        let _power = ClientSelectionStrategy::PowerOfChoice { choices: 3 };
918        let _all = ClientSelectionStrategy::All;
919    }
920
921    #[test]
922    fn test_availability_clamping() {
923        let client1 = FederatedClient::new("c1", 100, 1.5); // > 1.0
924        let client2 = FederatedClient::new("c2", 100, -0.1); // < 0.0
925
926        assert_eq!(client1.availability(), 1.0);
927        assert_eq!(client2.availability(), 0.0);
928    }
929}