Skip to main content

scirs2_neural/training/
federated.rs

1//! Federated Learning Primitives
2//!
3//! Implements core building blocks for federated learning, where a model is
4//! trained collaboratively across multiple clients without sharing raw data.
5//!
6//! # Algorithms
7//!
8//! - **FedAvg**: Federated Averaging (McMahan et al., 2017)
9//! - **Weighted aggregation**: Aggregate client updates proportionally to dataset size
10//! - **Client selection**: Random or importance-based selection per round
11//! - **Differential privacy**: Gaussian mechanism noise injection
12//! - **Gradient compression**: Top-k sparsification to reduce communication
13//!
14//! # Example
15//!
16//! ```rust
17//! use scirs2_neural::training::federated::{
18//!     FederatedConfig, FederatedServer, ClientUpdate,
19//! };
20//! use scirs2_core::ndarray::{array, Array, IxDyn};
21//!
22//! let config = FederatedConfig::builder()
23//!     .num_rounds(10)
24//!     .clients_per_round(3)
25//!     .build()
26//!     .expect("valid config");
27//!
28//! // Start with global parameters
29//! let global_params = vec![array![1.0_f64, 2.0, 3.0].into_dyn()];
30//! let mut server = FederatedServer::new(config, global_params);
31//!
32//! // Simulate client updates
33//! let updates = vec![
34//!     ClientUpdate::new(0, vec![array![1.1, 2.1, 3.1].into_dyn()], 100),
35//!     ClientUpdate::new(1, vec![array![0.9, 1.9, 2.9].into_dyn()], 200),
36//! ];
37//!
38//! server.aggregate_round(&updates).expect("aggregation ok");
39//! assert_eq!(server.current_round(), 1);
40//! ```
41
42use crate::error::{NeuralError, Result};
43use scirs2_core::ndarray::{Array, ArrayD, IxDyn, ScalarOperand};
44use scirs2_core::numeric::{Float, FromPrimitive, NumAssign, ToPrimitive};
45use scirs2_core::random::rngs::SmallRng;
46use scirs2_core::random::{Rng, RngExt, SeedableRng};
47use std::fmt::{self, Debug, Display};
48
49// ============================================================================
50// Types
51// ============================================================================
52
53/// Strategy for selecting clients each round.
54#[derive(Debug, Clone, Copy, PartialEq, Eq)]
55pub enum ClientSelectionStrategy {
56    /// Select clients uniformly at random.
57    Random,
58    /// Select clients proportionally to their dataset size.
59    ImportanceBased,
60    /// Select all available clients.
61    All,
62}
63
64impl Display for ClientSelectionStrategy {
65    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66        match self {
67            Self::Random => write!(f, "Random"),
68            Self::ImportanceBased => write!(f, "ImportanceBased"),
69            Self::All => write!(f, "All"),
70        }
71    }
72}
73
74/// Aggregation method for combining client updates.
75#[derive(Debug, Clone, Copy, PartialEq, Eq)]
76pub enum AggregationMethod {
77    /// FedAvg: weighted average by number of samples.
78    FedAvg,
79    /// Simple (unweighted) mean of client parameters.
80    SimpleMean,
81    /// Median aggregation (more robust to outliers).
82    Median,
83}
84
85impl Display for AggregationMethod {
86    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
87        match self {
88            Self::FedAvg => write!(f, "FedAvg"),
89            Self::SimpleMean => write!(f, "SimpleMean"),
90            Self::Median => write!(f, "Median"),
91        }
92    }
93}
94
95// ============================================================================
96// Differential Privacy
97// ============================================================================
98
99/// Configuration for differential privacy noise addition.
100#[derive(Debug, Clone)]
101pub struct DifferentialPrivacyConfig {
102    /// Whether DP is enabled.
103    pub enabled: bool,
104    /// Noise multiplier (sigma). Larger = more privacy, less utility.
105    pub noise_multiplier: f64,
106    /// Maximum L2 norm for gradient clipping before noise addition.
107    pub max_grad_norm: f64,
108    /// Target delta for (epsilon, delta)-DP.
109    pub delta: f64,
110}
111
112impl Default for DifferentialPrivacyConfig {
113    fn default() -> Self {
114        Self {
115            enabled: false,
116            noise_multiplier: 1.0,
117            max_grad_norm: 1.0,
118            delta: 1e-5,
119        }
120    }
121}
122
123/// Configuration for gradient compression.
124#[derive(Debug, Clone)]
125pub struct GradientCompressionConfig {
126    /// Whether compression is enabled.
127    pub enabled: bool,
128    /// Fraction of top-k values to keep (0.0 to 1.0).
129    pub top_k_fraction: f64,
130}
131
132impl Default for GradientCompressionConfig {
133    fn default() -> Self {
134        Self {
135            enabled: false,
136            top_k_fraction: 0.1,
137        }
138    }
139}
140
141// ============================================================================
142// Configuration
143// ============================================================================
144
145/// Configuration for federated learning.
146#[derive(Debug, Clone)]
147pub struct FederatedConfig {
148    /// Total number of communication rounds.
149    pub num_rounds: usize,
150    /// Number of clients to select per round.
151    pub clients_per_round: usize,
152    /// Client selection strategy.
153    pub client_selection: ClientSelectionStrategy,
154    /// Aggregation method.
155    pub aggregation: AggregationMethod,
156    /// Differential privacy configuration.
157    pub dp_config: DifferentialPrivacyConfig,
158    /// Gradient compression configuration.
159    pub compression: GradientCompressionConfig,
160    /// Local epochs per client per round.
161    pub local_epochs: usize,
162    /// Local learning rate for client training.
163    pub local_lr: f64,
164    /// Random seed for reproducibility.
165    pub seed: Option<u64>,
166}
167
168impl Default for FederatedConfig {
169    fn default() -> Self {
170        Self {
171            num_rounds: 100,
172            clients_per_round: 10,
173            client_selection: ClientSelectionStrategy::Random,
174            aggregation: AggregationMethod::FedAvg,
175            dp_config: DifferentialPrivacyConfig::default(),
176            compression: GradientCompressionConfig::default(),
177            local_epochs: 1,
178            local_lr: 0.01,
179            seed: None,
180        }
181    }
182}
183
184impl FederatedConfig {
185    /// Create a builder.
186    pub fn builder() -> FederatedConfigBuilder {
187        FederatedConfigBuilder::default()
188    }
189
190    /// Validate the configuration.
191    pub fn validate(&self) -> Result<()> {
192        if self.num_rounds == 0 {
193            return Err(NeuralError::InvalidArgument(
194                "num_rounds must be > 0".into(),
195            ));
196        }
197        if self.clients_per_round == 0 {
198            return Err(NeuralError::InvalidArgument(
199                "clients_per_round must be > 0".into(),
200            ));
201        }
202        if self.local_epochs == 0 {
203            return Err(NeuralError::InvalidArgument(
204                "local_epochs must be > 0".into(),
205            ));
206        }
207        if self.local_lr <= 0.0 {
208            return Err(NeuralError::InvalidArgument(
209                "local_lr must be positive".into(),
210            ));
211        }
212        if self.dp_config.enabled && self.dp_config.noise_multiplier <= 0.0 {
213            return Err(NeuralError::InvalidArgument(
214                "noise_multiplier must be positive when DP is enabled".into(),
215            ));
216        }
217        if self.dp_config.enabled && self.dp_config.max_grad_norm <= 0.0 {
218            return Err(NeuralError::InvalidArgument(
219                "max_grad_norm must be positive when DP is enabled".into(),
220            ));
221        }
222        if self.compression.enabled && !(0.0..=1.0).contains(&self.compression.top_k_fraction) {
223            return Err(NeuralError::InvalidArgument(
224                "top_k_fraction must be in [0.0, 1.0]".into(),
225            ));
226        }
227        Ok(())
228    }
229}
230
231// ============================================================================
232// Builder
233// ============================================================================
234
235/// Builder for [`FederatedConfig`].
236#[derive(Debug, Clone, Default)]
237pub struct FederatedConfigBuilder {
238    config: FederatedConfig,
239}
240
241impl FederatedConfigBuilder {
242    /// Set the number of communication rounds.
243    pub fn num_rounds(mut self, n: usize) -> Self {
244        self.config.num_rounds = n;
245        self
246    }
247
248    /// Set the number of clients per round.
249    pub fn clients_per_round(mut self, n: usize) -> Self {
250        self.config.clients_per_round = n;
251        self
252    }
253
254    /// Set the client selection strategy.
255    pub fn client_selection(mut self, s: ClientSelectionStrategy) -> Self {
256        self.config.client_selection = s;
257        self
258    }
259
260    /// Set the aggregation method.
261    pub fn aggregation(mut self, a: AggregationMethod) -> Self {
262        self.config.aggregation = a;
263        self
264    }
265
266    /// Enable differential privacy.
267    pub fn differential_privacy(mut self, noise_multiplier: f64, max_grad_norm: f64) -> Self {
268        self.config.dp_config.enabled = true;
269        self.config.dp_config.noise_multiplier = noise_multiplier;
270        self.config.dp_config.max_grad_norm = max_grad_norm;
271        self
272    }
273
274    /// Set the DP delta.
275    pub fn dp_delta(mut self, delta: f64) -> Self {
276        self.config.dp_config.delta = delta;
277        self
278    }
279
280    /// Enable gradient compression.
281    pub fn gradient_compression(mut self, top_k_fraction: f64) -> Self {
282        self.config.compression.enabled = true;
283        self.config.compression.top_k_fraction = top_k_fraction;
284        self
285    }
286
287    /// Set local epochs per client.
288    pub fn local_epochs(mut self, n: usize) -> Self {
289        self.config.local_epochs = n;
290        self
291    }
292
293    /// Set local learning rate.
294    pub fn local_lr(mut self, lr: f64) -> Self {
295        self.config.local_lr = lr;
296        self
297    }
298
299    /// Set the random seed.
300    pub fn seed(mut self, s: u64) -> Self {
301        self.config.seed = Some(s);
302        self
303    }
304
305    /// Build the configuration.
306    pub fn build(self) -> Result<FederatedConfig> {
307        self.config.validate()?;
308        Ok(self.config)
309    }
310}
311
312// ============================================================================
313// Client Update
314// ============================================================================
315
316/// A client's update to send to the server.
317#[derive(Debug, Clone)]
318pub struct ClientUpdate {
319    /// Client identifier.
320    pub client_id: usize,
321    /// Updated parameters (one array per parameter tensor).
322    pub parameters: Vec<ArrayD<f64>>,
323    /// Number of local training samples.
324    pub num_samples: usize,
325    /// Optional: local training loss (for diagnostics).
326    pub local_loss: Option<f64>,
327    /// Optional: local training metrics.
328    pub metrics: std::collections::HashMap<String, f64>,
329}
330
331impl ClientUpdate {
332    /// Create a new client update.
333    pub fn new(client_id: usize, parameters: Vec<ArrayD<f64>>, num_samples: usize) -> Self {
334        Self {
335            client_id,
336            parameters,
337            num_samples,
338            local_loss: None,
339            metrics: std::collections::HashMap::new(),
340        }
341    }
342
343    /// Set the local training loss.
344    pub fn with_loss(mut self, loss: f64) -> Self {
345        self.local_loss = Some(loss);
346        self
347    }
348
349    /// Add a metric.
350    pub fn with_metric(mut self, name: &str, value: f64) -> Self {
351        self.metrics.insert(name.to_string(), value);
352        self
353    }
354}
355
356// ============================================================================
357// Round statistics
358// ============================================================================
359
360/// Statistics for a single federated round.
361#[derive(Debug, Clone)]
362pub struct RoundStats {
363    /// Round number (0-indexed).
364    pub round: usize,
365    /// Number of clients participating.
366    pub num_clients: usize,
367    /// Total samples across participating clients.
368    pub total_samples: usize,
369    /// Average local loss (if reported).
370    pub avg_loss: Option<f64>,
371    /// IDs of participating clients.
372    pub client_ids: Vec<usize>,
373}
374
375// ============================================================================
376// Federated Server
377// ============================================================================
378
379/// Federated learning server that coordinates the training process.
380///
381/// The server holds the global model parameters and aggregates updates
382/// from participating clients each round.
383#[derive(Debug, Clone)]
384pub struct FederatedServer {
385    /// Configuration.
386    config: FederatedConfig,
387    /// Current global parameters.
388    global_params: Vec<ArrayD<f64>>,
389    /// Current communication round.
390    current_round: usize,
391    /// History of round statistics.
392    round_history: Vec<RoundStats>,
393    /// RNG for client selection.
394    rng: SmallRng,
395}
396
397impl FederatedServer {
398    /// Create a new federated server with initial global parameters.
399    pub fn new(config: FederatedConfig, global_params: Vec<ArrayD<f64>>) -> Self {
400        let rng = match config.seed {
401            Some(s) => SmallRng::seed_from_u64(s),
402            None => SmallRng::seed_from_u64(42),
403        };
404        Self {
405            config,
406            global_params,
407            current_round: 0,
408            round_history: Vec::new(),
409            rng,
410        }
411    }
412
413    /// Get the current global parameters.
414    pub fn global_params(&self) -> &[ArrayD<f64>] {
415        &self.global_params
416    }
417
418    /// Get the current round number.
419    pub fn current_round(&self) -> usize {
420        self.current_round
421    }
422
423    /// Get the round history.
424    pub fn round_history(&self) -> &[RoundStats] {
425        &self.round_history
426    }
427
428    /// Whether training is complete (all rounds done).
429    pub fn is_complete(&self) -> bool {
430        self.current_round >= self.config.num_rounds
431    }
432
433    /// Select clients for the current round.
434    ///
435    /// # Arguments
436    /// * `available_clients` - list of (client_id, num_samples)
437    ///
438    /// Returns the selected client IDs.
439    pub fn select_clients(&mut self, available_clients: &[(usize, usize)]) -> Vec<usize> {
440        if available_clients.is_empty() {
441            return Vec::new();
442        }
443
444        let k = self.config.clients_per_round.min(available_clients.len());
445
446        match self.config.client_selection {
447            ClientSelectionStrategy::All => available_clients.iter().map(|&(id, _)| id).collect(),
448            ClientSelectionStrategy::Random => {
449                // Fisher-Yates partial shuffle
450                let mut indices: Vec<usize> = (0..available_clients.len()).collect();
451                for i in 0..k {
452                    let j = i + self.rng.random_range(0..indices.len() - i);
453                    indices.swap(i, j);
454                }
455                indices[..k]
456                    .iter()
457                    .map(|&i| available_clients[i].0)
458                    .collect()
459            }
460            ClientSelectionStrategy::ImportanceBased => {
461                // Select proportionally to dataset size
462                let total: usize = available_clients.iter().map(|&(_, n)| n).sum();
463                if total == 0 {
464                    return available_clients
465                        .iter()
466                        .take(k)
467                        .map(|&(id, _)| id)
468                        .collect();
469                }
470
471                let mut selected = Vec::with_capacity(k);
472                let mut used = vec![false; available_clients.len()];
473
474                for _ in 0..k {
475                    let threshold = self.rng.random_range(0..total);
476                    let mut cumulative = 0usize;
477                    for (idx, &(client_id, n)) in available_clients.iter().enumerate() {
478                        if used[idx] {
479                            continue;
480                        }
481                        cumulative += n;
482                        if cumulative > threshold {
483                            selected.push(client_id);
484                            used[idx] = true;
485                            break;
486                        }
487                    }
488                    // If we didn't select anyone (edge case), pick first unused
489                    if selected.len() < selected.capacity()
490                        && selected.len() < k
491                        && selected.len() == selected.len()
492                    {
493                        // already handled by break above
494                    }
495                }
496
497                // Fill up if importance sampling missed some
498                if selected.len() < k {
499                    for (idx, &(client_id, _)) in available_clients.iter().enumerate() {
500                        if selected.len() >= k {
501                            break;
502                        }
503                        if !used[idx] {
504                            selected.push(client_id);
505                            used[idx] = true;
506                        }
507                    }
508                }
509
510                selected
511            }
512        }
513    }
514
515    /// Aggregate client updates for one round using the configured method.
516    pub fn aggregate_round(&mut self, updates: &[ClientUpdate]) -> Result<()> {
517        if updates.is_empty() {
518            return Err(NeuralError::InvalidArgument(
519                "No client updates to aggregate".into(),
520            ));
521        }
522
523        // Validate parameter shapes
524        for update in updates {
525            if update.parameters.len() != self.global_params.len() {
526                return Err(NeuralError::ShapeMismatch(format!(
527                    "Client {} has {} parameter tensors, expected {}",
528                    update.client_id,
529                    update.parameters.len(),
530                    self.global_params.len()
531                )));
532            }
533            for (i, param) in update.parameters.iter().enumerate() {
534                if param.shape() != self.global_params[i].shape() {
535                    return Err(NeuralError::ShapeMismatch(format!(
536                        "Client {} param[{}] shape {:?} != global {:?}",
537                        update.client_id,
538                        i,
539                        param.shape(),
540                        self.global_params[i].shape()
541                    )));
542                }
543            }
544        }
545
546        // Apply compression if enabled
547        let processed_updates = if self.config.compression.enabled {
548            updates
549                .iter()
550                .map(|u| {
551                    let compressed = compress_gradients(
552                        &u.parameters,
553                        &self.global_params,
554                        self.config.compression.top_k_fraction,
555                    );
556                    ClientUpdate {
557                        client_id: u.client_id,
558                        parameters: apply_compressed_delta(&self.global_params, &compressed),
559                        num_samples: u.num_samples,
560                        local_loss: u.local_loss,
561                        metrics: u.metrics.clone(),
562                    }
563                })
564                .collect::<Vec<_>>()
565        } else {
566            updates.to_vec()
567        };
568
569        // Aggregate
570        match self.config.aggregation {
571            AggregationMethod::FedAvg => self.fedavg_aggregate(&processed_updates),
572            AggregationMethod::SimpleMean => self.simple_mean_aggregate(&processed_updates),
573            AggregationMethod::Median => self.median_aggregate(&processed_updates),
574        }?;
575
576        // Apply differential privacy noise
577        if self.config.dp_config.enabled {
578            self.apply_dp_noise()?;
579        }
580
581        // Record round stats
582        let total_samples: usize = updates.iter().map(|u| u.num_samples).sum();
583        let avg_loss = {
584            let losses: Vec<f64> = updates.iter().filter_map(|u| u.local_loss).collect();
585            if losses.is_empty() {
586                None
587            } else {
588                Some(losses.iter().sum::<f64>() / losses.len() as f64)
589            }
590        };
591
592        self.round_history.push(RoundStats {
593            round: self.current_round,
594            num_clients: updates.len(),
595            total_samples,
596            avg_loss,
597            client_ids: updates.iter().map(|u| u.client_id).collect(),
598        });
599
600        self.current_round += 1;
601        Ok(())
602    }
603
604    /// FedAvg: weighted average by number of samples.
605    fn fedavg_aggregate(&mut self, updates: &[ClientUpdate]) -> Result<()> {
606        let total_samples: f64 = updates.iter().map(|u| u.num_samples as f64).sum();
607        if total_samples < f64::EPSILON {
608            return Err(NeuralError::ComputationError(
609                "Total samples is zero".into(),
610            ));
611        }
612
613        for p_idx in 0..self.global_params.len() {
614            let mut aggregated = ArrayD::<f64>::zeros(self.global_params[p_idx].raw_dim());
615            for update in updates {
616                let weight = update.num_samples as f64 / total_samples;
617                aggregated = aggregated + &update.parameters[p_idx] * weight;
618            }
619            self.global_params[p_idx] = aggregated;
620        }
621        Ok(())
622    }
623
624    /// Simple mean (unweighted).
625    fn simple_mean_aggregate(&mut self, updates: &[ClientUpdate]) -> Result<()> {
626        let n = updates.len() as f64;
627        for p_idx in 0..self.global_params.len() {
628            let mut aggregated = ArrayD::<f64>::zeros(self.global_params[p_idx].raw_dim());
629            for update in updates {
630                aggregated += &update.parameters[p_idx];
631            }
632            self.global_params[p_idx] = aggregated / n;
633        }
634        Ok(())
635    }
636
637    /// Median aggregation (element-wise median across clients).
638    fn median_aggregate(&mut self, updates: &[ClientUpdate]) -> Result<()> {
639        for p_idx in 0..self.global_params.len() {
640            let shape = self.global_params[p_idx].raw_dim();
641            let flat_len = self.global_params[p_idx].len();
642            let mut result = ArrayD::<f64>::zeros(shape);
643
644            for elem_idx in 0..flat_len {
645                let mut values: Vec<f64> = updates
646                    .iter()
647                    .map(|u| {
648                        u.parameters[p_idx]
649                            .as_slice()
650                            .map(|s| s[elem_idx])
651                            .unwrap_or(0.0)
652                    })
653                    .collect();
654                values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
655
656                let median = if values.len().is_multiple_of(2) && values.len() >= 2 {
657                    (values[values.len() / 2 - 1] + values[values.len() / 2]) / 2.0
658                } else {
659                    values[values.len() / 2]
660                };
661
662                if let Some(slice) = result.as_slice_mut() {
663                    slice[elem_idx] = median;
664                }
665            }
666            self.global_params[p_idx] = result;
667        }
668        Ok(())
669    }
670
671    /// Apply differential privacy Gaussian noise to global parameters.
672    fn apply_dp_noise(&mut self) -> Result<()> {
673        let sigma = self.config.dp_config.noise_multiplier * self.config.dp_config.max_grad_norm;
674
675        for param in &mut self.global_params {
676            let noise = generate_gaussian_noise(param.len(), 0.0, sigma, &mut self.rng);
677            let noise_arr = ArrayD::from_shape_vec(param.raw_dim(), noise).map_err(|e| {
678                NeuralError::ComputationError(format!("Failed to create noise array: {e}"))
679            })?;
680            *param = &*param + &noise_arr;
681        }
682        Ok(())
683    }
684
685    /// Generate a text summary of the federated training.
686    pub fn summary(&self) -> String {
687        let mut out = String::new();
688        out.push_str("=== Federated Learning Summary ===\n");
689        out.push_str(&format!("Aggregation: {}\n", self.config.aggregation));
690        out.push_str(&format!("Selection: {}\n", self.config.client_selection));
691        out.push_str(&format!(
692            "Rounds: {} / {}\n",
693            self.current_round, self.config.num_rounds
694        ));
695        out.push_str(&format!("DP enabled: {}\n", self.config.dp_config.enabled));
696        out.push_str(&format!(
697            "Compression enabled: {}\n",
698            self.config.compression.enabled
699        ));
700
701        if let Some(last) = self.round_history.last() {
702            out.push_str(&format!(
703                "Last round: {} clients, {} samples",
704                last.num_clients, last.total_samples
705            ));
706            if let Some(loss) = last.avg_loss {
707                out.push_str(&format!(", avg_loss={loss:.6}"));
708            }
709            out.push('\n');
710        }
711        out
712    }
713}
714
715// ============================================================================
716// Gradient compression
717// ============================================================================
718
719/// Compress parameter deltas using top-k sparsification.
720///
721/// Computes delta = client_params - global_params, then keeps only the
722/// top-k elements by absolute value.
723fn compress_gradients(
724    client_params: &[ArrayD<f64>],
725    global_params: &[ArrayD<f64>],
726    top_k_fraction: f64,
727) -> Vec<Vec<(usize, f64)>> {
728    let mut compressed = Vec::with_capacity(client_params.len());
729
730    for (cp, gp) in client_params.iter().zip(global_params.iter()) {
731        let delta = cp - gp;
732        let flat: Vec<f64> = delta
733            .as_slice()
734            .map(|s| s.to_vec())
735            .unwrap_or_else(|| delta.iter().copied().collect());
736
737        let k = ((flat.len() as f64 * top_k_fraction).ceil() as usize)
738            .max(1)
739            .min(flat.len());
740
741        // Find top-k by absolute value
742        let mut indexed: Vec<(usize, f64)> = flat.into_iter().enumerate().collect();
743        indexed.sort_by(|a, b| {
744            b.1.abs()
745                .partial_cmp(&a.1.abs())
746                .unwrap_or(std::cmp::Ordering::Equal)
747        });
748        indexed.truncate(k);
749        compressed.push(indexed);
750    }
751
752    compressed
753}
754
755/// Reconstruct parameters from compressed deltas.
756fn apply_compressed_delta(
757    global_params: &[ArrayD<f64>],
758    compressed: &[Vec<(usize, f64)>],
759) -> Vec<ArrayD<f64>> {
760    let mut result = global_params.to_vec();
761    for (p_idx, deltas) in compressed.iter().enumerate() {
762        if let Some(slice) = result[p_idx].as_slice_mut() {
763            for &(idx, val) in deltas {
764                if idx < slice.len() {
765                    slice[idx] += val;
766                }
767            }
768        }
769    }
770    result
771}
772
773/// Clip the L2 norm of a parameter vector.
774pub fn clip_l2_norm(params: &mut [ArrayD<f64>], max_norm: f64) {
775    let norm_sq: f64 = params
776        .iter()
777        .map(|p| p.iter().map(|&x| x * x).sum::<f64>())
778        .sum();
779    let norm = norm_sq.sqrt();
780    if norm > max_norm && norm > f64::EPSILON {
781        let scale = max_norm / norm;
782        for p in params.iter_mut() {
783            p.mapv_inplace(|x| x * scale);
784        }
785    }
786}
787
788/// Generate Gaussian noise for differential privacy.
789fn generate_gaussian_noise(len: usize, mean: f64, std_dev: f64, rng: &mut SmallRng) -> Vec<f64> {
790    // Box-Muller transform for Gaussian noise
791    let mut result = Vec::with_capacity(len);
792    let mut i = 0;
793    while i < len {
794        let u1: f64 = rng.random_range(f64::EPSILON..1.0);
795        let u2: f64 = rng.random_range(0.0..std::f64::consts::TAU);
796        let r = (-2.0 * u1.ln()).sqrt();
797        let z0 = r * u2.cos() * std_dev + mean;
798        let z1 = r * u2.sin() * std_dev + mean;
799        result.push(z0);
800        i += 1;
801        if i < len {
802            result.push(z1);
803            i += 1;
804        }
805    }
806    result
807}
808
809// ============================================================================
810// Tests
811// ============================================================================
812
813#[cfg(test)]
814mod tests {
815    use super::*;
816    use scirs2_core::ndarray::array;
817
818    #[test]
819    fn test_config_defaults() {
820        let config = FederatedConfig::default();
821        assert_eq!(config.num_rounds, 100);
822        assert_eq!(config.clients_per_round, 10);
823        assert_eq!(config.aggregation, AggregationMethod::FedAvg);
824    }
825
826    #[test]
827    fn test_config_builder() {
828        let config = FederatedConfig::builder()
829            .num_rounds(50)
830            .clients_per_round(5)
831            .aggregation(AggregationMethod::SimpleMean)
832            .local_epochs(3)
833            .local_lr(0.1)
834            .seed(123)
835            .build()
836            .expect("valid config");
837
838        assert_eq!(config.num_rounds, 50);
839        assert_eq!(config.clients_per_round, 5);
840        assert_eq!(config.local_epochs, 3);
841    }
842
843    #[test]
844    fn test_config_validation_errors() {
845        assert!(FederatedConfig::builder().num_rounds(0).build().is_err());
846        assert!(FederatedConfig::builder()
847            .clients_per_round(0)
848            .build()
849            .is_err());
850        assert!(FederatedConfig::builder().local_epochs(0).build().is_err());
851        assert!(FederatedConfig::builder().local_lr(-1.0).build().is_err());
852        assert!(FederatedConfig::builder()
853            .differential_privacy(0.0, 1.0)
854            .build()
855            .is_err());
856        assert!(FederatedConfig::builder()
857            .gradient_compression(-0.1)
858            .build()
859            .is_err());
860    }
861
862    #[test]
863    fn test_fedavg_aggregation() {
864        let config = FederatedConfig::builder()
865            .num_rounds(10)
866            .clients_per_round(2)
867            .aggregation(AggregationMethod::FedAvg)
868            .build()
869            .expect("valid");
870
871        let global = vec![array![0.0_f64, 0.0, 0.0].into_dyn()];
872        let mut server = FederatedServer::new(config, global);
873
874        // Client A: 100 samples, params [1, 2, 3]
875        // Client B: 300 samples, params [3, 2, 1]
876        // FedAvg: (100*[1,2,3] + 300*[3,2,1]) / 400 = [2.5, 2.0, 1.5]
877        let updates = vec![
878            ClientUpdate::new(0, vec![array![1.0, 2.0, 3.0].into_dyn()], 100),
879            ClientUpdate::new(1, vec![array![3.0, 2.0, 1.0].into_dyn()], 300),
880        ];
881
882        server.aggregate_round(&updates).expect("ok");
883
884        let result = &server.global_params()[0];
885        let slice = result.as_slice().expect("contiguous");
886        assert!((slice[0] - 2.5).abs() < 1e-10);
887        assert!((slice[1] - 2.0).abs() < 1e-10);
888        assert!((slice[2] - 1.5).abs() < 1e-10);
889        assert_eq!(server.current_round(), 1);
890    }
891
892    #[test]
893    fn test_simple_mean_aggregation() {
894        let config = FederatedConfig::builder()
895            .num_rounds(10)
896            .clients_per_round(3)
897            .aggregation(AggregationMethod::SimpleMean)
898            .build()
899            .expect("valid");
900
901        let global = vec![array![0.0_f64, 0.0].into_dyn()];
902        let mut server = FederatedServer::new(config, global);
903
904        let updates = vec![
905            ClientUpdate::new(0, vec![array![1.0, 4.0].into_dyn()], 10),
906            ClientUpdate::new(1, vec![array![2.0, 5.0].into_dyn()], 10),
907            ClientUpdate::new(2, vec![array![3.0, 6.0].into_dyn()], 10),
908        ];
909
910        server.aggregate_round(&updates).expect("ok");
911
912        let result = &server.global_params()[0];
913        let slice = result.as_slice().expect("contiguous");
914        assert!((slice[0] - 2.0).abs() < 1e-10);
915        assert!((slice[1] - 5.0).abs() < 1e-10);
916    }
917
918    #[test]
919    fn test_median_aggregation() {
920        let config = FederatedConfig::builder()
921            .num_rounds(10)
922            .clients_per_round(3)
923            .aggregation(AggregationMethod::Median)
924            .build()
925            .expect("valid");
926
927        let global = vec![array![0.0_f64, 0.0].into_dyn()];
928        let mut server = FederatedServer::new(config, global);
929
930        let updates = vec![
931            ClientUpdate::new(0, vec![array![1.0, 100.0].into_dyn()], 10),
932            ClientUpdate::new(1, vec![array![2.0, 5.0].into_dyn()], 10),
933            ClientUpdate::new(2, vec![array![3.0, 6.0].into_dyn()], 10),
934        ];
935
936        server.aggregate_round(&updates).expect("ok");
937
938        let result = &server.global_params()[0];
939        let slice = result.as_slice().expect("contiguous");
940        // Median of [1,2,3] = 2, median of [5,6,100] = 6
941        assert!((slice[0] - 2.0).abs() < 1e-10);
942        assert!((slice[1] - 6.0).abs() < 1e-10);
943    }
944
945    #[test]
946    fn test_empty_updates_error() {
947        let config = FederatedConfig::builder()
948            .num_rounds(10)
949            .clients_per_round(2)
950            .build()
951            .expect("valid");
952
953        let global = vec![array![1.0_f64, 2.0].into_dyn()];
954        let mut server = FederatedServer::new(config, global);
955
956        assert!(server.aggregate_round(&[]).is_err());
957    }
958
959    #[test]
960    fn test_shape_mismatch_error() {
961        let config = FederatedConfig::builder()
962            .num_rounds(10)
963            .clients_per_round(1)
964            .build()
965            .expect("valid");
966
967        let global = vec![array![1.0_f64, 2.0].into_dyn()];
968        let mut server = FederatedServer::new(config, global);
969
970        // Wrong number of parameter tensors
971        let updates = vec![ClientUpdate::new(
972            0,
973            vec![array![1.0, 2.0].into_dyn(), array![3.0].into_dyn()],
974            10,
975        )];
976        assert!(server.aggregate_round(&updates).is_err());
977    }
978
979    #[test]
980    fn test_client_selection_random() {
981        let config = FederatedConfig::builder()
982            .num_rounds(10)
983            .clients_per_round(3)
984            .client_selection(ClientSelectionStrategy::Random)
985            .seed(42)
986            .build()
987            .expect("valid");
988
989        let global = vec![array![0.0_f64].into_dyn()];
990        let mut server = FederatedServer::new(config, global);
991
992        let clients = vec![(0, 100), (1, 200), (2, 300), (3, 400), (4, 500)];
993        let selected = server.select_clients(&clients);
994
995        assert_eq!(selected.len(), 3);
996        // All should be valid IDs
997        for id in &selected {
998            assert!(*id <= 4);
999        }
1000    }
1001
1002    #[test]
1003    fn test_client_selection_all() {
1004        let config = FederatedConfig::builder()
1005            .num_rounds(10)
1006            .clients_per_round(2)
1007            .client_selection(ClientSelectionStrategy::All)
1008            .build()
1009            .expect("valid");
1010
1011        let global = vec![array![0.0_f64].into_dyn()];
1012        let mut server = FederatedServer::new(config, global);
1013
1014        let clients = vec![(0, 100), (1, 200), (2, 300)];
1015        let selected = server.select_clients(&clients);
1016
1017        assert_eq!(selected.len(), 3); // All selected
1018    }
1019
1020    #[test]
1021    fn test_client_selection_importance() {
1022        let config = FederatedConfig::builder()
1023            .num_rounds(10)
1024            .clients_per_round(2)
1025            .client_selection(ClientSelectionStrategy::ImportanceBased)
1026            .seed(42)
1027            .build()
1028            .expect("valid");
1029
1030        let global = vec![array![0.0_f64].into_dyn()];
1031        let mut server = FederatedServer::new(config, global);
1032
1033        let clients = vec![(0, 1), (1, 1000), (2, 1)];
1034        let selected = server.select_clients(&clients);
1035
1036        assert_eq!(selected.len(), 2);
1037    }
1038
1039    #[test]
1040    fn test_dp_noise_application() {
1041        let config = FederatedConfig::builder()
1042            .num_rounds(10)
1043            .clients_per_round(1)
1044            .differential_privacy(1.0, 1.0)
1045            .seed(42)
1046            .build()
1047            .expect("valid");
1048
1049        let global = vec![array![0.0_f64, 0.0, 0.0].into_dyn()];
1050        let mut server = FederatedServer::new(config, global);
1051
1052        let updates = vec![ClientUpdate::new(
1053            0,
1054            vec![array![1.0, 2.0, 3.0].into_dyn()],
1055            100,
1056        )];
1057
1058        server.aggregate_round(&updates).expect("ok");
1059
1060        // With DP noise, the result should not be exactly [1, 2, 3]
1061        let result = &server.global_params()[0];
1062        let slice = result.as_slice().expect("contiguous");
1063        let any_noisy = slice[0] != 1.0 || slice[1] != 2.0 || slice[2] != 3.0;
1064        assert!(any_noisy, "DP noise should perturb the result");
1065    }
1066
1067    #[test]
1068    fn test_gradient_compression() {
1069        let config = FederatedConfig::builder()
1070            .num_rounds(10)
1071            .clients_per_round(1)
1072            .gradient_compression(0.5)
1073            .build()
1074            .expect("valid");
1075
1076        let global = vec![array![1.0_f64, 2.0, 3.0, 4.0].into_dyn()];
1077        let mut server = FederatedServer::new(config, global);
1078
1079        // Client has a big delta on elements 0 and 3, small on 1 and 2
1080        let updates = vec![ClientUpdate::new(
1081            0,
1082            vec![array![10.0, 2.1, 3.1, 14.0].into_dyn()],
1083            100,
1084        )];
1085
1086        server.aggregate_round(&updates).expect("ok");
1087
1088        // With top-50% compression, only the 2 largest deltas should be applied
1089        // Deltas: [9, 0.1, 0.1, 10] -> top-2: indices 3(10) and 0(9)
1090        let result = &server.global_params()[0];
1091        let slice = result.as_slice().expect("contiguous");
1092        // Elements 0 and 3 should get the delta; 1 and 2 should stay at global
1093        assert!((slice[0] - 10.0).abs() < 1e-10);
1094        assert!((slice[1] - 2.0).abs() < 1e-10); // no change
1095        assert!((slice[2] - 3.0).abs() < 1e-10); // no change
1096        assert!((slice[3] - 14.0).abs() < 1e-10);
1097    }
1098
1099    #[test]
1100    fn test_clip_l2_norm() {
1101        let mut params = vec![array![3.0_f64, 4.0].into_dyn()];
1102        // norm = 5.0
1103        clip_l2_norm(&mut params, 1.0);
1104        let slice = params[0].as_slice().expect("contiguous");
1105        let norm = (slice[0] * slice[0] + slice[1] * slice[1]).sqrt();
1106        assert!((norm - 1.0).abs() < 1e-10);
1107    }
1108
1109    #[test]
1110    fn test_clip_l2_norm_no_clip_needed() {
1111        let mut params = vec![array![0.3_f64, 0.4].into_dyn()];
1112        // norm = 0.5 < 1.0
1113        clip_l2_norm(&mut params, 1.0);
1114        let slice = params[0].as_slice().expect("contiguous");
1115        assert!((slice[0] - 0.3).abs() < 1e-10);
1116        assert!((slice[1] - 0.4).abs() < 1e-10);
1117    }
1118
1119    #[test]
1120    fn test_multiple_rounds() {
1121        let config = FederatedConfig::builder()
1122            .num_rounds(3)
1123            .clients_per_round(2)
1124            .aggregation(AggregationMethod::SimpleMean)
1125            .build()
1126            .expect("valid");
1127
1128        let global = vec![array![0.0_f64, 0.0].into_dyn()];
1129        let mut server = FederatedServer::new(config, global);
1130
1131        for round in 0..3 {
1132            let v = (round + 1) as f64;
1133            let updates = vec![
1134                ClientUpdate::new(0, vec![array![v, v * 2.0].into_dyn()], 10),
1135                ClientUpdate::new(1, vec![array![v * 3.0, v * 4.0].into_dyn()], 10),
1136            ];
1137            server.aggregate_round(&updates).expect("ok");
1138        }
1139
1140        assert_eq!(server.current_round(), 3);
1141        assert!(server.is_complete());
1142        assert_eq!(server.round_history().len(), 3);
1143    }
1144
1145    #[test]
1146    fn test_client_update_with_metrics() {
1147        let update = ClientUpdate::new(0, vec![array![1.0_f64].into_dyn()], 100)
1148            .with_loss(0.5)
1149            .with_metric("accuracy", 0.95);
1150
1151        assert_eq!(update.local_loss, Some(0.5));
1152        assert!((update.metrics["accuracy"] - 0.95).abs() < 1e-10);
1153    }
1154
1155    #[test]
1156    fn test_round_stats_avg_loss() {
1157        let config = FederatedConfig::builder()
1158            .num_rounds(10)
1159            .clients_per_round(2)
1160            .build()
1161            .expect("valid");
1162
1163        let global = vec![array![0.0_f64].into_dyn()];
1164        let mut server = FederatedServer::new(config, global);
1165
1166        let updates = vec![
1167            ClientUpdate::new(0, vec![array![1.0].into_dyn()], 10).with_loss(0.3),
1168            ClientUpdate::new(1, vec![array![2.0].into_dyn()], 10).with_loss(0.7),
1169        ];
1170
1171        server.aggregate_round(&updates).expect("ok");
1172
1173        let stats = &server.round_history()[0];
1174        assert_eq!(stats.num_clients, 2);
1175        assert_eq!(stats.total_samples, 20);
1176        assert!((stats.avg_loss.expect("has loss") - 0.5).abs() < 1e-10);
1177    }
1178
1179    #[test]
1180    fn test_summary_generation() {
1181        let config = FederatedConfig::builder()
1182            .num_rounds(10)
1183            .clients_per_round(2)
1184            .build()
1185            .expect("valid");
1186
1187        let global = vec![array![0.0_f64].into_dyn()];
1188        let mut server = FederatedServer::new(config, global);
1189
1190        let updates = vec![ClientUpdate::new(0, vec![array![1.0].into_dyn()], 10)];
1191        server.aggregate_round(&updates).expect("ok");
1192
1193        let summary = server.summary();
1194        assert!(summary.contains("Federated Learning Summary"));
1195        assert!(summary.contains("FedAvg"));
1196    }
1197
1198    #[test]
1199    fn test_display_types() {
1200        assert_eq!(format!("{}", ClientSelectionStrategy::Random), "Random");
1201        assert_eq!(
1202            format!("{}", ClientSelectionStrategy::ImportanceBased),
1203            "ImportanceBased"
1204        );
1205        assert_eq!(format!("{}", ClientSelectionStrategy::All), "All");
1206        assert_eq!(format!("{}", AggregationMethod::FedAvg), "FedAvg");
1207        assert_eq!(format!("{}", AggregationMethod::SimpleMean), "SimpleMean");
1208        assert_eq!(format!("{}", AggregationMethod::Median), "Median");
1209    }
1210
1211    #[test]
1212    fn test_gaussian_noise_generation() {
1213        let mut rng = SmallRng::seed_from_u64(42);
1214        let noise = generate_gaussian_noise(1000, 0.0, 1.0, &mut rng);
1215        assert_eq!(noise.len(), 1000);
1216
1217        // Check that mean is approximately 0
1218        let mean = noise.iter().sum::<f64>() / noise.len() as f64;
1219        assert!(mean.abs() < 0.2, "mean={mean}, expected ~0");
1220
1221        // Check that std is approximately 1
1222        let var = noise.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / noise.len() as f64;
1223        let std = var.sqrt();
1224        assert!((std - 1.0).abs() < 0.2, "std={std}, expected ~1");
1225    }
1226
1227    #[test]
1228    fn test_multi_param_tensors() {
1229        let config = FederatedConfig::builder()
1230            .num_rounds(10)
1231            .clients_per_round(2)
1232            .aggregation(AggregationMethod::SimpleMean)
1233            .build()
1234            .expect("valid");
1235
1236        let global = vec![
1237            array![1.0_f64, 2.0].into_dyn(),
1238            array![3.0_f64, 4.0, 5.0].into_dyn(),
1239        ];
1240        let mut server = FederatedServer::new(config, global);
1241
1242        let updates = vec![
1243            ClientUpdate::new(
1244                0,
1245                vec![
1246                    array![2.0, 4.0].into_dyn(),
1247                    array![6.0, 8.0, 10.0].into_dyn(),
1248                ],
1249                10,
1250            ),
1251            ClientUpdate::new(
1252                1,
1253                vec![
1254                    array![4.0, 6.0].into_dyn(),
1255                    array![9.0, 12.0, 15.0].into_dyn(),
1256                ],
1257                10,
1258            ),
1259        ];
1260
1261        server.aggregate_round(&updates).expect("ok");
1262
1263        let p0 = server.global_params()[0].as_slice().expect("contiguous");
1264        assert!((p0[0] - 3.0).abs() < 1e-10);
1265        assert!((p0[1] - 5.0).abs() < 1e-10);
1266
1267        let p1 = server.global_params()[1].as_slice().expect("contiguous");
1268        assert!((p1[0] - 7.5).abs() < 1e-10);
1269        assert!((p1[1] - 10.0).abs() < 1e-10);
1270        assert!((p1[2] - 12.5).abs() < 1e-10);
1271    }
1272}