smp_tee_runtime/aggregation/
multi_krum.rs1fn squared_l2_distance(left: &[f32], right: &[f32]) -> Option<f32> {
2 if left.len() != right.len() {
3 return None;
4 }
5
6 Some(
7 left.iter()
8 .zip(right.iter())
9 .map(|(l, r)| {
10 let delta = l - r;
11 delta * delta
12 })
13 .sum(),
14 )
15}
16
17pub fn multi_krum(vectors: &[Vec<f32>], byzantine_tolerance: usize) -> Option<Vec<f32>> {
18 let n = vectors.len();
19 if n < 2 * byzantine_tolerance + 3 {
20 return None;
21 }
22 let dimension = vectors.first()?.len();
23 if vectors.iter().any(|vector| vector.len() != dimension) {
24 return None;
25 }
26
27 let neighbors = n.checked_sub(byzantine_tolerance + 2)?;
28 let mut best: Option<(usize, f32)> = None;
29
30 for (i, candidate) in vectors.iter().enumerate() {
31 let mut distances = vectors
32 .iter()
33 .enumerate()
34 .filter_map(|(j, other)| {
35 if i == j {
36 return None;
37 }
38 squared_l2_distance(candidate, other)
39 })
40 .collect::<Vec<_>>();
41
42 distances.sort_by(|a, b| a.total_cmp(b));
43 let score: f32 = distances.iter().take(neighbors).sum();
44
45 match best {
46 Some((_, best_score)) if best_score <= score => {}
47 _ => best = Some((i, score)),
48 }
49 }
50
51 best.map(|(idx, _)| vectors[idx].clone())
52}
53
54#[cfg(test)]
55mod tests {
56 use super::*;
57
58 #[test]
59 fn multi_krum_chooses_honest_update() {
60 let selected = multi_krum(
61 &[
62 vec![1.0, 1.0],
63 vec![1.1, 1.0],
64 vec![0.9, 1.1],
65 vec![1.0, 0.95],
66 vec![50.0, -50.0],
67 ],
68 1,
69 )
70 .unwrap();
71
72 assert!(selected[0] < 2.0);
73 assert!(selected[1] < 2.0);
74 }
75}