Skip to main content

trustformers_optim/
federated.rs

1//! # Federated Learning Optimization
2//!
3//! This module implements algorithms for federated learning, enabling distributed
4//! training across multiple clients while preserving privacy and handling
5//! heterogeneous data distributions.
6//!
7//! ## Available Algorithms
8//!
9//! - **FedAvg**: Standard federated averaging algorithm
10//! - **FedProx**: Federated optimization with proximal regularization
11//! - **Secure Aggregation**: Privacy-preserving parameter aggregation
12//! - **Differential Privacy**: Add noise for enhanced privacy protection
13//! - **Client Selection**: Strategies for selecting participating clients
14
15use anyhow::{anyhow, Result};
16use scirs2_core::random::StdRng; // Explicit import for type clarity
17use scirs2_core::random::*; // SciRS2 Integration Policy - Replaces rand
18use serde::{Deserialize, Serialize};
19use std::collections::HashMap;
20use trustformers_core::tensor::Tensor;
21
22/// Configuration for federated averaging (FedAvg).
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct FedAvgConfig {
25    /// Number of local epochs per client
26    pub local_epochs: usize,
27    /// Local learning rate for client updates
28    pub local_learning_rate: f32,
29    /// Fraction of clients participating per round
30    pub client_fraction: f32,
31    /// Minimum number of clients required per round
32    pub min_clients: usize,
33    /// Maximum number of clients per round
34    pub max_clients: usize,
35    /// Weight decay for regularization
36    pub weight_decay: f32,
37}
38
39impl Default for FedAvgConfig {
40    fn default() -> Self {
41        Self {
42            local_epochs: 5,
43            local_learning_rate: 1e-3,
44            client_fraction: 0.1,
45            min_clients: 2,
46            max_clients: 100,
47            weight_decay: 0.0,
48        }
49    }
50}
51
52/// Configuration for FedProx algorithm.
53#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct FedProxConfig {
55    /// FedAvg configuration
56    pub fedavg_config: FedAvgConfig,
57    /// Proximal term coefficient (μ)
58    pub mu: f32,
59}
60
61impl Default for FedProxConfig {
62    fn default() -> Self {
63        Self {
64            fedavg_config: FedAvgConfig::default(),
65            mu: 0.01,
66        }
67    }
68}
69
70/// Configuration for differential privacy.
71#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct DifferentialPrivacyConfig {
73    /// Privacy budget (epsilon)
74    pub epsilon: f32,
75    /// Delta parameter for (ε,δ)-differential privacy
76    pub delta: f32,
77    /// Sensitivity of the function (max change in output per unit change in input)
78    pub sensitivity: f32,
79    /// Noise mechanism to use
80    pub noise_mechanism: NoiseMechanism,
81}
82
83impl Default for DifferentialPrivacyConfig {
84    fn default() -> Self {
85        Self {
86            epsilon: 1.0,
87            delta: 1e-5,
88            sensitivity: 1.0,
89            noise_mechanism: NoiseMechanism::Gaussian,
90        }
91    }
92}
93
94/// Types of noise mechanisms for differential privacy.
95#[derive(Debug, Clone, Serialize, Deserialize)]
96pub enum NoiseMechanism {
97    /// Gaussian noise
98    Gaussian,
99    /// Laplace noise
100    Laplace,
101}
102
103/// Client selection strategies.
104#[derive(Debug, Clone, Serialize, Deserialize)]
105pub enum ClientSelectionStrategy {
106    /// Random selection
107    Random,
108    /// Selection based on data size
109    DataSize,
110    /// Selection based on computational capacity
111    ComputeCapacity,
112    /// Selection based on communication quality
113    CommunicationQuality,
114}
115
116/// Information about a federated client.
117#[derive(Debug, Clone, Serialize, Deserialize)]
118pub struct ClientInfo {
119    /// Client identifier
120    pub client_id: String,
121    /// Number of data samples
122    pub data_size: usize,
123    /// Computational capacity (relative metric)
124    pub compute_capacity: f32,
125    /// Communication quality (bandwidth, latency, etc.)
126    pub communication_quality: f32,
127    /// Client availability
128    pub available: bool,
129}
130
131/// Federated Averaging (FedAvg) optimizer.
132///
133/// Implements the standard federated learning algorithm where clients
134/// perform local updates and the server aggregates them via weighted averaging.
135#[derive(Debug)]
136pub struct FedAvg {
137    config: FedAvgConfig,
138    global_parameters: Vec<Tensor>,
139    client_weights: HashMap<String, f32>,
140    current_round: usize,
141    selected_clients: Vec<String>,
142    rng: StdRng,
143}
144
145impl FedAvg {
146    /// Create a new FedAvg optimizer.
147    pub fn new(config: FedAvgConfig) -> Self {
148        Self {
149            config,
150            global_parameters: Vec::new(),
151            client_weights: HashMap::new(),
152            current_round: 0,
153            selected_clients: Vec::new(),
154            rng: StdRng::seed_from_u64(42),
155        }
156    }
157
158    /// Initialize global parameters.
159    pub fn initialize_global_parameters(&mut self, parameters: Vec<Tensor>) {
160        self.global_parameters = parameters;
161    }
162
163    /// Select clients for the current round.
164    pub fn select_clients(
165        &mut self,
166        available_clients: &[ClientInfo],
167        strategy: ClientSelectionStrategy,
168    ) -> Result<Vec<String>> {
169        let available: Vec<&ClientInfo> =
170            available_clients.iter().filter(|c| c.available).collect();
171
172        if available.is_empty() {
173            return Err(anyhow!("No available clients"));
174        }
175
176        let num_clients = (available.len() as f32 * self.config.client_fraction).round() as usize;
177        let num_clients = num_clients
178            .max(self.config.min_clients)
179            .min(self.config.max_clients)
180            .min(available.len());
181
182        let selected = match strategy {
183            ClientSelectionStrategy::Random => {
184                let mut indices: Vec<usize> = (0..available.len()).collect();
185                for i in 0..num_clients {
186                    let j = self.rng.gen_range(i..indices.len());
187                    indices.swap(i, j);
188                }
189                indices[..num_clients].iter().map(|&i| available[i].client_id.clone()).collect()
190            },
191            ClientSelectionStrategy::DataSize => {
192                let mut clients_with_size: Vec<_> =
193                    available.iter().map(|c| (c.client_id.clone(), c.data_size)).collect();
194                clients_with_size.sort_by_key(|(_, size)| std::cmp::Reverse(*size));
195                clients_with_size[..num_clients].iter().map(|(id, _)| id.clone()).collect()
196            },
197            ClientSelectionStrategy::ComputeCapacity => {
198                let mut clients_with_capacity: Vec<_> =
199                    available.iter().map(|c| (c.client_id.clone(), c.compute_capacity)).collect();
200                clients_with_capacity.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap());
201                clients_with_capacity[..num_clients].iter().map(|(id, _)| id.clone()).collect()
202            },
203            ClientSelectionStrategy::CommunicationQuality => {
204                let mut clients_with_quality: Vec<_> = available
205                    .iter()
206                    .map(|c| (c.client_id.clone(), c.communication_quality))
207                    .collect();
208                clients_with_quality.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap());
209                clients_with_quality[..num_clients].iter().map(|(id, _)| id.clone()).collect()
210            },
211        };
212
213        self.selected_clients = selected;
214        Ok(self.selected_clients.clone())
215    }
216
217    /// Aggregate client updates using weighted averaging.
218    pub fn aggregate_updates(
219        &mut self,
220        client_updates: HashMap<String, Vec<Tensor>>,
221    ) -> Result<Vec<Tensor>> {
222        if client_updates.is_empty() {
223            return Err(anyhow!("No client updates to aggregate"));
224        }
225
226        let total_weight: f32 = client_updates
227            .keys()
228            .map(|client_id| self.client_weights.get(client_id).unwrap_or(&1.0))
229            .sum();
230
231        if total_weight == 0.0 {
232            return Err(anyhow!("Total client weight is zero"));
233        }
234
235        // Initialize aggregated parameters with zeros
236        let param_count = client_updates.values().next().unwrap().len();
237        let mut aggregated = Vec::with_capacity(param_count);
238
239        for i in 0..param_count {
240            // Get shape from first client's parameter
241            let first_param = &client_updates.values().next().unwrap()[i];
242            aggregated.push(Tensor::zeros_like(first_param)?);
243        }
244
245        // Weighted aggregation
246        for (client_id, updates) in &client_updates {
247            let weight = self.client_weights.get(client_id).unwrap_or(&1.0) / total_weight;
248
249            for (i, update) in updates.iter().enumerate() {
250                let weighted_update = update.mul_scalar(weight)?;
251                aggregated[i] = aggregated[i].add(&weighted_update)?;
252            }
253        }
254
255        // Update global parameters
256        self.global_parameters = aggregated.clone();
257        self.current_round += 1;
258
259        Ok(aggregated)
260    }
261
262    /// Set client weights for aggregation.
263    pub fn set_client_weights(&mut self, weights: HashMap<String, f32>) {
264        self.client_weights = weights;
265    }
266
267    /// Get current global parameters.
268    pub fn get_global_parameters(&self) -> &[Tensor] {
269        &self.global_parameters
270    }
271
272    /// Get current round number.
273    pub fn get_current_round(&self) -> usize {
274        self.current_round
275    }
276}
277
278/// FedProx optimizer with proximal regularization.
279///
280/// Extends FedAvg with a proximal term to handle client heterogeneity
281/// by adding regularization that keeps client updates close to global model.
282#[derive(Debug)]
283pub struct FedProx {
284    fedavg: FedAvg,
285    config: FedProxConfig,
286}
287
288impl FedProx {
289    /// Create a new FedProx optimizer.
290    pub fn new(config: FedProxConfig) -> Self {
291        Self {
292            fedavg: FedAvg::new(config.fedavg_config.clone()),
293            config,
294        }
295    }
296
297    /// Compute proximal term for client update.
298    pub fn compute_proximal_term(
299        &self,
300        client_params: &[Tensor],
301        global_params: &[Tensor],
302    ) -> Result<f32> {
303        if client_params.len() != global_params.len() {
304            return Err(anyhow!("Parameter count mismatch"));
305        }
306
307        let mut proximal_loss = 0.0;
308        for (client_param, global_param) in client_params.iter().zip(global_params.iter()) {
309            let diff = client_param.sub(global_param)?;
310            let norm_sq = diff.norm_squared()?.to_scalar()?;
311            proximal_loss += norm_sq;
312        }
313
314        Ok(self.config.mu * proximal_loss / 2.0)
315    }
316
317    /// Apply proximal update to client parameters.
318    pub fn apply_proximal_update(
319        &self,
320        client_params: &mut [Tensor],
321        global_params: &[Tensor],
322        learning_rate: f32,
323    ) -> Result<()> {
324        for (client_param, global_param) in client_params.iter_mut().zip(global_params.iter()) {
325            let diff = client_param.sub(global_param)?;
326            let proximal_grad = diff.mul_scalar(self.config.mu)?;
327            let update = proximal_grad.mul_scalar(learning_rate)?;
328            *client_param = client_param.sub(&update)?;
329        }
330        Ok(())
331    }
332
333    /// Delegate to FedAvg for other operations.
334    pub fn select_clients(
335        &mut self,
336        available_clients: &[ClientInfo],
337        strategy: ClientSelectionStrategy,
338    ) -> Result<Vec<String>> {
339        self.fedavg.select_clients(available_clients, strategy)
340    }
341
342    pub fn aggregate_updates(
343        &mut self,
344        client_updates: HashMap<String, Vec<Tensor>>,
345    ) -> Result<Vec<Tensor>> {
346        self.fedavg.aggregate_updates(client_updates)
347    }
348
349    pub fn get_global_parameters(&self) -> &[Tensor] {
350        self.fedavg.get_global_parameters()
351    }
352
353    pub fn get_current_round(&self) -> usize {
354        self.fedavg.get_current_round()
355    }
356}
357
358/// Differential privacy mechanism for federated learning.
359pub struct DifferentialPrivacy {
360    config: DifferentialPrivacyConfig,
361    rng: StdRng,
362}
363
364impl DifferentialPrivacy {
365    /// Create a new differential privacy mechanism.
366    pub fn new(config: DifferentialPrivacyConfig) -> Self {
367        Self {
368            config,
369            rng: StdRng::seed_from_u64(42),
370        }
371    }
372
373    /// Add noise to parameters for differential privacy.
374    pub fn add_noise(&mut self, parameters: &mut [Tensor]) -> Result<()> {
375        let noise_scale = self.compute_noise_scale()?;
376
377        for param in parameters.iter_mut() {
378            let noise = self.generate_noise_tensor(param, noise_scale)?;
379            *param = param.add(&noise)?;
380        }
381
382        Ok(())
383    }
384
385    fn compute_noise_scale(&self) -> Result<f32> {
386        match self.config.noise_mechanism {
387            NoiseMechanism::Gaussian => {
388                // For Gaussian mechanism: σ = sqrt(2 * ln(1.25/δ)) * Δf / ε
389                let ln_term = (1.25 / self.config.delta).ln();
390                let sigma = (2.0 * ln_term).sqrt() * self.config.sensitivity / self.config.epsilon;
391                Ok(sigma)
392            },
393            NoiseMechanism::Laplace => {
394                // For Laplace mechanism: b = Δf / ε
395                Ok(self.config.sensitivity / self.config.epsilon)
396            },
397        }
398    }
399
400    fn generate_noise_tensor(&mut self, reference: &Tensor, scale: f32) -> Result<Tensor> {
401        let shape = reference.shape();
402        let mut noise_data = Vec::new();
403
404        match self.config.noise_mechanism {
405            NoiseMechanism::Gaussian => {
406                use scirs2_core::random::{Distribution, Normal}; // SciRS2 Integration Policy
407                let normal = Normal::new(0.0, scale)
408                    .map_err(|e| anyhow!("Normal distribution error: {}", e))?;
409
410                for _ in 0..shape.iter().product::<usize>() {
411                    noise_data.push(normal.sample(&mut self.rng));
412                }
413            },
414            NoiseMechanism::Laplace => {
415                // Use exponential distribution to simulate Laplace
416                // Laplace(0, b) can be simulated as: sign * Exponential(1/b)
417                use scirs2_core::random::{Distribution, Exp}; // SciRS2 Integration Policy
418                let exp_dist = Exp::new(1.0 / scale)
419                    .map_err(|e| anyhow!("Exponential distribution error: {}", e))?;
420
421                for _ in 0..shape.iter().product::<usize>() {
422                    let sign = if self.rng.random::<bool>() { 1.0 } else { -1.0 };
423                    let exp_sample = exp_dist.sample(&mut self.rng);
424                    noise_data.push(sign * exp_sample);
425                }
426            },
427        }
428
429        Ok(Tensor::from_data(noise_data, &shape.to_vec())?)
430    }
431}
432
433/// Secure aggregation for federated learning.
434///
435/// Implements privacy-preserving aggregation where the server cannot
436/// see individual client updates, only the aggregated result.
437pub struct SecureAggregation {
438    threshold: usize,
439    #[allow(dead_code)]
440    total_clients: usize,
441}
442
443impl SecureAggregation {
444    /// Create a new secure aggregation instance.
445    pub fn new(threshold: usize, total_clients: usize) -> Result<Self> {
446        if threshold > total_clients {
447            return Err(anyhow!("Threshold cannot exceed total clients"));
448        }
449
450        Ok(Self {
451            threshold,
452            total_clients,
453        })
454    }
455
456    /// Generate random masks for secure aggregation.
457    /// In practice, this would use cryptographic protocols.
458    pub fn generate_masks(&self, client_id: &str, round: usize) -> Result<Vec<Tensor>> {
459        // This is a simplified implementation
460        // Real secure aggregation uses secret sharing and cryptographic techniques
461        let mut rng = StdRng::from_seed({
462            let mut seed = [0u8; 32];
463            let client_hash = format!("{}-{}", client_id, round);
464            let bytes = client_hash.as_bytes();
465            for (i, &byte) in bytes.iter().enumerate().take(32) {
466                seed[i] = byte;
467            }
468            seed
469        });
470
471        // Generate cryptographic masks for secure aggregation
472        // Each mask is a random tensor that will be used to blind the client's update
473        let mut masks = Vec::new();
474
475        // Generate masks based on client's expected parameter shapes
476        // In practice, these shapes would be communicated during federated setup
477        let parameter_shapes = vec![
478            vec![100, 50], // Example: First layer weights
479            vec![50],      // Example: First layer bias
480            vec![50, 20],  // Example: Second layer weights
481            vec![20],      // Example: Second layer bias
482        ];
483
484        for shape in parameter_shapes {
485            // Generate random mask with same shape as parameter
486            let mask_size = shape.iter().product::<usize>();
487            let mut mask_data: Vec<f32> = Vec::with_capacity(mask_size);
488
489            for _ in 0..mask_size {
490                // Generate random float in range [-1.0, 1.0] for better numerical stability
491                mask_data.push(rng.gen_range(-1.0..1.0));
492            }
493
494            let mask = Tensor::from_data(mask_data, &shape)?;
495            masks.push(mask);
496        }
497
498        Ok(masks)
499    }
500
501    /// Aggregate masked updates securely.
502    pub fn secure_aggregate(
503        &self,
504        masked_updates: HashMap<String, Vec<Tensor>>,
505    ) -> Result<Vec<Tensor>> {
506        if masked_updates.len() < self.threshold {
507            return Err(anyhow!("Not enough clients for secure aggregation"));
508        }
509
510        // In a real implementation, this would:
511        // 1. Collect masked updates from clients
512        // 2. Aggregate the masks
513        // 3. Remove the aggregate mask to reveal the sum
514        // 4. Compute the average
515
516        // Enhanced secure aggregation with validation and error handling
517        let mut result = Vec::new();
518        let client_count = masked_updates.len() as f32;
519
520        // Validate that all clients have the same number of parameters
521        let parameter_count =
522            masked_updates.values().next().map(|update| update.len()).unwrap_or(0);
523
524        for (client_id, update) in &masked_updates {
525            if update.len() != parameter_count {
526                return Err(anyhow!(
527                    "Client {} has {} parameters, expected {}",
528                    client_id,
529                    update.len(),
530                    parameter_count
531                ));
532            }
533        }
534
535        // Aggregate masked updates parameter by parameter
536        for param_idx in 0..parameter_count {
537            // Collect all client updates for this parameter
538            let mut parameter_updates = Vec::new();
539            let mut expected_shape: Option<Vec<usize>> = None;
540
541            for (client_id, update) in &masked_updates {
542                let param_update = &update[param_idx];
543
544                // Validate tensor shapes are consistent across clients
545                if let Some(ref shape) = expected_shape {
546                    if param_update.shape() != *shape {
547                        return Err(anyhow!(
548                            "Client {} parameter {} has shape {:?}, expected {:?}",
549                            client_id,
550                            param_idx,
551                            param_update.shape(),
552                            shape
553                        ));
554                    }
555                } else {
556                    expected_shape = Some(param_update.shape());
557                }
558
559                parameter_updates.push(param_update);
560            }
561
562            // Sum all client updates for this parameter
563            let mut aggregated_param = Tensor::zeros(&expected_shape.unwrap())?;
564            for param_update in parameter_updates {
565                aggregated_param = aggregated_param.add(param_update)?;
566            }
567
568            // Average the aggregated parameter
569            // In secure aggregation, masks cancel out during summation
570            // so we get the true average without revealing individual updates
571            result.push(aggregated_param.div_scalar(client_count)?);
572        }
573
574        Ok(result)
575    }
576}
577
578#[cfg(test)]
579mod tests {
580    use super::*;
581
582    #[test]
583    fn test_fedavg_config_default() {
584        let config = FedAvgConfig::default();
585        assert_eq!(config.local_epochs, 5);
586        assert_eq!(config.client_fraction, 0.1);
587        assert_eq!(config.min_clients, 2);
588    }
589
590    #[test]
591    fn test_fedprox_config_default() {
592        let config = FedProxConfig::default();
593        assert_eq!(config.mu, 0.01);
594        assert_eq!(config.fedavg_config.local_epochs, 5);
595    }
596
597    #[test]
598    fn test_differential_privacy_config() {
599        let config = DifferentialPrivacyConfig::default();
600        assert_eq!(config.epsilon, 1.0);
601        assert_eq!(config.delta, 1e-5);
602        assert!(matches!(config.noise_mechanism, NoiseMechanism::Gaussian));
603    }
604
605    #[test]
606    fn test_client_selection_strategies() {
607        let clients = vec![
608            ClientInfo {
609                client_id: "client1".to_string(),
610                data_size: 100,
611                compute_capacity: 0.8,
612                communication_quality: 0.9,
613                available: true,
614            },
615            ClientInfo {
616                client_id: "client2".to_string(),
617                data_size: 200,
618                compute_capacity: 0.6,
619                communication_quality: 0.7,
620                available: true,
621            },
622        ];
623
624        let mut fedavg = FedAvg::new(FedAvgConfig::default());
625
626        // Test random selection
627        let selected = fedavg.select_clients(&clients, ClientSelectionStrategy::Random).unwrap();
628        assert!(!selected.is_empty());
629
630        // Test data size selection
631        let selected = fedavg.select_clients(&clients, ClientSelectionStrategy::DataSize).unwrap();
632        assert!(!selected.is_empty());
633    }
634
635    #[test]
636    fn test_secure_aggregation_creation() {
637        let secure_agg = SecureAggregation::new(3, 5).unwrap();
638        assert_eq!(secure_agg.threshold, 3);
639        assert_eq!(secure_agg.total_clients, 5);
640
641        // Should fail if threshold > total clients
642        assert!(SecureAggregation::new(6, 5).is_err());
643    }
644}