1use anyhow::{anyhow, Result};
16use scirs2_core::random::StdRng; use scirs2_core::random::*; use serde::{Deserialize, Serialize};
19use std::collections::HashMap;
20use trustformers_core::tensor::Tensor;
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct FedAvgConfig {
25 pub local_epochs: usize,
27 pub local_learning_rate: f32,
29 pub client_fraction: f32,
31 pub min_clients: usize,
33 pub max_clients: usize,
35 pub weight_decay: f32,
37}
38
39impl Default for FedAvgConfig {
40 fn default() -> Self {
41 Self {
42 local_epochs: 5,
43 local_learning_rate: 1e-3,
44 client_fraction: 0.1,
45 min_clients: 2,
46 max_clients: 100,
47 weight_decay: 0.0,
48 }
49 }
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct FedProxConfig {
55 pub fedavg_config: FedAvgConfig,
57 pub mu: f32,
59}
60
61impl Default for FedProxConfig {
62 fn default() -> Self {
63 Self {
64 fedavg_config: FedAvgConfig::default(),
65 mu: 0.01,
66 }
67 }
68}
69
70#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct DifferentialPrivacyConfig {
73 pub epsilon: f32,
75 pub delta: f32,
77 pub sensitivity: f32,
79 pub noise_mechanism: NoiseMechanism,
81}
82
83impl Default for DifferentialPrivacyConfig {
84 fn default() -> Self {
85 Self {
86 epsilon: 1.0,
87 delta: 1e-5,
88 sensitivity: 1.0,
89 noise_mechanism: NoiseMechanism::Gaussian,
90 }
91 }
92}
93
94#[derive(Debug, Clone, Serialize, Deserialize)]
96pub enum NoiseMechanism {
97 Gaussian,
99 Laplace,
101}
102
103#[derive(Debug, Clone, Serialize, Deserialize)]
105pub enum ClientSelectionStrategy {
106 Random,
108 DataSize,
110 ComputeCapacity,
112 CommunicationQuality,
114}
115
116#[derive(Debug, Clone, Serialize, Deserialize)]
118pub struct ClientInfo {
119 pub client_id: String,
121 pub data_size: usize,
123 pub compute_capacity: f32,
125 pub communication_quality: f32,
127 pub available: bool,
129}
130
131#[derive(Debug)]
136pub struct FedAvg {
137 config: FedAvgConfig,
138 global_parameters: Vec<Tensor>,
139 client_weights: HashMap<String, f32>,
140 current_round: usize,
141 selected_clients: Vec<String>,
142 rng: StdRng,
143}
144
145impl FedAvg {
146 pub fn new(config: FedAvgConfig) -> Self {
148 Self {
149 config,
150 global_parameters: Vec::new(),
151 client_weights: HashMap::new(),
152 current_round: 0,
153 selected_clients: Vec::new(),
154 rng: StdRng::seed_from_u64(42),
155 }
156 }
157
158 pub fn initialize_global_parameters(&mut self, parameters: Vec<Tensor>) {
160 self.global_parameters = parameters;
161 }
162
163 pub fn select_clients(
165 &mut self,
166 available_clients: &[ClientInfo],
167 strategy: ClientSelectionStrategy,
168 ) -> Result<Vec<String>> {
169 let available: Vec<&ClientInfo> =
170 available_clients.iter().filter(|c| c.available).collect();
171
172 if available.is_empty() {
173 return Err(anyhow!("No available clients"));
174 }
175
176 let num_clients = (available.len() as f32 * self.config.client_fraction).round() as usize;
177 let num_clients = num_clients
178 .max(self.config.min_clients)
179 .min(self.config.max_clients)
180 .min(available.len());
181
182 let selected = match strategy {
183 ClientSelectionStrategy::Random => {
184 let mut indices: Vec<usize> = (0..available.len()).collect();
185 for i in 0..num_clients {
186 let j = self.rng.gen_range(i..indices.len());
187 indices.swap(i, j);
188 }
189 indices[..num_clients].iter().map(|&i| available[i].client_id.clone()).collect()
190 },
191 ClientSelectionStrategy::DataSize => {
192 let mut clients_with_size: Vec<_> =
193 available.iter().map(|c| (c.client_id.clone(), c.data_size)).collect();
194 clients_with_size.sort_by_key(|(_, size)| std::cmp::Reverse(*size));
195 clients_with_size[..num_clients].iter().map(|(id, _)| id.clone()).collect()
196 },
197 ClientSelectionStrategy::ComputeCapacity => {
198 let mut clients_with_capacity: Vec<_> =
199 available.iter().map(|c| (c.client_id.clone(), c.compute_capacity)).collect();
200 clients_with_capacity.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap());
201 clients_with_capacity[..num_clients].iter().map(|(id, _)| id.clone()).collect()
202 },
203 ClientSelectionStrategy::CommunicationQuality => {
204 let mut clients_with_quality: Vec<_> = available
205 .iter()
206 .map(|c| (c.client_id.clone(), c.communication_quality))
207 .collect();
208 clients_with_quality.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap());
209 clients_with_quality[..num_clients].iter().map(|(id, _)| id.clone()).collect()
210 },
211 };
212
213 self.selected_clients = selected;
214 Ok(self.selected_clients.clone())
215 }
216
217 pub fn aggregate_updates(
219 &mut self,
220 client_updates: HashMap<String, Vec<Tensor>>,
221 ) -> Result<Vec<Tensor>> {
222 if client_updates.is_empty() {
223 return Err(anyhow!("No client updates to aggregate"));
224 }
225
226 let total_weight: f32 = client_updates
227 .keys()
228 .map(|client_id| self.client_weights.get(client_id).unwrap_or(&1.0))
229 .sum();
230
231 if total_weight == 0.0 {
232 return Err(anyhow!("Total client weight is zero"));
233 }
234
235 let param_count = client_updates.values().next().unwrap().len();
237 let mut aggregated = Vec::with_capacity(param_count);
238
239 for i in 0..param_count {
240 let first_param = &client_updates.values().next().unwrap()[i];
242 aggregated.push(Tensor::zeros_like(first_param)?);
243 }
244
245 for (client_id, updates) in &client_updates {
247 let weight = self.client_weights.get(client_id).unwrap_or(&1.0) / total_weight;
248
249 for (i, update) in updates.iter().enumerate() {
250 let weighted_update = update.mul_scalar(weight)?;
251 aggregated[i] = aggregated[i].add(&weighted_update)?;
252 }
253 }
254
255 self.global_parameters = aggregated.clone();
257 self.current_round += 1;
258
259 Ok(aggregated)
260 }
261
262 pub fn set_client_weights(&mut self, weights: HashMap<String, f32>) {
264 self.client_weights = weights;
265 }
266
267 pub fn get_global_parameters(&self) -> &[Tensor] {
269 &self.global_parameters
270 }
271
272 pub fn get_current_round(&self) -> usize {
274 self.current_round
275 }
276}
277
278#[derive(Debug)]
283pub struct FedProx {
284 fedavg: FedAvg,
285 config: FedProxConfig,
286}
287
288impl FedProx {
289 pub fn new(config: FedProxConfig) -> Self {
291 Self {
292 fedavg: FedAvg::new(config.fedavg_config.clone()),
293 config,
294 }
295 }
296
297 pub fn compute_proximal_term(
299 &self,
300 client_params: &[Tensor],
301 global_params: &[Tensor],
302 ) -> Result<f32> {
303 if client_params.len() != global_params.len() {
304 return Err(anyhow!("Parameter count mismatch"));
305 }
306
307 let mut proximal_loss = 0.0;
308 for (client_param, global_param) in client_params.iter().zip(global_params.iter()) {
309 let diff = client_param.sub(global_param)?;
310 let norm_sq = diff.norm_squared()?.to_scalar()?;
311 proximal_loss += norm_sq;
312 }
313
314 Ok(self.config.mu * proximal_loss / 2.0)
315 }
316
317 pub fn apply_proximal_update(
319 &self,
320 client_params: &mut [Tensor],
321 global_params: &[Tensor],
322 learning_rate: f32,
323 ) -> Result<()> {
324 for (client_param, global_param) in client_params.iter_mut().zip(global_params.iter()) {
325 let diff = client_param.sub(global_param)?;
326 let proximal_grad = diff.mul_scalar(self.config.mu)?;
327 let update = proximal_grad.mul_scalar(learning_rate)?;
328 *client_param = client_param.sub(&update)?;
329 }
330 Ok(())
331 }
332
333 pub fn select_clients(
335 &mut self,
336 available_clients: &[ClientInfo],
337 strategy: ClientSelectionStrategy,
338 ) -> Result<Vec<String>> {
339 self.fedavg.select_clients(available_clients, strategy)
340 }
341
342 pub fn aggregate_updates(
343 &mut self,
344 client_updates: HashMap<String, Vec<Tensor>>,
345 ) -> Result<Vec<Tensor>> {
346 self.fedavg.aggregate_updates(client_updates)
347 }
348
349 pub fn get_global_parameters(&self) -> &[Tensor] {
350 self.fedavg.get_global_parameters()
351 }
352
353 pub fn get_current_round(&self) -> usize {
354 self.fedavg.get_current_round()
355 }
356}
357
358pub struct DifferentialPrivacy {
360 config: DifferentialPrivacyConfig,
361 rng: StdRng,
362}
363
364impl DifferentialPrivacy {
365 pub fn new(config: DifferentialPrivacyConfig) -> Self {
367 Self {
368 config,
369 rng: StdRng::seed_from_u64(42),
370 }
371 }
372
373 pub fn add_noise(&mut self, parameters: &mut [Tensor]) -> Result<()> {
375 let noise_scale = self.compute_noise_scale()?;
376
377 for param in parameters.iter_mut() {
378 let noise = self.generate_noise_tensor(param, noise_scale)?;
379 *param = param.add(&noise)?;
380 }
381
382 Ok(())
383 }
384
385 fn compute_noise_scale(&self) -> Result<f32> {
386 match self.config.noise_mechanism {
387 NoiseMechanism::Gaussian => {
388 let ln_term = (1.25 / self.config.delta).ln();
390 let sigma = (2.0 * ln_term).sqrt() * self.config.sensitivity / self.config.epsilon;
391 Ok(sigma)
392 },
393 NoiseMechanism::Laplace => {
394 Ok(self.config.sensitivity / self.config.epsilon)
396 },
397 }
398 }
399
400 fn generate_noise_tensor(&mut self, reference: &Tensor, scale: f32) -> Result<Tensor> {
401 let shape = reference.shape();
402 let mut noise_data = Vec::new();
403
404 match self.config.noise_mechanism {
405 NoiseMechanism::Gaussian => {
406 use scirs2_core::random::{Distribution, Normal}; let normal = Normal::new(0.0, scale)
408 .map_err(|e| anyhow!("Normal distribution error: {}", e))?;
409
410 for _ in 0..shape.iter().product::<usize>() {
411 noise_data.push(normal.sample(&mut self.rng));
412 }
413 },
414 NoiseMechanism::Laplace => {
415 use scirs2_core::random::{Distribution, Exp}; let exp_dist = Exp::new(1.0 / scale)
419 .map_err(|e| anyhow!("Exponential distribution error: {}", e))?;
420
421 for _ in 0..shape.iter().product::<usize>() {
422 let sign = if self.rng.random::<bool>() { 1.0 } else { -1.0 };
423 let exp_sample = exp_dist.sample(&mut self.rng);
424 noise_data.push(sign * exp_sample);
425 }
426 },
427 }
428
429 Ok(Tensor::from_data(noise_data, &shape.to_vec())?)
430 }
431}
432
433pub struct SecureAggregation {
438 threshold: usize,
439 #[allow(dead_code)]
440 total_clients: usize,
441}
442
443impl SecureAggregation {
444 pub fn new(threshold: usize, total_clients: usize) -> Result<Self> {
446 if threshold > total_clients {
447 return Err(anyhow!("Threshold cannot exceed total clients"));
448 }
449
450 Ok(Self {
451 threshold,
452 total_clients,
453 })
454 }
455
456 pub fn generate_masks(&self, client_id: &str, round: usize) -> Result<Vec<Tensor>> {
459 let mut rng = StdRng::from_seed({
462 let mut seed = [0u8; 32];
463 let client_hash = format!("{}-{}", client_id, round);
464 let bytes = client_hash.as_bytes();
465 for (i, &byte) in bytes.iter().enumerate().take(32) {
466 seed[i] = byte;
467 }
468 seed
469 });
470
471 let mut masks = Vec::new();
474
475 let parameter_shapes = vec![
478 vec![100, 50], vec![50], vec![50, 20], vec![20], ];
483
484 for shape in parameter_shapes {
485 let mask_size = shape.iter().product::<usize>();
487 let mut mask_data: Vec<f32> = Vec::with_capacity(mask_size);
488
489 for _ in 0..mask_size {
490 mask_data.push(rng.gen_range(-1.0..1.0));
492 }
493
494 let mask = Tensor::from_data(mask_data, &shape)?;
495 masks.push(mask);
496 }
497
498 Ok(masks)
499 }
500
501 pub fn secure_aggregate(
503 &self,
504 masked_updates: HashMap<String, Vec<Tensor>>,
505 ) -> Result<Vec<Tensor>> {
506 if masked_updates.len() < self.threshold {
507 return Err(anyhow!("Not enough clients for secure aggregation"));
508 }
509
510 let mut result = Vec::new();
518 let client_count = masked_updates.len() as f32;
519
520 let parameter_count =
522 masked_updates.values().next().map(|update| update.len()).unwrap_or(0);
523
524 for (client_id, update) in &masked_updates {
525 if update.len() != parameter_count {
526 return Err(anyhow!(
527 "Client {} has {} parameters, expected {}",
528 client_id,
529 update.len(),
530 parameter_count
531 ));
532 }
533 }
534
535 for param_idx in 0..parameter_count {
537 let mut parameter_updates = Vec::new();
539 let mut expected_shape: Option<Vec<usize>> = None;
540
541 for (client_id, update) in &masked_updates {
542 let param_update = &update[param_idx];
543
544 if let Some(ref shape) = expected_shape {
546 if param_update.shape() != *shape {
547 return Err(anyhow!(
548 "Client {} parameter {} has shape {:?}, expected {:?}",
549 client_id,
550 param_idx,
551 param_update.shape(),
552 shape
553 ));
554 }
555 } else {
556 expected_shape = Some(param_update.shape());
557 }
558
559 parameter_updates.push(param_update);
560 }
561
562 let mut aggregated_param = Tensor::zeros(&expected_shape.unwrap())?;
564 for param_update in parameter_updates {
565 aggregated_param = aggregated_param.add(param_update)?;
566 }
567
568 result.push(aggregated_param.div_scalar(client_count)?);
572 }
573
574 Ok(result)
575 }
576}
577
578#[cfg(test)]
579mod tests {
580 use super::*;
581
582 #[test]
583 fn test_fedavg_config_default() {
584 let config = FedAvgConfig::default();
585 assert_eq!(config.local_epochs, 5);
586 assert_eq!(config.client_fraction, 0.1);
587 assert_eq!(config.min_clients, 2);
588 }
589
590 #[test]
591 fn test_fedprox_config_default() {
592 let config = FedProxConfig::default();
593 assert_eq!(config.mu, 0.01);
594 assert_eq!(config.fedavg_config.local_epochs, 5);
595 }
596
597 #[test]
598 fn test_differential_privacy_config() {
599 let config = DifferentialPrivacyConfig::default();
600 assert_eq!(config.epsilon, 1.0);
601 assert_eq!(config.delta, 1e-5);
602 assert!(matches!(config.noise_mechanism, NoiseMechanism::Gaussian));
603 }
604
605 #[test]
606 fn test_client_selection_strategies() {
607 let clients = vec![
608 ClientInfo {
609 client_id: "client1".to_string(),
610 data_size: 100,
611 compute_capacity: 0.8,
612 communication_quality: 0.9,
613 available: true,
614 },
615 ClientInfo {
616 client_id: "client2".to_string(),
617 data_size: 200,
618 compute_capacity: 0.6,
619 communication_quality: 0.7,
620 available: true,
621 },
622 ];
623
624 let mut fedavg = FedAvg::new(FedAvgConfig::default());
625
626 let selected = fedavg.select_clients(&clients, ClientSelectionStrategy::Random).unwrap();
628 assert!(!selected.is_empty());
629
630 let selected = fedavg.select_clients(&clients, ClientSelectionStrategy::DataSize).unwrap();
632 assert!(!selected.is_empty());
633 }
634
635 #[test]
636 fn test_secure_aggregation_creation() {
637 let secure_agg = SecureAggregation::new(3, 5).unwrap();
638 assert_eq!(secure_agg.threshold, 3);
639 assert_eq!(secure_agg.total_clients, 5);
640
641 assert!(SecureAggregation::new(6, 5).is_err());
643 }
644}