ruqu_qflg/
protocol.rs

1//! Federated Learning Protocol
2//!
3//! This module provides the federated learning protocol implementation including
4//! client registration, key exchange, round management, and model synchronization.
5//!
6//! ## Protocol Flow
7//!
8//! 1. Clients register with the coordinator
9//! 2. Key exchange for secure communication
10//! 3. Training rounds:
11//!    - Coordinator broadcasts global model
12//!    - Clients train locally and submit gradients
13//!    - Coordinator aggregates with Byzantine tolerance
14//!    - Privacy mechanism applied to aggregate
15//! 4. Model synchronization after each round
16//!
17//! ## Example
18//!
19//! ```rust
20//! use ruqu_qflg::protocol::{FederatedCoordinator, CoordinatorConfig, ClientInfo};
21//!
22//! let config = CoordinatorConfig::default();
23//! let mut coordinator = FederatedCoordinator::new(config);
24//!
25//! // Register clients
26//! let client = ClientInfo::new("client_1".to_string(), vec![0u8; 32]);
27//! coordinator.register_client(client).unwrap();
28//! ```
29
30use std::collections::HashMap;
31use chrono::{DateTime, Utc};
32use ndarray::Array1;
33use parking_lot::RwLock;
34use serde::{Deserialize, Serialize};
35use uuid::Uuid;
36
37use crate::aggregation::{AggregatorConfig, GradientAggregator, WeightedAverageAggregator};
38use crate::byzantine::{ByzantineDetector, DetectorConfig, KrumDetector};
39use crate::error::{ProtocolError, Result};
40use crate::privacy::{GaussianMechanism, PrivacyBudget, PrivacyConfig, PrivacyMechanism};
41
42/// Configuration for the federated coordinator
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct CoordinatorConfig {
45    /// Minimum number of clients to start a round
46    pub min_clients: usize,
47    /// Maximum clients per round
48    pub max_clients: usize,
49    /// Round timeout in milliseconds
50    pub round_timeout_ms: u64,
51    /// Enable Byzantine tolerance
52    pub byzantine_enabled: bool,
53    /// Byzantine tolerance fraction
54    pub byzantine_fraction: f64,
55    /// Enable differential privacy
56    pub privacy_enabled: bool,
57    /// Privacy epsilon per round
58    pub privacy_epsilon: f64,
59    /// Privacy delta
60    pub privacy_delta: f64,
61    /// Maximum privacy budget
62    pub max_privacy_budget: f64,
63    /// Model dimension
64    pub model_dimension: usize,
65}
66
67impl Default for CoordinatorConfig {
68    fn default() -> Self {
69        Self {
70            min_clients: 3,
71            max_clients: 100,
72            round_timeout_ms: 60000,
73            byzantine_enabled: true,
74            byzantine_fraction: 0.3,
75            privacy_enabled: true,
76            privacy_epsilon: 1.0,
77            privacy_delta: 1e-5,
78            max_privacy_budget: 10.0,
79            model_dimension: 1000,
80        }
81    }
82}
83
84/// Client information
85#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct ClientInfo {
87    /// Unique client identifier
88    pub client_id: String,
89    /// Client's public key for secure communication
90    pub public_key: Vec<u8>,
91    /// Registration timestamp
92    pub registered_at: DateTime<Utc>,
93    /// Last seen timestamp
94    pub last_seen: DateTime<Utc>,
95    /// Number of rounds participated
96    pub rounds_participated: u64,
97    /// Data contribution weight
98    pub weight: f64,
99    /// Is client active
100    pub active: bool,
101}
102
103impl ClientInfo {
104    /// Create a new client info
105    pub fn new(client_id: String, public_key: Vec<u8>) -> Self {
106        let now = Utc::now();
107        Self {
108            client_id,
109            public_key,
110            registered_at: now,
111            last_seen: now,
112            rounds_participated: 0,
113            weight: 1.0,
114            active: true,
115        }
116    }
117
118    /// Update last seen timestamp
119    pub fn touch(&mut self) {
120        self.last_seen = Utc::now();
121    }
122}
123
124/// Round state
125#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
126pub enum RoundState {
127    /// Waiting for clients
128    Waiting,
129    /// Round in progress, collecting gradients
130    Collecting,
131    /// Aggregating gradients
132    Aggregating,
133    /// Round completed
134    Completed,
135    /// Round failed
136    Failed,
137}
138
139impl std::fmt::Display for RoundState {
140    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
141        match self {
142            RoundState::Waiting => write!(f, "Waiting"),
143            RoundState::Collecting => write!(f, "Collecting"),
144            RoundState::Aggregating => write!(f, "Aggregating"),
145            RoundState::Completed => write!(f, "Completed"),
146            RoundState::Failed => write!(f, "Failed"),
147        }
148    }
149}
150
151/// Information about a training round
152#[derive(Debug, Clone, Serialize, Deserialize)]
153pub struct RoundInfo {
154    /// Round number
155    pub round_number: u64,
156    /// Round unique ID
157    pub round_id: String,
158    /// Current state
159    pub state: RoundState,
160    /// Round start time
161    pub started_at: DateTime<Utc>,
162    /// Round end time
163    pub ended_at: Option<DateTime<Utc>>,
164    /// Participating clients
165    pub participants: Vec<String>,
166    /// Gradients received
167    pub gradients_received: usize,
168    /// Byzantine clients detected
169    pub byzantine_detected: Vec<String>,
170    /// Privacy epsilon spent this round
171    pub privacy_spent: f64,
172}
173
174impl RoundInfo {
175    /// Create a new round
176    pub fn new(round_number: u64) -> Self {
177        Self {
178            round_number,
179            round_id: Uuid::new_v4().to_string(),
180            state: RoundState::Waiting,
181            started_at: Utc::now(),
182            ended_at: None,
183            participants: Vec::new(),
184            gradients_received: 0,
185            byzantine_detected: Vec::new(),
186            privacy_spent: 0.0,
187        }
188    }
189}
190
191/// Gradient submission from a client
192#[derive(Debug, Clone, Serialize, Deserialize)]
193pub struct GradientSubmission {
194    /// Client ID
195    pub client_id: String,
196    /// Round ID
197    pub round_id: String,
198    /// Gradient data
199    pub gradient: Array1<f64>,
200    /// Data weight (number of samples)
201    pub weight: f64,
202    /// Signature
203    pub signature: Vec<u8>,
204    /// Submission timestamp
205    pub submitted_at: DateTime<Utc>,
206}
207
208/// Federated learning coordinator
209pub struct FederatedCoordinator {
210    config: CoordinatorConfig,
211    /// Registered clients
212    clients: RwLock<HashMap<String, ClientInfo>>,
213    /// Current round information
214    current_round: RwLock<Option<RoundInfo>>,
215    /// Round history
216    round_history: RwLock<Vec<RoundInfo>>,
217    /// Current global model
218    global_model: RwLock<Array1<f64>>,
219    /// Submitted gradients for current round
220    gradients: RwLock<Vec<GradientSubmission>>,
221    /// Privacy budget tracker
222    privacy_budget: RwLock<PrivacyBudget>,
223}
224
225impl FederatedCoordinator {
226    /// Create a new federated coordinator
227    pub fn new(config: CoordinatorConfig) -> Self {
228        let privacy_budget = PrivacyBudget::new(config.max_privacy_budget, config.privacy_delta);
229        let global_model = Array1::zeros(config.model_dimension);
230
231        Self {
232            config,
233            clients: RwLock::new(HashMap::new()),
234            current_round: RwLock::new(None),
235            round_history: RwLock::new(Vec::new()),
236            global_model: RwLock::new(global_model),
237            gradients: RwLock::new(Vec::new()),
238            privacy_budget: RwLock::new(privacy_budget),
239        }
240    }
241
242    /// Register a new client
243    pub fn register_client(&self, client: ClientInfo) -> Result<()> {
244        let mut clients = self.clients.write();
245
246        if clients.contains_key(&client.client_id) {
247            return Err(ProtocolError::DuplicateClient(client.client_id).into());
248        }
249
250        clients.insert(client.client_id.clone(), client);
251        Ok(())
252    }
253
254    /// Unregister a client
255    pub fn unregister_client(&self, client_id: &str) -> Result<()> {
256        let mut clients = self.clients.write();
257
258        if clients.remove(client_id).is_none() {
259            return Err(ProtocolError::ClientNotRegistered(client_id.to_string()).into());
260        }
261
262        Ok(())
263    }
264
265    /// Get client information
266    pub fn get_client(&self, client_id: &str) -> Option<ClientInfo> {
267        self.clients.read().get(client_id).cloned()
268    }
269
270    /// List all registered clients
271    pub fn list_clients(&self) -> Vec<ClientInfo> {
272        self.clients.read().values().cloned().collect()
273    }
274
275    /// Get number of active clients
276    pub fn active_client_count(&self) -> usize {
277        self.clients.read().values().filter(|c| c.active).count()
278    }
279
280    /// Start a new training round
281    pub fn start_round(&self) -> Result<RoundInfo> {
282        let mut current = self.current_round.write();
283
284        if let Some(ref round) = *current {
285            if round.state != RoundState::Completed && round.state != RoundState::Failed {
286                return Err(ProtocolError::RoundInProgress(round.round_number).into());
287            }
288        }
289
290        let round_number = self.round_history.read().len() as u64 + 1;
291        let mut round = RoundInfo::new(round_number);
292        round.state = RoundState::Collecting;
293
294        // Clear gradients
295        self.gradients.write().clear();
296
297        let round_info = round.clone();
298        *current = Some(round);
299
300        Ok(round_info)
301    }
302
303    /// Submit a gradient for the current round
304    pub fn submit_gradient(&self, submission: GradientSubmission) -> Result<()> {
305        // Verify client is registered
306        {
307            let clients = self.clients.read();
308            if !clients.contains_key(&submission.client_id) {
309                return Err(ProtocolError::ClientNotRegistered(submission.client_id.clone()).into());
310            }
311        }
312
313        // Verify round is active
314        {
315            let current = self.current_round.read();
316            match current.as_ref() {
317                None => return Err(ProtocolError::NoActiveRound.into()),
318                Some(round) => {
319                    if round.state != RoundState::Collecting {
320                        return Err(ProtocolError::InvalidStateTransition {
321                            from: round.state.to_string(),
322                            to: "Collecting".to_string(),
323                        }
324                        .into());
325                    }
326                    if submission.round_id != round.round_id {
327                        return Err(ProtocolError::NoActiveRound.into());
328                    }
329                }
330            }
331        }
332
333        // Add gradient
334        let mut gradients = self.gradients.write();
335        gradients.push(submission.clone());
336
337        // Update round info
338        {
339            let mut current = self.current_round.write();
340            if let Some(ref mut round) = *current {
341                round.gradients_received = gradients.len();
342                if !round.participants.contains(&submission.client_id) {
343                    round.participants.push(submission.client_id.clone());
344                }
345            }
346        }
347
348        // Update client
349        {
350            let mut clients = self.clients.write();
351            if let Some(client) = clients.get_mut(&submission.client_id) {
352                client.touch();
353            }
354        }
355
356        Ok(())
357    }
358
359    /// Complete the current round and compute aggregate
360    pub fn complete_round(&self) -> Result<Array1<f64>> {
361        // Check round state
362        {
363            let mut current = self.current_round.write();
364            match current.as_mut() {
365                None => return Err(ProtocolError::NoActiveRound.into()),
366                Some(round) => {
367                    if round.state != RoundState::Collecting {
368                        return Err(ProtocolError::InvalidStateTransition {
369                            from: round.state.to_string(),
370                            to: "Aggregating".to_string(),
371                        }
372                        .into());
373                    }
374                    round.state = RoundState::Aggregating;
375                }
376            }
377        }
378
379        let gradients = self.gradients.read();
380        let submissions: Vec<_> = gradients.iter().collect();
381
382        if submissions.is_empty() {
383            let mut current = self.current_round.write();
384            if let Some(ref mut round) = *current {
385                round.state = RoundState::Failed;
386                round.ended_at = Some(Utc::now());
387            }
388            return Err(ProtocolError::NoActiveRound.into());
389        }
390
391        // Extract gradients and weights
392        let gradient_arrays: Vec<Array1<f64>> = submissions.iter().map(|s| s.gradient.clone()).collect();
393        let weights: Vec<f64> = submissions.iter().map(|s| s.weight).collect();
394
395        // Byzantine detection
396        let (honest_indices, byzantine_indices) = if self.config.byzantine_enabled {
397            let detector_config = DetectorConfig::new(
398                self.config.min_clients,
399                self.config.byzantine_fraction,
400            );
401            let detector = KrumDetector::new(detector_config);
402            detector.detect(&gradient_arrays)?
403        } else {
404            ((0..gradient_arrays.len()).collect(), vec![])
405        };
406
407        // Record Byzantine clients
408        {
409            let mut current = self.current_round.write();
410            if let Some(ref mut round) = *current {
411                round.byzantine_detected = byzantine_indices
412                    .iter()
413                    .filter_map(|&i| submissions.get(i).map(|s| s.client_id.clone()))
414                    .collect();
415            }
416        }
417
418        // Filter to honest gradients
419        let honest_gradients: Vec<Array1<f64>> = honest_indices
420            .iter()
421            .map(|&i| gradient_arrays[i].clone())
422            .collect();
423        let honest_weights: Vec<f64> = honest_indices.iter().map(|&i| weights[i]).collect();
424
425        // Aggregate
426        let aggregator_config = AggregatorConfig::default();
427        let aggregator = WeightedAverageAggregator::new(aggregator_config);
428        let mut aggregate = aggregator.aggregate(&honest_gradients, &honest_weights)?;
429
430        // Apply privacy
431        if self.config.privacy_enabled {
432            let privacy_config = PrivacyConfig::new(
433                self.config.privacy_epsilon,
434                self.config.privacy_delta,
435                1.0, // Sensitivity after clipping
436            );
437            let mechanism = GaussianMechanism::new(privacy_config)?;
438            aggregate = mechanism.apply(&aggregate)?;
439
440            // Track privacy budget
441            {
442                let mut budget = self.privacy_budget.write();
443                budget.spend(
444                    self.config.privacy_epsilon,
445                    self.config.privacy_delta,
446                    &format!("round {}", self.current_round.read().as_ref().map(|r| r.round_number).unwrap_or(0)),
447                )?;
448            }
449
450            // Record privacy spent
451            {
452                let mut current = self.current_round.write();
453                if let Some(ref mut round) = *current {
454                    round.privacy_spent = self.config.privacy_epsilon;
455                }
456            }
457        }
458
459        // Update global model
460        {
461            let mut model = self.global_model.write();
462            *model = &*model + &aggregate;
463        }
464
465        // Complete round
466        {
467            let mut current = self.current_round.write();
468            if let Some(ref mut round) = *current {
469                round.state = RoundState::Completed;
470                round.ended_at = Some(Utc::now());
471
472                // Update client participation
473                let mut clients = self.clients.write();
474                for client_id in &round.participants {
475                    if let Some(client) = clients.get_mut(client_id) {
476                        client.rounds_participated += 1;
477                    }
478                }
479
480                // Move to history
481                self.round_history.write().push(round.clone());
482            }
483        }
484
485        Ok(aggregate)
486    }
487
488    /// Get current global model
489    pub fn get_global_model(&self) -> Array1<f64> {
490        self.global_model.read().clone()
491    }
492
493    /// Get current round info
494    pub fn get_current_round(&self) -> Option<RoundInfo> {
495        self.current_round.read().clone()
496    }
497
498    /// Get round history
499    pub fn get_round_history(&self) -> Vec<RoundInfo> {
500        self.round_history.read().clone()
501    }
502
503    /// Get remaining privacy budget
504    pub fn remaining_privacy_budget(&self) -> f64 {
505        self.privacy_budget.read().remaining()
506    }
507
508    /// Get coordinator configuration
509    pub fn config(&self) -> &CoordinatorConfig {
510        &self.config
511    }
512}
513
514/// Client for federated learning
515pub struct FederatedClient {
516    /// Client ID
517    client_id: String,
518    /// Public key
519    public_key: Vec<u8>,
520    /// Private key (simulated)
521    _private_key: Vec<u8>,
522    /// Local model copy
523    local_model: Array1<f64>,
524    /// Configuration
525    config: ClientConfig,
526}
527
528/// Client configuration
529#[derive(Debug, Clone, Serialize, Deserialize)]
530pub struct ClientConfig {
531    /// Local training epochs per round
532    pub local_epochs: usize,
533    /// Local batch size
534    pub batch_size: usize,
535    /// Learning rate
536    pub learning_rate: f64,
537    /// Model dimension
538    pub model_dimension: usize,
539}
540
541impl Default for ClientConfig {
542    fn default() -> Self {
543        Self {
544            local_epochs: 5,
545            batch_size: 32,
546            learning_rate: 0.01,
547            model_dimension: 1000,
548        }
549    }
550}
551
552impl FederatedClient {
553    /// Create a new federated client
554    pub fn new(client_id: String, config: ClientConfig) -> Self {
555        // Generate key pair (simulated)
556        let mut rng = rand::thread_rng();
557        let public_key: Vec<u8> = (0..32).map(|_| rand::Rng::gen(&mut rng)).collect();
558        let private_key: Vec<u8> = (0..32).map(|_| rand::Rng::gen(&mut rng)).collect();
559
560        let local_model = Array1::zeros(config.model_dimension);
561
562        Self {
563            client_id,
564            public_key,
565            _private_key: private_key,
566            local_model,
567            config,
568        }
569    }
570
571    /// Get client info for registration
572    pub fn get_info(&self) -> ClientInfo {
573        ClientInfo::new(self.client_id.clone(), self.public_key.clone())
574    }
575
576    /// Update local model with global model
577    pub fn sync_model(&mut self, global_model: &Array1<f64>) {
578        self.local_model = global_model.clone();
579    }
580
581    /// Simulate local training and return gradient
582    pub fn train_local(&self, _data_size: usize) -> GradientSubmission {
583        // Simulate gradient (in practice, this would be computed from local data)
584        let mut rng = rand::thread_rng();
585        let gradient: Vec<f64> = (0..self.config.model_dimension)
586            .map(|_| rand::Rng::gen_range(&mut rng, -0.1..0.1))
587            .collect();
588
589        GradientSubmission {
590            client_id: self.client_id.clone(),
591            round_id: String::new(), // Will be set when submitting
592            gradient: Array1::from_vec(gradient),
593            weight: _data_size as f64,
594            signature: vec![], // Would be computed in practice
595            submitted_at: Utc::now(),
596        }
597    }
598
599    /// Get client ID
600    pub fn client_id(&self) -> &str {
601        &self.client_id
602    }
603}
604
605/// Protocol statistics
606#[derive(Debug, Clone, Serialize, Deserialize, Default)]
607pub struct ProtocolStats {
608    /// Total rounds completed
609    pub rounds_completed: u64,
610    /// Total rounds failed
611    pub rounds_failed: u64,
612    /// Total gradients processed
613    pub total_gradients: u64,
614    /// Total Byzantine detected
615    pub total_byzantine: u64,
616    /// Total privacy spent
617    pub total_privacy_spent: f64,
618    /// Average round duration (ms)
619    pub avg_round_duration_ms: f64,
620}
621
622impl ProtocolStats {
623    /// Compute stats from round history
624    pub fn from_history(rounds: &[RoundInfo]) -> Self {
625        let mut stats = Self::default();
626
627        for round in rounds {
628            match round.state {
629                RoundState::Completed => stats.rounds_completed += 1,
630                RoundState::Failed => stats.rounds_failed += 1,
631                _ => {}
632            }
633
634            stats.total_gradients += round.gradients_received as u64;
635            stats.total_byzantine += round.byzantine_detected.len() as u64;
636            stats.total_privacy_spent += round.privacy_spent;
637
638            if let Some(ended) = round.ended_at {
639                let duration = (ended - round.started_at).num_milliseconds() as f64;
640                stats.avg_round_duration_ms = (stats.avg_round_duration_ms
641                    * (stats.rounds_completed + stats.rounds_failed - 1) as f64
642                    + duration)
643                    / (stats.rounds_completed + stats.rounds_failed) as f64;
644            }
645        }
646
647        stats
648    }
649}
650
651#[cfg(test)]
652mod tests {
653    use super::*;
654
655    #[test]
656    fn test_coordinator_creation() {
657        let config = CoordinatorConfig::default();
658        let coordinator = FederatedCoordinator::new(config);
659
660        assert_eq!(coordinator.active_client_count(), 0);
661    }
662
663    #[test]
664    fn test_client_registration() {
665        let config = CoordinatorConfig::default();
666        let coordinator = FederatedCoordinator::new(config);
667
668        let client = ClientInfo::new("client_1".to_string(), vec![0u8; 32]);
669        coordinator.register_client(client).unwrap();
670
671        assert_eq!(coordinator.active_client_count(), 1);
672        assert!(coordinator.get_client("client_1").is_some());
673    }
674
675    #[test]
676    fn test_duplicate_registration() {
677        let config = CoordinatorConfig::default();
678        let coordinator = FederatedCoordinator::new(config);
679
680        let client = ClientInfo::new("client_1".to_string(), vec![0u8; 32]);
681        coordinator.register_client(client.clone()).unwrap();
682
683        let result = coordinator.register_client(client);
684        assert!(result.is_err());
685    }
686
687    #[test]
688    fn test_client_unregistration() {
689        let config = CoordinatorConfig::default();
690        let coordinator = FederatedCoordinator::new(config);
691
692        let client = ClientInfo::new("client_1".to_string(), vec![0u8; 32]);
693        coordinator.register_client(client).unwrap();
694        coordinator.unregister_client("client_1").unwrap();
695
696        assert_eq!(coordinator.active_client_count(), 0);
697    }
698
699    #[test]
700    fn test_start_round() {
701        let config = CoordinatorConfig::default();
702        let coordinator = FederatedCoordinator::new(config);
703
704        let round = coordinator.start_round().unwrap();
705        assert_eq!(round.round_number, 1);
706        assert_eq!(round.state, RoundState::Collecting);
707    }
708
709    #[test]
710    fn test_round_already_in_progress() {
711        let config = CoordinatorConfig::default();
712        let coordinator = FederatedCoordinator::new(config);
713
714        coordinator.start_round().unwrap();
715        let result = coordinator.start_round();
716        assert!(result.is_err());
717    }
718
719    #[test]
720    fn test_submit_gradient() {
721        let config = CoordinatorConfig::default();
722        let coordinator = FederatedCoordinator::new(config);
723
724        let client = ClientInfo::new("client_1".to_string(), vec![0u8; 32]);
725        coordinator.register_client(client).unwrap();
726
727        let round = coordinator.start_round().unwrap();
728
729        let submission = GradientSubmission {
730            client_id: "client_1".to_string(),
731            round_id: round.round_id,
732            gradient: Array1::zeros(1000),
733            weight: 100.0,
734            signature: vec![],
735            submitted_at: Utc::now(),
736        };
737
738        coordinator.submit_gradient(submission).unwrap();
739
740        let current = coordinator.get_current_round().unwrap();
741        assert_eq!(current.gradients_received, 1);
742    }
743
744    #[test]
745    fn test_submit_unregistered_client() {
746        let config = CoordinatorConfig::default();
747        let coordinator = FederatedCoordinator::new(config);
748
749        coordinator.start_round().unwrap();
750
751        let submission = GradientSubmission {
752            client_id: "unknown".to_string(),
753            round_id: "test".to_string(),
754            gradient: Array1::zeros(1000),
755            weight: 100.0,
756            signature: vec![],
757            submitted_at: Utc::now(),
758        };
759
760        let result = coordinator.submit_gradient(submission);
761        assert!(result.is_err());
762    }
763
764    #[test]
765    fn test_full_round() {
766        let mut config = CoordinatorConfig::default();
767        config.byzantine_enabled = false; // Disable for simple test
768        config.privacy_enabled = false;
769        config.model_dimension = 10;
770
771        let coordinator = FederatedCoordinator::new(config);
772
773        // Register clients
774        for i in 0..5 {
775            let client = ClientInfo::new(format!("client_{}", i), vec![0u8; 32]);
776            coordinator.register_client(client).unwrap();
777        }
778
779        // Start round
780        let round = coordinator.start_round().unwrap();
781
782        // Submit gradients
783        for i in 0..5 {
784            let submission = GradientSubmission {
785                client_id: format!("client_{}", i),
786                round_id: round.round_id.clone(),
787                gradient: Array1::from_vec(vec![0.1; 10]),
788                weight: 100.0,
789                signature: vec![],
790                submitted_at: Utc::now(),
791            };
792            coordinator.submit_gradient(submission).unwrap();
793        }
794
795        // Complete round
796        let aggregate = coordinator.complete_round().unwrap();
797        assert_eq!(aggregate.len(), 10);
798
799        // Check round completed
800        let history = coordinator.get_round_history();
801        assert_eq!(history.len(), 1);
802        assert_eq!(history[0].state, RoundState::Completed);
803    }
804
805    #[test]
806    fn test_federated_client() {
807        let config = ClientConfig::default();
808        let client = FederatedClient::new("test_client".to_string(), config);
809
810        let info = client.get_info();
811        assert_eq!(info.client_id, "test_client");
812        assert_eq!(info.public_key.len(), 32);
813    }
814
815    #[test]
816    fn test_client_train_local() {
817        let config = ClientConfig {
818            model_dimension: 10,
819            ..Default::default()
820        };
821        let client = FederatedClient::new("test".to_string(), config);
822
823        let submission = client.train_local(100);
824        assert_eq!(submission.gradient.len(), 10);
825        assert_eq!(submission.weight, 100.0);
826    }
827
828    #[test]
829    fn test_client_sync_model() {
830        let config = ClientConfig {
831            model_dimension: 10,
832            ..Default::default()
833        };
834        let mut client = FederatedClient::new("test".to_string(), config);
835
836        let global = Array1::from_vec(vec![1.0; 10]);
837        client.sync_model(&global);
838
839        // Client should have updated model
840        assert_eq!(client.local_model.len(), 10);
841    }
842
843    #[test]
844    fn test_protocol_stats() {
845        let rounds = vec![
846            RoundInfo {
847                round_number: 1,
848                round_id: "r1".to_string(),
849                state: RoundState::Completed,
850                started_at: Utc::now(),
851                ended_at: Some(Utc::now()),
852                participants: vec!["a".to_string()],
853                gradients_received: 5,
854                byzantine_detected: vec!["b".to_string()],
855                privacy_spent: 1.0,
856            },
857        ];
858
859        let stats = ProtocolStats::from_history(&rounds);
860        assert_eq!(stats.rounds_completed, 1);
861        assert_eq!(stats.total_gradients, 5);
862        assert_eq!(stats.total_byzantine, 1);
863        assert!((stats.total_privacy_spent - 1.0).abs() < 1e-10);
864    }
865
866    #[test]
867    fn test_round_state_display() {
868        assert_eq!(RoundState::Waiting.to_string(), "Waiting");
869        assert_eq!(RoundState::Collecting.to_string(), "Collecting");
870        assert_eq!(RoundState::Completed.to_string(), "Completed");
871    }
872
873    #[test]
874    fn test_client_info_touch() {
875        let mut client = ClientInfo::new("test".to_string(), vec![]);
876        let original = client.last_seen;
877
878        std::thread::sleep(std::time::Duration::from_millis(10));
879        client.touch();
880
881        assert!(client.last_seen > original);
882    }
883
884    #[test]
885    fn test_coordinator_config_defaults() {
886        let config = CoordinatorConfig::default();
887        assert_eq!(config.min_clients, 3);
888        assert_eq!(config.max_clients, 100);
889        assert!(config.byzantine_enabled);
890        assert!(config.privacy_enabled);
891    }
892
893    #[test]
894    fn test_list_clients() {
895        let config = CoordinatorConfig::default();
896        let coordinator = FederatedCoordinator::new(config);
897
898        for i in 0..3 {
899            let client = ClientInfo::new(format!("client_{}", i), vec![]);
900            coordinator.register_client(client).unwrap();
901        }
902
903        let clients = coordinator.list_clients();
904        assert_eq!(clients.len(), 3);
905    }
906}