1use std::collections::HashMap;
31use chrono::{DateTime, Utc};
32use ndarray::Array1;
33use parking_lot::RwLock;
34use serde::{Deserialize, Serialize};
35use uuid::Uuid;
36
37use crate::aggregation::{AggregatorConfig, GradientAggregator, WeightedAverageAggregator};
38use crate::byzantine::{ByzantineDetector, DetectorConfig, KrumDetector};
39use crate::error::{ProtocolError, Result};
40use crate::privacy::{GaussianMechanism, PrivacyBudget, PrivacyConfig, PrivacyMechanism};
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct CoordinatorConfig {
45 pub min_clients: usize,
47 pub max_clients: usize,
49 pub round_timeout_ms: u64,
51 pub byzantine_enabled: bool,
53 pub byzantine_fraction: f64,
55 pub privacy_enabled: bool,
57 pub privacy_epsilon: f64,
59 pub privacy_delta: f64,
61 pub max_privacy_budget: f64,
63 pub model_dimension: usize,
65}
66
67impl Default for CoordinatorConfig {
68 fn default() -> Self {
69 Self {
70 min_clients: 3,
71 max_clients: 100,
72 round_timeout_ms: 60000,
73 byzantine_enabled: true,
74 byzantine_fraction: 0.3,
75 privacy_enabled: true,
76 privacy_epsilon: 1.0,
77 privacy_delta: 1e-5,
78 max_privacy_budget: 10.0,
79 model_dimension: 1000,
80 }
81 }
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct ClientInfo {
87 pub client_id: String,
89 pub public_key: Vec<u8>,
91 pub registered_at: DateTime<Utc>,
93 pub last_seen: DateTime<Utc>,
95 pub rounds_participated: u64,
97 pub weight: f64,
99 pub active: bool,
101}
102
103impl ClientInfo {
104 pub fn new(client_id: String, public_key: Vec<u8>) -> Self {
106 let now = Utc::now();
107 Self {
108 client_id,
109 public_key,
110 registered_at: now,
111 last_seen: now,
112 rounds_participated: 0,
113 weight: 1.0,
114 active: true,
115 }
116 }
117
118 pub fn touch(&mut self) {
120 self.last_seen = Utc::now();
121 }
122}
123
124#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
126pub enum RoundState {
127 Waiting,
129 Collecting,
131 Aggregating,
133 Completed,
135 Failed,
137}
138
139impl std::fmt::Display for RoundState {
140 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
141 match self {
142 RoundState::Waiting => write!(f, "Waiting"),
143 RoundState::Collecting => write!(f, "Collecting"),
144 RoundState::Aggregating => write!(f, "Aggregating"),
145 RoundState::Completed => write!(f, "Completed"),
146 RoundState::Failed => write!(f, "Failed"),
147 }
148 }
149}
150
151#[derive(Debug, Clone, Serialize, Deserialize)]
153pub struct RoundInfo {
154 pub round_number: u64,
156 pub round_id: String,
158 pub state: RoundState,
160 pub started_at: DateTime<Utc>,
162 pub ended_at: Option<DateTime<Utc>>,
164 pub participants: Vec<String>,
166 pub gradients_received: usize,
168 pub byzantine_detected: Vec<String>,
170 pub privacy_spent: f64,
172}
173
174impl RoundInfo {
175 pub fn new(round_number: u64) -> Self {
177 Self {
178 round_number,
179 round_id: Uuid::new_v4().to_string(),
180 state: RoundState::Waiting,
181 started_at: Utc::now(),
182 ended_at: None,
183 participants: Vec::new(),
184 gradients_received: 0,
185 byzantine_detected: Vec::new(),
186 privacy_spent: 0.0,
187 }
188 }
189}
190
191#[derive(Debug, Clone, Serialize, Deserialize)]
193pub struct GradientSubmission {
194 pub client_id: String,
196 pub round_id: String,
198 pub gradient: Array1<f64>,
200 pub weight: f64,
202 pub signature: Vec<u8>,
204 pub submitted_at: DateTime<Utc>,
206}
207
208pub struct FederatedCoordinator {
210 config: CoordinatorConfig,
211 clients: RwLock<HashMap<String, ClientInfo>>,
213 current_round: RwLock<Option<RoundInfo>>,
215 round_history: RwLock<Vec<RoundInfo>>,
217 global_model: RwLock<Array1<f64>>,
219 gradients: RwLock<Vec<GradientSubmission>>,
221 privacy_budget: RwLock<PrivacyBudget>,
223}
224
225impl FederatedCoordinator {
226 pub fn new(config: CoordinatorConfig) -> Self {
228 let privacy_budget = PrivacyBudget::new(config.max_privacy_budget, config.privacy_delta);
229 let global_model = Array1::zeros(config.model_dimension);
230
231 Self {
232 config,
233 clients: RwLock::new(HashMap::new()),
234 current_round: RwLock::new(None),
235 round_history: RwLock::new(Vec::new()),
236 global_model: RwLock::new(global_model),
237 gradients: RwLock::new(Vec::new()),
238 privacy_budget: RwLock::new(privacy_budget),
239 }
240 }
241
242 pub fn register_client(&self, client: ClientInfo) -> Result<()> {
244 let mut clients = self.clients.write();
245
246 if clients.contains_key(&client.client_id) {
247 return Err(ProtocolError::DuplicateClient(client.client_id).into());
248 }
249
250 clients.insert(client.client_id.clone(), client);
251 Ok(())
252 }
253
254 pub fn unregister_client(&self, client_id: &str) -> Result<()> {
256 let mut clients = self.clients.write();
257
258 if clients.remove(client_id).is_none() {
259 return Err(ProtocolError::ClientNotRegistered(client_id.to_string()).into());
260 }
261
262 Ok(())
263 }
264
265 pub fn get_client(&self, client_id: &str) -> Option<ClientInfo> {
267 self.clients.read().get(client_id).cloned()
268 }
269
270 pub fn list_clients(&self) -> Vec<ClientInfo> {
272 self.clients.read().values().cloned().collect()
273 }
274
275 pub fn active_client_count(&self) -> usize {
277 self.clients.read().values().filter(|c| c.active).count()
278 }
279
280 pub fn start_round(&self) -> Result<RoundInfo> {
282 let mut current = self.current_round.write();
283
284 if let Some(ref round) = *current {
285 if round.state != RoundState::Completed && round.state != RoundState::Failed {
286 return Err(ProtocolError::RoundInProgress(round.round_number).into());
287 }
288 }
289
290 let round_number = self.round_history.read().len() as u64 + 1;
291 let mut round = RoundInfo::new(round_number);
292 round.state = RoundState::Collecting;
293
294 self.gradients.write().clear();
296
297 let round_info = round.clone();
298 *current = Some(round);
299
300 Ok(round_info)
301 }
302
303 pub fn submit_gradient(&self, submission: GradientSubmission) -> Result<()> {
305 {
307 let clients = self.clients.read();
308 if !clients.contains_key(&submission.client_id) {
309 return Err(ProtocolError::ClientNotRegistered(submission.client_id.clone()).into());
310 }
311 }
312
313 {
315 let current = self.current_round.read();
316 match current.as_ref() {
317 None => return Err(ProtocolError::NoActiveRound.into()),
318 Some(round) => {
319 if round.state != RoundState::Collecting {
320 return Err(ProtocolError::InvalidStateTransition {
321 from: round.state.to_string(),
322 to: "Collecting".to_string(),
323 }
324 .into());
325 }
326 if submission.round_id != round.round_id {
327 return Err(ProtocolError::NoActiveRound.into());
328 }
329 }
330 }
331 }
332
333 let mut gradients = self.gradients.write();
335 gradients.push(submission.clone());
336
337 {
339 let mut current = self.current_round.write();
340 if let Some(ref mut round) = *current {
341 round.gradients_received = gradients.len();
342 if !round.participants.contains(&submission.client_id) {
343 round.participants.push(submission.client_id.clone());
344 }
345 }
346 }
347
348 {
350 let mut clients = self.clients.write();
351 if let Some(client) = clients.get_mut(&submission.client_id) {
352 client.touch();
353 }
354 }
355
356 Ok(())
357 }
358
359 pub fn complete_round(&self) -> Result<Array1<f64>> {
361 {
363 let mut current = self.current_round.write();
364 match current.as_mut() {
365 None => return Err(ProtocolError::NoActiveRound.into()),
366 Some(round) => {
367 if round.state != RoundState::Collecting {
368 return Err(ProtocolError::InvalidStateTransition {
369 from: round.state.to_string(),
370 to: "Aggregating".to_string(),
371 }
372 .into());
373 }
374 round.state = RoundState::Aggregating;
375 }
376 }
377 }
378
379 let gradients = self.gradients.read();
380 let submissions: Vec<_> = gradients.iter().collect();
381
382 if submissions.is_empty() {
383 let mut current = self.current_round.write();
384 if let Some(ref mut round) = *current {
385 round.state = RoundState::Failed;
386 round.ended_at = Some(Utc::now());
387 }
388 return Err(ProtocolError::NoActiveRound.into());
389 }
390
391 let gradient_arrays: Vec<Array1<f64>> = submissions.iter().map(|s| s.gradient.clone()).collect();
393 let weights: Vec<f64> = submissions.iter().map(|s| s.weight).collect();
394
395 let (honest_indices, byzantine_indices) = if self.config.byzantine_enabled {
397 let detector_config = DetectorConfig::new(
398 self.config.min_clients,
399 self.config.byzantine_fraction,
400 );
401 let detector = KrumDetector::new(detector_config);
402 detector.detect(&gradient_arrays)?
403 } else {
404 ((0..gradient_arrays.len()).collect(), vec![])
405 };
406
407 {
409 let mut current = self.current_round.write();
410 if let Some(ref mut round) = *current {
411 round.byzantine_detected = byzantine_indices
412 .iter()
413 .filter_map(|&i| submissions.get(i).map(|s| s.client_id.clone()))
414 .collect();
415 }
416 }
417
418 let honest_gradients: Vec<Array1<f64>> = honest_indices
420 .iter()
421 .map(|&i| gradient_arrays[i].clone())
422 .collect();
423 let honest_weights: Vec<f64> = honest_indices.iter().map(|&i| weights[i]).collect();
424
425 let aggregator_config = AggregatorConfig::default();
427 let aggregator = WeightedAverageAggregator::new(aggregator_config);
428 let mut aggregate = aggregator.aggregate(&honest_gradients, &honest_weights)?;
429
430 if self.config.privacy_enabled {
432 let privacy_config = PrivacyConfig::new(
433 self.config.privacy_epsilon,
434 self.config.privacy_delta,
435 1.0, );
437 let mechanism = GaussianMechanism::new(privacy_config)?;
438 aggregate = mechanism.apply(&aggregate)?;
439
440 {
442 let mut budget = self.privacy_budget.write();
443 budget.spend(
444 self.config.privacy_epsilon,
445 self.config.privacy_delta,
446 &format!("round {}", self.current_round.read().as_ref().map(|r| r.round_number).unwrap_or(0)),
447 )?;
448 }
449
450 {
452 let mut current = self.current_round.write();
453 if let Some(ref mut round) = *current {
454 round.privacy_spent = self.config.privacy_epsilon;
455 }
456 }
457 }
458
459 {
461 let mut model = self.global_model.write();
462 *model = &*model + &aggregate;
463 }
464
465 {
467 let mut current = self.current_round.write();
468 if let Some(ref mut round) = *current {
469 round.state = RoundState::Completed;
470 round.ended_at = Some(Utc::now());
471
472 let mut clients = self.clients.write();
474 for client_id in &round.participants {
475 if let Some(client) = clients.get_mut(client_id) {
476 client.rounds_participated += 1;
477 }
478 }
479
480 self.round_history.write().push(round.clone());
482 }
483 }
484
485 Ok(aggregate)
486 }
487
488 pub fn get_global_model(&self) -> Array1<f64> {
490 self.global_model.read().clone()
491 }
492
493 pub fn get_current_round(&self) -> Option<RoundInfo> {
495 self.current_round.read().clone()
496 }
497
498 pub fn get_round_history(&self) -> Vec<RoundInfo> {
500 self.round_history.read().clone()
501 }
502
503 pub fn remaining_privacy_budget(&self) -> f64 {
505 self.privacy_budget.read().remaining()
506 }
507
508 pub fn config(&self) -> &CoordinatorConfig {
510 &self.config
511 }
512}
513
514pub struct FederatedClient {
516 client_id: String,
518 public_key: Vec<u8>,
520 _private_key: Vec<u8>,
522 local_model: Array1<f64>,
524 config: ClientConfig,
526}
527
528#[derive(Debug, Clone, Serialize, Deserialize)]
530pub struct ClientConfig {
531 pub local_epochs: usize,
533 pub batch_size: usize,
535 pub learning_rate: f64,
537 pub model_dimension: usize,
539}
540
541impl Default for ClientConfig {
542 fn default() -> Self {
543 Self {
544 local_epochs: 5,
545 batch_size: 32,
546 learning_rate: 0.01,
547 model_dimension: 1000,
548 }
549 }
550}
551
552impl FederatedClient {
553 pub fn new(client_id: String, config: ClientConfig) -> Self {
555 let mut rng = rand::thread_rng();
557 let public_key: Vec<u8> = (0..32).map(|_| rand::Rng::gen(&mut rng)).collect();
558 let private_key: Vec<u8> = (0..32).map(|_| rand::Rng::gen(&mut rng)).collect();
559
560 let local_model = Array1::zeros(config.model_dimension);
561
562 Self {
563 client_id,
564 public_key,
565 _private_key: private_key,
566 local_model,
567 config,
568 }
569 }
570
571 pub fn get_info(&self) -> ClientInfo {
573 ClientInfo::new(self.client_id.clone(), self.public_key.clone())
574 }
575
576 pub fn sync_model(&mut self, global_model: &Array1<f64>) {
578 self.local_model = global_model.clone();
579 }
580
581 pub fn train_local(&self, _data_size: usize) -> GradientSubmission {
583 let mut rng = rand::thread_rng();
585 let gradient: Vec<f64> = (0..self.config.model_dimension)
586 .map(|_| rand::Rng::gen_range(&mut rng, -0.1..0.1))
587 .collect();
588
589 GradientSubmission {
590 client_id: self.client_id.clone(),
591 round_id: String::new(), gradient: Array1::from_vec(gradient),
593 weight: _data_size as f64,
594 signature: vec![], submitted_at: Utc::now(),
596 }
597 }
598
599 pub fn client_id(&self) -> &str {
601 &self.client_id
602 }
603}
604
605#[derive(Debug, Clone, Serialize, Deserialize, Default)]
607pub struct ProtocolStats {
608 pub rounds_completed: u64,
610 pub rounds_failed: u64,
612 pub total_gradients: u64,
614 pub total_byzantine: u64,
616 pub total_privacy_spent: f64,
618 pub avg_round_duration_ms: f64,
620}
621
622impl ProtocolStats {
623 pub fn from_history(rounds: &[RoundInfo]) -> Self {
625 let mut stats = Self::default();
626
627 for round in rounds {
628 match round.state {
629 RoundState::Completed => stats.rounds_completed += 1,
630 RoundState::Failed => stats.rounds_failed += 1,
631 _ => {}
632 }
633
634 stats.total_gradients += round.gradients_received as u64;
635 stats.total_byzantine += round.byzantine_detected.len() as u64;
636 stats.total_privacy_spent += round.privacy_spent;
637
638 if let Some(ended) = round.ended_at {
639 let duration = (ended - round.started_at).num_milliseconds() as f64;
640 stats.avg_round_duration_ms = (stats.avg_round_duration_ms
641 * (stats.rounds_completed + stats.rounds_failed - 1) as f64
642 + duration)
643 / (stats.rounds_completed + stats.rounds_failed) as f64;
644 }
645 }
646
647 stats
648 }
649}
650
651#[cfg(test)]
652mod tests {
653 use super::*;
654
655 #[test]
656 fn test_coordinator_creation() {
657 let config = CoordinatorConfig::default();
658 let coordinator = FederatedCoordinator::new(config);
659
660 assert_eq!(coordinator.active_client_count(), 0);
661 }
662
663 #[test]
664 fn test_client_registration() {
665 let config = CoordinatorConfig::default();
666 let coordinator = FederatedCoordinator::new(config);
667
668 let client = ClientInfo::new("client_1".to_string(), vec![0u8; 32]);
669 coordinator.register_client(client).unwrap();
670
671 assert_eq!(coordinator.active_client_count(), 1);
672 assert!(coordinator.get_client("client_1").is_some());
673 }
674
675 #[test]
676 fn test_duplicate_registration() {
677 let config = CoordinatorConfig::default();
678 let coordinator = FederatedCoordinator::new(config);
679
680 let client = ClientInfo::new("client_1".to_string(), vec![0u8; 32]);
681 coordinator.register_client(client.clone()).unwrap();
682
683 let result = coordinator.register_client(client);
684 assert!(result.is_err());
685 }
686
687 #[test]
688 fn test_client_unregistration() {
689 let config = CoordinatorConfig::default();
690 let coordinator = FederatedCoordinator::new(config);
691
692 let client = ClientInfo::new("client_1".to_string(), vec![0u8; 32]);
693 coordinator.register_client(client).unwrap();
694 coordinator.unregister_client("client_1").unwrap();
695
696 assert_eq!(coordinator.active_client_count(), 0);
697 }
698
699 #[test]
700 fn test_start_round() {
701 let config = CoordinatorConfig::default();
702 let coordinator = FederatedCoordinator::new(config);
703
704 let round = coordinator.start_round().unwrap();
705 assert_eq!(round.round_number, 1);
706 assert_eq!(round.state, RoundState::Collecting);
707 }
708
709 #[test]
710 fn test_round_already_in_progress() {
711 let config = CoordinatorConfig::default();
712 let coordinator = FederatedCoordinator::new(config);
713
714 coordinator.start_round().unwrap();
715 let result = coordinator.start_round();
716 assert!(result.is_err());
717 }
718
719 #[test]
720 fn test_submit_gradient() {
721 let config = CoordinatorConfig::default();
722 let coordinator = FederatedCoordinator::new(config);
723
724 let client = ClientInfo::new("client_1".to_string(), vec![0u8; 32]);
725 coordinator.register_client(client).unwrap();
726
727 let round = coordinator.start_round().unwrap();
728
729 let submission = GradientSubmission {
730 client_id: "client_1".to_string(),
731 round_id: round.round_id,
732 gradient: Array1::zeros(1000),
733 weight: 100.0,
734 signature: vec![],
735 submitted_at: Utc::now(),
736 };
737
738 coordinator.submit_gradient(submission).unwrap();
739
740 let current = coordinator.get_current_round().unwrap();
741 assert_eq!(current.gradients_received, 1);
742 }
743
744 #[test]
745 fn test_submit_unregistered_client() {
746 let config = CoordinatorConfig::default();
747 let coordinator = FederatedCoordinator::new(config);
748
749 coordinator.start_round().unwrap();
750
751 let submission = GradientSubmission {
752 client_id: "unknown".to_string(),
753 round_id: "test".to_string(),
754 gradient: Array1::zeros(1000),
755 weight: 100.0,
756 signature: vec![],
757 submitted_at: Utc::now(),
758 };
759
760 let result = coordinator.submit_gradient(submission);
761 assert!(result.is_err());
762 }
763
764 #[test]
765 fn test_full_round() {
766 let mut config = CoordinatorConfig::default();
767 config.byzantine_enabled = false; config.privacy_enabled = false;
769 config.model_dimension = 10;
770
771 let coordinator = FederatedCoordinator::new(config);
772
773 for i in 0..5 {
775 let client = ClientInfo::new(format!("client_{}", i), vec![0u8; 32]);
776 coordinator.register_client(client).unwrap();
777 }
778
779 let round = coordinator.start_round().unwrap();
781
782 for i in 0..5 {
784 let submission = GradientSubmission {
785 client_id: format!("client_{}", i),
786 round_id: round.round_id.clone(),
787 gradient: Array1::from_vec(vec![0.1; 10]),
788 weight: 100.0,
789 signature: vec![],
790 submitted_at: Utc::now(),
791 };
792 coordinator.submit_gradient(submission).unwrap();
793 }
794
795 let aggregate = coordinator.complete_round().unwrap();
797 assert_eq!(aggregate.len(), 10);
798
799 let history = coordinator.get_round_history();
801 assert_eq!(history.len(), 1);
802 assert_eq!(history[0].state, RoundState::Completed);
803 }
804
805 #[test]
806 fn test_federated_client() {
807 let config = ClientConfig::default();
808 let client = FederatedClient::new("test_client".to_string(), config);
809
810 let info = client.get_info();
811 assert_eq!(info.client_id, "test_client");
812 assert_eq!(info.public_key.len(), 32);
813 }
814
815 #[test]
816 fn test_client_train_local() {
817 let config = ClientConfig {
818 model_dimension: 10,
819 ..Default::default()
820 };
821 let client = FederatedClient::new("test".to_string(), config);
822
823 let submission = client.train_local(100);
824 assert_eq!(submission.gradient.len(), 10);
825 assert_eq!(submission.weight, 100.0);
826 }
827
828 #[test]
829 fn test_client_sync_model() {
830 let config = ClientConfig {
831 model_dimension: 10,
832 ..Default::default()
833 };
834 let mut client = FederatedClient::new("test".to_string(), config);
835
836 let global = Array1::from_vec(vec![1.0; 10]);
837 client.sync_model(&global);
838
839 assert_eq!(client.local_model.len(), 10);
841 }
842
843 #[test]
844 fn test_protocol_stats() {
845 let rounds = vec![
846 RoundInfo {
847 round_number: 1,
848 round_id: "r1".to_string(),
849 state: RoundState::Completed,
850 started_at: Utc::now(),
851 ended_at: Some(Utc::now()),
852 participants: vec!["a".to_string()],
853 gradients_received: 5,
854 byzantine_detected: vec!["b".to_string()],
855 privacy_spent: 1.0,
856 },
857 ];
858
859 let stats = ProtocolStats::from_history(&rounds);
860 assert_eq!(stats.rounds_completed, 1);
861 assert_eq!(stats.total_gradients, 5);
862 assert_eq!(stats.total_byzantine, 1);
863 assert!((stats.total_privacy_spent - 1.0).abs() < 1e-10);
864 }
865
866 #[test]
867 fn test_round_state_display() {
868 assert_eq!(RoundState::Waiting.to_string(), "Waiting");
869 assert_eq!(RoundState::Collecting.to_string(), "Collecting");
870 assert_eq!(RoundState::Completed.to_string(), "Completed");
871 }
872
873 #[test]
874 fn test_client_info_touch() {
875 let mut client = ClientInfo::new("test".to_string(), vec![]);
876 let original = client.last_seen;
877
878 std::thread::sleep(std::time::Duration::from_millis(10));
879 client.touch();
880
881 assert!(client.last_seen > original);
882 }
883
884 #[test]
885 fn test_coordinator_config_defaults() {
886 let config = CoordinatorConfig::default();
887 assert_eq!(config.min_clients, 3);
888 assert_eq!(config.max_clients, 100);
889 assert!(config.byzantine_enabled);
890 assert!(config.privacy_enabled);
891 }
892
893 #[test]
894 fn test_list_clients() {
895 let config = CoordinatorConfig::default();
896 let coordinator = FederatedCoordinator::new(config);
897
898 for i in 0..3 {
899 let client = ClientInfo::new(format!("client_{}", i), vec![]);
900 coordinator.register_client(client).unwrap();
901 }
902
903 let clients = coordinator.list_clients();
904 assert_eq!(clients.len(), 3);
905 }
906}