Skip to main content

smp_tee_runtime/aggregation/
multi_krum.rs

1fn squared_l2_distance(left: &[f32], right: &[f32]) -> Option<f32> {
2    if left.len() != right.len() {
3        return None;
4    }
5
6    Some(
7        left.iter()
8            .zip(right.iter())
9            .map(|(l, r)| {
10                let delta = l - r;
11                delta * delta
12            })
13            .sum(),
14    )
15}
16
17pub fn multi_krum(vectors: &[Vec<f32>], byzantine_tolerance: usize) -> Option<Vec<f32>> {
18    let n = vectors.len();
19    if n < 2 * byzantine_tolerance + 3 {
20        return None;
21    }
22    let dimension = vectors.first()?.len();
23    if vectors.iter().any(|vector| vector.len() != dimension) {
24        return None;
25    }
26
27    let neighbors = n.checked_sub(byzantine_tolerance + 2)?;
28    let mut best: Option<(usize, f32)> = None;
29
30    for (i, candidate) in vectors.iter().enumerate() {
31        let mut distances = vectors
32            .iter()
33            .enumerate()
34            .filter_map(|(j, other)| {
35                if i == j {
36                    return None;
37                }
38                squared_l2_distance(candidate, other)
39            })
40            .collect::<Vec<_>>();
41
42        distances.sort_by(|a, b| a.total_cmp(b));
43        let score: f32 = distances.iter().take(neighbors).sum();
44
45        match best {
46            Some((_, best_score)) if best_score <= score => {}
47            _ => best = Some((i, score)),
48        }
49    }
50
51    best.map(|(idx, _)| vectors[idx].clone())
52}
53
54#[cfg(test)]
55mod tests {
56    use super::*;
57
58    #[test]
59    fn multi_krum_chooses_honest_update() {
60        let selected = multi_krum(
61            &[
62                vec![1.0, 1.0],
63                vec![1.1, 1.0],
64                vec![0.9, 1.1],
65                vec![1.0, 0.95],
66                vec![50.0, -50.0],
67            ],
68            1,
69        )
70        .unwrap();
71
72        assert!(selected[0] < 2.0);
73        assert!(selected[1] < 2.0);
74    }
75}