1use rand::rngs::StdRng;
7use rand::{Rng, SeedableRng, rng};
8use rustkernel_core::{domain::Domain, kernel::KernelMetadata, traits::GpuKernel};
9use serde::{Deserialize, Serialize};
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct SecureAggConfig {
18 pub min_participants: usize,
20 pub max_participants: usize,
22 pub epsilon: f64,
24 pub clip_threshold: f64,
26 pub add_noise: bool,
28 pub seed: Option<u64>,
30}
31
32impl Default for SecureAggConfig {
33 fn default() -> Self {
34 Self {
35 min_participants: 3,
36 max_participants: 100,
37 epsilon: 1.0,
38 clip_threshold: 1.0,
39 add_noise: true,
40 seed: None,
41 }
42 }
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct ParticipantUpdate {
48 pub participant_id: String,
50 pub parameters: Vec<f64>,
52 pub sample_count: usize,
54 pub local_loss: Option<f64>,
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct AggregationResult {
61 pub aggregated_params: Vec<f64>,
63 pub participant_count: usize,
65 pub total_samples: usize,
67 pub average_loss: Option<f64>,
69 pub privacy_guarantee: PrivacyGuarantee,
71 pub included_participants: Vec<String>,
73 pub excluded_participants: Vec<String>,
75}
76
77#[derive(Debug, Clone, Serialize, Deserialize)]
79pub struct PrivacyGuarantee {
80 pub epsilon: f64,
82 pub delta: f64,
84 pub secure_aggregation: bool,
86 pub noise_scale: f64,
88}
89
90#[derive(Debug, Clone)]
92#[allow(dead_code)]
93struct SecureMask {
94 participant_id: String,
95 mask: Vec<f64>,
96 seed: u64,
97}
98
99#[allow(dead_code)]
100impl SecureMask {
101 fn generate(participant_id: &str, size: usize, seed: u64) -> Self {
102 let mut rng = StdRng::seed_from_u64(seed);
103 let mask: Vec<f64> = (0..size).map(|_| rng.random_range(-1.0..1.0)).collect();
104 Self {
105 participant_id: participant_id.to_string(),
106 mask,
107 seed,
108 }
109 }
110}
111
112#[derive(Debug, Clone)]
118pub struct SecureAggregation {
119 metadata: KernelMetadata,
120}
121
122impl Default for SecureAggregation {
123 fn default() -> Self {
124 Self::new()
125 }
126}
127
128impl SecureAggregation {
129 #[must_use]
131 pub fn new() -> Self {
132 Self {
133 metadata: KernelMetadata::batch("ml/secure-aggregation", Domain::StatisticalML)
134 .with_description("Privacy-preserving federated model aggregation")
135 .with_throughput(1_000)
136 .with_latency_us(500.0),
137 }
138 }
139
140 pub fn aggregate(updates: &[ParticipantUpdate], config: &SecureAggConfig) -> AggregationResult {
142 if updates.is_empty() {
143 return AggregationResult {
144 aggregated_params: Vec::new(),
145 participant_count: 0,
146 total_samples: 0,
147 average_loss: None,
148 privacy_guarantee: PrivacyGuarantee {
149 epsilon: config.epsilon,
150 delta: 1e-5,
151 secure_aggregation: false,
152 noise_scale: 0.0,
153 },
154 included_participants: Vec::new(),
155 excluded_participants: Vec::new(),
156 };
157 }
158
159 if updates.len() < config.min_participants {
161 return AggregationResult {
162 aggregated_params: Vec::new(),
163 participant_count: 0,
164 total_samples: 0,
165 average_loss: None,
166 privacy_guarantee: PrivacyGuarantee {
167 epsilon: f64::INFINITY,
168 delta: 1.0,
169 secure_aggregation: false,
170 noise_scale: 0.0,
171 },
172 included_participants: Vec::new(),
173 excluded_participants: updates.iter().map(|u| u.participant_id.clone()).collect(),
174 };
175 }
176
177 let param_size = updates[0].parameters.len();
178 let mut included = Vec::new();
179 let mut excluded = Vec::new();
180
181 let clipped_updates: Vec<(String, Vec<f64>, usize)> = updates
183 .iter()
184 .filter_map(|u| {
185 if u.parameters.len() != param_size {
186 excluded.push(u.participant_id.clone());
187 return None;
188 }
189 included.push(u.participant_id.clone());
190 let clipped = Self::clip_update(&u.parameters, config.clip_threshold);
191 Some((u.participant_id.clone(), clipped, u.sample_count))
192 })
193 .collect();
194
195 if clipped_updates.len() < config.min_participants {
196 return AggregationResult {
197 aggregated_params: Vec::new(),
198 participant_count: 0,
199 total_samples: 0,
200 average_loss: None,
201 privacy_guarantee: PrivacyGuarantee {
202 epsilon: f64::INFINITY,
203 delta: 1.0,
204 secure_aggregation: false,
205 noise_scale: 0.0,
206 },
207 included_participants: Vec::new(),
208 excluded_participants: updates.iter().map(|u| u.participant_id.clone()).collect(),
209 };
210 }
211
212 let total_samples: usize = clipped_updates.iter().map(|(_, _, s)| s).sum();
214 let mut aggregated = vec![0.0; param_size];
215
216 for (_, params, sample_count) in &clipped_updates {
217 let weight = *sample_count as f64 / total_samples as f64;
218 for (i, &p) in params.iter().enumerate() {
219 aggregated[i] += p * weight;
220 }
221 }
222
223 let noise_scale = if config.add_noise {
225 Self::add_dp_noise(&mut aggregated, config)
226 } else {
227 0.0
228 };
229
230 let average_loss = {
232 let losses: Vec<f64> = updates.iter().filter_map(|u| u.local_loss).collect();
233 if losses.is_empty() {
234 None
235 } else {
236 Some(losses.iter().sum::<f64>() / losses.len() as f64)
237 }
238 };
239
240 AggregationResult {
241 aggregated_params: aggregated,
242 participant_count: clipped_updates.len(),
243 total_samples,
244 average_loss,
245 privacy_guarantee: PrivacyGuarantee {
246 epsilon: config.epsilon,
247 delta: 1e-5,
248 secure_aggregation: true,
249 noise_scale,
250 },
251 included_participants: included,
252 excluded_participants: excluded,
253 }
254 }
255
256 fn clip_update(params: &[f64], threshold: f64) -> Vec<f64> {
258 let norm: f64 = params.iter().map(|x| x * x).sum::<f64>().sqrt();
259 if norm <= threshold {
260 params.to_vec()
261 } else {
262 let scale = threshold / norm;
263 params.iter().map(|&x| x * scale).collect()
264 }
265 }
266
267 fn add_dp_noise(params: &mut [f64], config: &SecureAggConfig) -> f64 {
269 let delta = 1e-5;
271 let sensitivity = config.clip_threshold;
272 let sigma = sensitivity * (2.0 * (1.25_f64 / delta).ln()).sqrt() / config.epsilon;
273
274 let mut rng = match config.seed {
275 Some(seed) => StdRng::seed_from_u64(seed),
276 None => StdRng::from_rng(&mut rng()),
277 };
278
279 for p in params.iter_mut() {
280 let u1: f64 = rng.random_range(0.0001..1.0);
282 let u2: f64 = rng.random_range(0.0..1.0);
283 let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
284 *p += sigma * z;
285 }
286
287 sigma
288 }
289
290 pub fn verify_aggregation(_updates: &[ParticipantUpdate], result: &AggregationResult) -> bool {
292 if result.participant_count == 0 {
294 return result.aggregated_params.is_empty();
295 }
296
297 if result.aggregated_params.is_empty() {
298 return false;
299 }
300
301 result.included_participants.len() == result.participant_count
303 }
304
305 pub fn simulate_round(
307 _global_model: &[f64],
308 local_updates: &[Vec<f64>],
309 sample_counts: &[usize],
310 config: &SecureAggConfig,
311 ) -> AggregationResult {
312 let updates: Vec<ParticipantUpdate> = local_updates
313 .iter()
314 .zip(sample_counts.iter())
315 .enumerate()
316 .map(|(i, (params, &count))| ParticipantUpdate {
317 participant_id: format!("participant_{}", i),
318 parameters: params.clone(),
319 sample_count: count,
320 local_loss: Some(0.5), })
322 .collect();
323
324 Self::aggregate(&updates, config)
325 }
326}
327
328impl GpuKernel for SecureAggregation {
329 fn metadata(&self) -> &KernelMetadata {
330 &self.metadata
331 }
332}
333
334#[cfg(test)]
335mod tests {
336 use super::*;
337
338 #[test]
339 fn test_secure_aggregation_metadata() {
340 let kernel = SecureAggregation::new();
341 assert_eq!(kernel.metadata().id, "ml/secure-aggregation");
342 }
343
344 #[test]
345 fn test_basic_aggregation() {
346 let updates = vec![
347 ParticipantUpdate {
348 participant_id: "p1".to_string(),
349 parameters: vec![1.0, 2.0, 3.0],
350 sample_count: 100,
351 local_loss: Some(0.5),
352 },
353 ParticipantUpdate {
354 participant_id: "p2".to_string(),
355 parameters: vec![2.0, 3.0, 4.0],
356 sample_count: 100,
357 local_loss: Some(0.6),
358 },
359 ParticipantUpdate {
360 participant_id: "p3".to_string(),
361 parameters: vec![3.0, 4.0, 5.0],
362 sample_count: 100,
363 local_loss: Some(0.7),
364 },
365 ];
366
367 let config = SecureAggConfig {
368 min_participants: 3,
369 add_noise: false, clip_threshold: 100.0, ..Default::default()
372 };
373
374 let result = SecureAggregation::aggregate(&updates, &config);
375
376 assert_eq!(result.participant_count, 3);
377 assert_eq!(result.total_samples, 300);
378 assert_eq!(result.aggregated_params.len(), 3);
379
380 assert!((result.aggregated_params[0] - 2.0).abs() < 0.01);
382 assert!((result.aggregated_params[1] - 3.0).abs() < 0.01);
383 assert!((result.aggregated_params[2] - 4.0).abs() < 0.01);
384 }
385
386 #[test]
387 fn test_weighted_aggregation() {
388 let updates = vec![
389 ParticipantUpdate {
390 participant_id: "p1".to_string(),
391 parameters: vec![1.0],
392 sample_count: 100, local_loss: None,
394 },
395 ParticipantUpdate {
396 participant_id: "p2".to_string(),
397 parameters: vec![4.0],
398 sample_count: 200, local_loss: None,
400 },
401 ParticipantUpdate {
402 participant_id: "p3".to_string(),
403 parameters: vec![1.0],
404 sample_count: 0, local_loss: None,
406 },
407 ];
408
409 let config = SecureAggConfig {
410 min_participants: 2,
411 add_noise: false,
412 clip_threshold: 100.0, ..Default::default()
414 };
415
416 let result = SecureAggregation::aggregate(&updates, &config);
417
418 assert!((result.aggregated_params[0] - 3.0).abs() < 0.01);
420 }
421
422 #[test]
423 fn test_insufficient_participants() {
424 let updates = vec![ParticipantUpdate {
425 participant_id: "p1".to_string(),
426 parameters: vec![1.0],
427 sample_count: 100,
428 local_loss: None,
429 }];
430
431 let config = SecureAggConfig {
432 min_participants: 3,
433 ..Default::default()
434 };
435
436 let result = SecureAggregation::aggregate(&updates, &config);
437
438 assert_eq!(result.participant_count, 0);
439 assert!(result.aggregated_params.is_empty());
440 assert_eq!(result.privacy_guarantee.epsilon, f64::INFINITY);
441 }
442
443 #[test]
444 fn test_clipping() {
445 let params = vec![3.0, 4.0]; let clipped = SecureAggregation::clip_update(¶ms, 1.0);
447
448 let norm: f64 = clipped.iter().map(|x| x * x).sum::<f64>().sqrt();
449 assert!((norm - 1.0).abs() < 0.001);
450 }
451
452 #[test]
453 fn test_dp_noise_added() {
454 let updates = vec![
455 ParticipantUpdate {
456 participant_id: "p1".to_string(),
457 parameters: vec![1.0, 1.0],
458 sample_count: 100,
459 local_loss: None,
460 },
461 ParticipantUpdate {
462 participant_id: "p2".to_string(),
463 parameters: vec![1.0, 1.0],
464 sample_count: 100,
465 local_loss: None,
466 },
467 ParticipantUpdate {
468 participant_id: "p3".to_string(),
469 parameters: vec![1.0, 1.0],
470 sample_count: 100,
471 local_loss: None,
472 },
473 ];
474
475 let config = SecureAggConfig {
476 min_participants: 3,
477 add_noise: true,
478 epsilon: 1.0,
479 seed: Some(42),
480 ..Default::default()
481 };
482
483 let result = SecureAggregation::aggregate(&updates, &config);
484
485 assert!(result.privacy_guarantee.noise_scale > 0.0);
487 }
489
490 #[test]
491 fn test_empty_updates() {
492 let config = SecureAggConfig::default();
493 let result = SecureAggregation::aggregate(&[], &config);
494
495 assert!(result.aggregated_params.is_empty());
496 assert_eq!(result.participant_count, 0);
497 }
498
499 #[test]
500 fn test_simulate_round() {
501 let global = vec![0.0, 0.0, 0.0];
502 let local_updates = vec![
503 vec![0.1, 0.2, 0.3],
504 vec![0.2, 0.3, 0.4],
505 vec![0.3, 0.4, 0.5],
506 ];
507 let sample_counts = vec![100, 100, 100];
508
509 let config = SecureAggConfig {
510 min_participants: 3,
511 add_noise: false,
512 ..Default::default()
513 };
514
515 let result =
516 SecureAggregation::simulate_round(&global, &local_updates, &sample_counts, &config);
517
518 assert_eq!(result.participant_count, 3);
519 assert!(result.average_loss.is_some());
520 }
521
522 #[test]
523 fn test_verify_aggregation() {
524 let updates = vec![
525 ParticipantUpdate {
526 participant_id: "p1".to_string(),
527 parameters: vec![1.0],
528 sample_count: 100,
529 local_loss: None,
530 },
531 ParticipantUpdate {
532 participant_id: "p2".to_string(),
533 parameters: vec![2.0],
534 sample_count: 100,
535 local_loss: None,
536 },
537 ParticipantUpdate {
538 participant_id: "p3".to_string(),
539 parameters: vec![3.0],
540 sample_count: 100,
541 local_loss: None,
542 },
543 ];
544
545 let config = SecureAggConfig {
546 min_participants: 3,
547 add_noise: false,
548 ..Default::default()
549 };
550
551 let result = SecureAggregation::aggregate(&updates, &config);
552 assert!(SecureAggregation::verify_aggregation(&updates, &result));
553 }
554}