Skip to main content

rvf_federation/
aggregate.rs

1//! Federated aggregation: FedAvg, FedProx, Byzantine-tolerant weighted averaging.
2
3use crate::error::FederationError;
4use crate::types::AggregateWeights;
5
6/// Aggregation strategy.
7#[derive(Clone, Copy, Debug, PartialEq, Eq)]
8pub enum AggregationStrategy {
9    /// Federated Averaging (McMahan et al., 2017).
10    FedAvg,
11    /// Federated Proximal (Li et al., 2020).
12    FedProx { mu: u32 },
13    /// Simple weighted average.
14    WeightedAverage,
15}
16
17impl Default for AggregationStrategy {
18    fn default() -> Self {
19        Self::FedAvg
20    }
21}
22
23/// A single contribution to a federated averaging round.
24#[derive(Clone, Debug)]
25pub struct Contribution {
26    /// Contributor pseudonym.
27    pub contributor: String,
28    /// Weight vector (LoRA deltas).
29    pub weights: Vec<f64>,
30    /// Quality/reputation weight for this contributor.
31    pub quality_weight: f64,
32    /// Number of training trajectories behind this contribution.
33    pub trajectory_count: u64,
34}
35
36/// Federated aggregation server.
37pub struct FederatedAggregator {
38    /// Aggregation strategy.
39    strategy: AggregationStrategy,
40    /// Domain identifier.
41    domain_id: String,
42    /// Current round number.
43    round: u64,
44    /// Minimum contributions required for a round.
45    min_contributions: usize,
46    /// Standard deviation threshold for Byzantine outlier detection.
47    byzantine_std_threshold: f64,
48    /// Collected contributions for the current round.
49    contributions: Vec<Contribution>,
50}
51
52impl FederatedAggregator {
53    /// Create a new aggregator.
54    pub fn new(domain_id: String, strategy: AggregationStrategy) -> Self {
55        Self {
56            strategy,
57            domain_id,
58            round: 0,
59            min_contributions: 2,
60            byzantine_std_threshold: 2.0,
61            contributions: Vec::new(),
62        }
63    }
64
65    /// Set minimum contributions required.
66    pub fn with_min_contributions(mut self, min: usize) -> Self {
67        self.min_contributions = min;
68        self
69    }
70
71    /// Set Byzantine outlier threshold (in standard deviations).
72    pub fn with_byzantine_threshold(mut self, threshold: f64) -> Self {
73        self.byzantine_std_threshold = threshold;
74        self
75    }
76
77    /// Add a contribution for the current round.
78    pub fn add_contribution(&mut self, contribution: Contribution) {
79        self.contributions.push(contribution);
80    }
81
82    /// Number of contributions collected so far.
83    pub fn contribution_count(&self) -> usize {
84        self.contributions.len()
85    }
86
87    /// Current round number.
88    pub fn round(&self) -> u64 {
89        self.round
90    }
91
92    /// Check if we have enough contributions to aggregate.
93    pub fn ready(&self) -> bool {
94        self.contributions.len() >= self.min_contributions
95    }
96
97    /// Detect and remove Byzantine outliers.
98    ///
99    /// Returns the number of outliers removed.
100    fn remove_byzantine_outliers(&mut self) -> u32 {
101        if self.contributions.len() < 3 {
102            return 0; // Need at least 3 for meaningful outlier detection
103        }
104
105        let dim = self.contributions[0].weights.len();
106        if dim == 0 || !self.contributions.iter().all(|c| c.weights.len() == dim) {
107            return 0;
108        }
109
110        // Compute mean and std of L2 norms
111        let norms: Vec<f64> = self.contributions.iter()
112            .map(|c| c.weights.iter().map(|w| w * w).sum::<f64>().sqrt())
113            .collect();
114
115        let mean_norm = norms.iter().sum::<f64>() / norms.len() as f64;
116        let variance = norms.iter().map(|n| (n - mean_norm).powi(2)).sum::<f64>() / norms.len() as f64;
117        let std_dev = variance.sqrt();
118
119        if std_dev < 1e-10 {
120            return 0;
121        }
122
123        let original_count = self.contributions.len();
124        let threshold = self.byzantine_std_threshold;
125
126        self.contributions.retain(|c| {
127            let norm = c.weights.iter().map(|w| w * w).sum::<f64>().sqrt();
128            ((norm - mean_norm) / std_dev).abs() <= threshold
129        });
130
131        (original_count - self.contributions.len()) as u32
132    }
133
134    /// Aggregate contributions and produce an `AggregateWeights` segment.
135    pub fn aggregate(&mut self) -> Result<AggregateWeights, FederationError> {
136        if self.contributions.len() < self.min_contributions {
137            return Err(FederationError::InsufficientContributions {
138                min: self.min_contributions,
139                got: self.contributions.len(),
140            });
141        }
142
143        // Byzantine outlier removal
144        let outliers_removed = self.remove_byzantine_outliers();
145
146        if self.contributions.is_empty() {
147            return Err(FederationError::InsufficientContributions {
148                min: self.min_contributions,
149                got: 0,
150            });
151        }
152
153        let dim = self.contributions[0].weights.len();
154
155        let result = match self.strategy {
156            AggregationStrategy::FedAvg => self.fedavg(dim),
157            AggregationStrategy::FedProx { mu } => self.fedprox(dim, mu as f64 / 100.0),
158            AggregationStrategy::WeightedAverage => self.weighted_avg(dim),
159        };
160
161        self.round += 1;
162        let participation_count = self.contributions.len() as u32;
163
164        // Compute loss stats
165        let losses: Vec<f64> = self.contributions.iter()
166            .map(|c| {
167                // Use inverse quality as a proxy for loss
168                1.0 - c.quality_weight.clamp(0.0, 1.0)
169            })
170            .collect();
171        let mean_loss = losses.iter().sum::<f64>() / losses.len() as f64;
172        let loss_variance = losses.iter().map(|l| (l - mean_loss).powi(2)).sum::<f64>() / losses.len() as f64;
173
174        self.contributions.clear();
175
176        Ok(AggregateWeights {
177            round: self.round,
178            participation_count,
179            lora_deltas: result.0,
180            confidences: result.1,
181            mean_loss,
182            loss_variance,
183            domain_id: self.domain_id.clone(),
184            byzantine_filtered: outliers_removed > 0,
185            outliers_removed,
186        })
187    }
188
189    /// FedAvg: weighted average by trajectory count.
190    fn fedavg(&self, dim: usize) -> (Vec<f64>, Vec<f64>) {
191        let total_trajectories: f64 = self.contributions.iter()
192            .map(|c| c.trajectory_count as f64)
193            .sum();
194
195        let mut avg = vec![0.0f64; dim];
196        let mut confidences = vec![0.0f64; dim];
197
198        if total_trajectories <= 0.0 {
199            return (avg, confidences);
200        }
201
202        for c in &self.contributions {
203            let w = c.trajectory_count as f64 / total_trajectories;
204            for (i, val) in c.weights.iter().enumerate() {
205                if i < dim {
206                    avg[i] += w * val;
207                }
208            }
209        }
210
211        // Confidence = inverse of variance across contributions per dimension
212        for i in 0..dim {
213            let mean = avg[i];
214            let var: f64 = self.contributions.iter()
215                .map(|c| {
216                    let v = if i < c.weights.len() { c.weights[i] } else { 0.0 };
217                    (v - mean).powi(2)
218                })
219                .sum::<f64>() / self.contributions.len() as f64;
220            confidences[i] = 1.0 / (1.0 + var);
221        }
222
223        (avg, confidences)
224    }
225
226    /// FedProx: weighted average with proximal term.
227    fn fedprox(&self, dim: usize, mu: f64) -> (Vec<f64>, Vec<f64>) {
228        let (mut avg, confidences) = self.fedavg(dim);
229        // Apply proximal regularization: pull toward zero (global model)
230        for val in &mut avg {
231            *val *= 1.0 / (1.0 + mu);
232        }
233        (avg, confidences)
234    }
235
236    /// Weighted average by quality_weight.
237    fn weighted_avg(&self, dim: usize) -> (Vec<f64>, Vec<f64>) {
238        let total_weight: f64 = self.contributions.iter().map(|c| c.quality_weight).sum();
239
240        let mut avg = vec![0.0f64; dim];
241        let mut confidences = vec![0.0f64; dim];
242
243        if total_weight <= 0.0 {
244            return (avg, confidences);
245        }
246
247        for c in &self.contributions {
248            let w = c.quality_weight / total_weight;
249            for (i, val) in c.weights.iter().enumerate() {
250                if i < dim {
251                    avg[i] += w * val;
252                }
253            }
254        }
255
256        for i in 0..dim {
257            let mean = avg[i];
258            let var: f64 = self.contributions.iter()
259                .map(|c| {
260                    let v = if i < c.weights.len() { c.weights[i] } else { 0.0 };
261                    (v - mean).powi(2)
262                })
263                .sum::<f64>() / self.contributions.len() as f64;
264            confidences[i] = 1.0 / (1.0 + var);
265        }
266
267        (avg, confidences)
268    }
269}
270
271#[cfg(test)]
272mod tests {
273    use super::*;
274
275    fn make_contribution(name: &str, weights: Vec<f64>, quality: f64, trajectories: u64) -> Contribution {
276        Contribution {
277            contributor: name.to_string(),
278            weights,
279            quality_weight: quality,
280            trajectory_count: trajectories,
281        }
282    }
283
284    #[test]
285    fn fedavg_two_equal_contributions() {
286        let mut agg = FederatedAggregator::new("test".into(), AggregationStrategy::FedAvg)
287            .with_min_contributions(2);
288
289        agg.add_contribution(make_contribution("a", vec![1.0, 2.0, 3.0], 1.0, 100));
290        agg.add_contribution(make_contribution("b", vec![3.0, 4.0, 5.0], 1.0, 100));
291
292        let result = agg.aggregate().unwrap();
293        assert_eq!(result.round, 1);
294        assert_eq!(result.participation_count, 2);
295        assert!((result.lora_deltas[0] - 2.0).abs() < 1e-10);
296        assert!((result.lora_deltas[1] - 3.0).abs() < 1e-10);
297        assert!((result.lora_deltas[2] - 4.0).abs() < 1e-10);
298    }
299
300    #[test]
301    fn fedavg_weighted_by_trajectories() {
302        let mut agg = FederatedAggregator::new("test".into(), AggregationStrategy::FedAvg)
303            .with_min_contributions(2);
304
305        // A has 3x more trajectories, so A's values should dominate
306        agg.add_contribution(make_contribution("a", vec![10.0], 1.0, 300));
307        agg.add_contribution(make_contribution("b", vec![0.0], 1.0, 100));
308
309        let result = agg.aggregate().unwrap();
310        // (300*10 + 100*0) / 400 = 7.5
311        assert!((result.lora_deltas[0] - 7.5).abs() < 1e-10);
312    }
313
314    #[test]
315    fn fedprox_shrinks_toward_zero() {
316        let mut agg_avg = FederatedAggregator::new("test".into(), AggregationStrategy::FedAvg)
317            .with_min_contributions(2);
318        agg_avg.add_contribution(make_contribution("a", vec![10.0], 1.0, 100));
319        agg_avg.add_contribution(make_contribution("b", vec![10.0], 1.0, 100));
320        let avg_result = agg_avg.aggregate().unwrap();
321
322        let mut agg_prox = FederatedAggregator::new("test".into(), AggregationStrategy::FedProx { mu: 50 })
323            .with_min_contributions(2);
324        agg_prox.add_contribution(make_contribution("a", vec![10.0], 1.0, 100));
325        agg_prox.add_contribution(make_contribution("b", vec![10.0], 1.0, 100));
326        let prox_result = agg_prox.aggregate().unwrap();
327
328        // FedProx should produce smaller values due to proximal regularization
329        assert!(prox_result.lora_deltas[0] < avg_result.lora_deltas[0]);
330    }
331
332    #[test]
333    fn byzantine_outlier_removal() {
334        let mut agg = FederatedAggregator::new("test".into(), AggregationStrategy::FedAvg)
335            .with_min_contributions(2)
336            .with_byzantine_threshold(2.0);
337
338        // Need enough good contributions so the outlier's z-score exceeds 2.0.
339        // With k good + 1 evil, the evil z-score grows with sqrt(k).
340        agg.add_contribution(make_contribution("good1", vec![1.0, 1.0], 1.0, 100));
341        agg.add_contribution(make_contribution("good2", vec![1.1, 0.9], 1.0, 100));
342        agg.add_contribution(make_contribution("good3", vec![0.9, 1.1], 1.0, 100));
343        agg.add_contribution(make_contribution("good4", vec![1.0, 1.0], 1.0, 100));
344        agg.add_contribution(make_contribution("good5", vec![1.0, 1.0], 1.0, 100));
345        agg.add_contribution(make_contribution("good6", vec![1.0, 1.0], 1.0, 100));
346        agg.add_contribution(make_contribution("evil", vec![100.0, 100.0], 1.0, 100)); // outlier
347
348        let result = agg.aggregate().unwrap();
349        assert!(result.byzantine_filtered);
350        assert!(result.outliers_removed >= 1);
351        // Result should be close to 1.0, not pulled toward 100
352        assert!(result.lora_deltas[0] < 5.0);
353    }
354
355    #[test]
356    fn insufficient_contributions_error() {
357        let mut agg = FederatedAggregator::new("test".into(), AggregationStrategy::FedAvg)
358            .with_min_contributions(3);
359
360        agg.add_contribution(make_contribution("a", vec![1.0], 1.0, 100));
361
362        let result = agg.aggregate();
363        assert!(result.is_err());
364    }
365
366    #[test]
367    fn weighted_average_strategy() {
368        let mut agg = FederatedAggregator::new("test".into(), AggregationStrategy::WeightedAverage)
369            .with_min_contributions(2);
370
371        agg.add_contribution(make_contribution("a", vec![10.0], 0.9, 10));
372        agg.add_contribution(make_contribution("b", vec![0.0], 0.1, 10));
373
374        let result = agg.aggregate().unwrap();
375        // (0.9*10 + 0.1*0) / 1.0 = 9.0
376        assert!((result.lora_deltas[0] - 9.0).abs() < 1e-10);
377    }
378
379    #[test]
380    fn round_increments() {
381        let mut agg = FederatedAggregator::new("test".into(), AggregationStrategy::FedAvg)
382            .with_min_contributions(2);
383
384        agg.add_contribution(make_contribution("a", vec![1.0], 1.0, 100));
385        agg.add_contribution(make_contribution("b", vec![2.0], 1.0, 100));
386        let r1 = agg.aggregate().unwrap();
387        assert_eq!(r1.round, 1);
388
389        agg.add_contribution(make_contribution("a", vec![3.0], 1.0, 100));
390        agg.add_contribution(make_contribution("b", vec![4.0], 1.0, 100));
391        let r2 = agg.aggregate().unwrap();
392        assert_eq!(r2.round, 2);
393    }
394
395    #[test]
396    fn confidences_high_when_agreement() {
397        let mut agg = FederatedAggregator::new("test".into(), AggregationStrategy::FedAvg)
398            .with_min_contributions(2);
399
400        agg.add_contribution(make_contribution("a", vec![1.0], 1.0, 100));
401        agg.add_contribution(make_contribution("b", vec![1.0], 1.0, 100));
402
403        let result = agg.aggregate().unwrap();
404        // When all agree, variance = 0, confidence = 1/(1+0) = 1.0
405        assert!((result.confidences[0] - 1.0).abs() < 1e-10);
406    }
407
408    #[test]
409    fn confidences_lower_when_disagreement() {
410        let mut agg = FederatedAggregator::new("test".into(), AggregationStrategy::FedAvg)
411            .with_min_contributions(2);
412
413        agg.add_contribution(make_contribution("a", vec![0.0], 1.0, 100));
414        agg.add_contribution(make_contribution("b", vec![10.0], 1.0, 100));
415
416        let result = agg.aggregate().unwrap();
417        // When disagreement, confidence < 1.0
418        assert!(result.confidences[0] < 1.0);
419    }
420}