1use crate::{Dataset, Result};
8use scirs2_core::random::rand_prelude::*;
9use scirs2_core::random::rngs::StdRng;
10use scirs2_core::random::SeedableRng;
11use scirs2_core::random::{rng, Random};
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14use std::sync::{Arc, Mutex};
15use tenflowers_core::{Tensor, TensorError};
16
17pub type ClientId = String;
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct ClientConfig {
23 pub client_id: ClientId,
25 pub distribution_type: DataDistribution,
27 pub privacy_config: PrivacyConfig,
29 pub metadata: HashMap<String, String>,
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
35pub enum DataDistribution {
36 Iid,
38 NonIidClassImbalance { class_weights: Vec<f64> },
40 NonIidFeatureShift { shift_factor: f64 },
42 NonIidMixed {
44 class_weights: Vec<f64>,
45 shift_factor: f64,
46 },
47 Custom {
49 strategy_name: String,
50 parameters: HashMap<String, f64>,
51 },
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct PrivacyConfig {
57 pub enable_dp: bool,
59 pub epsilon: f64,
61 pub delta: f64,
63 pub noise_mechanism: NoiseMechanism,
65 pub privacy_budget: f64,
67}
68
69#[derive(Debug, Clone, Serialize, Deserialize)]
71pub enum NoiseMechanism {
72 Laplace { sensitivity: f64 },
74 Gaussian { sensitivity: f64 },
76 Exponential { sensitivity: f64 },
78}
79
80impl Default for PrivacyConfig {
81 fn default() -> Self {
82 Self {
83 enable_dp: false,
84 epsilon: 1.0,
85 delta: 1e-5,
86 noise_mechanism: NoiseMechanism::Gaussian { sensitivity: 1.0 },
87 privacy_budget: 10.0,
88 }
89 }
90}
91
92#[derive(Debug)]
94pub struct FederatedClientDataset<T, D> {
95 config: ClientConfig,
97 dataset: D,
99 privacy_manager: Arc<Mutex<PrivacyManager>>,
101 stats: ClientStats,
103 _phantom: std::marker::PhantomData<T>,
104}
105
106#[derive(Debug)]
108pub struct PrivacyManager {
109 remaining_budget: f64,
111 rng: StdRng,
113 noise_scale_cache: HashMap<String, f64>,
115}
116
117#[derive(Debug, Clone, Serialize, Deserialize)]
119pub struct ClientStats {
120 pub sample_count: usize,
122 pub class_distribution: HashMap<String, usize>,
124 pub feature_stats: FederatedFeatureStats,
126 pub quality_metrics: QualityMetrics,
128}
129
130#[derive(Debug, Clone, Serialize, Deserialize)]
132pub struct FederatedFeatureStats {
133 pub means: Vec<f64>,
135 pub stds: Vec<f64>,
137 pub ranges: Vec<(f64, f64)>,
139}
140
141#[derive(Debug, Clone, Serialize, Deserialize)]
143pub struct QualityMetrics {
144 pub missing_percentage: f64,
146 pub outlier_percentage: f64,
148 pub consistency_score: f64,
150}
151
152impl PrivacyManager {
153 pub fn new(config: &PrivacyConfig, seed: u64) -> Self {
155 Self {
156 remaining_budget: config.privacy_budget,
157 rng: StdRng::seed_from_u64(seed),
158 noise_scale_cache: HashMap::new(),
159 }
160 }
161
162 pub fn add_noise(
164 &mut self,
165 value: f64,
166 config: &PrivacyConfig,
167 query_sensitivity: f64,
168 ) -> Result<f64> {
169 if !config.enable_dp {
170 return Ok(value);
171 }
172
173 if self.remaining_budget <= 0.0 {
174 return Err(TensorError::invalid_argument(
175 "Privacy budget exhausted".to_string(),
176 ));
177 }
178
179 let noise_scale = self.calculate_noise_scale(config, query_sensitivity);
180 let noise = match &config.noise_mechanism {
181 NoiseMechanism::Laplace { .. } => {
182 let scale = noise_scale;
184 self.sample_laplace(scale)
185 }
186 NoiseMechanism::Gaussian { .. } => {
187 let sigma = noise_scale;
189 self.sample_gaussian(sigma)
190 }
191 NoiseMechanism::Exponential { .. } => {
192 let scale = noise_scale;
194 self.sample_laplace(scale)
195 }
196 };
197
198 self.remaining_budget -= config.epsilon;
200
201 Ok(value + noise)
202 }
203
204 pub fn add_noise_tensor<T>(
206 &mut self,
207 tensor: &Tensor<T>,
208 config: &PrivacyConfig,
209 sensitivity: f64,
210 ) -> Result<Tensor<T>>
211 where
212 T: Clone + Default + Send + Sync + 'static,
213 T: From<f64> + Into<f64>,
214 {
215 if !config.enable_dp {
216 return Ok(tensor.clone());
217 }
218
219 let shape = tensor.shape().dims().to_vec();
220 let mut noisy_data = Vec::new();
221
222 if let Some(slice) = tensor.as_slice() {
223 for value in slice {
224 let original_value: f64 = value.clone().into();
225 let noisy_value = self.add_noise(original_value, config, sensitivity)?;
226 noisy_data.push(T::from(noisy_value));
227 }
228 } else {
229 let value: f64 = tensor.get(&[]).unwrap_or_default().into();
231 let noisy_value = self.add_noise(value, config, sensitivity)?;
232 noisy_data.push(T::from(noisy_value));
233 }
234
235 Tensor::from_vec(noisy_data, &shape)
236 }
237
238 fn calculate_noise_scale(&mut self, config: &PrivacyConfig, sensitivity: f64) -> f64 {
239 let cache_key = format!("{}_{}", config.epsilon, sensitivity);
240
241 if let Some(&cached_scale) = self.noise_scale_cache.get(&cache_key) {
242 return cached_scale;
243 }
244
245 let scale = match &config.noise_mechanism {
246 NoiseMechanism::Laplace { .. } => sensitivity / config.epsilon,
247 NoiseMechanism::Gaussian { .. } => {
248 let factor = (2.0 * (1.25 / config.delta).ln()).sqrt();
249 factor * sensitivity / config.epsilon
250 }
251 NoiseMechanism::Exponential { .. } => sensitivity / config.epsilon,
252 };
253
254 self.noise_scale_cache.insert(cache_key, scale);
255 scale
256 }
257
258 fn sample_laplace(&mut self, scale: f64) -> f64 {
259 let u1: f64 = self.rng.random();
261 let u2: f64 = self.rng.random();
262
263 let sign = if u1 < 0.5 { -1.0 } else { 1.0 };
264 sign * scale * (1.0_f64 - 2.0_f64 * u2.abs()).max(1e-10_f64).ln()
265 }
266
267 fn sample_gaussian(&mut self, sigma: f64) -> f64 {
268 let u1: f64 = self.rng.random();
270 let u2: f64 = self.rng.random();
271
272 let z0 = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
273 z0 * sigma
274 }
275
276 pub fn can_spend_budget(&self, epsilon: f64) -> bool {
278 self.remaining_budget >= epsilon
279 }
280
281 pub fn remaining_budget(&self) -> f64 {
283 self.remaining_budget
284 }
285}
286
287impl<T, D> FederatedClientDataset<T, D>
288where
289 D: Dataset<T>,
290 T: Clone + Default + Send + Sync + 'static,
291{
292 pub fn new(dataset: D, config: ClientConfig) -> Self {
294 let stats = Self::compute_basic_stats(&dataset);
295 let privacy_manager = Arc::new(Mutex::new(PrivacyManager::new(&config.privacy_config, 42)));
296
297 Self {
298 config,
299 dataset,
300 privacy_manager,
301 stats,
302 _phantom: std::marker::PhantomData,
303 }
304 }
305
306 pub fn config(&self) -> &ClientConfig {
308 &self.config
309 }
310
311 pub fn stats(&self) -> &ClientStats {
313 &self.stats
314 }
315
316 pub fn get_private(&self, index: usize) -> Result<(Tensor<T>, Tensor<T>)>
318 where
319 T: From<f64> + Into<f64>,
320 {
321 let (features, labels) = self.dataset.get(index)?;
322
323 if !self.config.privacy_config.enable_dp {
324 return Ok((features, labels));
325 }
326
327 let mut privacy_manager = self
328 .privacy_manager
329 .lock()
330 .expect("lock should not be poisoned");
331 let noisy_features =
332 privacy_manager.add_noise_tensor(&features, &self.config.privacy_config, 1.0)?;
333
334 Ok((noisy_features, labels))
335 }
336
337 pub fn compute_private_statistics(&self) -> Result<PrivateStats>
339 where
340 T: From<f64> + Into<f64>,
341 {
342 let mut feature_sums = Vec::new();
343 let mut feature_counts = Vec::new();
344 let sample_count = self.dataset.len();
345
346 if sample_count == 0 {
347 return Ok(PrivateStats {
348 sample_count: 0,
349 feature_means: Vec::new(),
350 class_counts: HashMap::new(),
351 });
352 }
353
354 let (first_features, _) = self.dataset.get(0)?;
356 let feature_dim = if let Some(slice) = first_features.as_slice() {
357 slice.len()
358 } else {
359 1
360 };
361
362 feature_sums.resize(feature_dim, 0.0);
363 feature_counts.resize(feature_dim, 0);
364
365 for i in 0..sample_count {
367 let (features, _) = self.dataset.get(i)?;
368
369 if let Some(slice) = features.as_slice() {
370 for (j, value) in slice.iter().enumerate() {
371 feature_sums[j] += value.clone().into();
372 feature_counts[j] += 1;
373 }
374 } else {
375 let value: f64 = features.get(&[]).unwrap_or(T::default()).into();
376 feature_sums[0] += value;
377 feature_counts[0] += 1;
378 }
379 }
380
381 let mut private_means = Vec::new();
383 let mut privacy_manager = self
384 .privacy_manager
385 .lock()
386 .expect("lock should not be poisoned");
387
388 for i in 0..feature_dim {
389 let mean = if feature_counts[i] > 0 {
390 feature_sums[i] / feature_counts[i] as f64
391 } else {
392 0.0
393 };
394
395 let private_mean = privacy_manager.add_noise(mean, &self.config.privacy_config, 1.0)?;
396 private_means.push(private_mean);
397 }
398
399 let private_sample_count =
401 privacy_manager.add_noise(sample_count as f64, &self.config.privacy_config, 1.0)?
402 as usize;
403
404 Ok(PrivateStats {
405 sample_count: private_sample_count,
406 feature_means: private_means,
407 class_counts: HashMap::new(), })
409 }
410
411 fn compute_basic_stats(dataset: &D) -> ClientStats {
412 let sample_count = dataset.len();
413
414 ClientStats {
415 sample_count,
416 class_distribution: HashMap::new(), feature_stats: FederatedFeatureStats {
418 means: Vec::new(),
419 stds: Vec::new(),
420 ranges: Vec::new(),
421 },
422 quality_metrics: QualityMetrics {
423 missing_percentage: 0.0,
424 outlier_percentage: 0.0,
425 consistency_score: 1.0,
426 },
427 }
428 }
429}
430
431impl<T, D> Dataset<T> for FederatedClientDataset<T, D>
432where
433 D: Dataset<T>,
434 T: Clone + Default + Send + Sync + 'static,
435{
436 fn len(&self) -> usize {
437 self.dataset.len()
438 }
439
440 fn get(&self, index: usize) -> Result<(Tensor<T>, Tensor<T>)> {
441 self.dataset.get(index)
442 }
443}
444
445#[derive(Debug, Clone)]
447pub struct PrivateStats {
448 pub sample_count: usize,
450 pub feature_means: Vec<f64>,
452 pub class_counts: HashMap<String, usize>,
454}
455
456#[derive(Debug)]
458pub struct FederatedPartitioner {
459 num_clients: usize,
461 strategy: PartitioningStrategy,
463 rng: StdRng,
465}
466
467#[derive(Debug, Clone)]
469pub enum PartitioningStrategy {
470 Uniform,
472 Dirichlet { alpha: f64 },
474 ClassBased { classes_per_client: usize },
476 QuantityBased { size_variance: f64 },
478}
479
480impl FederatedPartitioner {
481 pub fn new(num_clients: usize, strategy: PartitioningStrategy, seed: u64) -> Self {
483 Self {
484 num_clients,
485 strategy,
486 rng: StdRng::seed_from_u64(seed),
487 }
488 }
489
490 pub fn partition<T, D>(
492 &mut self,
493 dataset: D,
494 ) -> Result<Vec<FederatedClientDataset<T, ClientIndexedDataset<T, D>>>>
495 where
496 D: Dataset<T> + Clone,
497 T: Clone + Default + Send + Sync + 'static,
498 {
499 let total_samples = dataset.len();
500 let client_assignments = self.generate_client_assignments(total_samples)?;
501
502 let mut client_datasets = Vec::new();
503
504 for (client_idx, indices) in client_assignments.into_iter().enumerate() {
505 let client_id = format!("client_{client_idx}");
506 let client_dataset = ClientIndexedDataset::new(dataset.clone(), indices);
507
508 let config = ClientConfig {
509 client_id: client_id.clone(),
510 distribution_type: self.get_distribution_type_for_client(client_idx),
511 privacy_config: PrivacyConfig::default(),
512 metadata: HashMap::new(),
513 };
514
515 let federated_client = FederatedClientDataset::new(client_dataset, config);
516 client_datasets.push(federated_client);
517 }
518
519 Ok(client_datasets)
520 }
521
522 fn generate_client_assignments(&mut self, total_samples: usize) -> Result<Vec<Vec<usize>>> {
523 match &self.strategy {
524 PartitioningStrategy::Uniform => self.uniform_partition(total_samples),
525 PartitioningStrategy::Dirichlet { alpha } => {
526 self.dirichlet_partition(total_samples, *alpha)
527 }
528 PartitioningStrategy::ClassBased {
529 classes_per_client: _,
530 } => {
531 self.uniform_partition(total_samples)
533 }
534 PartitioningStrategy::QuantityBased { size_variance } => {
535 self.quantity_based_partition(total_samples, *size_variance)
536 }
537 }
538 }
539
540 fn uniform_partition(&mut self, total_samples: usize) -> Result<Vec<Vec<usize>>> {
541 let mut indices: Vec<usize> = (0..total_samples).collect();
542
543 for i in (1..indices.len()).rev() {
545 let j = self.rng.random_range(0..i);
546 indices.swap(i, j);
547 }
548
549 let base_size = total_samples / self.num_clients;
550 let remainder = total_samples % self.num_clients;
551
552 let mut client_assignments = Vec::new();
553 let mut start_idx = 0;
554
555 for i in 0..self.num_clients {
556 let client_size = base_size + if i < remainder { 1 } else { 0 };
557 let end_idx = start_idx + client_size;
558
559 client_assignments.push(indices[start_idx..end_idx].to_vec());
560 start_idx = end_idx;
561 }
562
563 Ok(client_assignments)
564 }
565
566 fn dirichlet_partition(&mut self, total_samples: usize, alpha: f64) -> Result<Vec<Vec<usize>>> {
567 let mut proportions = Vec::new();
570 let mut sum = 0.0;
571
572 for _ in 0..self.num_clients {
573 let prop = self.rng.random::<f64>() * alpha + 0.1; proportions.push(prop);
575 sum += prop;
576 }
577
578 for prop in &mut proportions {
580 *prop /= sum;
581 }
582
583 let mut client_assignments = Vec::new();
584 let mut assigned_samples = 0;
585
586 for (i, &proportion) in proportions.iter().enumerate() {
587 let client_samples = if i == self.num_clients - 1 {
588 total_samples - assigned_samples
590 } else {
591 (total_samples as f64 * proportion) as usize
592 };
593
594 let indices: Vec<usize> =
595 (assigned_samples..assigned_samples + client_samples).collect();
596 client_assignments.push(indices);
597 assigned_samples += client_samples;
598 }
599
600 Ok(client_assignments)
601 }
602
603 fn quantity_based_partition(
604 &mut self,
605 total_samples: usize,
606 size_variance: f64,
607 ) -> Result<Vec<Vec<usize>>> {
608 let base_size = total_samples as f64 / self.num_clients as f64;
609 let mut client_sizes = Vec::new();
610 let mut total_assigned = 0;
611
612 for i in 0..self.num_clients {
613 let variance_factor = 1.0 + (self.rng.random::<f64>() - 0.5) * 2.0 * size_variance;
614 let client_size = if i == self.num_clients - 1 {
615 total_samples - total_assigned
617 } else {
618 ((base_size * variance_factor) as usize).min(total_samples - total_assigned)
619 };
620
621 client_sizes.push(client_size);
622 total_assigned += client_size;
623
624 if total_assigned >= total_samples {
625 break;
626 }
627 }
628
629 let mut client_assignments = Vec::new();
630 let mut start_idx = 0;
631
632 for &size in &client_sizes {
633 let end_idx = (start_idx + size).min(total_samples);
634 let indices: Vec<usize> = (start_idx..end_idx).collect();
635 client_assignments.push(indices);
636 start_idx = end_idx;
637 }
638
639 Ok(client_assignments)
640 }
641
642 fn get_distribution_type_for_client(&self, _client_idx: usize) -> DataDistribution {
643 match &self.strategy {
644 PartitioningStrategy::Uniform => DataDistribution::Iid,
645 PartitioningStrategy::Dirichlet { alpha } => DataDistribution::NonIidClassImbalance {
646 class_weights: vec![*alpha, 1.0 - alpha],
647 },
648 PartitioningStrategy::ClassBased { .. } => DataDistribution::NonIidClassImbalance {
649 class_weights: vec![0.8, 0.2],
650 },
651 PartitioningStrategy::QuantityBased { .. } => DataDistribution::Iid,
652 }
653 }
654}
655
656#[derive(Debug, Clone)]
658pub struct ClientIndexedDataset<T, D> {
659 dataset: D,
660 indices: Vec<usize>,
661 _phantom: std::marker::PhantomData<T>,
662}
663
664impl<T, D> ClientIndexedDataset<T, D>
665where
666 D: Dataset<T>,
667 T: Clone + Default + Send + Sync + 'static,
668{
669 pub fn new(dataset: D, indices: Vec<usize>) -> Self {
671 Self {
672 dataset,
673 indices,
674 _phantom: std::marker::PhantomData,
675 }
676 }
677
678 pub fn inner(&self) -> &D {
680 &self.dataset
681 }
682
683 pub fn indices(&self) -> &[usize] {
685 &self.indices
686 }
687}
688
689impl<T, D> Dataset<T> for ClientIndexedDataset<T, D>
690where
691 D: Dataset<T>,
692 T: Clone + Default + Send + Sync + 'static,
693{
694 fn len(&self) -> usize {
695 self.indices.len()
696 }
697
698 fn get(&self, index: usize) -> Result<(Tensor<T>, Tensor<T>)> {
699 if index >= self.indices.len() {
700 return Err(TensorError::invalid_argument(format!(
701 "Index {} out of bounds for client dataset of length {}",
702 index,
703 self.indices.len()
704 )));
705 }
706
707 let actual_index = self.indices[index];
708 self.dataset.get(actual_index)
709 }
710}
711
712#[derive(Debug)]
714pub struct FederatedAggregator {
715 strategy: AggregationStrategy,
717 client_weights: HashMap<ClientId, f64>,
719}
720
721#[derive(Debug, Clone)]
723pub enum AggregationStrategy {
724 Average,
726 WeightedBySize,
728 WeightedByQuality,
730 Median,
732 TrimmedMean { trim_fraction: f64 },
734}
735
736impl FederatedAggregator {
737 pub fn new(strategy: AggregationStrategy) -> Self {
739 Self {
740 strategy,
741 client_weights: HashMap::new(),
742 }
743 }
744
745 pub fn set_client_weight(&mut self, client_id: ClientId, weight: f64) {
747 self.client_weights.insert(client_id, weight);
748 }
749
750 pub fn aggregate_statistics(
752 &self,
753 client_stats: Vec<(ClientId, PrivateStats)>,
754 ) -> Result<PrivateStats> {
755 if client_stats.is_empty() {
756 return Err(TensorError::invalid_argument(
757 "No client statistics provided".to_string(),
758 ));
759 }
760
761 match &self.strategy {
762 AggregationStrategy::Average => self.average_statistics(client_stats),
763 AggregationStrategy::WeightedBySize => {
764 self.weighted_statistics(client_stats, |stats| stats.sample_count as f64)
765 }
766 AggregationStrategy::WeightedByQuality => {
767 self.weighted_statistics(client_stats, |_| 1.0)
768 } AggregationStrategy::Median => self.median_statistics(client_stats),
770 AggregationStrategy::TrimmedMean { trim_fraction } => {
771 self.trimmed_mean_statistics(client_stats, *trim_fraction)
772 }
773 }
774 }
775
776 fn average_statistics(
777 &self,
778 client_stats: Vec<(ClientId, PrivateStats)>,
779 ) -> Result<PrivateStats> {
780 let num_clients = client_stats.len() as f64;
781 let mut total_samples = 0;
782 let mut aggregated_means = Vec::new();
783 let mut aggregated_class_counts = HashMap::new();
784
785 if let Some((_, first_stats)) = client_stats.first() {
787 aggregated_means.resize(first_stats.feature_means.len(), 0.0);
788 }
789
790 for (_, stats) in &client_stats {
791 total_samples += stats.sample_count;
792
793 for (i, &mean) in stats.feature_means.iter().enumerate() {
795 if i < aggregated_means.len() {
796 aggregated_means[i] += mean / num_clients;
797 }
798 }
799
800 for (class, &count) in &stats.class_counts {
802 *aggregated_class_counts.entry(class.clone()).or_insert(0) += count;
803 }
804 }
805
806 Ok(PrivateStats {
807 sample_count: total_samples,
808 feature_means: aggregated_means,
809 class_counts: aggregated_class_counts,
810 })
811 }
812
813 fn weighted_statistics<F>(
814 &self,
815 client_stats: Vec<(ClientId, PrivateStats)>,
816 weight_fn: F,
817 ) -> Result<PrivateStats>
818 where
819 F: Fn(&PrivateStats) -> f64,
820 {
821 #[allow(unused_assignments)]
822 let mut total_weight = 0.0;
823 let mut total_samples = 0;
824 let mut aggregated_means = Vec::new();
825 let mut aggregated_class_counts = HashMap::new();
826
827 let weights: Vec<f64> = client_stats
829 .iter()
830 .map(|(_, stats)| weight_fn(stats))
831 .collect();
832
833 total_weight = weights.iter().sum();
834
835 if let Some((_, first_stats)) = client_stats.first() {
836 aggregated_means.resize(first_stats.feature_means.len(), 0.0);
837 }
838
839 for ((_, stats), weight) in client_stats.iter().zip(weights.iter()) {
840 let normalized_weight = weight / total_weight;
841 total_samples += stats.sample_count;
842
843 for (i, &mean) in stats.feature_means.iter().enumerate() {
845 if i < aggregated_means.len() {
846 aggregated_means[i] += mean * normalized_weight;
847 }
848 }
849
850 for (class, &count) in &stats.class_counts {
852 let weighted_count = (count as f64 * normalized_weight) as usize;
853 *aggregated_class_counts.entry(class.clone()).or_insert(0) += weighted_count;
854 }
855 }
856
857 Ok(PrivateStats {
858 sample_count: total_samples,
859 feature_means: aggregated_means,
860 class_counts: aggregated_class_counts,
861 })
862 }
863
864 fn median_statistics(
865 &self,
866 client_stats: Vec<(ClientId, PrivateStats)>,
867 ) -> Result<PrivateStats> {
868 self.average_statistics(client_stats)
870 }
871
872 fn trimmed_mean_statistics(
873 &self,
874 client_stats: Vec<(ClientId, PrivateStats)>,
875 _trim_fraction: f64,
876 ) -> Result<PrivateStats> {
877 self.average_statistics(client_stats)
879 }
880}
881
882pub trait FederatedDatasetExt<T>: Dataset<T> + Sized
884where
885 T: Clone + Default + Send + Sync + 'static,
886{
887 fn federated_client(self, config: ClientConfig) -> FederatedClientDataset<T, Self> {
889 FederatedClientDataset::new(self, config)
890 }
891
892 fn partition_federated(
894 self,
895 num_clients: usize,
896 strategy: PartitioningStrategy,
897 seed: u64,
898 ) -> Result<Vec<FederatedClientDataset<T, ClientIndexedDataset<T, Self>>>>
899 where
900 Self: Clone,
901 {
902 let mut partitioner = FederatedPartitioner::new(num_clients, strategy, seed);
903 partitioner.partition(self)
904 }
905}
906
907impl<T, D: Dataset<T>> FederatedDatasetExt<T> for D where T: Clone + Default + Send + Sync + 'static {}
908
909#[cfg(test)]
910mod tests {
911 use super::*;
912 use crate::TensorDataset;
913
914 #[test]
915 fn test_privacy_manager() {
916 let config = PrivacyConfig {
917 enable_dp: true,
918 epsilon: 1.0,
919 delta: 1e-5,
920 noise_mechanism: NoiseMechanism::Gaussian { sensitivity: 1.0 },
921 privacy_budget: 10.0,
922 };
923
924 let mut privacy_manager = PrivacyManager::new(&config, 42);
925
926 assert_eq!(privacy_manager.remaining_budget(), 10.0);
927 assert!(privacy_manager.can_spend_budget(1.0));
928
929 let noisy_value = privacy_manager
930 .add_noise(5.0, &config, 1.0)
931 .expect("test: operation should succeed");
932 assert!(privacy_manager.remaining_budget() < 10.0);
933 assert_ne!(noisy_value, 5.0); }
935
936 #[test]
937 fn test_federated_client_dataset() {
938 let features_data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
940 let labels_data = vec![0.0, 1.0, 0.0];
941 let features =
942 Tensor::from_vec(features_data, &[3, 2]).expect("test: tensor creation should succeed");
943 let labels =
944 Tensor::from_vec(labels_data, &[3]).expect("test: tensor creation should succeed");
945 let dataset = TensorDataset::new(features, labels);
946
947 let config = ClientConfig {
948 client_id: "test_client".to_string(),
949 distribution_type: DataDistribution::Iid,
950 privacy_config: PrivacyConfig::default(),
951 metadata: HashMap::new(),
952 };
953
954 let federated_dataset = FederatedClientDataset::new(dataset, config);
955
956 assert_eq!(federated_dataset.len(), 3);
957 assert_eq!(federated_dataset.config().client_id, "test_client");
958 assert_eq!(federated_dataset.stats().sample_count, 3);
959 }
960
961 #[test]
962 fn test_federated_partitioner() {
963 let features_data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
965 let labels_data = vec![0.0, 1.0, 0.0, 1.0];
966 let features =
967 Tensor::from_vec(features_data, &[4, 2]).expect("test: tensor creation should succeed");
968 let labels =
969 Tensor::from_vec(labels_data, &[4]).expect("test: tensor creation should succeed");
970 let dataset = TensorDataset::new(features, labels);
971
972 let mut partitioner = FederatedPartitioner::new(2, PartitioningStrategy::Uniform, 42);
973 let client_datasets = partitioner
974 .partition(dataset)
975 .expect("test: operation should succeed");
976
977 assert_eq!(client_datasets.len(), 2);
978
979 let total_samples: usize = client_datasets.iter().map(|d| d.len()).sum();
980 assert_eq!(total_samples, 4);
981 }
982
983 #[test]
984 fn test_client_indexed_dataset() {
985 let features_data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
987 let labels_data = vec![0.0, 1.0, 0.0];
988 let features =
989 Tensor::from_vec(features_data, &[3, 2]).expect("test: tensor creation should succeed");
990 let labels =
991 Tensor::from_vec(labels_data, &[3]).expect("test: tensor creation should succeed");
992 let dataset = TensorDataset::new(features, labels);
993
994 let indices = vec![0, 2]; let client_dataset = ClientIndexedDataset::new(dataset, indices);
996
997 assert_eq!(client_dataset.len(), 2);
998 assert_eq!(client_dataset.indices(), &[0, 2]);
999
1000 let (features, labels) = client_dataset.get(0).expect("index should be in bounds");
1001 let features_slice = features.as_slice().expect("tensor should be contiguous");
1002 assert_eq!(features_slice, &[1.0, 2.0]); assert_eq!(labels.get(&[]).expect("test: get should succeed"), 0.0);
1004
1005 let (features, labels) = client_dataset.get(1).expect("index should be in bounds");
1006 let features_slice = features.as_slice().expect("tensor should be contiguous");
1007 assert_eq!(features_slice, &[5.0, 6.0]); assert_eq!(labels.get(&[]).expect("test: get should succeed"), 0.0);
1009 }
1010
1011 #[test]
1012 fn test_federated_aggregator() {
1013 let aggregator = FederatedAggregator::new(AggregationStrategy::Average);
1014
1015 let client_stats = vec![
1016 (
1017 "client1".to_string(),
1018 PrivateStats {
1019 sample_count: 100,
1020 feature_means: vec![1.0, 2.0],
1021 class_counts: HashMap::new(),
1022 },
1023 ),
1024 (
1025 "client2".to_string(),
1026 PrivateStats {
1027 sample_count: 200,
1028 feature_means: vec![3.0, 4.0],
1029 class_counts: HashMap::new(),
1030 },
1031 ),
1032 ];
1033
1034 let aggregated = aggregator
1035 .aggregate_statistics(client_stats)
1036 .expect("test: operation should succeed");
1037
1038 assert_eq!(aggregated.sample_count, 300);
1039 assert_eq!(aggregated.feature_means, vec![2.0, 3.0]); }
1041
1042 #[test]
1043 fn test_privacy_config_serialization() {
1044 let config = PrivacyConfig {
1045 enable_dp: true,
1046 epsilon: 1.0,
1047 delta: 1e-5,
1048 noise_mechanism: NoiseMechanism::Gaussian { sensitivity: 1.0 },
1049 privacy_budget: 10.0,
1050 };
1051
1052 let json = serde_json::to_string(&config).expect("test: serialization should succeed");
1053 let deserialized: PrivacyConfig =
1054 serde_json::from_str(&json).expect("test: JSON parsing should succeed");
1055
1056 assert!(deserialized.enable_dp);
1057 assert_eq!(deserialized.epsilon, 1.0);
1058 }
1059
1060 #[test]
1061 fn test_extension_trait() {
1062 let features_data = vec![1.0, 2.0, 3.0, 4.0];
1064 let labels_data = vec![0.0, 1.0];
1065 let features =
1066 Tensor::from_vec(features_data, &[2, 2]).expect("test: tensor creation should succeed");
1067 let labels =
1068 Tensor::from_vec(labels_data, &[2]).expect("test: tensor creation should succeed");
1069 let dataset = TensorDataset::new(features, labels);
1070
1071 let config = ClientConfig {
1073 client_id: "test_client".to_string(),
1074 distribution_type: DataDistribution::Iid,
1075 privacy_config: PrivacyConfig::default(),
1076 metadata: HashMap::new(),
1077 };
1078
1079 let federated_dataset = dataset.federated_client(config);
1080 assert_eq!(federated_dataset.len(), 2);
1081
1082 let features_data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
1084 let labels_data = vec![0.0, 1.0, 0.0, 1.0];
1085 let features =
1086 Tensor::from_vec(features_data, &[4, 2]).expect("test: tensor creation should succeed");
1087 let labels =
1088 Tensor::from_vec(labels_data, &[4]).expect("test: tensor creation should succeed");
1089 let dataset = TensorDataset::new(features, labels);
1090
1091 let client_datasets = dataset
1092 .partition_federated(2, PartitioningStrategy::Uniform, 42)
1093 .expect("test: operation should succeed");
1094 assert_eq!(client_datasets.len(), 2);
1095 }
1096}