Skip to main content

tenflowers_dataset/
federated.rs

1//! Federated learning dataset utilities for privacy-preserving distributed ML
2//!
3//! This module provides basic infrastructure for federated learning scenarios,
4//! including client data partitioning, differential privacy, and heterogeneous
5//! data distribution management.
6
7use crate::{Dataset, Result};
8use scirs2_core::random::rand_prelude::*;
9use scirs2_core::random::rngs::StdRng;
10use scirs2_core::random::SeedableRng;
11use scirs2_core::random::{rng, Random};
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14use std::sync::{Arc, Mutex};
15use tenflowers_core::{Tensor, TensorError};
16
17/// Unique identifier for federated learning clients
18pub type ClientId = String;
19
20/// Federated learning client configuration
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct ClientConfig {
23    /// Client identifier
24    pub client_id: ClientId,
25    /// Data distribution type for this client
26    pub distribution_type: DataDistribution,
27    /// Privacy settings
28    pub privacy_config: PrivacyConfig,
29    /// Client-specific metadata
30    pub metadata: HashMap<String, String>,
31}
32
33/// Data distribution types for federated clients
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub enum DataDistribution {
36    /// Independent and identically distributed (IID)
37    Iid,
38    /// Non-IID with class imbalance
39    NonIidClassImbalance { class_weights: Vec<f64> },
40    /// Non-IID with feature shift
41    NonIidFeatureShift { shift_factor: f64 },
42    /// Non-IID with both class and feature shifts
43    NonIidMixed {
44        class_weights: Vec<f64>,
45        shift_factor: f64,
46    },
47    /// Custom distribution strategy
48    Custom {
49        strategy_name: String,
50        parameters: HashMap<String, f64>,
51    },
52}
53
54/// Privacy configuration for federated learning
55#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct PrivacyConfig {
57    /// Enable differential privacy
58    pub enable_dp: bool,
59    /// Epsilon value for differential privacy (smaller = more private)
60    pub epsilon: f64,
61    /// Delta value for differential privacy
62    pub delta: f64,
63    /// Noise mechanism
64    pub noise_mechanism: NoiseMechanism,
65    /// Privacy budget tracking
66    pub privacy_budget: f64,
67}
68
69/// Noise mechanisms for differential privacy
70#[derive(Debug, Clone, Serialize, Deserialize)]
71pub enum NoiseMechanism {
72    /// Laplace mechanism
73    Laplace { sensitivity: f64 },
74    /// Gaussian mechanism
75    Gaussian { sensitivity: f64 },
76    /// Exponential mechanism
77    Exponential { sensitivity: f64 },
78}
79
80impl Default for PrivacyConfig {
81    fn default() -> Self {
82        Self {
83            enable_dp: false,
84            epsilon: 1.0,
85            delta: 1e-5,
86            noise_mechanism: NoiseMechanism::Gaussian { sensitivity: 1.0 },
87            privacy_budget: 10.0,
88        }
89    }
90}
91
92/// Client dataset for federated learning
93#[derive(Debug)]
94pub struct FederatedClientDataset<T, D> {
95    /// Client configuration
96    config: ClientConfig,
97    /// Local dataset
98    dataset: D,
99    /// Privacy manager for differential privacy
100    privacy_manager: Arc<Mutex<PrivacyManager>>,
101    /// Client statistics
102    stats: ClientStats,
103    _phantom: std::marker::PhantomData<T>,
104}
105
106/// Privacy manager for differential privacy operations
107#[derive(Debug)]
108pub struct PrivacyManager {
109    /// Remaining privacy budget
110    remaining_budget: f64,
111    /// Noise generation RNG
112    rng: StdRng,
113    /// Noise scale cache
114    noise_scale_cache: HashMap<String, f64>,
115}
116
117/// Client-specific statistics
118#[derive(Debug, Clone, Serialize, Deserialize)]
119pub struct ClientStats {
120    /// Number of samples
121    pub sample_count: usize,
122    /// Class distribution
123    pub class_distribution: HashMap<String, usize>,
124    /// Feature statistics
125    pub feature_stats: FederatedFeatureStats,
126    /// Data quality metrics
127    pub quality_metrics: QualityMetrics,
128}
129
130/// Feature statistics for federated analysis
131#[derive(Debug, Clone, Serialize, Deserialize)]
132pub struct FederatedFeatureStats {
133    /// Feature means
134    pub means: Vec<f64>,
135    /// Feature standard deviations
136    pub stds: Vec<f64>,
137    /// Feature ranges
138    pub ranges: Vec<(f64, f64)>,
139}
140
141/// Data quality metrics for federated learning
142#[derive(Debug, Clone, Serialize, Deserialize)]
143pub struct QualityMetrics {
144    /// Missing value percentage
145    pub missing_percentage: f64,
146    /// Outlier percentage
147    pub outlier_percentage: f64,
148    /// Data consistency score (0.0 to 1.0)
149    pub consistency_score: f64,
150}
151
152impl PrivacyManager {
153    /// Create a new privacy manager
154    pub fn new(config: &PrivacyConfig, seed: u64) -> Self {
155        Self {
156            remaining_budget: config.privacy_budget,
157            rng: StdRng::seed_from_u64(seed),
158            noise_scale_cache: HashMap::new(),
159        }
160    }
161
162    /// Add differential privacy noise to a value
163    pub fn add_noise(
164        &mut self,
165        value: f64,
166        config: &PrivacyConfig,
167        query_sensitivity: f64,
168    ) -> Result<f64> {
169        if !config.enable_dp {
170            return Ok(value);
171        }
172
173        if self.remaining_budget <= 0.0 {
174            return Err(TensorError::invalid_argument(
175                "Privacy budget exhausted".to_string(),
176            ));
177        }
178
179        let noise_scale = self.calculate_noise_scale(config, query_sensitivity);
180        let noise = match &config.noise_mechanism {
181            NoiseMechanism::Laplace { .. } => {
182                // Laplace noise: scale = sensitivity / epsilon
183                let scale = noise_scale;
184                self.sample_laplace(scale)
185            }
186            NoiseMechanism::Gaussian { .. } => {
187                // Gaussian noise: sigma = sqrt(2 * ln(1.25/delta)) * sensitivity / epsilon
188                let sigma = noise_scale;
189                self.sample_gaussian(sigma)
190            }
191            NoiseMechanism::Exponential { .. } => {
192                // Simplified exponential mechanism (not full implementation)
193                let scale = noise_scale;
194                self.sample_laplace(scale)
195            }
196        };
197
198        // Consume privacy budget
199        self.remaining_budget -= config.epsilon;
200
201        Ok(value + noise)
202    }
203
204    /// Add noise to a tensor with differential privacy
205    pub fn add_noise_tensor<T>(
206        &mut self,
207        tensor: &Tensor<T>,
208        config: &PrivacyConfig,
209        sensitivity: f64,
210    ) -> Result<Tensor<T>>
211    where
212        T: Clone + Default + Send + Sync + 'static,
213        T: From<f64> + Into<f64>,
214    {
215        if !config.enable_dp {
216            return Ok(tensor.clone());
217        }
218
219        let shape = tensor.shape().dims().to_vec();
220        let mut noisy_data = Vec::new();
221
222        if let Some(slice) = tensor.as_slice() {
223            for value in slice {
224                let original_value: f64 = value.clone().into();
225                let noisy_value = self.add_noise(original_value, config, sensitivity)?;
226                noisy_data.push(T::from(noisy_value));
227            }
228        } else {
229            // Handle scalar tensor
230            let value: f64 = tensor.get(&[]).unwrap_or_default().into();
231            let noisy_value = self.add_noise(value, config, sensitivity)?;
232            noisy_data.push(T::from(noisy_value));
233        }
234
235        Tensor::from_vec(noisy_data, &shape)
236    }
237
238    fn calculate_noise_scale(&mut self, config: &PrivacyConfig, sensitivity: f64) -> f64 {
239        let cache_key = format!("{}_{}", config.epsilon, sensitivity);
240
241        if let Some(&cached_scale) = self.noise_scale_cache.get(&cache_key) {
242            return cached_scale;
243        }
244
245        let scale = match &config.noise_mechanism {
246            NoiseMechanism::Laplace { .. } => sensitivity / config.epsilon,
247            NoiseMechanism::Gaussian { .. } => {
248                let factor = (2.0 * (1.25 / config.delta).ln()).sqrt();
249                factor * sensitivity / config.epsilon
250            }
251            NoiseMechanism::Exponential { .. } => sensitivity / config.epsilon,
252        };
253
254        self.noise_scale_cache.insert(cache_key, scale);
255        scale
256    }
257
258    fn sample_laplace(&mut self, scale: f64) -> f64 {
259        // Box-Muller transform for Laplace distribution approximation
260        let u1: f64 = self.rng.random();
261        let u2: f64 = self.rng.random();
262
263        let sign = if u1 < 0.5 { -1.0 } else { 1.0 };
264        sign * scale * (1.0_f64 - 2.0_f64 * u2.abs()).max(1e-10_f64).ln()
265    }
266
267    fn sample_gaussian(&mut self, sigma: f64) -> f64 {
268        // Box-Muller transform to generate Gaussian noise
269        let u1: f64 = self.rng.random();
270        let u2: f64 = self.rng.random();
271
272        let z0 = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
273        z0 * sigma
274    }
275
276    /// Check if privacy budget allows the operation
277    pub fn can_spend_budget(&self, epsilon: f64) -> bool {
278        self.remaining_budget >= epsilon
279    }
280
281    /// Get remaining privacy budget
282    pub fn remaining_budget(&self) -> f64 {
283        self.remaining_budget
284    }
285}
286
287impl<T, D> FederatedClientDataset<T, D>
288where
289    D: Dataset<T>,
290    T: Clone + Default + Send + Sync + 'static,
291{
292    /// Create a new federated client dataset
293    pub fn new(dataset: D, config: ClientConfig) -> Self {
294        let stats = Self::compute_basic_stats(&dataset);
295        let privacy_manager = Arc::new(Mutex::new(PrivacyManager::new(&config.privacy_config, 42)));
296
297        Self {
298            config,
299            dataset,
300            privacy_manager,
301            stats,
302            _phantom: std::marker::PhantomData,
303        }
304    }
305
306    /// Get client configuration
307    pub fn config(&self) -> &ClientConfig {
308        &self.config
309    }
310
311    /// Get client statistics
312    pub fn stats(&self) -> &ClientStats {
313        &self.stats
314    }
315
316    /// Get a private sample (with differential privacy if enabled)
317    pub fn get_private(&self, index: usize) -> Result<(Tensor<T>, Tensor<T>)>
318    where
319        T: From<f64> + Into<f64>,
320    {
321        let (features, labels) = self.dataset.get(index)?;
322
323        if !self.config.privacy_config.enable_dp {
324            return Ok((features, labels));
325        }
326
327        let mut privacy_manager = self
328            .privacy_manager
329            .lock()
330            .expect("lock should not be poisoned");
331        let noisy_features =
332            privacy_manager.add_noise_tensor(&features, &self.config.privacy_config, 1.0)?;
333
334        Ok((noisy_features, labels))
335    }
336
337    /// Compute aggregated statistics with differential privacy
338    pub fn compute_private_statistics(&self) -> Result<PrivateStats>
339    where
340        T: From<f64> + Into<f64>,
341    {
342        let mut feature_sums = Vec::new();
343        let mut feature_counts = Vec::new();
344        let sample_count = self.dataset.len();
345
346        if sample_count == 0 {
347            return Ok(PrivateStats {
348                sample_count: 0,
349                feature_means: Vec::new(),
350                class_counts: HashMap::new(),
351            });
352        }
353
354        // Get first sample to determine dimensions
355        let (first_features, _) = self.dataset.get(0)?;
356        let feature_dim = if let Some(slice) = first_features.as_slice() {
357            slice.len()
358        } else {
359            1
360        };
361
362        feature_sums.resize(feature_dim, 0.0);
363        feature_counts.resize(feature_dim, 0);
364
365        // Aggregate features
366        for i in 0..sample_count {
367            let (features, _) = self.dataset.get(i)?;
368
369            if let Some(slice) = features.as_slice() {
370                for (j, value) in slice.iter().enumerate() {
371                    feature_sums[j] += value.clone().into();
372                    feature_counts[j] += 1;
373                }
374            } else {
375                let value: f64 = features.get(&[]).unwrap_or(T::default()).into();
376                feature_sums[0] += value;
377                feature_counts[0] += 1;
378            }
379        }
380
381        // Compute means with differential privacy
382        let mut private_means = Vec::new();
383        let mut privacy_manager = self
384            .privacy_manager
385            .lock()
386            .expect("lock should not be poisoned");
387
388        for i in 0..feature_dim {
389            let mean = if feature_counts[i] > 0 {
390                feature_sums[i] / feature_counts[i] as f64
391            } else {
392                0.0
393            };
394
395            let private_mean = privacy_manager.add_noise(mean, &self.config.privacy_config, 1.0)?;
396            private_means.push(private_mean);
397        }
398
399        // Add noise to sample count
400        let private_sample_count =
401            privacy_manager.add_noise(sample_count as f64, &self.config.privacy_config, 1.0)?
402                as usize;
403
404        Ok(PrivateStats {
405            sample_count: private_sample_count,
406            feature_means: private_means,
407            class_counts: HashMap::new(), // Simplified for basic implementation
408        })
409    }
410
411    fn compute_basic_stats(dataset: &D) -> ClientStats {
412        let sample_count = dataset.len();
413
414        ClientStats {
415            sample_count,
416            class_distribution: HashMap::new(), // Simplified
417            feature_stats: FederatedFeatureStats {
418                means: Vec::new(),
419                stds: Vec::new(),
420                ranges: Vec::new(),
421            },
422            quality_metrics: QualityMetrics {
423                missing_percentage: 0.0,
424                outlier_percentage: 0.0,
425                consistency_score: 1.0,
426            },
427        }
428    }
429}
430
431impl<T, D> Dataset<T> for FederatedClientDataset<T, D>
432where
433    D: Dataset<T>,
434    T: Clone + Default + Send + Sync + 'static,
435{
436    fn len(&self) -> usize {
437        self.dataset.len()
438    }
439
440    fn get(&self, index: usize) -> Result<(Tensor<T>, Tensor<T>)> {
441        self.dataset.get(index)
442    }
443}
444
445/// Statistics computed with differential privacy
446#[derive(Debug, Clone)]
447pub struct PrivateStats {
448    /// Sample count (with noise)
449    pub sample_count: usize,
450    /// Feature means (with noise)
451    pub feature_means: Vec<f64>,
452    /// Class counts (with noise)
453    pub class_counts: HashMap<String, usize>,
454}
455
456/// Federated dataset partitioner for distributing data across clients
457#[derive(Debug)]
458pub struct FederatedPartitioner {
459    /// Total number of clients
460    num_clients: usize,
461    /// Partitioning strategy
462    strategy: PartitioningStrategy,
463    /// Random number generator for partitioning
464    rng: StdRng,
465}
466
467/// Partitioning strategies for federated learning
468#[derive(Debug, Clone)]
469pub enum PartitioningStrategy {
470    /// Uniform random distribution
471    Uniform,
472    /// Dirichlet distribution for non-IID data
473    Dirichlet { alpha: f64 },
474    /// Class-based partitioning
475    ClassBased { classes_per_client: usize },
476    /// Quantity-based partitioning (different dataset sizes)
477    QuantityBased { size_variance: f64 },
478}
479
480impl FederatedPartitioner {
481    /// Create a new federated partitioner
482    pub fn new(num_clients: usize, strategy: PartitioningStrategy, seed: u64) -> Self {
483        Self {
484            num_clients,
485            strategy,
486            rng: StdRng::seed_from_u64(seed),
487        }
488    }
489
490    /// Partition a dataset across multiple clients
491    pub fn partition<T, D>(
492        &mut self,
493        dataset: D,
494    ) -> Result<Vec<FederatedClientDataset<T, ClientIndexedDataset<T, D>>>>
495    where
496        D: Dataset<T> + Clone,
497        T: Clone + Default + Send + Sync + 'static,
498    {
499        let total_samples = dataset.len();
500        let client_assignments = self.generate_client_assignments(total_samples)?;
501
502        let mut client_datasets = Vec::new();
503
504        for (client_idx, indices) in client_assignments.into_iter().enumerate() {
505            let client_id = format!("client_{client_idx}");
506            let client_dataset = ClientIndexedDataset::new(dataset.clone(), indices);
507
508            let config = ClientConfig {
509                client_id: client_id.clone(),
510                distribution_type: self.get_distribution_type_for_client(client_idx),
511                privacy_config: PrivacyConfig::default(),
512                metadata: HashMap::new(),
513            };
514
515            let federated_client = FederatedClientDataset::new(client_dataset, config);
516            client_datasets.push(federated_client);
517        }
518
519        Ok(client_datasets)
520    }
521
522    fn generate_client_assignments(&mut self, total_samples: usize) -> Result<Vec<Vec<usize>>> {
523        match &self.strategy {
524            PartitioningStrategy::Uniform => self.uniform_partition(total_samples),
525            PartitioningStrategy::Dirichlet { alpha } => {
526                self.dirichlet_partition(total_samples, *alpha)
527            }
528            PartitioningStrategy::ClassBased {
529                classes_per_client: _,
530            } => {
531                // Simplified class-based partitioning
532                self.uniform_partition(total_samples)
533            }
534            PartitioningStrategy::QuantityBased { size_variance } => {
535                self.quantity_based_partition(total_samples, *size_variance)
536            }
537        }
538    }
539
540    fn uniform_partition(&mut self, total_samples: usize) -> Result<Vec<Vec<usize>>> {
541        let mut indices: Vec<usize> = (0..total_samples).collect();
542
543        // Shuffle indices
544        for i in (1..indices.len()).rev() {
545            let j = self.rng.random_range(0..i);
546            indices.swap(i, j);
547        }
548
549        let base_size = total_samples / self.num_clients;
550        let remainder = total_samples % self.num_clients;
551
552        let mut client_assignments = Vec::new();
553        let mut start_idx = 0;
554
555        for i in 0..self.num_clients {
556            let client_size = base_size + if i < remainder { 1 } else { 0 };
557            let end_idx = start_idx + client_size;
558
559            client_assignments.push(indices[start_idx..end_idx].to_vec());
560            start_idx = end_idx;
561        }
562
563        Ok(client_assignments)
564    }
565
566    fn dirichlet_partition(&mut self, total_samples: usize, alpha: f64) -> Result<Vec<Vec<usize>>> {
567        // Simplified Dirichlet partitioning (basic implementation)
568        // In a full implementation, this would use proper Dirichlet distribution
569        let mut proportions = Vec::new();
570        let mut sum = 0.0;
571
572        for _ in 0..self.num_clients {
573            let prop = self.rng.random::<f64>() * alpha + 0.1; // Simple approximation
574            proportions.push(prop);
575            sum += prop;
576        }
577
578        // Normalize proportions
579        for prop in &mut proportions {
580            *prop /= sum;
581        }
582
583        let mut client_assignments = Vec::new();
584        let mut assigned_samples = 0;
585
586        for (i, &proportion) in proportions.iter().enumerate() {
587            let client_samples = if i == self.num_clients - 1 {
588                // Last client gets remaining samples
589                total_samples - assigned_samples
590            } else {
591                (total_samples as f64 * proportion) as usize
592            };
593
594            let indices: Vec<usize> =
595                (assigned_samples..assigned_samples + client_samples).collect();
596            client_assignments.push(indices);
597            assigned_samples += client_samples;
598        }
599
600        Ok(client_assignments)
601    }
602
603    fn quantity_based_partition(
604        &mut self,
605        total_samples: usize,
606        size_variance: f64,
607    ) -> Result<Vec<Vec<usize>>> {
608        let base_size = total_samples as f64 / self.num_clients as f64;
609        let mut client_sizes = Vec::new();
610        let mut total_assigned = 0;
611
612        for i in 0..self.num_clients {
613            let variance_factor = 1.0 + (self.rng.random::<f64>() - 0.5) * 2.0 * size_variance;
614            let client_size = if i == self.num_clients - 1 {
615                // Last client gets remaining samples
616                total_samples - total_assigned
617            } else {
618                ((base_size * variance_factor) as usize).min(total_samples - total_assigned)
619            };
620
621            client_sizes.push(client_size);
622            total_assigned += client_size;
623
624            if total_assigned >= total_samples {
625                break;
626            }
627        }
628
629        let mut client_assignments = Vec::new();
630        let mut start_idx = 0;
631
632        for &size in &client_sizes {
633            let end_idx = (start_idx + size).min(total_samples);
634            let indices: Vec<usize> = (start_idx..end_idx).collect();
635            client_assignments.push(indices);
636            start_idx = end_idx;
637        }
638
639        Ok(client_assignments)
640    }
641
642    fn get_distribution_type_for_client(&self, _client_idx: usize) -> DataDistribution {
643        match &self.strategy {
644            PartitioningStrategy::Uniform => DataDistribution::Iid,
645            PartitioningStrategy::Dirichlet { alpha } => DataDistribution::NonIidClassImbalance {
646                class_weights: vec![*alpha, 1.0 - alpha],
647            },
648            PartitioningStrategy::ClassBased { .. } => DataDistribution::NonIidClassImbalance {
649                class_weights: vec![0.8, 0.2],
650            },
651            PartitioningStrategy::QuantityBased { .. } => DataDistribution::Iid,
652        }
653    }
654}
655
656/// Dataset wrapper that provides access to a subset of indices
657#[derive(Debug, Clone)]
658pub struct ClientIndexedDataset<T, D> {
659    dataset: D,
660    indices: Vec<usize>,
661    _phantom: std::marker::PhantomData<T>,
662}
663
664impl<T, D> ClientIndexedDataset<T, D>
665where
666    D: Dataset<T>,
667    T: Clone + Default + Send + Sync + 'static,
668{
669    /// Create a new client indexed dataset
670    pub fn new(dataset: D, indices: Vec<usize>) -> Self {
671        Self {
672            dataset,
673            indices,
674            _phantom: std::marker::PhantomData,
675        }
676    }
677
678    /// Get the underlying dataset
679    pub fn inner(&self) -> &D {
680        &self.dataset
681    }
682
683    /// Get the client indices
684    pub fn indices(&self) -> &[usize] {
685        &self.indices
686    }
687}
688
689impl<T, D> Dataset<T> for ClientIndexedDataset<T, D>
690where
691    D: Dataset<T>,
692    T: Clone + Default + Send + Sync + 'static,
693{
694    fn len(&self) -> usize {
695        self.indices.len()
696    }
697
698    fn get(&self, index: usize) -> Result<(Tensor<T>, Tensor<T>)> {
699        if index >= self.indices.len() {
700            return Err(TensorError::invalid_argument(format!(
701                "Index {} out of bounds for client dataset of length {}",
702                index,
703                self.indices.len()
704            )));
705        }
706
707        let actual_index = self.indices[index];
708        self.dataset.get(actual_index)
709    }
710}
711
712/// Federated aggregator for combining results from multiple clients
713#[derive(Debug)]
714pub struct FederatedAggregator {
715    /// Aggregation strategy
716    strategy: AggregationStrategy,
717    /// Client weights for weighted aggregation
718    client_weights: HashMap<ClientId, f64>,
719}
720
721/// Aggregation strategies for federated learning
722#[derive(Debug, Clone)]
723pub enum AggregationStrategy {
724    /// Simple averaging
725    Average,
726    /// Weighted averaging by dataset size
727    WeightedBySize,
728    /// Weighted averaging by data quality
729    WeightedByQuality,
730    /// Median aggregation
731    Median,
732    /// Trimmed mean (excluding outliers)
733    TrimmedMean { trim_fraction: f64 },
734}
735
736impl FederatedAggregator {
737    /// Create a new federated aggregator
738    pub fn new(strategy: AggregationStrategy) -> Self {
739        Self {
740            strategy,
741            client_weights: HashMap::new(),
742        }
743    }
744
745    /// Set client weight for weighted aggregation
746    pub fn set_client_weight(&mut self, client_id: ClientId, weight: f64) {
747        self.client_weights.insert(client_id, weight);
748    }
749
750    /// Aggregate statistics from multiple clients
751    pub fn aggregate_statistics(
752        &self,
753        client_stats: Vec<(ClientId, PrivateStats)>,
754    ) -> Result<PrivateStats> {
755        if client_stats.is_empty() {
756            return Err(TensorError::invalid_argument(
757                "No client statistics provided".to_string(),
758            ));
759        }
760
761        match &self.strategy {
762            AggregationStrategy::Average => self.average_statistics(client_stats),
763            AggregationStrategy::WeightedBySize => {
764                self.weighted_statistics(client_stats, |stats| stats.sample_count as f64)
765            }
766            AggregationStrategy::WeightedByQuality => {
767                self.weighted_statistics(client_stats, |_| 1.0)
768            } // Simplified
769            AggregationStrategy::Median => self.median_statistics(client_stats),
770            AggregationStrategy::TrimmedMean { trim_fraction } => {
771                self.trimmed_mean_statistics(client_stats, *trim_fraction)
772            }
773        }
774    }
775
776    fn average_statistics(
777        &self,
778        client_stats: Vec<(ClientId, PrivateStats)>,
779    ) -> Result<PrivateStats> {
780        let num_clients = client_stats.len() as f64;
781        let mut total_samples = 0;
782        let mut aggregated_means = Vec::new();
783        let mut aggregated_class_counts = HashMap::new();
784
785        // Initialize aggregated means with zeros
786        if let Some((_, first_stats)) = client_stats.first() {
787            aggregated_means.resize(first_stats.feature_means.len(), 0.0);
788        }
789
790        for (_, stats) in &client_stats {
791            total_samples += stats.sample_count;
792
793            // Aggregate feature means
794            for (i, &mean) in stats.feature_means.iter().enumerate() {
795                if i < aggregated_means.len() {
796                    aggregated_means[i] += mean / num_clients;
797                }
798            }
799
800            // Aggregate class counts
801            for (class, &count) in &stats.class_counts {
802                *aggregated_class_counts.entry(class.clone()).or_insert(0) += count;
803            }
804        }
805
806        Ok(PrivateStats {
807            sample_count: total_samples,
808            feature_means: aggregated_means,
809            class_counts: aggregated_class_counts,
810        })
811    }
812
813    fn weighted_statistics<F>(
814        &self,
815        client_stats: Vec<(ClientId, PrivateStats)>,
816        weight_fn: F,
817    ) -> Result<PrivateStats>
818    where
819        F: Fn(&PrivateStats) -> f64,
820    {
821        #[allow(unused_assignments)]
822        let mut total_weight = 0.0;
823        let mut total_samples = 0;
824        let mut aggregated_means = Vec::new();
825        let mut aggregated_class_counts = HashMap::new();
826
827        // Calculate weights and initialize
828        let weights: Vec<f64> = client_stats
829            .iter()
830            .map(|(_, stats)| weight_fn(stats))
831            .collect();
832
833        total_weight = weights.iter().sum();
834
835        if let Some((_, first_stats)) = client_stats.first() {
836            aggregated_means.resize(first_stats.feature_means.len(), 0.0);
837        }
838
839        for ((_, stats), weight) in client_stats.iter().zip(weights.iter()) {
840            let normalized_weight = weight / total_weight;
841            total_samples += stats.sample_count;
842
843            // Aggregate feature means
844            for (i, &mean) in stats.feature_means.iter().enumerate() {
845                if i < aggregated_means.len() {
846                    aggregated_means[i] += mean * normalized_weight;
847                }
848            }
849
850            // Aggregate class counts
851            for (class, &count) in &stats.class_counts {
852                let weighted_count = (count as f64 * normalized_weight) as usize;
853                *aggregated_class_counts.entry(class.clone()).or_insert(0) += weighted_count;
854            }
855        }
856
857        Ok(PrivateStats {
858            sample_count: total_samples,
859            feature_means: aggregated_means,
860            class_counts: aggregated_class_counts,
861        })
862    }
863
864    fn median_statistics(
865        &self,
866        client_stats: Vec<(ClientId, PrivateStats)>,
867    ) -> Result<PrivateStats> {
868        // Simplified median implementation
869        self.average_statistics(client_stats)
870    }
871
872    fn trimmed_mean_statistics(
873        &self,
874        client_stats: Vec<(ClientId, PrivateStats)>,
875        _trim_fraction: f64,
876    ) -> Result<PrivateStats> {
877        // Simplified trimmed mean implementation
878        self.average_statistics(client_stats)
879    }
880}
881
882/// Extension trait for federated dataset operations
883pub trait FederatedDatasetExt<T>: Dataset<T> + Sized
884where
885    T: Clone + Default + Send + Sync + 'static,
886{
887    /// Create a federated client dataset
888    fn federated_client(self, config: ClientConfig) -> FederatedClientDataset<T, Self> {
889        FederatedClientDataset::new(self, config)
890    }
891
892    /// Partition dataset for federated learning
893    fn partition_federated(
894        self,
895        num_clients: usize,
896        strategy: PartitioningStrategy,
897        seed: u64,
898    ) -> Result<Vec<FederatedClientDataset<T, ClientIndexedDataset<T, Self>>>>
899    where
900        Self: Clone,
901    {
902        let mut partitioner = FederatedPartitioner::new(num_clients, strategy, seed);
903        partitioner.partition(self)
904    }
905}
906
907impl<T, D: Dataset<T>> FederatedDatasetExt<T> for D where T: Clone + Default + Send + Sync + 'static {}
908
909#[cfg(test)]
910mod tests {
911    use super::*;
912    use crate::TensorDataset;
913
914    #[test]
915    fn test_privacy_manager() {
916        let config = PrivacyConfig {
917            enable_dp: true,
918            epsilon: 1.0,
919            delta: 1e-5,
920            noise_mechanism: NoiseMechanism::Gaussian { sensitivity: 1.0 },
921            privacy_budget: 10.0,
922        };
923
924        let mut privacy_manager = PrivacyManager::new(&config, 42);
925
926        assert_eq!(privacy_manager.remaining_budget(), 10.0);
927        assert!(privacy_manager.can_spend_budget(1.0));
928
929        let noisy_value = privacy_manager
930            .add_noise(5.0, &config, 1.0)
931            .expect("test: operation should succeed");
932        assert!(privacy_manager.remaining_budget() < 10.0);
933        assert_ne!(noisy_value, 5.0); // Should have noise added
934    }
935
936    #[test]
937    fn test_federated_client_dataset() {
938        // Create test dataset
939        let features_data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
940        let labels_data = vec![0.0, 1.0, 0.0];
941        let features =
942            Tensor::from_vec(features_data, &[3, 2]).expect("test: tensor creation should succeed");
943        let labels =
944            Tensor::from_vec(labels_data, &[3]).expect("test: tensor creation should succeed");
945        let dataset = TensorDataset::new(features, labels);
946
947        let config = ClientConfig {
948            client_id: "test_client".to_string(),
949            distribution_type: DataDistribution::Iid,
950            privacy_config: PrivacyConfig::default(),
951            metadata: HashMap::new(),
952        };
953
954        let federated_dataset = FederatedClientDataset::new(dataset, config);
955
956        assert_eq!(federated_dataset.len(), 3);
957        assert_eq!(federated_dataset.config().client_id, "test_client");
958        assert_eq!(federated_dataset.stats().sample_count, 3);
959    }
960
961    #[test]
962    fn test_federated_partitioner() {
963        // Create test dataset
964        let features_data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
965        let labels_data = vec![0.0, 1.0, 0.0, 1.0];
966        let features =
967            Tensor::from_vec(features_data, &[4, 2]).expect("test: tensor creation should succeed");
968        let labels =
969            Tensor::from_vec(labels_data, &[4]).expect("test: tensor creation should succeed");
970        let dataset = TensorDataset::new(features, labels);
971
972        let mut partitioner = FederatedPartitioner::new(2, PartitioningStrategy::Uniform, 42);
973        let client_datasets = partitioner
974            .partition(dataset)
975            .expect("test: operation should succeed");
976
977        assert_eq!(client_datasets.len(), 2);
978
979        let total_samples: usize = client_datasets.iter().map(|d| d.len()).sum();
980        assert_eq!(total_samples, 4);
981    }
982
983    #[test]
984    fn test_client_indexed_dataset() {
985        // Create test dataset
986        let features_data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
987        let labels_data = vec![0.0, 1.0, 0.0];
988        let features =
989            Tensor::from_vec(features_data, &[3, 2]).expect("test: tensor creation should succeed");
990        let labels =
991            Tensor::from_vec(labels_data, &[3]).expect("test: tensor creation should succeed");
992        let dataset = TensorDataset::new(features, labels);
993
994        let indices = vec![0, 2]; // Skip index 1
995        let client_dataset = ClientIndexedDataset::new(dataset, indices);
996
997        assert_eq!(client_dataset.len(), 2);
998        assert_eq!(client_dataset.indices(), &[0, 2]);
999
1000        let (features, labels) = client_dataset.get(0).expect("index should be in bounds");
1001        let features_slice = features.as_slice().expect("tensor should be contiguous");
1002        assert_eq!(features_slice, &[1.0, 2.0]); // First sample
1003        assert_eq!(labels.get(&[]).expect("test: get should succeed"), 0.0);
1004
1005        let (features, labels) = client_dataset.get(1).expect("index should be in bounds");
1006        let features_slice = features.as_slice().expect("tensor should be contiguous");
1007        assert_eq!(features_slice, &[5.0, 6.0]); // Third sample (index 2)
1008        assert_eq!(labels.get(&[]).expect("test: get should succeed"), 0.0);
1009    }
1010
1011    #[test]
1012    fn test_federated_aggregator() {
1013        let aggregator = FederatedAggregator::new(AggregationStrategy::Average);
1014
1015        let client_stats = vec![
1016            (
1017                "client1".to_string(),
1018                PrivateStats {
1019                    sample_count: 100,
1020                    feature_means: vec![1.0, 2.0],
1021                    class_counts: HashMap::new(),
1022                },
1023            ),
1024            (
1025                "client2".to_string(),
1026                PrivateStats {
1027                    sample_count: 200,
1028                    feature_means: vec![3.0, 4.0],
1029                    class_counts: HashMap::new(),
1030                },
1031            ),
1032        ];
1033
1034        let aggregated = aggregator
1035            .aggregate_statistics(client_stats)
1036            .expect("test: operation should succeed");
1037
1038        assert_eq!(aggregated.sample_count, 300);
1039        assert_eq!(aggregated.feature_means, vec![2.0, 3.0]); // Average of [1,2] and [3,4]
1040    }
1041
1042    #[test]
1043    fn test_privacy_config_serialization() {
1044        let config = PrivacyConfig {
1045            enable_dp: true,
1046            epsilon: 1.0,
1047            delta: 1e-5,
1048            noise_mechanism: NoiseMechanism::Gaussian { sensitivity: 1.0 },
1049            privacy_budget: 10.0,
1050        };
1051
1052        let json = serde_json::to_string(&config).expect("test: serialization should succeed");
1053        let deserialized: PrivacyConfig =
1054            serde_json::from_str(&json).expect("test: JSON parsing should succeed");
1055
1056        assert!(deserialized.enable_dp);
1057        assert_eq!(deserialized.epsilon, 1.0);
1058    }
1059
1060    #[test]
1061    fn test_extension_trait() {
1062        // Create test dataset
1063        let features_data = vec![1.0, 2.0, 3.0, 4.0];
1064        let labels_data = vec![0.0, 1.0];
1065        let features =
1066            Tensor::from_vec(features_data, &[2, 2]).expect("test: tensor creation should succeed");
1067        let labels =
1068            Tensor::from_vec(labels_data, &[2]).expect("test: tensor creation should succeed");
1069        let dataset = TensorDataset::new(features, labels);
1070
1071        // Test federated client creation
1072        let config = ClientConfig {
1073            client_id: "test_client".to_string(),
1074            distribution_type: DataDistribution::Iid,
1075            privacy_config: PrivacyConfig::default(),
1076            metadata: HashMap::new(),
1077        };
1078
1079        let federated_dataset = dataset.federated_client(config);
1080        assert_eq!(federated_dataset.len(), 2);
1081
1082        // Test partitioning
1083        let features_data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
1084        let labels_data = vec![0.0, 1.0, 0.0, 1.0];
1085        let features =
1086            Tensor::from_vec(features_data, &[4, 2]).expect("test: tensor creation should succeed");
1087        let labels =
1088            Tensor::from_vec(labels_data, &[4]).expect("test: tensor creation should succeed");
1089        let dataset = TensorDataset::new(features, labels);
1090
1091        let client_datasets = dataset
1092            .partition_federated(2, PartitioningStrategy::Uniform, 42)
1093            .expect("test: operation should succeed");
1094        assert_eq!(client_datasets.len(), 2);
1095    }
1096}