1use crate::error::FederationError;
4use crate::types::AggregateWeights;
5
6#[derive(Clone, Copy, Debug, PartialEq, Eq)]
8pub enum AggregationStrategy {
9 FedAvg,
11 FedProx { mu: u32 },
13 WeightedAverage,
15}
16
17impl Default for AggregationStrategy {
18 fn default() -> Self {
19 Self::FedAvg
20 }
21}
22
23#[derive(Clone, Debug)]
25pub struct Contribution {
26 pub contributor: String,
28 pub weights: Vec<f64>,
30 pub quality_weight: f64,
32 pub trajectory_count: u64,
34}
35
36pub struct FederatedAggregator {
38 strategy: AggregationStrategy,
40 domain_id: String,
42 round: u64,
44 min_contributions: usize,
46 byzantine_std_threshold: f64,
48 contributions: Vec<Contribution>,
50}
51
52impl FederatedAggregator {
53 pub fn new(domain_id: String, strategy: AggregationStrategy) -> Self {
55 Self {
56 strategy,
57 domain_id,
58 round: 0,
59 min_contributions: 2,
60 byzantine_std_threshold: 2.0,
61 contributions: Vec::new(),
62 }
63 }
64
65 pub fn with_min_contributions(mut self, min: usize) -> Self {
67 self.min_contributions = min;
68 self
69 }
70
71 pub fn with_byzantine_threshold(mut self, threshold: f64) -> Self {
73 self.byzantine_std_threshold = threshold;
74 self
75 }
76
77 pub fn add_contribution(&mut self, contribution: Contribution) {
79 self.contributions.push(contribution);
80 }
81
82 pub fn contribution_count(&self) -> usize {
84 self.contributions.len()
85 }
86
87 pub fn round(&self) -> u64 {
89 self.round
90 }
91
92 pub fn ready(&self) -> bool {
94 self.contributions.len() >= self.min_contributions
95 }
96
97 fn remove_byzantine_outliers(&mut self) -> u32 {
101 if self.contributions.len() < 3 {
102 return 0; }
104
105 let dim = self.contributions[0].weights.len();
106 if dim == 0 || !self.contributions.iter().all(|c| c.weights.len() == dim) {
107 return 0;
108 }
109
110 let norms: Vec<f64> = self.contributions.iter()
112 .map(|c| c.weights.iter().map(|w| w * w).sum::<f64>().sqrt())
113 .collect();
114
115 let mean_norm = norms.iter().sum::<f64>() / norms.len() as f64;
116 let variance = norms.iter().map(|n| (n - mean_norm).powi(2)).sum::<f64>() / norms.len() as f64;
117 let std_dev = variance.sqrt();
118
119 if std_dev < 1e-10 {
120 return 0;
121 }
122
123 let original_count = self.contributions.len();
124 let threshold = self.byzantine_std_threshold;
125
126 self.contributions.retain(|c| {
127 let norm = c.weights.iter().map(|w| w * w).sum::<f64>().sqrt();
128 ((norm - mean_norm) / std_dev).abs() <= threshold
129 });
130
131 (original_count - self.contributions.len()) as u32
132 }
133
134 pub fn aggregate(&mut self) -> Result<AggregateWeights, FederationError> {
136 if self.contributions.len() < self.min_contributions {
137 return Err(FederationError::InsufficientContributions {
138 min: self.min_contributions,
139 got: self.contributions.len(),
140 });
141 }
142
143 let outliers_removed = self.remove_byzantine_outliers();
145
146 if self.contributions.is_empty() {
147 return Err(FederationError::InsufficientContributions {
148 min: self.min_contributions,
149 got: 0,
150 });
151 }
152
153 let dim = self.contributions[0].weights.len();
154
155 let result = match self.strategy {
156 AggregationStrategy::FedAvg => self.fedavg(dim),
157 AggregationStrategy::FedProx { mu } => self.fedprox(dim, mu as f64 / 100.0),
158 AggregationStrategy::WeightedAverage => self.weighted_avg(dim),
159 };
160
161 self.round += 1;
162 let participation_count = self.contributions.len() as u32;
163
164 let losses: Vec<f64> = self.contributions.iter()
166 .map(|c| {
167 1.0 - c.quality_weight.clamp(0.0, 1.0)
169 })
170 .collect();
171 let mean_loss = losses.iter().sum::<f64>() / losses.len() as f64;
172 let loss_variance = losses.iter().map(|l| (l - mean_loss).powi(2)).sum::<f64>() / losses.len() as f64;
173
174 self.contributions.clear();
175
176 Ok(AggregateWeights {
177 round: self.round,
178 participation_count,
179 lora_deltas: result.0,
180 confidences: result.1,
181 mean_loss,
182 loss_variance,
183 domain_id: self.domain_id.clone(),
184 byzantine_filtered: outliers_removed > 0,
185 outliers_removed,
186 })
187 }
188
189 fn fedavg(&self, dim: usize) -> (Vec<f64>, Vec<f64>) {
191 let total_trajectories: f64 = self.contributions.iter()
192 .map(|c| c.trajectory_count as f64)
193 .sum();
194
195 let mut avg = vec![0.0f64; dim];
196 let mut confidences = vec![0.0f64; dim];
197
198 if total_trajectories <= 0.0 {
199 return (avg, confidences);
200 }
201
202 for c in &self.contributions {
203 let w = c.trajectory_count as f64 / total_trajectories;
204 for (i, val) in c.weights.iter().enumerate() {
205 if i < dim {
206 avg[i] += w * val;
207 }
208 }
209 }
210
211 for i in 0..dim {
213 let mean = avg[i];
214 let var: f64 = self.contributions.iter()
215 .map(|c| {
216 let v = if i < c.weights.len() { c.weights[i] } else { 0.0 };
217 (v - mean).powi(2)
218 })
219 .sum::<f64>() / self.contributions.len() as f64;
220 confidences[i] = 1.0 / (1.0 + var);
221 }
222
223 (avg, confidences)
224 }
225
226 fn fedprox(&self, dim: usize, mu: f64) -> (Vec<f64>, Vec<f64>) {
228 let (mut avg, confidences) = self.fedavg(dim);
229 for val in &mut avg {
231 *val *= 1.0 / (1.0 + mu);
232 }
233 (avg, confidences)
234 }
235
236 fn weighted_avg(&self, dim: usize) -> (Vec<f64>, Vec<f64>) {
238 let total_weight: f64 = self.contributions.iter().map(|c| c.quality_weight).sum();
239
240 let mut avg = vec![0.0f64; dim];
241 let mut confidences = vec![0.0f64; dim];
242
243 if total_weight <= 0.0 {
244 return (avg, confidences);
245 }
246
247 for c in &self.contributions {
248 let w = c.quality_weight / total_weight;
249 for (i, val) in c.weights.iter().enumerate() {
250 if i < dim {
251 avg[i] += w * val;
252 }
253 }
254 }
255
256 for i in 0..dim {
257 let mean = avg[i];
258 let var: f64 = self.contributions.iter()
259 .map(|c| {
260 let v = if i < c.weights.len() { c.weights[i] } else { 0.0 };
261 (v - mean).powi(2)
262 })
263 .sum::<f64>() / self.contributions.len() as f64;
264 confidences[i] = 1.0 / (1.0 + var);
265 }
266
267 (avg, confidences)
268 }
269}
270
271#[cfg(test)]
272mod tests {
273 use super::*;
274
275 fn make_contribution(name: &str, weights: Vec<f64>, quality: f64, trajectories: u64) -> Contribution {
276 Contribution {
277 contributor: name.to_string(),
278 weights,
279 quality_weight: quality,
280 trajectory_count: trajectories,
281 }
282 }
283
284 #[test]
285 fn fedavg_two_equal_contributions() {
286 let mut agg = FederatedAggregator::new("test".into(), AggregationStrategy::FedAvg)
287 .with_min_contributions(2);
288
289 agg.add_contribution(make_contribution("a", vec![1.0, 2.0, 3.0], 1.0, 100));
290 agg.add_contribution(make_contribution("b", vec![3.0, 4.0, 5.0], 1.0, 100));
291
292 let result = agg.aggregate().unwrap();
293 assert_eq!(result.round, 1);
294 assert_eq!(result.participation_count, 2);
295 assert!((result.lora_deltas[0] - 2.0).abs() < 1e-10);
296 assert!((result.lora_deltas[1] - 3.0).abs() < 1e-10);
297 assert!((result.lora_deltas[2] - 4.0).abs() < 1e-10);
298 }
299
300 #[test]
301 fn fedavg_weighted_by_trajectories() {
302 let mut agg = FederatedAggregator::new("test".into(), AggregationStrategy::FedAvg)
303 .with_min_contributions(2);
304
305 agg.add_contribution(make_contribution("a", vec![10.0], 1.0, 300));
307 agg.add_contribution(make_contribution("b", vec![0.0], 1.0, 100));
308
309 let result = agg.aggregate().unwrap();
310 assert!((result.lora_deltas[0] - 7.5).abs() < 1e-10);
312 }
313
314 #[test]
315 fn fedprox_shrinks_toward_zero() {
316 let mut agg_avg = FederatedAggregator::new("test".into(), AggregationStrategy::FedAvg)
317 .with_min_contributions(2);
318 agg_avg.add_contribution(make_contribution("a", vec![10.0], 1.0, 100));
319 agg_avg.add_contribution(make_contribution("b", vec![10.0], 1.0, 100));
320 let avg_result = agg_avg.aggregate().unwrap();
321
322 let mut agg_prox = FederatedAggregator::new("test".into(), AggregationStrategy::FedProx { mu: 50 })
323 .with_min_contributions(2);
324 agg_prox.add_contribution(make_contribution("a", vec![10.0], 1.0, 100));
325 agg_prox.add_contribution(make_contribution("b", vec![10.0], 1.0, 100));
326 let prox_result = agg_prox.aggregate().unwrap();
327
328 assert!(prox_result.lora_deltas[0] < avg_result.lora_deltas[0]);
330 }
331
332 #[test]
333 fn byzantine_outlier_removal() {
334 let mut agg = FederatedAggregator::new("test".into(), AggregationStrategy::FedAvg)
335 .with_min_contributions(2)
336 .with_byzantine_threshold(2.0);
337
338 agg.add_contribution(make_contribution("good1", vec![1.0, 1.0], 1.0, 100));
341 agg.add_contribution(make_contribution("good2", vec![1.1, 0.9], 1.0, 100));
342 agg.add_contribution(make_contribution("good3", vec![0.9, 1.1], 1.0, 100));
343 agg.add_contribution(make_contribution("good4", vec![1.0, 1.0], 1.0, 100));
344 agg.add_contribution(make_contribution("good5", vec![1.0, 1.0], 1.0, 100));
345 agg.add_contribution(make_contribution("good6", vec![1.0, 1.0], 1.0, 100));
346 agg.add_contribution(make_contribution("evil", vec![100.0, 100.0], 1.0, 100)); let result = agg.aggregate().unwrap();
349 assert!(result.byzantine_filtered);
350 assert!(result.outliers_removed >= 1);
351 assert!(result.lora_deltas[0] < 5.0);
353 }
354
355 #[test]
356 fn insufficient_contributions_error() {
357 let mut agg = FederatedAggregator::new("test".into(), AggregationStrategy::FedAvg)
358 .with_min_contributions(3);
359
360 agg.add_contribution(make_contribution("a", vec![1.0], 1.0, 100));
361
362 let result = agg.aggregate();
363 assert!(result.is_err());
364 }
365
366 #[test]
367 fn weighted_average_strategy() {
368 let mut agg = FederatedAggregator::new("test".into(), AggregationStrategy::WeightedAverage)
369 .with_min_contributions(2);
370
371 agg.add_contribution(make_contribution("a", vec![10.0], 0.9, 10));
372 agg.add_contribution(make_contribution("b", vec![0.0], 0.1, 10));
373
374 let result = agg.aggregate().unwrap();
375 assert!((result.lora_deltas[0] - 9.0).abs() < 1e-10);
377 }
378
379 #[test]
380 fn round_increments() {
381 let mut agg = FederatedAggregator::new("test".into(), AggregationStrategy::FedAvg)
382 .with_min_contributions(2);
383
384 agg.add_contribution(make_contribution("a", vec![1.0], 1.0, 100));
385 agg.add_contribution(make_contribution("b", vec![2.0], 1.0, 100));
386 let r1 = agg.aggregate().unwrap();
387 assert_eq!(r1.round, 1);
388
389 agg.add_contribution(make_contribution("a", vec![3.0], 1.0, 100));
390 agg.add_contribution(make_contribution("b", vec![4.0], 1.0, 100));
391 let r2 = agg.aggregate().unwrap();
392 assert_eq!(r2.round, 2);
393 }
394
395 #[test]
396 fn confidences_high_when_agreement() {
397 let mut agg = FederatedAggregator::new("test".into(), AggregationStrategy::FedAvg)
398 .with_min_contributions(2);
399
400 agg.add_contribution(make_contribution("a", vec![1.0], 1.0, 100));
401 agg.add_contribution(make_contribution("b", vec![1.0], 1.0, 100));
402
403 let result = agg.aggregate().unwrap();
404 assert!((result.confidences[0] - 1.0).abs() < 1e-10);
406 }
407
408 #[test]
409 fn confidences_lower_when_disagreement() {
410 let mut agg = FederatedAggregator::new("test".into(), AggregationStrategy::FedAvg)
411 .with_min_contributions(2);
412
413 agg.add_contribution(make_contribution("a", vec![0.0], 1.0, 100));
414 agg.add_contribution(make_contribution("b", vec![10.0], 1.0, 100));
415
416 let result = agg.aggregate().unwrap();
417 assert!(result.confidences[0] < 1.0);
419 }
420}