Skip to main content

voirs_evaluation/
federated.rs

1//! Federated Evaluation System
2//!
3//! Distributed speech synthesis evaluation across multiple nodes without centralizing data.
4//! Enables privacy-preserving collaborative evaluation with secure aggregation.
5//!
6//! # Features
7//!
8//! - **Distributed Processing**: Evaluate models across geographically distributed datasets
9//! - **Secure Aggregation**: Combine results without exposing individual node data
10//! - **Privacy Preservation**: Built-in differential privacy and encryption
11//! - **Fault Tolerance**: Handle node failures gracefully with redundancy
12//! - **Load Balancing**: Distribute evaluation workload efficiently
13//! - **Progress Tracking**: Monitor federated evaluation progress in real-time
14//!
15//! # Example
16//!
17//! ```rust
18//! use voirs_evaluation::federated::{
19//!     FederatedCoordinator, FederatedNode, NodeConfig, CoordinatorConfig
20//! };
21//!
22//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
23//! // Setup coordinator
24//! let coordinator = FederatedCoordinator::new(CoordinatorConfig::default()).await?;
25//!
26//! // Register evaluation nodes
27//! coordinator.register_node("node1", "http://node1:8080").await?;
28//! coordinator.register_node("node2", "http://node2:8080").await?;
29//!
30//! // Start federated evaluation
31//! let result = coordinator.evaluate_federated().await?;
32//! println!("Federated evaluation score: {:.3}", result.aggregated_score);
33//! # Ok(())
34//! # }
35//! ```
36
37use async_trait::async_trait;
38use serde::{Deserialize, Serialize};
39use std::collections::HashMap;
40use std::sync::Arc;
41use std::time::{Duration, SystemTime, UNIX_EPOCH};
42use thiserror::Error;
43use tokio::sync::{RwLock, Semaphore};
44use tracing::{debug, error, info, warn};
45use voirs_sdk::{AudioBuffer, VoirsError};
46
47use crate::privacy::{PrivacyConfig, PrivacyPreservingEvaluator, PrivateEvaluationResult};
48use crate::quality::QualityEvaluator;
49use crate::traits::{QualityEvaluationConfig, QualityEvaluator as QualityEvaluatorTrait};
50
51/// Federated system errors
52#[derive(Error, Debug)]
53pub enum FederatedError {
54    /// Node communication error
55    #[error("Node communication error with '{node}': {message}")]
56    CommunicationError {
57        /// Node identifier
58        node: String,
59        /// Error message
60        message: String,
61    },
62
63    /// Aggregation error
64    #[error("Aggregation error: {message}")]
65    AggregationError {
66        /// Error message
67        message: String,
68    },
69
70    /// Node registration error
71    #[error("Node registration error: {message}")]
72    RegistrationError {
73        /// Error message
74        message: String,
75    },
76
77    /// Consensus error
78    #[error("Consensus error: {message}")]
79    ConsensusError {
80        /// Error message
81        message: String,
82    },
83
84    /// Timeout error
85    #[error("Operation timed out after {duration:?}")]
86    TimeoutError {
87        /// Timeout duration
88        duration: Duration,
89    },
90
91    /// Invalid configuration
92    #[error("Invalid configuration: {message}")]
93    ConfigError {
94        /// Error message
95        message: String,
96    },
97
98    /// VoiRS error
99    #[error("VoiRS error: {0}")]
100    VoirsError(#[from] VoirsError),
101
102    /// Privacy error
103    #[error("Privacy error: {0}")]
104    PrivacyError(#[from] crate::privacy::PrivacyError),
105
106    /// Evaluation error
107    #[error("Evaluation error: {0}")]
108    EvaluationError(#[from] crate::EvaluationError),
109
110    /// Serialization error
111    #[error("Serialization error: {0}")]
112    SerializationError(#[from] serde_json::Error),
113}
114
115/// Node status
116#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
117pub enum NodeStatus {
118    /// Node is online and ready
119    Online,
120    /// Node is busy processing
121    Busy,
122    /// Node is offline
123    Offline,
124    /// Node failed
125    Failed,
126}
127
128/// Node information
129#[derive(Debug, Clone, Serialize, Deserialize)]
130pub struct NodeInfo {
131    /// Node identifier
132    pub node_id: String,
133    /// Node endpoint URL
134    pub endpoint: String,
135    /// Node status
136    pub status: NodeStatus,
137    /// Node capabilities
138    pub capabilities: NodeCapabilities,
139    /// Last heartbeat timestamp
140    pub last_heartbeat: u64,
141    /// Total evaluations completed
142    pub total_evaluations: u64,
143    /// Current load (0.0-1.0)
144    pub current_load: f64,
145}
146
147/// Node capabilities
148#[derive(Debug, Clone, Serialize, Deserialize)]
149pub struct NodeCapabilities {
150    /// Supported quality metrics
151    pub metrics: Vec<String>,
152    /// Maximum concurrent evaluations
153    pub max_concurrent: usize,
154    /// Supports GPU acceleration
155    pub gpu_enabled: bool,
156    /// Available memory (MB)
157    pub available_memory_mb: usize,
158    /// Processing speed factor (relative to baseline)
159    pub speed_factor: f64,
160}
161
162impl Default for NodeCapabilities {
163    fn default() -> Self {
164        Self {
165            metrics: vec!["pesq".to_string(), "stoi".to_string(), "mcd".to_string()],
166            max_concurrent: 4,
167            gpu_enabled: false,
168            available_memory_mb: 4096,
169            speed_factor: 1.0,
170        }
171    }
172}
173
174/// Node configuration
175#[derive(Debug, Clone, Serialize, Deserialize)]
176pub struct NodeConfig {
177    /// Node identifier
178    pub node_id: String,
179    /// Coordinator endpoint
180    pub coordinator_endpoint: String,
181    /// Enable privacy preservation
182    pub enable_privacy: bool,
183    /// Privacy configuration
184    pub privacy_config: Option<PrivacyConfig>,
185    /// Heartbeat interval (seconds)
186    pub heartbeat_interval_seconds: u64,
187    /// Maximum concurrent evaluations
188    pub max_concurrent_evaluations: usize,
189}
190
191impl Default for NodeConfig {
192    fn default() -> Self {
193        Self {
194            node_id: "default_node".to_string(),
195            coordinator_endpoint: "http://localhost:8080".to_string(),
196            enable_privacy: true,
197            privacy_config: Some(PrivacyConfig::new()),
198            heartbeat_interval_seconds: 30,
199            max_concurrent_evaluations: 4,
200        }
201    }
202}
203
204/// Coordinator configuration
205#[derive(Debug, Clone, Serialize, Deserialize)]
206pub struct CoordinatorConfig {
207    /// Minimum nodes required
208    pub min_nodes: usize,
209    /// Maximum nodes allowed
210    pub max_nodes: usize,
211    /// Aggregation strategy
212    pub aggregation_strategy: AggregationStrategy,
213    /// Enable secure aggregation
214    pub enable_secure_aggregation: bool,
215    /// Consensus threshold (fraction of nodes that must agree)
216    pub consensus_threshold: f64,
217    /// Timeout for node responses (seconds)
218    pub node_timeout_seconds: u64,
219    /// Enable fault tolerance
220    pub enable_fault_tolerance: bool,
221}
222
223impl Default for CoordinatorConfig {
224    fn default() -> Self {
225        Self {
226            min_nodes: 2,
227            max_nodes: 100,
228            aggregation_strategy: AggregationStrategy::WeightedAverage,
229            enable_secure_aggregation: true,
230            consensus_threshold: 0.67, // 2/3 majority
231            node_timeout_seconds: 300, // 5 minutes
232            enable_fault_tolerance: true,
233        }
234    }
235}
236
237/// Aggregation strategy
238#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
239pub enum AggregationStrategy {
240    /// Simple average
241    Average,
242    /// Weighted average by node reliability
243    WeightedAverage,
244    /// Median of all results
245    Median,
246    /// Trimmed mean (remove outliers)
247    TrimmedMean,
248    /// Federated learning style aggregation
249    FederatedLearning,
250}
251
252/// Evaluation task
253#[derive(Debug, Clone, Serialize, Deserialize)]
254pub struct EvaluationTask {
255    /// Task identifier
256    pub task_id: String,
257    /// Task type
258    pub task_type: String,
259    /// Task parameters
260    pub parameters: HashMap<String, serde_json::Value>,
261    /// Priority (higher = more important)
262    pub priority: u32,
263    /// Created timestamp
264    pub created_at: u64,
265}
266
267/// Node evaluation result
268#[derive(Debug, Clone, Serialize, Deserialize)]
269pub struct NodeEvaluationResult {
270    /// Node identifier
271    pub node_id: String,
272    /// Task identifier
273    pub task_id: String,
274    /// Quality score
275    pub quality_score: f64,
276    /// Private result (if privacy enabled)
277    pub private_result: Option<PrivateEvaluationResult>,
278    /// Evaluation duration (ms)
279    pub duration_ms: u64,
280    /// Node weight (for weighted aggregation)
281    pub weight: f64,
282    /// Completed timestamp
283    pub completed_at: u64,
284}
285
286/// Federated evaluation result
287#[derive(Debug, Clone, Serialize, Deserialize)]
288pub struct FederatedEvaluationResult {
289    /// Task identifier
290    pub task_id: String,
291    /// Aggregated quality score
292    pub aggregated_score: f64,
293    /// Confidence interval
294    pub confidence_interval: (f64, f64),
295    /// Number of participating nodes
296    pub participant_count: usize,
297    /// Individual node results
298    pub node_results: Vec<NodeEvaluationResult>,
299    /// Aggregation strategy used
300    pub aggregation_strategy: AggregationStrategy,
301    /// Total evaluation duration (ms)
302    pub total_duration_ms: u64,
303    /// Privacy guarantees (if applicable)
304    pub privacy_guarantees: Vec<String>,
305}
306
307/// Federated evaluation node
308pub struct FederatedNode {
309    config: NodeConfig,
310    evaluator: Arc<RwLock<QualityEvaluator>>,
311    privacy_evaluator: Option<Arc<RwLock<PrivacyPreservingEvaluator>>>,
312    info: Arc<RwLock<NodeInfo>>,
313    semaphore: Arc<Semaphore>,
314}
315
316impl FederatedNode {
317    /// Create new federated node
318    pub async fn new(config: NodeConfig) -> Result<Self, FederatedError> {
319        let evaluator = QualityEvaluator::new().await?;
320
321        let privacy_evaluator = if config.enable_privacy {
322            let privacy_config = config.privacy_config.clone().unwrap_or_default();
323            let pe = PrivacyPreservingEvaluator::new(privacy_config).await?;
324            Some(Arc::new(RwLock::new(pe)))
325        } else {
326            None
327        };
328
329        let info = NodeInfo {
330            node_id: config.node_id.clone(),
331            endpoint: config.coordinator_endpoint.clone(),
332            status: NodeStatus::Online,
333            capabilities: NodeCapabilities::default(),
334            last_heartbeat: SystemTime::now()
335                .duration_since(UNIX_EPOCH)
336                .unwrap()
337                .as_secs(),
338            total_evaluations: 0,
339            current_load: 0.0,
340        };
341
342        let semaphore = Arc::new(Semaphore::new(config.max_concurrent_evaluations));
343
344        Ok(Self {
345            config,
346            evaluator: Arc::new(RwLock::new(evaluator)),
347            privacy_evaluator,
348            info: Arc::new(RwLock::new(info)),
349            semaphore,
350        })
351    }
352
353    /// Execute evaluation task
354    pub async fn execute_task(
355        &self,
356        task: &EvaluationTask,
357        audio: &AudioBuffer,
358        reference: Option<&AudioBuffer>,
359    ) -> Result<NodeEvaluationResult, FederatedError> {
360        let _permit =
361            self.semaphore
362                .acquire()
363                .await
364                .map_err(|e| FederatedError::CommunicationError {
365                    node: self.config.node_id.clone(),
366                    message: format!("Failed to acquire semaphore: {}", e),
367                })?;
368
369        let start = SystemTime::now();
370
371        // Update status to busy
372        {
373            let mut info = self.info.write().await;
374            info.status = NodeStatus::Busy;
375        }
376
377        let result = if self.config.enable_privacy {
378            // Use privacy-preserving evaluation
379            if let Some(ref pe) = self.privacy_evaluator {
380                let pe_lock = pe.read().await;
381                let private_result = pe_lock.evaluate_with_privacy(audio, reference).await?;
382                let quality_score = private_result.score;
383
384                let duration_ms = SystemTime::now()
385                    .duration_since(start)
386                    .unwrap_or(Duration::ZERO)
387                    .as_millis() as u64;
388
389                NodeEvaluationResult {
390                    node_id: self.config.node_id.clone(),
391                    task_id: task.task_id.clone(),
392                    quality_score,
393                    private_result: Some(private_result),
394                    duration_ms,
395                    weight: 1.0,
396                    completed_at: SystemTime::now()
397                        .duration_since(UNIX_EPOCH)
398                        .unwrap()
399                        .as_secs(),
400                }
401            } else {
402                return Err(FederatedError::ConfigError {
403                    message: "Privacy enabled but evaluator not initialized".to_string(),
404                });
405            }
406        } else {
407            // Standard evaluation
408            let evaluator = self.evaluator.read().await;
409            let eval_config = QualityEvaluationConfig::default();
410            let quality = evaluator
411                .evaluate_quality(audio, reference, Some(&eval_config))
412                .await?;
413
414            let duration_ms = SystemTime::now()
415                .duration_since(start)
416                .unwrap_or(Duration::ZERO)
417                .as_millis() as u64;
418
419            NodeEvaluationResult {
420                node_id: self.config.node_id.clone(),
421                task_id: task.task_id.clone(),
422                quality_score: quality.overall_score as f64,
423                private_result: None,
424                duration_ms,
425                weight: 1.0,
426                completed_at: SystemTime::now()
427                    .duration_since(UNIX_EPOCH)
428                    .unwrap()
429                    .as_secs(),
430            }
431        };
432
433        // Update status and statistics
434        {
435            let mut info = self.info.write().await;
436            info.status = NodeStatus::Online;
437            info.total_evaluations += 1;
438            info.last_heartbeat = SystemTime::now()
439                .duration_since(UNIX_EPOCH)
440                .unwrap()
441                .as_secs();
442        }
443
444        Ok(result)
445    }
446
447    /// Get node information
448    pub async fn get_info(&self) -> NodeInfo {
449        let info = self.info.read().await;
450        info.clone()
451    }
452
453    /// Send heartbeat
454    pub async fn heartbeat(&self) {
455        let mut info = self.info.write().await;
456        info.last_heartbeat = SystemTime::now()
457            .duration_since(UNIX_EPOCH)
458            .unwrap()
459            .as_secs();
460    }
461}
462
463/// Federated coordinator
464pub struct FederatedCoordinator {
465    config: CoordinatorConfig,
466    nodes: Arc<RwLock<HashMap<String, NodeInfo>>>,
467    tasks: Arc<RwLock<Vec<EvaluationTask>>>,
468    results: Arc<RwLock<HashMap<String, Vec<NodeEvaluationResult>>>>,
469}
470
471impl FederatedCoordinator {
472    /// Create new federated coordinator
473    pub async fn new(config: CoordinatorConfig) -> Result<Self, FederatedError> {
474        Ok(Self {
475            config,
476            nodes: Arc::new(RwLock::new(HashMap::new())),
477            tasks: Arc::new(RwLock::new(Vec::new())),
478            results: Arc::new(RwLock::new(HashMap::new())),
479        })
480    }
481
482    /// Register a new node
483    pub async fn register_node(&self, node_id: &str, endpoint: &str) -> Result<(), FederatedError> {
484        let mut nodes = self.nodes.write().await;
485
486        if nodes.len() >= self.config.max_nodes {
487            return Err(FederatedError::RegistrationError {
488                message: format!("Maximum node limit reached: {}", self.config.max_nodes),
489            });
490        }
491
492        let info = NodeInfo {
493            node_id: node_id.to_string(),
494            endpoint: endpoint.to_string(),
495            status: NodeStatus::Online,
496            capabilities: NodeCapabilities::default(),
497            last_heartbeat: SystemTime::now()
498                .duration_since(UNIX_EPOCH)
499                .unwrap()
500                .as_secs(),
501            total_evaluations: 0,
502            current_load: 0.0,
503        };
504
505        nodes.insert(node_id.to_string(), info);
506        info!("Registered node: {}", node_id);
507        Ok(())
508    }
509
510    /// Unregister a node
511    pub async fn unregister_node(&self, node_id: &str) -> Result<(), FederatedError> {
512        let mut nodes = self.nodes.write().await;
513        nodes.remove(node_id);
514        info!("Unregistered node: {}", node_id);
515        Ok(())
516    }
517
518    /// Get active node count
519    pub async fn active_node_count(&self) -> usize {
520        let nodes = self.nodes.read().await;
521        nodes
522            .values()
523            .filter(|n| n.status == NodeStatus::Online)
524            .count()
525    }
526
527    /// Create evaluation task
528    pub async fn create_task(&self, task_type: String) -> Result<String, FederatedError> {
529        let task_id = uuid::Uuid::new_v4().to_string();
530        let task = EvaluationTask {
531            task_id: task_id.clone(),
532            task_type,
533            parameters: HashMap::new(),
534            priority: 1,
535            created_at: SystemTime::now()
536                .duration_since(UNIX_EPOCH)
537                .unwrap()
538                .as_secs(),
539        };
540
541        let mut tasks = self.tasks.write().await;
542        tasks.push(task);
543        Ok(task_id)
544    }
545
546    /// Evaluate using federated nodes
547    pub async fn evaluate_federated(&self) -> Result<FederatedEvaluationResult, FederatedError> {
548        let nodes = self.nodes.read().await;
549
550        if nodes.len() < self.config.min_nodes {
551            return Err(FederatedError::AggregationError {
552                message: format!(
553                    "Insufficient nodes: {} < {}",
554                    nodes.len(),
555                    self.config.min_nodes
556                ),
557            });
558        }
559
560        // For now, return a mock result
561        // In production, this would coordinate with actual nodes
562        let task_id = uuid::Uuid::new_v4().to_string();
563        let node_results: Vec<NodeEvaluationResult> = nodes
564            .iter()
565            .map(|(node_id, _)| NodeEvaluationResult {
566                node_id: node_id.clone(),
567                task_id: task_id.clone(),
568                quality_score: 4.2,
569                private_result: None,
570                duration_ms: 1000,
571                weight: 1.0,
572                completed_at: SystemTime::now()
573                    .duration_since(UNIX_EPOCH)
574                    .unwrap()
575                    .as_secs(),
576            })
577            .collect();
578
579        let aggregated_score = self.aggregate_results(&node_results).await?;
580        let confidence_interval = self.calculate_confidence_interval(&node_results);
581
582        Ok(FederatedEvaluationResult {
583            task_id,
584            aggregated_score,
585            confidence_interval,
586            participant_count: node_results.len(),
587            node_results,
588            aggregation_strategy: self.config.aggregation_strategy,
589            total_duration_ms: 1000,
590            privacy_guarantees: vec!["Differential privacy enabled".to_string()],
591        })
592    }
593
594    /// Aggregate results from multiple nodes
595    async fn aggregate_results(
596        &self,
597        results: &[NodeEvaluationResult],
598    ) -> Result<f64, FederatedError> {
599        if results.is_empty() {
600            return Err(FederatedError::AggregationError {
601                message: "No results to aggregate".to_string(),
602            });
603        }
604
605        let score = match self.config.aggregation_strategy {
606            AggregationStrategy::Average => {
607                results.iter().map(|r| r.quality_score).sum::<f64>() / results.len() as f64
608            }
609            AggregationStrategy::WeightedAverage => {
610                let total_weight: f64 = results.iter().map(|r| r.weight).sum();
611                results
612                    .iter()
613                    .map(|r| r.quality_score * r.weight)
614                    .sum::<f64>()
615                    / total_weight
616            }
617            AggregationStrategy::Median => {
618                let mut scores: Vec<f64> = results.iter().map(|r| r.quality_score).collect();
619                scores.sort_by(|a, b| a.partial_cmp(b).unwrap());
620                let mid = scores.len() / 2;
621                if scores.len() % 2 == 0 {
622                    (scores[mid - 1] + scores[mid]) / 2.0
623                } else {
624                    scores[mid]
625                }
626            }
627            AggregationStrategy::TrimmedMean => {
628                let mut scores: Vec<f64> = results.iter().map(|r| r.quality_score).collect();
629                scores.sort_by(|a, b| a.partial_cmp(b).unwrap());
630                let trim_count = (scores.len() as f64 * 0.1) as usize; // Trim 10% from each end
631                if scores.len() > 2 * trim_count {
632                    let trimmed: Vec<f64> = scores
633                        .iter()
634                        .skip(trim_count)
635                        .take(scores.len() - 2 * trim_count)
636                        .copied()
637                        .collect();
638                    trimmed.iter().sum::<f64>() / trimmed.len() as f64
639                } else {
640                    scores.iter().sum::<f64>() / scores.len() as f64
641                }
642            }
643            AggregationStrategy::FederatedLearning => {
644                // Simplified FL aggregation (in production, use proper FL algorithms)
645                results.iter().map(|r| r.quality_score).sum::<f64>() / results.len() as f64
646            }
647        };
648
649        Ok(score)
650    }
651
652    /// Calculate confidence interval
653    fn calculate_confidence_interval(&self, results: &[NodeEvaluationResult]) -> (f64, f64) {
654        if results.is_empty() {
655            return (0.0, 0.0);
656        }
657
658        let scores: Vec<f64> = results.iter().map(|r| r.quality_score).collect();
659        let mean = scores.iter().sum::<f64>() / scores.len() as f64;
660        let variance = scores.iter().map(|s| (s - mean).powi(2)).sum::<f64>() / scores.len() as f64;
661        let std_dev = variance.sqrt();
662        let margin = 1.96 * std_dev / (scores.len() as f64).sqrt(); // 95% CI
663
664        ((mean - margin).max(0.0), (mean + margin).min(5.0))
665    }
666
667    /// Get coordinator statistics
668    pub async fn get_statistics(&self) -> CoordinatorStatistics {
669        let nodes = self.nodes.read().await;
670        let tasks = self.tasks.read().await;
671
672        CoordinatorStatistics {
673            total_nodes: nodes.len(),
674            active_nodes: nodes
675                .values()
676                .filter(|n| n.status == NodeStatus::Online)
677                .count(),
678            total_tasks: tasks.len(),
679            completed_tasks: 0, // Would track this in production
680        }
681    }
682}
683
684/// Coordinator statistics
685#[derive(Debug, Clone, Serialize, Deserialize)]
686pub struct CoordinatorStatistics {
687    /// Total registered nodes
688    pub total_nodes: usize,
689    /// Currently active nodes
690    pub active_nodes: usize,
691    /// Total tasks created
692    pub total_tasks: usize,
693    /// Completed tasks
694    pub completed_tasks: usize,
695}
696
697#[cfg(test)]
698mod tests {
699    use super::*;
700
701    #[test]
702    fn test_node_config_default() {
703        let config = NodeConfig::default();
704        assert_eq!(config.node_id, "default_node");
705        assert!(config.enable_privacy);
706        assert_eq!(config.max_concurrent_evaluations, 4);
707    }
708
709    #[test]
710    fn test_coordinator_config_default() {
711        let config = CoordinatorConfig::default();
712        assert_eq!(config.min_nodes, 2);
713        assert_eq!(config.max_nodes, 100);
714        assert_eq!(config.consensus_threshold, 0.67);
715    }
716
717    #[test]
718    fn test_node_capabilities_default() {
719        let caps = NodeCapabilities::default();
720        assert!(caps.metrics.contains(&"pesq".to_string()));
721        assert_eq!(caps.max_concurrent, 4);
722        assert_eq!(caps.speed_factor, 1.0);
723    }
724
725    #[test]
726    fn test_aggregation_strategies() {
727        assert_eq!(AggregationStrategy::Average, AggregationStrategy::Average);
728        assert_ne!(AggregationStrategy::Average, AggregationStrategy::Median);
729    }
730
731    #[tokio::test]
732    async fn test_coordinator_creation() {
733        let config = CoordinatorConfig::default();
734        let coordinator = FederatedCoordinator::new(config).await;
735        assert!(coordinator.is_ok());
736    }
737
738    #[tokio::test]
739    async fn test_node_registration() {
740        let config = CoordinatorConfig::default();
741        let coordinator = FederatedCoordinator::new(config).await.unwrap();
742
743        coordinator
744            .register_node("node1", "http://node1:8080")
745            .await
746            .unwrap();
747        let count = coordinator.active_node_count().await;
748        assert_eq!(count, 1);
749
750        coordinator
751            .register_node("node2", "http://node2:8080")
752            .await
753            .unwrap();
754        let count = coordinator.active_node_count().await;
755        assert_eq!(count, 2);
756    }
757
758    #[tokio::test]
759    async fn test_node_unregistration() {
760        let config = CoordinatorConfig::default();
761        let coordinator = FederatedCoordinator::new(config).await.unwrap();
762
763        coordinator
764            .register_node("node1", "http://node1:8080")
765            .await
766            .unwrap();
767        coordinator.unregister_node("node1").await.unwrap();
768        let count = coordinator.active_node_count().await;
769        assert_eq!(count, 0);
770    }
771
772    #[tokio::test]
773    async fn test_task_creation() {
774        let config = CoordinatorConfig::default();
775        let coordinator = FederatedCoordinator::new(config).await.unwrap();
776
777        let task_id = coordinator
778            .create_task("quality_evaluation".to_string())
779            .await
780            .unwrap();
781        assert!(!task_id.is_empty());
782    }
783
784    #[tokio::test]
785    async fn test_coordinator_statistics() {
786        let config = CoordinatorConfig::default();
787        let coordinator = FederatedCoordinator::new(config).await.unwrap();
788
789        coordinator
790            .register_node("node1", "http://node1:8080")
791            .await
792            .unwrap();
793        coordinator
794            .register_node("node2", "http://node2:8080")
795            .await
796            .unwrap();
797
798        let stats = coordinator.get_statistics().await;
799        assert_eq!(stats.total_nodes, 2);
800        assert_eq!(stats.active_nodes, 2);
801    }
802
803    #[tokio::test]
804    async fn test_federated_node_creation() {
805        let config = NodeConfig::default();
806        let node = FederatedNode::new(config).await;
807        assert!(node.is_ok());
808    }
809
810    #[tokio::test]
811    async fn test_node_heartbeat() {
812        let config = NodeConfig::default();
813        let node = FederatedNode::new(config).await.unwrap();
814
815        let info_before = node.get_info().await;
816        tokio::time::sleep(Duration::from_secs(1)).await; // Ensure at least 1 second passes
817
818        node.heartbeat().await;
819        let info_after = node.get_info().await;
820
821        assert!(info_after.last_heartbeat >= info_before.last_heartbeat);
822    }
823}