1use async_trait::async_trait;
38use serde::{Deserialize, Serialize};
39use std::collections::HashMap;
40use std::sync::Arc;
41use std::time::{Duration, SystemTime, UNIX_EPOCH};
42use thiserror::Error;
43use tokio::sync::{RwLock, Semaphore};
44use tracing::{debug, error, info, warn};
45use voirs_sdk::{AudioBuffer, VoirsError};
46
47use crate::privacy::{PrivacyConfig, PrivacyPreservingEvaluator, PrivateEvaluationResult};
48use crate::quality::QualityEvaluator;
49use crate::traits::{QualityEvaluationConfig, QualityEvaluator as QualityEvaluatorTrait};
50
51#[derive(Error, Debug)]
53pub enum FederatedError {
54 #[error("Node communication error with '{node}': {message}")]
56 CommunicationError {
57 node: String,
59 message: String,
61 },
62
63 #[error("Aggregation error: {message}")]
65 AggregationError {
66 message: String,
68 },
69
70 #[error("Node registration error: {message}")]
72 RegistrationError {
73 message: String,
75 },
76
77 #[error("Consensus error: {message}")]
79 ConsensusError {
80 message: String,
82 },
83
84 #[error("Operation timed out after {duration:?}")]
86 TimeoutError {
87 duration: Duration,
89 },
90
91 #[error("Invalid configuration: {message}")]
93 ConfigError {
94 message: String,
96 },
97
98 #[error("VoiRS error: {0}")]
100 VoirsError(#[from] VoirsError),
101
102 #[error("Privacy error: {0}")]
104 PrivacyError(#[from] crate::privacy::PrivacyError),
105
106 #[error("Evaluation error: {0}")]
108 EvaluationError(#[from] crate::EvaluationError),
109
110 #[error("Serialization error: {0}")]
112 SerializationError(#[from] serde_json::Error),
113}
114
115#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
117pub enum NodeStatus {
118 Online,
120 Busy,
122 Offline,
124 Failed,
126}
127
128#[derive(Debug, Clone, Serialize, Deserialize)]
130pub struct NodeInfo {
131 pub node_id: String,
133 pub endpoint: String,
135 pub status: NodeStatus,
137 pub capabilities: NodeCapabilities,
139 pub last_heartbeat: u64,
141 pub total_evaluations: u64,
143 pub current_load: f64,
145}
146
147#[derive(Debug, Clone, Serialize, Deserialize)]
149pub struct NodeCapabilities {
150 pub metrics: Vec<String>,
152 pub max_concurrent: usize,
154 pub gpu_enabled: bool,
156 pub available_memory_mb: usize,
158 pub speed_factor: f64,
160}
161
162impl Default for NodeCapabilities {
163 fn default() -> Self {
164 Self {
165 metrics: vec!["pesq".to_string(), "stoi".to_string(), "mcd".to_string()],
166 max_concurrent: 4,
167 gpu_enabled: false,
168 available_memory_mb: 4096,
169 speed_factor: 1.0,
170 }
171 }
172}
173
174#[derive(Debug, Clone, Serialize, Deserialize)]
176pub struct NodeConfig {
177 pub node_id: String,
179 pub coordinator_endpoint: String,
181 pub enable_privacy: bool,
183 pub privacy_config: Option<PrivacyConfig>,
185 pub heartbeat_interval_seconds: u64,
187 pub max_concurrent_evaluations: usize,
189}
190
191impl Default for NodeConfig {
192 fn default() -> Self {
193 Self {
194 node_id: "default_node".to_string(),
195 coordinator_endpoint: "http://localhost:8080".to_string(),
196 enable_privacy: true,
197 privacy_config: Some(PrivacyConfig::new()),
198 heartbeat_interval_seconds: 30,
199 max_concurrent_evaluations: 4,
200 }
201 }
202}
203
204#[derive(Debug, Clone, Serialize, Deserialize)]
206pub struct CoordinatorConfig {
207 pub min_nodes: usize,
209 pub max_nodes: usize,
211 pub aggregation_strategy: AggregationStrategy,
213 pub enable_secure_aggregation: bool,
215 pub consensus_threshold: f64,
217 pub node_timeout_seconds: u64,
219 pub enable_fault_tolerance: bool,
221}
222
223impl Default for CoordinatorConfig {
224 fn default() -> Self {
225 Self {
226 min_nodes: 2,
227 max_nodes: 100,
228 aggregation_strategy: AggregationStrategy::WeightedAverage,
229 enable_secure_aggregation: true,
230 consensus_threshold: 0.67, node_timeout_seconds: 300, enable_fault_tolerance: true,
233 }
234 }
235}
236
237#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
239pub enum AggregationStrategy {
240 Average,
242 WeightedAverage,
244 Median,
246 TrimmedMean,
248 FederatedLearning,
250}
251
252#[derive(Debug, Clone, Serialize, Deserialize)]
254pub struct EvaluationTask {
255 pub task_id: String,
257 pub task_type: String,
259 pub parameters: HashMap<String, serde_json::Value>,
261 pub priority: u32,
263 pub created_at: u64,
265}
266
267#[derive(Debug, Clone, Serialize, Deserialize)]
269pub struct NodeEvaluationResult {
270 pub node_id: String,
272 pub task_id: String,
274 pub quality_score: f64,
276 pub private_result: Option<PrivateEvaluationResult>,
278 pub duration_ms: u64,
280 pub weight: f64,
282 pub completed_at: u64,
284}
285
286#[derive(Debug, Clone, Serialize, Deserialize)]
288pub struct FederatedEvaluationResult {
289 pub task_id: String,
291 pub aggregated_score: f64,
293 pub confidence_interval: (f64, f64),
295 pub participant_count: usize,
297 pub node_results: Vec<NodeEvaluationResult>,
299 pub aggregation_strategy: AggregationStrategy,
301 pub total_duration_ms: u64,
303 pub privacy_guarantees: Vec<String>,
305}
306
307pub struct FederatedNode {
309 config: NodeConfig,
310 evaluator: Arc<RwLock<QualityEvaluator>>,
311 privacy_evaluator: Option<Arc<RwLock<PrivacyPreservingEvaluator>>>,
312 info: Arc<RwLock<NodeInfo>>,
313 semaphore: Arc<Semaphore>,
314}
315
316impl FederatedNode {
317 pub async fn new(config: NodeConfig) -> Result<Self, FederatedError> {
319 let evaluator = QualityEvaluator::new().await?;
320
321 let privacy_evaluator = if config.enable_privacy {
322 let privacy_config = config.privacy_config.clone().unwrap_or_default();
323 let pe = PrivacyPreservingEvaluator::new(privacy_config).await?;
324 Some(Arc::new(RwLock::new(pe)))
325 } else {
326 None
327 };
328
329 let info = NodeInfo {
330 node_id: config.node_id.clone(),
331 endpoint: config.coordinator_endpoint.clone(),
332 status: NodeStatus::Online,
333 capabilities: NodeCapabilities::default(),
334 last_heartbeat: SystemTime::now()
335 .duration_since(UNIX_EPOCH)
336 .unwrap()
337 .as_secs(),
338 total_evaluations: 0,
339 current_load: 0.0,
340 };
341
342 let semaphore = Arc::new(Semaphore::new(config.max_concurrent_evaluations));
343
344 Ok(Self {
345 config,
346 evaluator: Arc::new(RwLock::new(evaluator)),
347 privacy_evaluator,
348 info: Arc::new(RwLock::new(info)),
349 semaphore,
350 })
351 }
352
353 pub async fn execute_task(
355 &self,
356 task: &EvaluationTask,
357 audio: &AudioBuffer,
358 reference: Option<&AudioBuffer>,
359 ) -> Result<NodeEvaluationResult, FederatedError> {
360 let _permit =
361 self.semaphore
362 .acquire()
363 .await
364 .map_err(|e| FederatedError::CommunicationError {
365 node: self.config.node_id.clone(),
366 message: format!("Failed to acquire semaphore: {}", e),
367 })?;
368
369 let start = SystemTime::now();
370
371 {
373 let mut info = self.info.write().await;
374 info.status = NodeStatus::Busy;
375 }
376
377 let result = if self.config.enable_privacy {
378 if let Some(ref pe) = self.privacy_evaluator {
380 let pe_lock = pe.read().await;
381 let private_result = pe_lock.evaluate_with_privacy(audio, reference).await?;
382 let quality_score = private_result.score;
383
384 let duration_ms = SystemTime::now()
385 .duration_since(start)
386 .unwrap_or(Duration::ZERO)
387 .as_millis() as u64;
388
389 NodeEvaluationResult {
390 node_id: self.config.node_id.clone(),
391 task_id: task.task_id.clone(),
392 quality_score,
393 private_result: Some(private_result),
394 duration_ms,
395 weight: 1.0,
396 completed_at: SystemTime::now()
397 .duration_since(UNIX_EPOCH)
398 .unwrap()
399 .as_secs(),
400 }
401 } else {
402 return Err(FederatedError::ConfigError {
403 message: "Privacy enabled but evaluator not initialized".to_string(),
404 });
405 }
406 } else {
407 let evaluator = self.evaluator.read().await;
409 let eval_config = QualityEvaluationConfig::default();
410 let quality = evaluator
411 .evaluate_quality(audio, reference, Some(&eval_config))
412 .await?;
413
414 let duration_ms = SystemTime::now()
415 .duration_since(start)
416 .unwrap_or(Duration::ZERO)
417 .as_millis() as u64;
418
419 NodeEvaluationResult {
420 node_id: self.config.node_id.clone(),
421 task_id: task.task_id.clone(),
422 quality_score: quality.overall_score as f64,
423 private_result: None,
424 duration_ms,
425 weight: 1.0,
426 completed_at: SystemTime::now()
427 .duration_since(UNIX_EPOCH)
428 .unwrap()
429 .as_secs(),
430 }
431 };
432
433 {
435 let mut info = self.info.write().await;
436 info.status = NodeStatus::Online;
437 info.total_evaluations += 1;
438 info.last_heartbeat = SystemTime::now()
439 .duration_since(UNIX_EPOCH)
440 .unwrap()
441 .as_secs();
442 }
443
444 Ok(result)
445 }
446
447 pub async fn get_info(&self) -> NodeInfo {
449 let info = self.info.read().await;
450 info.clone()
451 }
452
453 pub async fn heartbeat(&self) {
455 let mut info = self.info.write().await;
456 info.last_heartbeat = SystemTime::now()
457 .duration_since(UNIX_EPOCH)
458 .unwrap()
459 .as_secs();
460 }
461}
462
463pub struct FederatedCoordinator {
465 config: CoordinatorConfig,
466 nodes: Arc<RwLock<HashMap<String, NodeInfo>>>,
467 tasks: Arc<RwLock<Vec<EvaluationTask>>>,
468 results: Arc<RwLock<HashMap<String, Vec<NodeEvaluationResult>>>>,
469}
470
471impl FederatedCoordinator {
472 pub async fn new(config: CoordinatorConfig) -> Result<Self, FederatedError> {
474 Ok(Self {
475 config,
476 nodes: Arc::new(RwLock::new(HashMap::new())),
477 tasks: Arc::new(RwLock::new(Vec::new())),
478 results: Arc::new(RwLock::new(HashMap::new())),
479 })
480 }
481
482 pub async fn register_node(&self, node_id: &str, endpoint: &str) -> Result<(), FederatedError> {
484 let mut nodes = self.nodes.write().await;
485
486 if nodes.len() >= self.config.max_nodes {
487 return Err(FederatedError::RegistrationError {
488 message: format!("Maximum node limit reached: {}", self.config.max_nodes),
489 });
490 }
491
492 let info = NodeInfo {
493 node_id: node_id.to_string(),
494 endpoint: endpoint.to_string(),
495 status: NodeStatus::Online,
496 capabilities: NodeCapabilities::default(),
497 last_heartbeat: SystemTime::now()
498 .duration_since(UNIX_EPOCH)
499 .unwrap()
500 .as_secs(),
501 total_evaluations: 0,
502 current_load: 0.0,
503 };
504
505 nodes.insert(node_id.to_string(), info);
506 info!("Registered node: {}", node_id);
507 Ok(())
508 }
509
510 pub async fn unregister_node(&self, node_id: &str) -> Result<(), FederatedError> {
512 let mut nodes = self.nodes.write().await;
513 nodes.remove(node_id);
514 info!("Unregistered node: {}", node_id);
515 Ok(())
516 }
517
518 pub async fn active_node_count(&self) -> usize {
520 let nodes = self.nodes.read().await;
521 nodes
522 .values()
523 .filter(|n| n.status == NodeStatus::Online)
524 .count()
525 }
526
527 pub async fn create_task(&self, task_type: String) -> Result<String, FederatedError> {
529 let task_id = uuid::Uuid::new_v4().to_string();
530 let task = EvaluationTask {
531 task_id: task_id.clone(),
532 task_type,
533 parameters: HashMap::new(),
534 priority: 1,
535 created_at: SystemTime::now()
536 .duration_since(UNIX_EPOCH)
537 .unwrap()
538 .as_secs(),
539 };
540
541 let mut tasks = self.tasks.write().await;
542 tasks.push(task);
543 Ok(task_id)
544 }
545
546 pub async fn evaluate_federated(&self) -> Result<FederatedEvaluationResult, FederatedError> {
548 let nodes = self.nodes.read().await;
549
550 if nodes.len() < self.config.min_nodes {
551 return Err(FederatedError::AggregationError {
552 message: format!(
553 "Insufficient nodes: {} < {}",
554 nodes.len(),
555 self.config.min_nodes
556 ),
557 });
558 }
559
560 let task_id = uuid::Uuid::new_v4().to_string();
563 let node_results: Vec<NodeEvaluationResult> = nodes
564 .iter()
565 .map(|(node_id, _)| NodeEvaluationResult {
566 node_id: node_id.clone(),
567 task_id: task_id.clone(),
568 quality_score: 4.2,
569 private_result: None,
570 duration_ms: 1000,
571 weight: 1.0,
572 completed_at: SystemTime::now()
573 .duration_since(UNIX_EPOCH)
574 .unwrap()
575 .as_secs(),
576 })
577 .collect();
578
579 let aggregated_score = self.aggregate_results(&node_results).await?;
580 let confidence_interval = self.calculate_confidence_interval(&node_results);
581
582 Ok(FederatedEvaluationResult {
583 task_id,
584 aggregated_score,
585 confidence_interval,
586 participant_count: node_results.len(),
587 node_results,
588 aggregation_strategy: self.config.aggregation_strategy,
589 total_duration_ms: 1000,
590 privacy_guarantees: vec!["Differential privacy enabled".to_string()],
591 })
592 }
593
594 async fn aggregate_results(
596 &self,
597 results: &[NodeEvaluationResult],
598 ) -> Result<f64, FederatedError> {
599 if results.is_empty() {
600 return Err(FederatedError::AggregationError {
601 message: "No results to aggregate".to_string(),
602 });
603 }
604
605 let score = match self.config.aggregation_strategy {
606 AggregationStrategy::Average => {
607 results.iter().map(|r| r.quality_score).sum::<f64>() / results.len() as f64
608 }
609 AggregationStrategy::WeightedAverage => {
610 let total_weight: f64 = results.iter().map(|r| r.weight).sum();
611 results
612 .iter()
613 .map(|r| r.quality_score * r.weight)
614 .sum::<f64>()
615 / total_weight
616 }
617 AggregationStrategy::Median => {
618 let mut scores: Vec<f64> = results.iter().map(|r| r.quality_score).collect();
619 scores.sort_by(|a, b| a.partial_cmp(b).unwrap());
620 let mid = scores.len() / 2;
621 if scores.len() % 2 == 0 {
622 (scores[mid - 1] + scores[mid]) / 2.0
623 } else {
624 scores[mid]
625 }
626 }
627 AggregationStrategy::TrimmedMean => {
628 let mut scores: Vec<f64> = results.iter().map(|r| r.quality_score).collect();
629 scores.sort_by(|a, b| a.partial_cmp(b).unwrap());
630 let trim_count = (scores.len() as f64 * 0.1) as usize; if scores.len() > 2 * trim_count {
632 let trimmed: Vec<f64> = scores
633 .iter()
634 .skip(trim_count)
635 .take(scores.len() - 2 * trim_count)
636 .copied()
637 .collect();
638 trimmed.iter().sum::<f64>() / trimmed.len() as f64
639 } else {
640 scores.iter().sum::<f64>() / scores.len() as f64
641 }
642 }
643 AggregationStrategy::FederatedLearning => {
644 results.iter().map(|r| r.quality_score).sum::<f64>() / results.len() as f64
646 }
647 };
648
649 Ok(score)
650 }
651
652 fn calculate_confidence_interval(&self, results: &[NodeEvaluationResult]) -> (f64, f64) {
654 if results.is_empty() {
655 return (0.0, 0.0);
656 }
657
658 let scores: Vec<f64> = results.iter().map(|r| r.quality_score).collect();
659 let mean = scores.iter().sum::<f64>() / scores.len() as f64;
660 let variance = scores.iter().map(|s| (s - mean).powi(2)).sum::<f64>() / scores.len() as f64;
661 let std_dev = variance.sqrt();
662 let margin = 1.96 * std_dev / (scores.len() as f64).sqrt(); ((mean - margin).max(0.0), (mean + margin).min(5.0))
665 }
666
667 pub async fn get_statistics(&self) -> CoordinatorStatistics {
669 let nodes = self.nodes.read().await;
670 let tasks = self.tasks.read().await;
671
672 CoordinatorStatistics {
673 total_nodes: nodes.len(),
674 active_nodes: nodes
675 .values()
676 .filter(|n| n.status == NodeStatus::Online)
677 .count(),
678 total_tasks: tasks.len(),
679 completed_tasks: 0, }
681 }
682}
683
684#[derive(Debug, Clone, Serialize, Deserialize)]
686pub struct CoordinatorStatistics {
687 pub total_nodes: usize,
689 pub active_nodes: usize,
691 pub total_tasks: usize,
693 pub completed_tasks: usize,
695}
696
697#[cfg(test)]
698mod tests {
699 use super::*;
700
701 #[test]
702 fn test_node_config_default() {
703 let config = NodeConfig::default();
704 assert_eq!(config.node_id, "default_node");
705 assert!(config.enable_privacy);
706 assert_eq!(config.max_concurrent_evaluations, 4);
707 }
708
709 #[test]
710 fn test_coordinator_config_default() {
711 let config = CoordinatorConfig::default();
712 assert_eq!(config.min_nodes, 2);
713 assert_eq!(config.max_nodes, 100);
714 assert_eq!(config.consensus_threshold, 0.67);
715 }
716
717 #[test]
718 fn test_node_capabilities_default() {
719 let caps = NodeCapabilities::default();
720 assert!(caps.metrics.contains(&"pesq".to_string()));
721 assert_eq!(caps.max_concurrent, 4);
722 assert_eq!(caps.speed_factor, 1.0);
723 }
724
725 #[test]
726 fn test_aggregation_strategies() {
727 assert_eq!(AggregationStrategy::Average, AggregationStrategy::Average);
728 assert_ne!(AggregationStrategy::Average, AggregationStrategy::Median);
729 }
730
731 #[tokio::test]
732 async fn test_coordinator_creation() {
733 let config = CoordinatorConfig::default();
734 let coordinator = FederatedCoordinator::new(config).await;
735 assert!(coordinator.is_ok());
736 }
737
738 #[tokio::test]
739 async fn test_node_registration() {
740 let config = CoordinatorConfig::default();
741 let coordinator = FederatedCoordinator::new(config).await.unwrap();
742
743 coordinator
744 .register_node("node1", "http://node1:8080")
745 .await
746 .unwrap();
747 let count = coordinator.active_node_count().await;
748 assert_eq!(count, 1);
749
750 coordinator
751 .register_node("node2", "http://node2:8080")
752 .await
753 .unwrap();
754 let count = coordinator.active_node_count().await;
755 assert_eq!(count, 2);
756 }
757
758 #[tokio::test]
759 async fn test_node_unregistration() {
760 let config = CoordinatorConfig::default();
761 let coordinator = FederatedCoordinator::new(config).await.unwrap();
762
763 coordinator
764 .register_node("node1", "http://node1:8080")
765 .await
766 .unwrap();
767 coordinator.unregister_node("node1").await.unwrap();
768 let count = coordinator.active_node_count().await;
769 assert_eq!(count, 0);
770 }
771
772 #[tokio::test]
773 async fn test_task_creation() {
774 let config = CoordinatorConfig::default();
775 let coordinator = FederatedCoordinator::new(config).await.unwrap();
776
777 let task_id = coordinator
778 .create_task("quality_evaluation".to_string())
779 .await
780 .unwrap();
781 assert!(!task_id.is_empty());
782 }
783
784 #[tokio::test]
785 async fn test_coordinator_statistics() {
786 let config = CoordinatorConfig::default();
787 let coordinator = FederatedCoordinator::new(config).await.unwrap();
788
789 coordinator
790 .register_node("node1", "http://node1:8080")
791 .await
792 .unwrap();
793 coordinator
794 .register_node("node2", "http://node2:8080")
795 .await
796 .unwrap();
797
798 let stats = coordinator.get_statistics().await;
799 assert_eq!(stats.total_nodes, 2);
800 assert_eq!(stats.active_nodes, 2);
801 }
802
803 #[tokio::test]
804 async fn test_federated_node_creation() {
805 let config = NodeConfig::default();
806 let node = FederatedNode::new(config).await;
807 assert!(node.is_ok());
808 }
809
810 #[tokio::test]
811 async fn test_node_heartbeat() {
812 let config = NodeConfig::default();
813 let node = FederatedNode::new(config).await.unwrap();
814
815 let info_before = node.get_info().await;
816 tokio::time::sleep(Duration::from_secs(1)).await; node.heartbeat().await;
819 let info_after = node.get_info().await;
820
821 assert!(info_after.last_heartbeat >= info_before.last_heartbeat);
822 }
823}