1#![allow(dead_code)]
15use crate::{TorshDistributedError, TorshResult};
16use serde::{Deserialize, Serialize};
17use std::collections::{HashMap, VecDeque};
18use std::sync::{Arc, Mutex, RwLock};
19use std::time::{Duration, SystemTime};
20use tokio::time::interval;
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct EdgeComputingConfig {
25 pub heterogeneous_devices: bool,
27 pub adaptive_communication: bool,
29 pub federated_learning: bool,
31 pub device_discovery: DeviceDiscoveryConfig,
33 pub bandwidth_adaptation: BandwidthAdaptationConfig,
35 pub federated_config: FederatedLearningConfig,
37 pub edge_optimizations: EdgeOptimizationConfig,
39 pub hierarchical_training: HierarchicalTrainingConfig,
41 pub privacy_config: PrivacyConfig,
43}
44
45impl Default for EdgeComputingConfig {
46 fn default() -> Self {
47 Self {
48 heterogeneous_devices: true,
49 adaptive_communication: true,
50 federated_learning: true,
51 device_discovery: DeviceDiscoveryConfig::default(),
52 bandwidth_adaptation: BandwidthAdaptationConfig::default(),
53 federated_config: FederatedLearningConfig::default(),
54 edge_optimizations: EdgeOptimizationConfig::default(),
55 hierarchical_training: HierarchicalTrainingConfig::default(),
56 privacy_config: PrivacyConfig::default(),
57 }
58 }
59}
60
61#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct DeviceDiscoveryConfig {
64 pub auto_discovery: bool,
66 pub discovery_protocol: DiscoveryProtocol,
68 pub discovery_interval: u64,
70 pub max_devices: usize,
72 pub heartbeat_interval: u64,
74 pub device_timeout: u64,
76}
77
78impl Default for DeviceDiscoveryConfig {
79 fn default() -> Self {
80 Self {
81 auto_discovery: true,
82 discovery_protocol: DiscoveryProtocol::Mdns,
83 discovery_interval: 30,
84 max_devices: 100,
85 heartbeat_interval: 10,
86 device_timeout: 60,
87 }
88 }
89}
90
91#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
93pub enum DiscoveryProtocol {
94 Mdns,
96 Upnp,
98 Ble,
100 Broadcast,
102 Manual,
104}
105
106#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct BandwidthAdaptationConfig {
109 pub auto_detection: bool,
111 pub min_bandwidth: f64,
113 pub measurement_interval: u64,
115 pub compression_threshold: f64,
117 pub adaptive_batch_size: bool,
119 pub max_timeout: u64,
121}
122
123impl Default for BandwidthAdaptationConfig {
124 fn default() -> Self {
125 Self {
126 auto_detection: true,
127 min_bandwidth: 1.0, measurement_interval: 30,
129 compression_threshold: 10.0, adaptive_batch_size: true,
131 max_timeout: 300, }
133 }
134}
135
136#[derive(Debug, Clone, Serialize, Deserialize)]
138pub struct FederatedLearningConfig {
139 pub algorithm: FederatedAlgorithm,
141 pub local_rounds: usize,
143 pub client_selection: ClientSelectionStrategy,
145 pub min_clients_per_round: usize,
147 pub max_clients_per_round: usize,
149 pub aggregation: AggregationStrategy,
151 pub privacy_mechanism: PrivacyMechanism,
153 pub communication_rounds: usize,
155}
156
157impl Default for FederatedLearningConfig {
158 fn default() -> Self {
159 Self {
160 algorithm: FederatedAlgorithm::FedAvg,
161 local_rounds: 5,
162 client_selection: ClientSelectionStrategy::Random,
163 min_clients_per_round: 10,
164 max_clients_per_round: 100,
165 aggregation: AggregationStrategy::FedAvg,
166 privacy_mechanism: PrivacyMechanism::DifferentialPrivacy,
167 communication_rounds: 100,
168 }
169 }
170}
171
172#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
174pub enum FederatedAlgorithm {
175 FedAvg,
177 FedProx,
179 FedNova,
181 FedMom,
183 FedAdam,
185}
186
187#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
189pub enum ClientSelectionStrategy {
190 Random,
192 RoundRobin,
194 DataBased,
196 ComputeBased,
198 NetworkBased,
200 Adaptive,
202}
203
204#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
206pub enum AggregationStrategy {
207 FedAvg,
209 WeightedAvg,
211 Median,
213 TrimmedMean,
215 Krum,
217 ByzantineRobust,
219}
220
221#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
223pub enum PrivacyMechanism {
224 None,
226 DifferentialPrivacy,
228 HomomorphicEncryption,
230 SecureMultipartyComputation,
232 SecureAggregation,
234}
235
236#[derive(Debug, Clone, Serialize, Deserialize)]
238pub struct EdgeOptimizationConfig {
239 pub model_compression: bool,
241 pub gradient_compression: bool,
243 pub quantization: bool,
245 pub pruning: bool,
247 pub knowledge_distillation: bool,
249 pub compression_ratio: f64,
251 pub quantization_bits: u8,
253 pub pruning_sparsity: f64,
255}
256
257impl Default for EdgeOptimizationConfig {
258 fn default() -> Self {
259 Self {
260 model_compression: true,
261 gradient_compression: true,
262 quantization: true,
263 pruning: false,
264 knowledge_distillation: false,
265 compression_ratio: 0.1, quantization_bits: 8,
267 pruning_sparsity: 0.5, }
269 }
270}
271
272#[derive(Debug, Clone, Serialize, Deserialize)]
274pub struct HierarchicalTrainingConfig {
275 pub enable_hierarchical: bool,
277 pub tiers: Vec<TrainingTier>,
279 pub aggregation_schedule: AggregationSchedule,
281 pub load_balancing: bool,
283}
284
285impl Default for HierarchicalTrainingConfig {
286 fn default() -> Self {
287 Self {
288 enable_hierarchical: true,
289 tiers: vec![TrainingTier::Edge, TrainingTier::Fog, TrainingTier::Cloud],
290 aggregation_schedule: AggregationSchedule::default(),
291 load_balancing: true,
292 }
293 }
294}
295
296#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
298pub enum TrainingTier {
299 Edge,
301 Fog,
303 Cloud,
305}
306
307#[derive(Debug, Clone, Serialize, Deserialize)]
309pub struct AggregationSchedule {
310 pub edge_to_fog_frequency: u64,
312 pub fog_to_cloud_frequency: u64,
314 pub global_aggregation_frequency: u64,
316}
317
318impl Default for AggregationSchedule {
319 fn default() -> Self {
320 Self {
321 edge_to_fog_frequency: 5, fog_to_cloud_frequency: 10, global_aggregation_frequency: 20, }
325 }
326}
327
328#[derive(Debug, Clone, Serialize, Deserialize)]
330pub struct PrivacyConfig {
331 pub differential_privacy: bool,
333 pub privacy_budget: f64,
335 pub secure_aggregation: bool,
337 pub data_anonymization: bool,
339 pub local_training_only: bool,
341}
342
343impl Default for PrivacyConfig {
344 fn default() -> Self {
345 Self {
346 differential_privacy: true,
347 privacy_budget: 1.0,
348 secure_aggregation: true,
349 data_anonymization: true,
350 local_training_only: true,
351 }
352 }
353}
354
355#[derive(Debug, Clone, Serialize, Deserialize)]
357pub struct EdgeDevice {
358 pub device_id: String,
360 pub device_type: DeviceType,
362 pub compute_capability: ComputeCapability,
364 pub network_info: NetworkInfo,
366 pub status: DeviceStatus,
368 pub resources: DeviceResources,
370 pub data_info: DataInfo,
372 pub last_seen: SystemTime,
374 pub location: Option<DeviceLocation>,
376}
377
378#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
380pub enum DeviceType {
381 Smartphone,
383 Tablet,
385 Laptop,
387 IoTSensor,
389 EdgeServer,
391 FogNode,
393 Embedded,
395 Automotive,
397}
398
399#[derive(Debug, Clone, Serialize, Deserialize)]
401pub struct ComputeCapability {
402 pub cpu_cores: u32,
404 pub cpu_frequency: u32,
406 pub ram_mb: u32,
408 pub has_gpu: bool,
410 pub gpu_memory_mb: u32,
412 pub has_accelerator: bool,
414 pub estimated_flops: f64,
416 pub power_consumption: f64,
418}
419
420#[derive(Debug, Clone, Serialize, Deserialize)]
422pub struct NetworkInfo {
423 pub connection_type: ConnectionType,
425 pub bandwidth: f64,
427 pub latency: f64,
429 pub packet_loss: f64,
431 pub is_stable: bool,
433 pub data_limits: Option<DataLimits>,
435}
436
437#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
439pub enum ConnectionType {
440 WiFi,
442 Cellular4G,
444 Cellular5G,
446 Ethernet,
448 Bluetooth,
450 Satellite,
452 LoRa,
454}
455
456#[derive(Debug, Clone, Serialize, Deserialize)]
458pub struct DataLimits {
459 pub monthly_limit_mb: u64,
461 pub used_data_mb: u64,
463 pub unlimited: bool,
465}
466
467#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
469pub enum DeviceStatus {
470 Available,
472 Training,
474 Unavailable,
476 Disconnected,
478 Sleeping,
480 Maintenance,
482}
483
484#[derive(Debug, Clone, Serialize, Deserialize)]
486pub struct DeviceResources {
487 pub cpu_available: f64,
489 pub memory_available_mb: u32,
491 pub storage_available_mb: u64,
493 pub battery_level: Option<f64>,
495 pub is_charging: Option<bool>,
497 pub thermal_state: ThermalState,
499}
500
501#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
503pub enum ThermalState {
504 Normal,
506 Warm,
508 Hot,
510 Critical,
512}
513
514#[derive(Debug, Clone, Serialize, Deserialize)]
516pub struct DataInfo {
517 pub sample_count: usize,
519 pub quality_score: f64,
521 pub diversity_score: f64,
523 pub label_distribution: HashMap<String, f64>,
525 pub freshness_hours: f64,
527 pub privacy_level: PrivacyLevel,
529}
530
531#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
533pub enum PrivacyLevel {
534 Public,
536 Internal,
538 Confidential,
540 HighlySensitive,
542}
543
544#[derive(Debug, Clone, Serialize, Deserialize)]
546pub struct DeviceLocation {
547 pub latitude: f64,
549 pub longitude: f64,
551 pub country: String,
553 pub region: String,
555 pub timezone: String,
557}
558
559pub struct EdgeComputingManager {
561 config: EdgeComputingConfig,
562 devices: Arc<RwLock<HashMap<String, EdgeDevice>>>,
563 device_groups: Arc<RwLock<HashMap<String, Vec<String>>>>,
564 communication_manager: Arc<Mutex<CommunicationManager>>,
565 federated_coordinator: Option<FederatedLearningCoordinator>,
566 bandwidth_monitor: Arc<Mutex<BandwidthMonitor>>,
567 privacy_manager: Arc<Mutex<PrivacyManager>>,
568 hierarchical_coordinator: Option<HierarchicalTrainingCoordinator>,
569}
570
571impl EdgeComputingManager {
572 pub fn new(config: EdgeComputingConfig) -> TorshResult<Self> {
574 let federated_coordinator = if config.federated_learning {
575 Some(FederatedLearningCoordinator::new(&config.federated_config)?)
576 } else {
577 None
578 };
579
580 let hierarchical_coordinator = if config.hierarchical_training.enable_hierarchical {
581 Some(HierarchicalTrainingCoordinator::new(
582 &config.hierarchical_training,
583 )?)
584 } else {
585 None
586 };
587
588 Ok(Self {
589 config: config.clone(),
590 devices: Arc::new(RwLock::new(HashMap::new())),
591 device_groups: Arc::new(RwLock::new(HashMap::new())),
592 communication_manager: Arc::new(Mutex::new(CommunicationManager::new(
593 &config.bandwidth_adaptation,
594 )?)),
595 federated_coordinator,
596 bandwidth_monitor: Arc::new(Mutex::new(BandwidthMonitor::new(
597 &config.bandwidth_adaptation,
598 )?)),
599 privacy_manager: Arc::new(Mutex::new(PrivacyManager::new(&config.privacy_config)?)),
600 hierarchical_coordinator,
601 })
602 }
603
604 pub fn register_device(&self, device: EdgeDevice) -> TorshResult<()> {
606 let mut devices = self.devices.write().map_err(|_| {
607 TorshDistributedError::InternalError("Failed to acquire devices lock".to_string())
608 })?;
609
610 tracing::info!(
611 "Registering edge device: {} (type: {:?})",
612 device.device_id,
613 device.device_type
614 );
615 devices.insert(device.device_id.clone(), device);
616
617 Ok(())
618 }
619
620 pub async fn discover_devices(&self) -> TorshResult<Vec<EdgeDevice>> {
622 if !self.config.device_discovery.auto_discovery {
623 return Ok(Vec::new());
624 }
625
626 let discovered = match self.config.device_discovery.discovery_protocol {
628 DiscoveryProtocol::Mdns => self.discover_mdns_devices().await?,
629 DiscoveryProtocol::Upnp => self.discover_upnp_devices().await?,
630 DiscoveryProtocol::Ble => self.discover_ble_devices().await?,
631 DiscoveryProtocol::Broadcast => self.discover_broadcast_devices().await?,
632 DiscoveryProtocol::Manual => Vec::new(),
633 };
634
635 for device in &discovered {
637 self.register_device(device.clone())?;
638 }
639
640 tracing::info!("Discovered {} edge devices", discovered.len());
641 Ok(discovered)
642 }
643
644 async fn discover_mdns_devices(&self) -> TorshResult<Vec<EdgeDevice>> {
646 let mock_devices = vec![
648 self.create_mock_device("edge-phone-1", DeviceType::Smartphone),
649 self.create_mock_device("edge-tablet-1", DeviceType::Tablet),
650 self.create_mock_device("fog-server-1", DeviceType::FogNode),
651 ];
652
653 Ok(mock_devices)
654 }
655
656 async fn discover_upnp_devices(&self) -> TorshResult<Vec<EdgeDevice>> {
658 let mock_devices = vec![
659 self.create_mock_device("edge-laptop-1", DeviceType::Laptop),
660 self.create_mock_device("edge-server-1", DeviceType::EdgeServer),
661 ];
662
663 Ok(mock_devices)
664 }
665
666 async fn discover_ble_devices(&self) -> TorshResult<Vec<EdgeDevice>> {
668 let mock_devices = vec![
669 self.create_mock_device("iot-sensor-1", DeviceType::IoTSensor),
670 self.create_mock_device("embedded-1", DeviceType::Embedded),
671 ];
672
673 Ok(mock_devices)
674 }
675
676 async fn discover_broadcast_devices(&self) -> TorshResult<Vec<EdgeDevice>> {
678 let mock_devices = vec![self.create_mock_device("auto-ecu-1", DeviceType::Automotive)];
679
680 Ok(mock_devices)
681 }
682
683 fn create_mock_device(&self, device_id: &str, device_type: DeviceType) -> EdgeDevice {
685 EdgeDevice {
686 device_id: device_id.to_string(),
687 device_type,
688 compute_capability: match device_type {
689 DeviceType::Smartphone => ComputeCapability {
690 cpu_cores: 8,
691 cpu_frequency: 2400,
692 ram_mb: 6144,
693 has_gpu: true,
694 gpu_memory_mb: 1024,
695 has_accelerator: false,
696 estimated_flops: 1e9,
697 power_consumption: 5.0,
698 },
699 DeviceType::FogNode => ComputeCapability {
700 cpu_cores: 16,
701 cpu_frequency: 3200,
702 ram_mb: 32768,
703 has_gpu: true,
704 gpu_memory_mb: 8192,
705 has_accelerator: true,
706 estimated_flops: 1e12,
707 power_consumption: 200.0,
708 },
709 _ => ComputeCapability {
710 cpu_cores: 4,
711 cpu_frequency: 1800,
712 ram_mb: 4096,
713 has_gpu: false,
714 gpu_memory_mb: 0,
715 has_accelerator: false,
716 estimated_flops: 1e8,
717 power_consumption: 10.0,
718 },
719 },
720 network_info: NetworkInfo {
721 connection_type: match device_type {
722 DeviceType::Smartphone => ConnectionType::Cellular5G,
723 DeviceType::FogNode => ConnectionType::Ethernet,
724 _ => ConnectionType::WiFi,
725 },
726 bandwidth: match device_type {
727 DeviceType::Smartphone => 50.0,
728 DeviceType::FogNode => 1000.0,
729 _ => 100.0,
730 },
731 latency: 20.0,
732 packet_loss: 0.01,
733 is_stable: true,
734 data_limits: if device_type == DeviceType::Smartphone {
735 Some(DataLimits {
736 monthly_limit_mb: 10240,
737 used_data_mb: 2048,
738 unlimited: false,
739 })
740 } else {
741 None
742 },
743 },
744 status: DeviceStatus::Available,
745 resources: DeviceResources {
746 cpu_available: 80.0,
747 memory_available_mb: 2048,
748 storage_available_mb: 5120,
749 battery_level: if device_type == DeviceType::Smartphone {
750 Some(85.0)
751 } else {
752 None
753 },
754 is_charging: if device_type == DeviceType::Smartphone {
755 Some(false)
756 } else {
757 None
758 },
759 thermal_state: ThermalState::Normal,
760 },
761 data_info: DataInfo {
762 sample_count: 1000,
763 quality_score: 0.8,
764 diversity_score: 0.6,
765 label_distribution: HashMap::new(),
766 freshness_hours: 2.0,
767 privacy_level: PrivacyLevel::Internal,
768 },
769 last_seen: SystemTime::now(),
770 location: Some(DeviceLocation {
771 latitude: 37.7749,
772 longitude: -122.4194,
773 country: "US".to_string(),
774 region: "CA".to_string(),
775 timezone: "UTC-8".to_string(),
776 }),
777 }
778 }
779
780 pub fn select_clients(&self, round: usize) -> TorshResult<Vec<String>> {
782 let devices = self.devices.read().map_err(|_| {
783 TorshDistributedError::InternalError("Failed to acquire devices lock".to_string())
784 })?;
785
786 let available_devices: Vec<&EdgeDevice> = devices
787 .values()
788 .filter(|d| d.status == DeviceStatus::Available)
789 .collect();
790
791 let selection_strategy = self.config.federated_config.client_selection;
792 let min_clients = self.config.federated_config.min_clients_per_round;
793 let max_clients = self.config.federated_config.max_clients_per_round;
794
795 let selected = match selection_strategy {
796 ClientSelectionStrategy::Random => {
797 let mut selected = Vec::new();
798 let count = (available_devices.len()).min(max_clients).max(min_clients);
799
800 for (i, device) in available_devices.iter().enumerate() {
802 if i < count {
803 selected.push(device.device_id.clone());
804 }
805 }
806 selected
807 }
808 ClientSelectionStrategy::ComputeBased => {
809 let mut sorted_devices = available_devices.clone();
811 sorted_devices.sort_by(|a, b| {
812 b.compute_capability
813 .estimated_flops
814 .partial_cmp(&a.compute_capability.estimated_flops)
815 .unwrap_or(std::cmp::Ordering::Equal)
816 });
817
818 sorted_devices
819 .iter()
820 .take(max_clients.min(available_devices.len()))
821 .map(|d| d.device_id.clone())
822 .collect()
823 }
824 ClientSelectionStrategy::NetworkBased => {
825 let mut sorted_devices = available_devices.clone();
827 sorted_devices.sort_by(|a, b| {
828 let score_a = a.network_info.bandwidth / (a.network_info.latency + 1.0);
829 let score_b = b.network_info.bandwidth / (b.network_info.latency + 1.0);
830 score_b
831 .partial_cmp(&score_a)
832 .unwrap_or(std::cmp::Ordering::Equal)
833 });
834
835 sorted_devices
836 .iter()
837 .take(max_clients.min(available_devices.len()))
838 .map(|d| d.device_id.clone())
839 .collect()
840 }
841 _ => {
842 available_devices
844 .iter()
845 .take(max_clients.min(available_devices.len()))
846 .map(|d| d.device_id.clone())
847 .collect()
848 }
849 };
850
851 tracing::info!(
852 "Selected {} clients for federated learning round {}",
853 selected.len(),
854 round
855 );
856 Ok(selected)
857 }
858
859 pub fn get_device(&self, device_id: &str) -> TorshResult<Option<EdgeDevice>> {
861 let devices = self.devices.read().map_err(|_| {
862 TorshDistributedError::InternalError("Failed to acquire devices lock".to_string())
863 })?;
864
865 Ok(devices.get(device_id).cloned())
866 }
867
868 pub fn get_all_devices(&self) -> TorshResult<Vec<EdgeDevice>> {
870 let devices = self.devices.read().map_err(|_| {
871 TorshDistributedError::InternalError("Failed to acquire devices lock".to_string())
872 })?;
873
874 Ok(devices.values().cloned().collect())
875 }
876
877 pub fn update_device_status(&self, device_id: &str, status: DeviceStatus) -> TorshResult<()> {
879 let mut devices = self.devices.write().map_err(|_| {
880 TorshDistributedError::InternalError("Failed to acquire devices lock".to_string())
881 })?;
882
883 if let Some(device) = devices.get_mut(device_id) {
884 device.status = status;
885 device.last_seen = SystemTime::now();
886 tracing::debug!("Updated device {} status to {:?}", device_id, status);
887 }
888
889 Ok(())
890 }
891
892 pub async fn start_device_monitoring(&self) -> TorshResult<()> {
894 let heartbeat_interval =
895 Duration::from_secs(self.config.device_discovery.heartbeat_interval);
896 let mut interval_timer = interval(heartbeat_interval);
897
898 loop {
899 interval_timer.tick().await;
900
901 if let Err(e) = self.check_device_health().await {
903 tracing::error!("Device health check failed: {}", e);
904 }
905 }
906 }
907
908 async fn check_device_health(&self) -> TorshResult<()> {
910 let device_timeout = Duration::from_secs(self.config.device_discovery.device_timeout);
911 let now = SystemTime::now();
912
913 let mut devices = self.devices.write().map_err(|_| {
914 TorshDistributedError::InternalError("Failed to acquire devices lock".to_string())
915 })?;
916
917 for device in devices.values_mut() {
918 if let Ok(elapsed) = now.duration_since(device.last_seen) {
919 if elapsed > device_timeout && device.status != DeviceStatus::Disconnected {
920 device.status = DeviceStatus::Disconnected;
921 tracing::warn!(
922 "Device {} marked as disconnected due to timeout",
923 device.device_id
924 );
925 }
926 }
927 }
928
929 Ok(())
930 }
931}
932
933pub struct CommunicationManager {
935 config: BandwidthAdaptationConfig,
936 bandwidth_history: VecDeque<(SystemTime, f64)>,
937}
938
939impl CommunicationManager {
940 pub fn new(config: &BandwidthAdaptationConfig) -> TorshResult<Self> {
941 Ok(Self {
942 config: config.clone(),
943 bandwidth_history: VecDeque::with_capacity(100),
944 })
945 }
946
947 pub async fn measure_bandwidth(&mut self, device_id: &str) -> TorshResult<f64> {
949 let bandwidth = 50.0 + (device_id.len() as f64 * 10.0) % 100.0; self.bandwidth_history
953 .push_back((SystemTime::now(), bandwidth));
954 if self.bandwidth_history.len() > 100 {
955 self.bandwidth_history.pop_front();
956 }
957
958 Ok(bandwidth)
959 }
960
961 pub fn get_adaptive_params(&self, current_bandwidth: f64) -> AdaptiveCommunicationParams {
963 let should_compress = current_bandwidth < self.config.compression_threshold;
964 let timeout_multiplier = if current_bandwidth < self.config.min_bandwidth {
965 3.0
966 } else if current_bandwidth < self.config.compression_threshold {
967 2.0
968 } else {
969 1.0
970 };
971
972 AdaptiveCommunicationParams {
973 use_compression: should_compress,
974 compression_ratio: if should_compress { 0.1 } else { 1.0 },
975 timeout_multiplier,
976 max_batch_size: if self.config.adaptive_batch_size {
977 ((current_bandwidth / 10.0) as usize).clamp(1, 64)
978 } else {
979 32
980 },
981 }
982 }
983}
984
985#[derive(Debug, Clone)]
987pub struct AdaptiveCommunicationParams {
988 pub use_compression: bool,
989 pub compression_ratio: f64,
990 pub timeout_multiplier: f64,
991 pub max_batch_size: usize,
992}
993
994pub struct BandwidthMonitor {
996 config: BandwidthAdaptationConfig,
997 measurements: HashMap<String, VecDeque<(SystemTime, f64)>>,
998}
999
1000impl BandwidthMonitor {
1001 pub fn new(config: &BandwidthAdaptationConfig) -> TorshResult<Self> {
1002 Ok(Self {
1003 config: config.clone(),
1004 measurements: HashMap::new(),
1005 })
1006 }
1007
1008 pub fn record_measurement(&mut self, device_id: String, bandwidth: f64) {
1010 let measurements = self
1011 .measurements
1012 .entry(device_id)
1013 .or_insert_with(|| VecDeque::with_capacity(100));
1014 measurements.push_back((SystemTime::now(), bandwidth));
1015
1016 if measurements.len() > 100 {
1017 measurements.pop_front();
1018 }
1019 }
1020
1021 pub fn get_average_bandwidth(&self, device_id: &str, window_minutes: u64) -> Option<f64> {
1023 let measurements = self.measurements.get(device_id)?;
1024 let window = Duration::from_secs(window_minutes * 60);
1025 let now = SystemTime::now();
1026
1027 let recent_measurements: Vec<f64> = measurements
1028 .iter()
1029 .filter_map(|(time, bandwidth)| {
1030 if now.duration_since(*time).unwrap_or(Duration::MAX) <= window {
1031 Some(*bandwidth)
1032 } else {
1033 None
1034 }
1035 })
1036 .collect();
1037
1038 if recent_measurements.is_empty() {
1039 None
1040 } else {
1041 Some(recent_measurements.iter().sum::<f64>() / recent_measurements.len() as f64)
1042 }
1043 }
1044}
1045
1046pub struct PrivacyManager {
1048 config: PrivacyConfig,
1049}
1050
1051impl PrivacyManager {
1052 pub fn new(config: &PrivacyConfig) -> TorshResult<Self> {
1053 Ok(Self {
1054 config: config.clone(),
1055 })
1056 }
1057
1058 pub fn apply_differential_privacy(
1060 &self,
1061 gradients: &[f32],
1062 sensitivity: f64,
1063 ) -> TorshResult<Vec<f32>> {
1064 if !self.config.differential_privacy {
1065 return Ok(gradients.to_vec());
1066 }
1067
1068 let noise_scale = sensitivity / self.config.privacy_budget;
1070 let mut private_gradients = Vec::with_capacity(gradients.len());
1071
1072 for &gradient in gradients {
1073 let noise = (gradient.abs() * 0.01) * (2.0 * std::f32::consts::PI).sin(); private_gradients.push(gradient + noise * noise_scale as f32);
1076 }
1077
1078 Ok(private_gradients)
1079 }
1080}
1081
1082pub struct FederatedLearningCoordinator {
1084 config: FederatedLearningConfig,
1085 current_round: Arc<std::sync::atomic::AtomicUsize>,
1086 aggregation_buffer: Arc<Mutex<HashMap<String, Vec<f32>>>>,
1087}
1088
1089impl FederatedLearningCoordinator {
1090 pub fn new(config: &FederatedLearningConfig) -> TorshResult<Self> {
1091 Ok(Self {
1092 config: config.clone(),
1093 current_round: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
1094 aggregation_buffer: Arc::new(Mutex::new(HashMap::new())),
1095 })
1096 }
1097
1098 pub fn aggregate_updates(
1100 &self,
1101 client_updates: HashMap<String, Vec<f32>>,
1102 ) -> TorshResult<Vec<f32>> {
1103 if client_updates.is_empty() {
1104 return Err(TorshDistributedError::InternalError(
1105 "No client updates to aggregate".to_string(),
1106 ));
1107 }
1108
1109 match self.config.aggregation {
1110 AggregationStrategy::FedAvg => self.federated_averaging(client_updates),
1111 AggregationStrategy::WeightedAvg => self.weighted_averaging(client_updates),
1112 AggregationStrategy::Median => self.median_aggregation(client_updates),
1113 _ => self.federated_averaging(client_updates), }
1115 }
1116
1117 fn federated_averaging(
1119 &self,
1120 client_updates: HashMap<String, Vec<f32>>,
1121 ) -> TorshResult<Vec<f32>> {
1122 if client_updates.is_empty() {
1123 return Err(TorshDistributedError::InternalError(
1124 "No updates to aggregate".to_string(),
1125 ));
1126 }
1127
1128 let num_clients = client_updates.len() as f32;
1129 let update_size = client_updates
1130 .values()
1131 .next()
1132 .expect("client_updates should not be empty")
1133 .len();
1134 let mut aggregated = vec![0.0; update_size];
1135
1136 for updates in client_updates.values() {
1137 for (i, &update) in updates.iter().enumerate() {
1138 aggregated[i] += update / num_clients;
1139 }
1140 }
1141
1142 Ok(aggregated)
1143 }
1144
1145 fn weighted_averaging(
1147 &self,
1148 client_updates: HashMap<String, Vec<f32>>,
1149 ) -> TorshResult<Vec<f32>> {
1150 self.federated_averaging(client_updates)
1153 }
1154
1155 fn median_aggregation(
1157 &self,
1158 client_updates: HashMap<String, Vec<f32>>,
1159 ) -> TorshResult<Vec<f32>> {
1160 if client_updates.is_empty() {
1161 return Err(TorshDistributedError::InternalError(
1162 "No updates to aggregate".to_string(),
1163 ));
1164 }
1165
1166 let update_size = client_updates
1167 .values()
1168 .next()
1169 .expect("client_updates should not be empty")
1170 .len();
1171 let mut aggregated = vec![0.0; update_size];
1172
1173 for i in 0..update_size {
1174 let mut values: Vec<f32> = client_updates.values().map(|updates| updates[i]).collect();
1175 values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
1176
1177 aggregated[i] = if values.len() % 2 == 0 {
1178 (values[values.len() / 2 - 1] + values[values.len() / 2]) / 2.0
1179 } else {
1180 values[values.len() / 2]
1181 };
1182 }
1183
1184 Ok(aggregated)
1185 }
1186}
1187
1188pub struct HierarchicalTrainingCoordinator {
1190 config: HierarchicalTrainingConfig,
1191 tier_assignments: HashMap<String, TrainingTier>,
1192}
1193
1194impl HierarchicalTrainingCoordinator {
1195 pub fn new(config: &HierarchicalTrainingConfig) -> TorshResult<Self> {
1196 Ok(Self {
1197 config: config.clone(),
1198 tier_assignments: HashMap::new(),
1199 })
1200 }
1201
1202 pub fn assign_device_tier(&mut self, device_id: String, device: &EdgeDevice) -> TrainingTier {
1204 let tier = match device.device_type {
1205 DeviceType::Smartphone
1206 | DeviceType::Tablet
1207 | DeviceType::IoTSensor
1208 | DeviceType::Embedded => TrainingTier::Edge,
1209 DeviceType::Laptop | DeviceType::EdgeServer | DeviceType::Automotive => {
1210 TrainingTier::Fog
1211 }
1212 DeviceType::FogNode => TrainingTier::Cloud,
1213 };
1214
1215 self.tier_assignments.insert(device_id, tier);
1216 tier
1217 }
1218
1219 pub fn get_tier_devices(&self, tier: TrainingTier) -> Vec<String> {
1221 self.tier_assignments
1222 .iter()
1223 .filter_map(|(device_id, &device_tier)| {
1224 if device_tier == tier {
1225 Some(device_id.clone())
1226 } else {
1227 None
1228 }
1229 })
1230 .collect()
1231 }
1232}
1233
1234#[cfg(test)]
1235mod tests {
1236 use super::*;
1237
1238 #[test]
1239 fn test_edge_computing_config_default() {
1240 let config = EdgeComputingConfig::default();
1241 assert!(config.heterogeneous_devices);
1242 assert!(config.adaptive_communication);
1243 assert!(config.federated_learning);
1244 }
1245
1246 #[test]
1247 fn test_device_creation() {
1248 let device = EdgeDevice {
1249 device_id: "test-device".to_string(),
1250 device_type: DeviceType::Smartphone,
1251 compute_capability: ComputeCapability {
1252 cpu_cores: 8,
1253 cpu_frequency: 2400,
1254 ram_mb: 6144,
1255 has_gpu: true,
1256 gpu_memory_mb: 1024,
1257 has_accelerator: false,
1258 estimated_flops: 1e9,
1259 power_consumption: 5.0,
1260 },
1261 network_info: NetworkInfo {
1262 connection_type: ConnectionType::Cellular5G,
1263 bandwidth: 50.0,
1264 latency: 20.0,
1265 packet_loss: 0.01,
1266 is_stable: true,
1267 data_limits: None,
1268 },
1269 status: DeviceStatus::Available,
1270 resources: DeviceResources {
1271 cpu_available: 80.0,
1272 memory_available_mb: 2048,
1273 storage_available_mb: 5120,
1274 battery_level: Some(85.0),
1275 is_charging: Some(false),
1276 thermal_state: ThermalState::Normal,
1277 },
1278 data_info: DataInfo {
1279 sample_count: 1000,
1280 quality_score: 0.8,
1281 diversity_score: 0.6,
1282 label_distribution: HashMap::new(),
1283 freshness_hours: 2.0,
1284 privacy_level: PrivacyLevel::Internal,
1285 },
1286 last_seen: SystemTime::now(),
1287 location: None,
1288 };
1289
1290 assert_eq!(device.device_type, DeviceType::Smartphone);
1291 assert_eq!(device.status, DeviceStatus::Available);
1292 }
1293
1294 #[tokio::test]
1295 async fn test_edge_computing_manager_creation() {
1296 let config = EdgeComputingConfig::default();
1297 let manager = EdgeComputingManager::new(config).unwrap();
1298
1299 let device = manager.create_mock_device("test-device", DeviceType::Smartphone);
1301 manager.register_device(device).unwrap();
1302
1303 let retrieved = manager.get_device("test-device").unwrap();
1305 assert!(retrieved.is_some());
1306 }
1307
1308 #[tokio::test]
1309 async fn test_device_discovery() {
1310 let config = EdgeComputingConfig::default();
1311 let manager = EdgeComputingManager::new(config).unwrap();
1312
1313 let discovered = manager.discover_devices().await.unwrap();
1314 assert!(!discovered.is_empty());
1315 }
1316
1317 #[test]
1318 fn test_client_selection() {
1319 let config = EdgeComputingConfig::default();
1320 let manager = EdgeComputingManager::new(config).unwrap();
1321
1322 for i in 0..5 {
1324 let device =
1325 manager.create_mock_device(&format!("device-{}", i), DeviceType::Smartphone);
1326 manager.register_device(device).unwrap();
1327 }
1328
1329 let selected = manager.select_clients(1).unwrap();
1330 assert!(!selected.is_empty());
1331 assert!(selected.len() <= 5);
1332 }
1333
1334 #[test]
1335 fn test_federated_aggregation() {
1336 let config = FederatedLearningConfig::default();
1337 let coordinator = FederatedLearningCoordinator::new(&config).unwrap();
1338
1339 let mut client_updates = HashMap::new();
1340 client_updates.insert("client1".to_string(), vec![1.0, 2.0, 3.0]);
1341 client_updates.insert("client2".to_string(), vec![2.0, 3.0, 4.0]);
1342 client_updates.insert("client3".to_string(), vec![3.0, 4.0, 5.0]);
1343
1344 let aggregated = coordinator.aggregate_updates(client_updates).unwrap();
1345 let expected = [2.0, 3.0, 4.0]; assert_eq!(aggregated.len(), expected.len());
1349 for (i, (&actual, &exp)) in aggregated.iter().zip(expected.iter()).enumerate() {
1350 assert!(
1351 (actual - exp).abs() < 1e-6,
1352 "Element {} mismatch: expected {}, got {}",
1353 i,
1354 exp,
1355 actual
1356 );
1357 }
1358 }
1359
1360 #[test]
1361 fn test_bandwidth_adaptation() {
1362 let config = BandwidthAdaptationConfig::default();
1363 let comm_manager = CommunicationManager::new(&config).unwrap();
1364
1365 let high_bw_params = comm_manager.get_adaptive_params(100.0);
1367 assert!(!high_bw_params.use_compression);
1368
1369 let low_bw_params = comm_manager.get_adaptive_params(5.0);
1371 assert!(low_bw_params.use_compression);
1372 assert!(low_bw_params.timeout_multiplier > 1.0);
1373 }
1374
1375 #[test]
1376 fn test_privacy_mechanism() {
1377 let config = PrivacyConfig::default();
1378 let privacy_manager = PrivacyManager::new(&config).unwrap();
1379
1380 let gradients = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1381 let private_gradients = privacy_manager
1382 .apply_differential_privacy(&gradients, 1.0)
1383 .unwrap();
1384
1385 assert_eq!(private_gradients.len(), gradients.len());
1386 }
1388
1389 #[test]
1390 fn test_hierarchical_training() {
1391 let config = HierarchicalTrainingConfig::default();
1392 let mut coordinator = HierarchicalTrainingCoordinator::new(&config).unwrap();
1393
1394 let phone_device = EdgeDevice {
1395 device_id: "phone".to_string(),
1396 device_type: DeviceType::Smartphone,
1397 compute_capability: ComputeCapability {
1398 cpu_cores: 8,
1399 cpu_frequency: 2400,
1400 ram_mb: 6144,
1401 has_gpu: true,
1402 gpu_memory_mb: 1024,
1403 has_accelerator: false,
1404 estimated_flops: 1e9,
1405 power_consumption: 5.0,
1406 },
1407 network_info: NetworkInfo {
1408 connection_type: ConnectionType::Cellular5G,
1409 bandwidth: 50.0,
1410 latency: 20.0,
1411 packet_loss: 0.01,
1412 is_stable: true,
1413 data_limits: None,
1414 },
1415 status: DeviceStatus::Available,
1416 resources: DeviceResources {
1417 cpu_available: 80.0,
1418 memory_available_mb: 2048,
1419 storage_available_mb: 5120,
1420 battery_level: Some(85.0),
1421 is_charging: Some(false),
1422 thermal_state: ThermalState::Normal,
1423 },
1424 data_info: DataInfo {
1425 sample_count: 1000,
1426 quality_score: 0.8,
1427 diversity_score: 0.6,
1428 label_distribution: HashMap::new(),
1429 freshness_hours: 2.0,
1430 privacy_level: PrivacyLevel::Internal,
1431 },
1432 last_seen: SystemTime::now(),
1433 location: None,
1434 };
1435
1436 let tier = coordinator.assign_device_tier("phone".to_string(), &phone_device);
1437 assert_eq!(tier, TrainingTier::Edge);
1438
1439 let edge_devices = coordinator.get_tier_devices(TrainingTier::Edge);
1440 assert!(edge_devices.contains(&"phone".to_string()));
1441 }
1442}