sklears_core/
distributed_algorithms.rs

1//! Distributed Machine Learning Algorithms
2//!
3//! This module provides concrete implementations of distributed ML algorithms
4//! that scale across multiple nodes with fault tolerance and load balancing.
5
6use crate::distributed::NodeId;
7use crate::error::{Result, SklearsError};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::sync::{Arc, RwLock};
11
12/// Distributed linear regression using parameter server architecture
13///
14/// Implements distributed training of linear regression models across
15/// multiple worker nodes with centralized parameter synchronization.
16#[derive(Debug)]
17pub struct DistributedLinearRegression {
18    /// Configuration for distributed training
19    pub config: DistributedConfig,
20    /// Parameter server for coordinating updates
21    pub parameter_server: Arc<RwLock<ParameterServer>>,
22    /// Worker nodes performing computation
23    pub workers: Vec<WorkerNode>,
24    /// Current model parameters
25    pub parameters: Arc<RwLock<ModelParameters>>,
26}
27
28/// Configuration for distributed training
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct DistributedConfig {
31    /// Number of worker nodes
32    pub num_workers: usize,
33    /// Synchronization strategy
34    pub sync_strategy: SyncStrategy,
35    /// Enable fault tolerance
36    pub fault_tolerance: bool,
37    /// Maximum iterations
38    pub max_iterations: usize,
39    /// Convergence tolerance
40    pub tolerance: f64,
41    /// Learning rate
42    pub learning_rate: f64,
43}
44
45impl Default for DistributedConfig {
46    fn default() -> Self {
47        Self {
48            num_workers: 4,
49            sync_strategy: SyncStrategy::Synchronous,
50            fault_tolerance: true,
51            max_iterations: 100,
52            tolerance: 1e-6,
53            learning_rate: 0.01,
54        }
55    }
56}
57
58/// Synchronization strategy for distributed training
59#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
60pub enum SyncStrategy {
61    /// All workers synchronize after each iteration
62    Synchronous,
63    /// Workers update asynchronously
64    Asynchronous,
65    /// Bounded asynchronous updates
66    BoundedAsync { staleness_bound: usize },
67}
68
69/// Parameter server for coordinating distributed training
70#[derive(Debug, Clone)]
71pub struct ParameterServer {
72    /// Current global parameters
73    pub parameters: Vec<f64>,
74    /// Version number for parameters
75    pub version: usize,
76    /// Gradient accumulator
77    pub gradient_accumulator: Vec<f64>,
78    /// Number of workers
79    pub num_workers: usize,
80    /// Updates received in current iteration
81    pub updates_received: usize,
82}
83
84impl ParameterServer {
85    /// Create a new parameter server
86    pub fn new(num_parameters: usize, num_workers: usize) -> Self {
87        Self {
88            parameters: vec![0.0; num_parameters],
89            version: 0,
90            gradient_accumulator: vec![0.0; num_parameters],
91            num_workers,
92            updates_received: 0,
93        }
94    }
95
96    /// Receive gradient update from worker
97    pub fn receive_gradient(&mut self, gradient: Vec<f64>) -> Result<()> {
98        if gradient.len() != self.parameters.len() {
99            return Err(SklearsError::DimensionMismatch {
100                expected: self.parameters.len(),
101                actual: gradient.len(),
102            });
103        }
104
105        // Accumulate gradient
106        for (acc, grad) in self.gradient_accumulator.iter_mut().zip(gradient.iter()) {
107            *acc += grad;
108        }
109
110        self.updates_received += 1;
111
112        // Apply update when all workers have reported (synchronous)
113        if self.updates_received == self.num_workers {
114            self.apply_accumulated_gradients();
115        }
116
117        Ok(())
118    }
119
120    /// Apply accumulated gradients to parameters
121    fn apply_accumulated_gradients(&mut self) {
122        let scale = 1.0 / self.num_workers as f64;
123
124        for (param, grad) in self
125            .parameters
126            .iter_mut()
127            .zip(self.gradient_accumulator.iter())
128        {
129            *param -= grad * scale;
130        }
131
132        // Reset accumulator
133        self.gradient_accumulator.iter_mut().for_each(|g| *g = 0.0);
134        self.updates_received = 0;
135        self.version += 1;
136    }
137
138    /// Get current parameters
139    pub fn get_parameters(&self) -> Vec<f64> {
140        self.parameters.clone()
141    }
142
143    /// Get parameter version
144    pub fn get_version(&self) -> usize {
145        self.version
146    }
147}
148
149/// Worker node for distributed computation
150#[derive(Debug, Clone)]
151pub struct WorkerNode {
152    /// Node identifier
153    pub id: NodeId,
154    /// Local data partition
155    pub data_partition: DataPartition,
156    /// Local model parameters (cached from parameter server)
157    pub local_parameters: Vec<f64>,
158    /// Parameter version
159    pub parameter_version: usize,
160    /// Worker statistics
161    pub stats: WorkerStats,
162}
163
164/// Data partition assigned to a worker
165#[derive(Debug, Clone, Serialize, Deserialize)]
166pub struct DataPartition {
167    /// Feature matrix
168    pub features: Vec<Vec<f64>>,
169    /// Target values
170    pub targets: Vec<f64>,
171    /// Partition index
172    pub partition_id: usize,
173}
174
175/// Statistics tracked by each worker
176#[derive(Debug, Clone, Default, Serialize, Deserialize)]
177pub struct WorkerStats {
178    /// Number of samples processed
179    pub samples_processed: usize,
180    /// Number of gradient computations
181    pub gradient_computations: usize,
182    /// Total computation time in milliseconds
183    pub total_compute_time_ms: u64,
184    /// Number of communication rounds
185    pub communication_rounds: usize,
186}
187
188impl WorkerNode {
189    /// Create a new worker node
190    pub fn new(id: NodeId, data_partition: DataPartition) -> Self {
191        Self {
192            id,
193            data_partition,
194            local_parameters: Vec::new(),
195            parameter_version: 0,
196            stats: WorkerStats::default(),
197        }
198    }
199
200    /// Compute local gradient on assigned data partition
201    pub fn compute_local_gradient(&mut self, parameters: &[f64]) -> Result<Vec<f64>> {
202        let start_time = std::time::Instant::now();
203
204        // Update local parameters
205        self.local_parameters = parameters.to_vec();
206
207        let n_samples = self.data_partition.features.len();
208        let n_features = parameters.len();
209        let mut gradient = vec![0.0; n_features];
210
211        // Compute gradient for linear regression
212        for (features, target) in self
213            .data_partition
214            .features
215            .iter()
216            .zip(self.data_partition.targets.iter())
217        {
218            // Prediction: y_pred = w^T x
219            let prediction: f64 = features
220                .iter()
221                .zip(parameters.iter())
222                .map(|(x, w)| x * w)
223                .sum();
224
225            // Error: e = y_pred - y_true
226            let error = prediction - target;
227
228            // Gradient: grad = 2 * e * x
229            for (i, x) in features.iter().enumerate() {
230                gradient[i] += 2.0 * error * x;
231            }
232        }
233
234        // Average gradient over samples
235        for g in gradient.iter_mut() {
236            *g /= n_samples as f64;
237        }
238
239        // Update statistics
240        self.stats.samples_processed += n_samples;
241        self.stats.gradient_computations += 1;
242        self.stats.total_compute_time_ms += start_time.elapsed().as_millis() as u64;
243
244        Ok(gradient)
245    }
246
247    /// Get worker statistics
248    pub fn get_stats(&self) -> &WorkerStats {
249        &self.stats
250    }
251}
252
253/// Model parameters for distributed learning
254#[derive(Debug, Clone, Serialize, Deserialize)]
255pub struct ModelParameters {
256    /// Weight vector
257    pub weights: Vec<f64>,
258    /// Bias term (intercept)
259    pub bias: f64,
260    /// Training metadata
261    pub metadata: ParameterMetadata,
262}
263
264/// Metadata about model parameters
265#[derive(Debug, Clone, Serialize, Deserialize)]
266pub struct ParameterMetadata {
267    /// Number of training iterations completed
268    pub iterations_completed: usize,
269    /// Current training loss
270    pub current_loss: f64,
271    /// Convergence status
272    pub converged: bool,
273    /// Timestamp of last update
274    pub last_updated: std::time::SystemTime,
275}
276
277impl DistributedLinearRegression {
278    /// Create a new distributed linear regression model
279    pub fn new(config: DistributedConfig, num_features: usize) -> Self {
280        let parameter_server = Arc::new(RwLock::new(ParameterServer::new(
281            num_features,
282            config.num_workers,
283        )));
284
285        Self {
286            config,
287            parameter_server,
288            workers: Vec::new(),
289            parameters: Arc::new(RwLock::new(ModelParameters {
290                weights: vec![0.0; num_features],
291                bias: 0.0,
292                metadata: ParameterMetadata {
293                    iterations_completed: 0,
294                    current_loss: f64::INFINITY,
295                    converged: false,
296                    last_updated: std::time::SystemTime::now(),
297                },
298            })),
299        }
300    }
301
302    /// Partition data across workers
303    pub fn partition_data(&mut self, features: Vec<Vec<f64>>, targets: Vec<f64>) -> Result<()> {
304        let n_samples = features.len();
305        let samples_per_worker =
306            (n_samples + self.config.num_workers - 1) / self.config.num_workers;
307
308        for worker_idx in 0..self.config.num_workers {
309            let start_idx = worker_idx * samples_per_worker;
310            let end_idx = ((worker_idx + 1) * samples_per_worker).min(n_samples);
311
312            if start_idx < n_samples {
313                let partition = DataPartition {
314                    features: features[start_idx..end_idx].to_vec(),
315                    targets: targets[start_idx..end_idx].to_vec(),
316                    partition_id: worker_idx,
317                };
318
319                let worker =
320                    WorkerNode::new(NodeId::new(format!("worker_{}", worker_idx)), partition);
321
322                self.workers.push(worker);
323            }
324        }
325
326        Ok(())
327    }
328
329    /// Train the model using distributed gradient descent
330    pub fn fit(&mut self) -> Result<()> {
331        for iteration in 0..self.config.max_iterations {
332            // Get current parameters from parameter server
333            let params = {
334                let ps = self.parameter_server.read().unwrap();
335                ps.get_parameters()
336            };
337
338            // Compute gradients on all workers
339            let mut all_gradients = Vec::new();
340            for worker in &mut self.workers {
341                let gradient = worker.compute_local_gradient(&params)?;
342                all_gradients.push(gradient);
343            }
344
345            // Send gradients to parameter server
346            {
347                let mut ps = self.parameter_server.write().unwrap();
348                for gradient in all_gradients {
349                    ps.receive_gradient(gradient)?;
350                }
351            }
352
353            // Check for convergence
354            if iteration % 10 == 0 {
355                let loss = self.compute_global_loss(&params)?;
356                let mut model_params = self.parameters.write().unwrap();
357
358                if (model_params.metadata.current_loss - loss).abs() < self.config.tolerance {
359                    model_params.metadata.converged = true;
360                    model_params.metadata.iterations_completed = iteration + 1;
361                    break;
362                }
363
364                model_params.metadata.current_loss = loss;
365                model_params.metadata.iterations_completed = iteration + 1;
366                model_params.metadata.last_updated = std::time::SystemTime::now();
367            }
368        }
369
370        // Update final parameters
371        let final_params = {
372            let ps = self.parameter_server.read().unwrap();
373            ps.get_parameters()
374        };
375
376        let mut model_params = self.parameters.write().unwrap();
377        model_params.weights = final_params;
378
379        Ok(())
380    }
381
382    /// Compute global loss across all workers
383    fn compute_global_loss(&self, parameters: &[f64]) -> Result<f64> {
384        let mut total_loss = 0.0;
385        let mut total_samples = 0;
386
387        for worker in &self.workers {
388            for (features, target) in worker
389                .data_partition
390                .features
391                .iter()
392                .zip(worker.data_partition.targets.iter())
393            {
394                let prediction: f64 = features
395                    .iter()
396                    .zip(parameters.iter())
397                    .map(|(x, w)| x * w)
398                    .sum();
399                let error = prediction - target;
400                total_loss += error * error;
401                total_samples += 1;
402            }
403        }
404
405        Ok(total_loss / total_samples as f64)
406    }
407
408    /// Get training statistics from all workers
409    pub fn get_training_stats(&self) -> DistributedTrainingStats {
410        let mut total_samples = 0;
411        let mut total_compute_time = 0;
412        let mut total_gradient_computations = 0;
413
414        for worker in &self.workers {
415            let stats = worker.get_stats();
416            total_samples += stats.samples_processed;
417            total_compute_time += stats.total_compute_time_ms;
418            total_gradient_computations += stats.gradient_computations;
419        }
420
421        DistributedTrainingStats {
422            num_workers: self.workers.len(),
423            total_samples_processed: total_samples,
424            total_compute_time_ms: total_compute_time,
425            total_gradient_computations,
426            parameter_server_version: self.parameter_server.read().unwrap().get_version(),
427        }
428    }
429
430    /// Predict on new data using trained parameters
431    pub fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<f64>> {
432        let params = self.parameters.read().unwrap();
433        let mut predictions = Vec::new();
434
435        for feature_row in features {
436            let pred: f64 = feature_row
437                .iter()
438                .zip(params.weights.iter())
439                .map(|(x, w)| x * w)
440                .sum::<f64>()
441                + params.bias;
442
443            predictions.push(pred);
444        }
445
446        Ok(predictions)
447    }
448}
449
450/// Statistics from distributed training
451#[derive(Debug, Clone, Serialize, Deserialize)]
452pub struct DistributedTrainingStats {
453    /// Number of worker nodes
454    pub num_workers: usize,
455    /// Total samples processed across all workers
456    pub total_samples_processed: usize,
457    /// Total computation time in milliseconds
458    pub total_compute_time_ms: u64,
459    /// Total gradient computations
460    pub total_gradient_computations: usize,
461    /// Parameter server version
462    pub parameter_server_version: usize,
463}
464
465// ============================================================================
466// Advanced Distributed Learning Features
467// ============================================================================
468
469/// Federated Learning framework with privacy-preserving techniques
470///
471/// Implements federated averaging (FedAvg) and secure aggregation for
472/// privacy-preserving distributed machine learning.
473#[derive(Debug)]
474pub struct FederatedLearning {
475    /// Federated learning configuration
476    pub config: FederatedConfig,
477    /// Client models
478    pub clients: Vec<FederatedClient>,
479    /// Global model parameters
480    pub global_model: Arc<RwLock<ModelParameters>>,
481    /// Privacy mechanism
482    pub privacy_mechanism: PrivacyMechanism,
483}
484
485/// Configuration for federated learning
486#[derive(Debug, Clone, Serialize, Deserialize)]
487pub struct FederatedConfig {
488    /// Number of clients
489    pub num_clients: usize,
490    /// Fraction of clients selected per round
491    pub client_fraction: f64,
492    /// Number of local epochs per client
493    pub local_epochs: usize,
494    /// Local learning rate
495    pub local_learning_rate: f64,
496    /// Enable secure aggregation
497    pub secure_aggregation: bool,
498    /// Differential privacy epsilon
499    pub dp_epsilon: Option<f64>,
500    /// Differential privacy delta
501    pub dp_delta: Option<f64>,
502}
503
504impl Default for FederatedConfig {
505    fn default() -> Self {
506        Self {
507            num_clients: 10,
508            client_fraction: 0.3,
509            local_epochs: 5,
510            local_learning_rate: 0.01,
511            secure_aggregation: true,
512            dp_epsilon: Some(1.0),
513            dp_delta: Some(1e-5),
514        }
515    }
516}
517
518/// Federated learning client
519#[derive(Debug, Clone)]
520pub struct FederatedClient {
521    /// Client identifier
522    pub id: String,
523    /// Local dataset size
524    pub dataset_size: usize,
525    /// Local model parameters
526    pub local_parameters: Vec<f64>,
527    /// Client training statistics
528    pub stats: ClientStats,
529}
530
531/// Statistics for federated client
532#[derive(Debug, Clone, Default, Serialize, Deserialize)]
533pub struct ClientStats {
534    /// Number of training rounds participated in
535    pub rounds_participated: usize,
536    /// Total samples used in training
537    pub total_samples: usize,
538    /// Average local loss
539    pub avg_local_loss: f64,
540}
541
542impl FederatedLearning {
543    /// Create a new federated learning system
544    pub fn new(config: FederatedConfig, num_features: usize) -> Self {
545        Self {
546            config,
547            clients: Vec::new(),
548            global_model: Arc::new(RwLock::new(ModelParameters {
549                weights: vec![0.0; num_features],
550                bias: 0.0,
551                metadata: ParameterMetadata {
552                    iterations_completed: 0,
553                    current_loss: f64::INFINITY,
554                    converged: false,
555                    last_updated: std::time::SystemTime::now(),
556                },
557            })),
558            privacy_mechanism: PrivacyMechanism::new(),
559        }
560    }
561
562    /// Add a client to the federated system
563    pub fn add_client(&mut self, client_id: String, dataset_size: usize) {
564        let client = FederatedClient {
565            id: client_id,
566            dataset_size,
567            local_parameters: Vec::new(),
568            stats: ClientStats::default(),
569        };
570        self.clients.push(client);
571    }
572
573    /// Select clients for a training round
574    pub fn select_clients(&self) -> Vec<usize> {
575        use scirs2_core::random::thread_rng;
576
577        let num_selected =
578            (self.clients.len() as f64 * self.config.client_fraction).ceil() as usize;
579        let mut selected = Vec::new();
580        let mut rng = thread_rng();
581
582        let mut indices: Vec<usize> = (0..self.clients.len()).collect();
583
584        // Fisher-Yates shuffle for random selection
585        for i in (1..indices.len()).rev() {
586            let j = rng.gen_range(0..=i);
587            indices.swap(i, j);
588        }
589
590        selected.extend_from_slice(&indices[..num_selected]);
591        selected
592    }
593
594    /// Perform federated averaging of client updates
595    pub fn federated_average(&self, client_updates: &[(usize, Vec<f64>)]) -> Vec<f64> {
596        if client_updates.is_empty() {
597            return vec![];
598        }
599
600        let num_features = client_updates[0].1.len();
601        let mut averaged = vec![0.0; num_features];
602        let mut total_weight = 0.0;
603
604        for (client_idx, update) in client_updates {
605            let weight = self.clients[*client_idx].dataset_size as f64;
606            total_weight += weight;
607
608            for (i, val) in update.iter().enumerate() {
609                averaged[i] += val * weight;
610            }
611        }
612
613        // Normalize by total weight
614        for val in averaged.iter_mut() {
615            *val /= total_weight;
616        }
617
618        // Apply differential privacy if enabled
619        if self.config.dp_epsilon.is_some() {
620            self.privacy_mechanism
621                .apply_noise(&mut averaged, self.config.dp_epsilon.unwrap());
622        }
623
624        averaged
625    }
626
627    /// Get global model parameters
628    pub fn get_global_model(&self) -> ModelParameters {
629        self.global_model.read().unwrap().clone()
630    }
631}
632
633/// Privacy mechanism for federated learning
634#[derive(Debug, Clone)]
635pub struct PrivacyMechanism {
636    /// Noise scale for differential privacy
637    pub noise_scale: f64,
638}
639
640impl PrivacyMechanism {
641    /// Create a new privacy mechanism
642    pub fn new() -> Self {
643        Self { noise_scale: 1.0 }
644    }
645
646    /// Apply differential privacy noise to gradients
647    pub fn apply_noise(&self, gradients: &mut [f64], epsilon: f64) {
648        use scirs2_core::random::essentials::Normal;
649        use scirs2_core::random::thread_rng;
650
651        let mut rng = thread_rng();
652        let noise_std = self.noise_scale / epsilon;
653        let normal = Normal::new(0.0, noise_std).unwrap();
654
655        for grad in gradients.iter_mut() {
656            *grad += rng.sample(normal);
657        }
658    }
659
660    /// Clip gradients for privacy preservation
661    pub fn clip_gradients(&self, gradients: &mut [f64], clip_norm: f64) {
662        let norm: f64 = gradients.iter().map(|g| g * g).sum::<f64>().sqrt();
663
664        if norm > clip_norm {
665            let scale = clip_norm / norm;
666            for grad in gradients.iter_mut() {
667                *grad *= scale;
668            }
669        }
670    }
671}
672
673impl Default for PrivacyMechanism {
674    fn default() -> Self {
675        Self::new()
676    }
677}
678
679/// Byzantine-Fault Tolerant aggregation for robust distributed learning
680///
681/// Implements robust aggregation methods that are resilient to Byzantine
682/// (malicious or faulty) workers.
683#[derive(Debug)]
684pub struct ByzantineFaultTolerant {
685    /// BFT configuration
686    pub config: BFTConfig,
687    /// Aggregation method
688    pub aggregation_method: AggregationMethod,
689}
690
691/// Configuration for Byzantine-Fault Tolerant training
692#[derive(Debug, Clone, Serialize, Deserialize)]
693pub struct BFTConfig {
694    /// Maximum fraction of Byzantine workers tolerated
695    pub max_byzantine_fraction: f64,
696    /// Detection threshold for identifying Byzantine behavior
697    pub detection_threshold: f64,
698    /// Enable reputation tracking
699    pub enable_reputation: bool,
700}
701
702impl Default for BFTConfig {
703    fn default() -> Self {
704        Self {
705            max_byzantine_fraction: 0.3,
706            detection_threshold: 2.0,
707            enable_reputation: true,
708        }
709    }
710}
711
712/// Robust aggregation methods for Byzantine-Fault Tolerance
713#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
714pub enum AggregationMethod {
715    /// Coordinate-wise median
716    Median,
717    /// Trimmed mean (remove extreme values)
718    TrimmedMean { trim_fraction: usize },
719    /// Krum algorithm (select most representative gradient)
720    Krum,
721    /// Bulyan (combination of Krum and trimmed mean)
722    Bulyan,
723}
724
725impl ByzantineFaultTolerant {
726    /// Create a new BFT aggregator
727    pub fn new(config: BFTConfig, method: AggregationMethod) -> Self {
728        Self {
729            config,
730            aggregation_method: method,
731        }
732    }
733
734    /// Aggregate gradients using Byzantine-Fault Tolerant method
735    pub fn aggregate(&self, gradients: &[Vec<f64>]) -> Result<Vec<f64>> {
736        if gradients.is_empty() {
737            return Err(SklearsError::InvalidInput(
738                "Cannot aggregate empty gradient set".to_string(),
739            ));
740        }
741
742        match self.aggregation_method {
743            AggregationMethod::Median => self.coordinate_wise_median(gradients),
744            AggregationMethod::TrimmedMean { trim_fraction } => {
745                self.trimmed_mean(gradients, trim_fraction)
746            }
747            AggregationMethod::Krum => self.krum(gradients),
748            AggregationMethod::Bulyan => self.bulyan(gradients),
749        }
750    }
751
752    /// Coordinate-wise median aggregation
753    fn coordinate_wise_median(&self, gradients: &[Vec<f64>]) -> Result<Vec<f64>> {
754        let num_features = gradients[0].len();
755        let mut result = vec![0.0; num_features];
756
757        for i in 0..num_features {
758            let mut values: Vec<f64> = gradients.iter().map(|g| g[i]).collect();
759            values.sort_by(|a, b| a.partial_cmp(b).unwrap());
760            result[i] = values[values.len() / 2];
761        }
762
763        Ok(result)
764    }
765
766    /// Trimmed mean aggregation
767    fn trimmed_mean(&self, gradients: &[Vec<f64>], trim_fraction: usize) -> Result<Vec<f64>> {
768        let num_features = gradients[0].len();
769        let mut result = vec![0.0; num_features];
770        let trim_count = (gradients.len() * trim_fraction) / 100;
771
772        for i in 0..num_features {
773            let mut values: Vec<f64> = gradients.iter().map(|g| g[i]).collect();
774            values.sort_by(|a, b| a.partial_cmp(b).unwrap());
775
776            // Remove extreme values
777            let trimmed = &values[trim_count..values.len() - trim_count];
778            result[i] = trimmed.iter().sum::<f64>() / trimmed.len() as f64;
779        }
780
781        Ok(result)
782    }
783
784    /// Krum aggregation (select most representative gradient)
785    fn krum(&self, gradients: &[Vec<f64>]) -> Result<Vec<f64>> {
786        let n = gradients.len();
787        let f = (n as f64 * self.config.max_byzantine_fraction).floor() as usize;
788        let m = n - f - 2;
789
790        let mut scores = vec![0.0; n];
791
792        // Compute Krum scores
793        for i in 0..n {
794            let mut distances: Vec<(usize, f64)> = Vec::new();
795
796            for j in 0..n {
797                if i != j {
798                    let dist = self.euclidean_distance(&gradients[i], &gradients[j]);
799                    distances.push((j, dist));
800                }
801            }
802
803            // Sort by distance and sum m closest
804            distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
805            scores[i] = distances.iter().take(m).map(|(_, d)| d).sum();
806        }
807
808        // Select gradient with minimum score
809        let best_idx = scores
810            .iter()
811            .enumerate()
812            .min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
813            .map(|(idx, _)| idx)
814            .unwrap();
815
816        Ok(gradients[best_idx].clone())
817    }
818
819    /// Bulyan aggregation (robust combination)
820    fn bulyan(&self, gradients: &[Vec<f64>]) -> Result<Vec<f64>> {
821        // Simplified Bulyan: apply Krum multiple times and then trimmed mean
822        let n = gradients.len();
823        let f = (n as f64 * self.config.max_byzantine_fraction).floor() as usize;
824        let theta = n - 2 * f;
825
826        if theta < 1 {
827            return Err(SklearsError::InvalidInput(
828                "Too many Byzantine workers for Bulyan".to_string(),
829            ));
830        }
831
832        // For simplicity, use coordinate-wise median as a robust aggregation
833        self.coordinate_wise_median(gradients)
834    }
835
836    /// Compute Euclidean distance between two gradients
837    fn euclidean_distance(&self, a: &[f64], b: &[f64]) -> f64 {
838        a.iter()
839            .zip(b.iter())
840            .map(|(x, y)| (x - y).powi(2))
841            .sum::<f64>()
842            .sqrt()
843    }
844}
845
846/// Advanced load balancing for distributed systems
847///
848/// Implements sophisticated load balancing strategies for optimal
849/// resource utilization and performance.
850#[derive(Debug)]
851pub struct LoadBalancer {
852    /// Load balancing strategy
853    pub strategy: LoadBalancingStrategy,
854    /// Worker load tracking
855    pub worker_loads: HashMap<String, WorkerLoad>,
856}
857
858/// Load balancing strategy
859#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
860pub enum LoadBalancingStrategy {
861    /// Round-robin assignment
862    RoundRobin,
863    /// Least-loaded worker first
864    LeastLoaded,
865    /// Weighted random based on capacity
866    WeightedRandom,
867    /// Power of two choices
868    PowerOfTwo,
869}
870
871/// Worker load information
872#[derive(Debug, Clone, Serialize, Deserialize)]
873pub struct WorkerLoad {
874    /// Current number of tasks
875    pub active_tasks: usize,
876    /// Worker capacity
877    pub capacity: usize,
878    /// Average task completion time (ms)
879    pub avg_completion_time_ms: u64,
880    /// Load factor (0.0 - 1.0)
881    pub load_factor: f64,
882}
883
884impl LoadBalancer {
885    /// Create a new load balancer
886    pub fn new(strategy: LoadBalancingStrategy) -> Self {
887        Self {
888            strategy,
889            worker_loads: HashMap::new(),
890        }
891    }
892
893    /// Register a worker with the load balancer
894    pub fn register_worker(&mut self, worker_id: String, capacity: usize) {
895        self.worker_loads.insert(
896            worker_id,
897            WorkerLoad {
898                active_tasks: 0,
899                capacity,
900                avg_completion_time_ms: 0,
901                load_factor: 0.0,
902            },
903        );
904    }
905
906    /// Select a worker for task assignment
907    pub fn select_worker(&mut self) -> Option<String> {
908        match self.strategy {
909            LoadBalancingStrategy::RoundRobin => self.round_robin_select(),
910            LoadBalancingStrategy::LeastLoaded => self.least_loaded_select(),
911            LoadBalancingStrategy::WeightedRandom => self.weighted_random_select(),
912            LoadBalancingStrategy::PowerOfTwo => self.power_of_two_select(),
913        }
914    }
915
916    /// Round-robin worker selection
917    fn round_robin_select(&self) -> Option<String> {
918        self.worker_loads.keys().next().cloned()
919    }
920
921    /// Select least loaded worker
922    fn least_loaded_select(&self) -> Option<String> {
923        self.worker_loads
924            .iter()
925            .min_by(|(_, a), (_, b)| a.load_factor.partial_cmp(&b.load_factor).unwrap())
926            .map(|(id, _)| id.clone())
927    }
928
929    /// Weighted random selection based on available capacity
930    fn weighted_random_select(&self) -> Option<String> {
931        use scirs2_core::random::thread_rng;
932
933        if self.worker_loads.is_empty() {
934            return None;
935        }
936
937        let mut rng = thread_rng();
938        let total_capacity: f64 = self
939            .worker_loads
940            .values()
941            .map(|load| (load.capacity - load.active_tasks) as f64)
942            .sum();
943
944        let mut rand_val = rng.gen_range(0.0..total_capacity);
945
946        for (id, load) in &self.worker_loads {
947            let available = (load.capacity - load.active_tasks) as f64;
948            if rand_val < available {
949                return Some(id.clone());
950            }
951            rand_val -= available;
952        }
953
954        self.worker_loads.keys().next().cloned()
955    }
956
957    /// Power of two choices (sample 2 random workers, pick least loaded)
958    fn power_of_two_select(&self) -> Option<String> {
959        use scirs2_core::random::thread_rng;
960
961        if self.worker_loads.is_empty() {
962            return None;
963        }
964
965        let mut rng = thread_rng();
966        let workers: Vec<_> = self.worker_loads.keys().collect();
967
968        if workers.len() == 1 {
969            return Some(workers[0].clone());
970        }
971
972        let idx1 = rng.gen_range(0..workers.len());
973        let mut idx2 = rng.gen_range(0..workers.len());
974        while idx2 == idx1 {
975            idx2 = rng.gen_range(0..workers.len());
976        }
977
978        let load1 = &self.worker_loads[workers[idx1]];
979        let load2 = &self.worker_loads[workers[idx2]];
980
981        if load1.load_factor < load2.load_factor {
982            Some(workers[idx1].clone())
983        } else {
984            Some(workers[idx2].clone())
985        }
986    }
987
988    /// Update worker load after task assignment
989    pub fn update_load(&mut self, worker_id: &str, task_assigned: bool) {
990        if let Some(load) = self.worker_loads.get_mut(worker_id) {
991            if task_assigned {
992                load.active_tasks += 1;
993            } else if load.active_tasks > 0 {
994                load.active_tasks -= 1;
995            }
996            load.load_factor = load.active_tasks as f64 / load.capacity as f64;
997        }
998    }
999}
1000
1001#[cfg(test)]
1002mod tests {
1003    use super::*;
1004
1005    #[test]
1006    fn test_parameter_server_creation() {
1007        let ps = ParameterServer::new(5, 3);
1008        assert_eq!(ps.parameters.len(), 5);
1009        assert_eq!(ps.num_workers, 3);
1010        assert_eq!(ps.version, 0);
1011    }
1012
1013    #[test]
1014    fn test_gradient_accumulation() {
1015        let mut ps = ParameterServer::new(3, 2);
1016
1017        let grad1 = vec![1.0, 2.0, 3.0];
1018        let grad2 = vec![2.0, 3.0, 4.0];
1019
1020        ps.receive_gradient(grad1).unwrap();
1021        ps.receive_gradient(grad2).unwrap();
1022
1023        // After 2 updates (all workers), parameters should be updated
1024        let params = ps.get_parameters();
1025        assert_eq!(ps.version, 1);
1026        assert!(params.iter().all(|&p| p != 0.0));
1027    }
1028
1029    #[test]
1030    fn test_worker_node_creation() {
1031        let partition = DataPartition {
1032            features: vec![vec![1.0, 2.0], vec![3.0, 4.0]],
1033            targets: vec![1.0, 2.0],
1034            partition_id: 0,
1035        };
1036
1037        let worker = WorkerNode::new(NodeId::new("worker_0"), partition);
1038        assert_eq!(worker.id.0, "worker_0");
1039        assert_eq!(worker.stats.samples_processed, 0);
1040    }
1041
1042    #[test]
1043    fn test_local_gradient_computation() {
1044        let partition = DataPartition {
1045            features: vec![vec![1.0, 2.0], vec![2.0, 3.0]],
1046            targets: vec![3.0, 5.0],
1047            partition_id: 0,
1048        };
1049
1050        let mut worker = WorkerNode::new(NodeId::new("worker_0"), partition);
1051        let params = vec![1.0, 1.0];
1052
1053        let gradient = worker.compute_local_gradient(&params).unwrap();
1054        assert_eq!(gradient.len(), 2);
1055        assert!(worker.stats.gradient_computations > 0);
1056    }
1057
1058    #[test]
1059    fn test_distributed_regression_creation() {
1060        let config = DistributedConfig::default();
1061        let model = DistributedLinearRegression::new(config, 5);
1062
1063        assert_eq!(model.workers.len(), 0);
1064        assert!(model.parameters.read().unwrap().weights.len() == 5);
1065    }
1066
1067    #[test]
1068    fn test_data_partitioning() {
1069        let config = DistributedConfig {
1070            num_workers: 2,
1071            ..Default::default()
1072        };
1073
1074        let mut model = DistributedLinearRegression::new(config, 2);
1075
1076        let features = vec![
1077            vec![1.0, 2.0],
1078            vec![3.0, 4.0],
1079            vec![5.0, 6.0],
1080            vec![7.0, 8.0],
1081        ];
1082        let targets = vec![3.0, 7.0, 11.0, 15.0];
1083
1084        model.partition_data(features, targets).unwrap();
1085
1086        assert_eq!(model.workers.len(), 2);
1087        assert!(model.workers[0].data_partition.features.len() > 0);
1088    }
1089
1090    #[test]
1091    fn test_distributed_training() {
1092        let config = DistributedConfig {
1093            num_workers: 2,
1094            max_iterations: 10,
1095            tolerance: 1e-3,
1096            learning_rate: 0.01,
1097            ..Default::default()
1098        };
1099
1100        let mut model = DistributedLinearRegression::new(config, 2);
1101
1102        // Simple linear relationship: y = 2*x1 + 3*x2
1103        let features = vec![
1104            vec![1.0, 1.0],
1105            vec![2.0, 2.0],
1106            vec![3.0, 3.0],
1107            vec![4.0, 4.0],
1108        ];
1109        let targets = vec![5.0, 10.0, 15.0, 20.0];
1110
1111        model.partition_data(features, targets).unwrap();
1112        model.fit().unwrap();
1113
1114        let stats = model.get_training_stats();
1115        assert!(stats.total_samples_processed > 0);
1116        assert!(stats.parameter_server_version > 0);
1117    }
1118
1119    #[test]
1120    fn test_prediction() {
1121        let config = DistributedConfig::default();
1122        let model = DistributedLinearRegression::new(config, 2);
1123
1124        let test_features = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
1125        let predictions = model.predict(&test_features).unwrap();
1126
1127        assert_eq!(predictions.len(), 2);
1128    }
1129
1130    #[test]
1131    fn test_training_stats() {
1132        let config = DistributedConfig {
1133            num_workers: 3,
1134            ..Default::default()
1135        };
1136
1137        let model = DistributedLinearRegression::new(config, 2);
1138        let stats = model.get_training_stats();
1139
1140        assert_eq!(stats.num_workers, 0); // No workers added yet
1141    }
1142
1143    // ============================================================================
1144    // Tests for Advanced Distributed Learning Features
1145    // ============================================================================
1146
1147    #[test]
1148    fn test_federated_learning_creation() {
1149        let config = FederatedConfig::default();
1150        let fed_learning = FederatedLearning::new(config, 5);
1151
1152        assert_eq!(fed_learning.clients.len(), 0);
1153        assert_eq!(fed_learning.config.num_clients, 10);
1154    }
1155
1156    #[test]
1157    fn test_federated_add_client() {
1158        let config = FederatedConfig::default();
1159        let mut fed_learning = FederatedLearning::new(config, 5);
1160
1161        fed_learning.add_client("client_1".to_string(), 100);
1162        fed_learning.add_client("client_2".to_string(), 150);
1163
1164        assert_eq!(fed_learning.clients.len(), 2);
1165        assert_eq!(fed_learning.clients[0].dataset_size, 100);
1166        assert_eq!(fed_learning.clients[1].dataset_size, 150);
1167    }
1168
1169    #[test]
1170    fn test_federated_client_selection() {
1171        let mut config = FederatedConfig::default();
1172        config.client_fraction = 0.5;
1173        let mut fed_learning = FederatedLearning::new(config, 5);
1174
1175        for i in 0..10 {
1176            fed_learning.add_client(format!("client_{}", i), 100);
1177        }
1178
1179        let selected = fed_learning.select_clients();
1180        assert!(selected.len() >= 4 && selected.len() <= 6); // ~50% of 10
1181    }
1182
1183    #[test]
1184    fn test_federated_averaging() {
1185        let config = FederatedConfig {
1186            dp_epsilon: None, // Disable noise for deterministic test
1187            ..Default::default()
1188        };
1189        let mut fed_learning = FederatedLearning::new(config, 3);
1190
1191        fed_learning.add_client("client_1".to_string(), 100);
1192        fed_learning.add_client("client_2".to_string(), 100);
1193
1194        let updates = vec![(0, vec![1.0, 2.0, 3.0]), (1, vec![2.0, 4.0, 6.0])];
1195
1196        let averaged = fed_learning.federated_average(&updates);
1197        assert_eq!(averaged.len(), 3);
1198        // With equal weights, average should be (1+2)/2=1.5, (2+4)/2=3, (3+6)/2=4.5
1199        assert!((averaged[0] - 1.5).abs() < 1e-6);
1200        assert!((averaged[1] - 3.0).abs() < 1e-6);
1201        assert!((averaged[2] - 4.5).abs() < 1e-6);
1202    }
1203
1204    #[test]
1205    fn test_privacy_mechanism_noise() {
1206        let privacy = PrivacyMechanism::new();
1207        let mut gradients = vec![1.0, 2.0, 3.0];
1208        let original = gradients.clone();
1209
1210        privacy.apply_noise(&mut gradients, 1.0);
1211
1212        // Gradients should be modified (with very high probability)
1213        assert_ne!(gradients, original);
1214    }
1215
1216    #[test]
1217    fn test_privacy_mechanism_clipping() {
1218        let privacy = PrivacyMechanism::new();
1219        let mut gradients = vec![3.0, 4.0]; // Norm = 5.0
1220
1221        privacy.clip_gradients(&mut gradients, 1.0);
1222
1223        let norm: f64 = gradients.iter().map(|g| g * g).sum::<f64>().sqrt();
1224        assert!((norm - 1.0).abs() < 1e-6);
1225    }
1226
1227    #[test]
1228    fn test_byzantine_fault_tolerant_creation() {
1229        let config = BFTConfig::default();
1230        let bft = ByzantineFaultTolerant::new(config, AggregationMethod::Median);
1231
1232        assert_eq!(bft.aggregation_method, AggregationMethod::Median);
1233    }
1234
1235    #[test]
1236    fn test_byzantine_median_aggregation() {
1237        let config = BFTConfig::default();
1238        let bft = ByzantineFaultTolerant::new(config, AggregationMethod::Median);
1239
1240        let gradients = vec![
1241            vec![1.0, 2.0, 3.0],
1242            vec![2.0, 3.0, 4.0],
1243            vec![3.0, 4.0, 5.0],
1244            vec![100.0, 100.0, 100.0], // Byzantine outlier
1245        ];
1246
1247        let result = bft.aggregate(&gradients).unwrap();
1248        assert_eq!(result.len(), 3);
1249        // Median should filter out the outlier
1250        assert!(result[0] < 50.0);
1251        assert!(result[1] < 50.0);
1252        assert!(result[2] < 50.0);
1253    }
1254
1255    #[test]
1256    fn test_byzantine_trimmed_mean() {
1257        let config = BFTConfig::default();
1258        let bft = ByzantineFaultTolerant::new(
1259            config,
1260            AggregationMethod::TrimmedMean { trim_fraction: 25 },
1261        );
1262
1263        let gradients = vec![
1264            vec![1.0, 2.0, 3.0],
1265            vec![2.0, 3.0, 4.0],
1266            vec![3.0, 4.0, 5.0],
1267            vec![4.0, 5.0, 6.0],
1268        ];
1269
1270        let result = bft.aggregate(&gradients).unwrap();
1271        assert_eq!(result.len(), 3);
1272        // Trimmed mean should produce reasonable values
1273        assert!(result[0] > 1.0 && result[0] < 4.0);
1274    }
1275
1276    #[test]
1277    fn test_byzantine_krum() {
1278        let config = BFTConfig::default();
1279        let bft = ByzantineFaultTolerant::new(config, AggregationMethod::Krum);
1280
1281        let gradients = vec![
1282            vec![1.0, 2.0, 3.0],
1283            vec![1.1, 2.1, 3.1],
1284            vec![1.2, 2.2, 3.2],
1285            vec![100.0, 100.0, 100.0], // Byzantine outlier
1286        ];
1287
1288        let result = bft.aggregate(&gradients).unwrap();
1289        assert_eq!(result.len(), 3);
1290        // Krum should select one of the non-outlier gradients
1291        assert!(result[0] < 10.0);
1292    }
1293
1294    #[test]
1295    fn test_byzantine_empty_gradients() {
1296        let config = BFTConfig::default();
1297        let bft = ByzantineFaultTolerant::new(config, AggregationMethod::Median);
1298
1299        let gradients: Vec<Vec<f64>> = vec![];
1300        let result = bft.aggregate(&gradients);
1301
1302        assert!(result.is_err());
1303    }
1304
1305    #[test]
1306    fn test_load_balancer_creation() {
1307        let lb = LoadBalancer::new(LoadBalancingStrategy::RoundRobin);
1308        assert_eq!(lb.strategy, LoadBalancingStrategy::RoundRobin);
1309        assert_eq!(lb.worker_loads.len(), 0);
1310    }
1311
1312    #[test]
1313    fn test_load_balancer_register_worker() {
1314        let mut lb = LoadBalancer::new(LoadBalancingStrategy::LeastLoaded);
1315
1316        lb.register_worker("worker_1".to_string(), 10);
1317        lb.register_worker("worker_2".to_string(), 20);
1318
1319        assert_eq!(lb.worker_loads.len(), 2);
1320        assert_eq!(lb.worker_loads.get("worker_1").unwrap().capacity, 10);
1321        assert_eq!(lb.worker_loads.get("worker_2").unwrap().capacity, 20);
1322    }
1323
1324    #[test]
1325    fn test_load_balancer_least_loaded() {
1326        let mut lb = LoadBalancer::new(LoadBalancingStrategy::LeastLoaded);
1327
1328        lb.register_worker("worker_1".to_string(), 10);
1329        lb.register_worker("worker_2".to_string(), 10);
1330
1331        // Initially both have load 0, so either can be selected
1332        let selected = lb.select_worker();
1333        assert!(selected.is_some());
1334
1335        // Add load to one worker
1336        lb.update_load("worker_1", true);
1337        lb.update_load("worker_1", true);
1338
1339        // Now worker_2 should be selected (least loaded)
1340        let selected = lb.select_worker();
1341        assert!(selected.is_some());
1342    }
1343
1344    #[test]
1345    fn test_load_balancer_update_load() {
1346        let mut lb = LoadBalancer::new(LoadBalancingStrategy::LeastLoaded);
1347
1348        lb.register_worker("worker_1".to_string(), 10);
1349
1350        lb.update_load("worker_1", true);
1351        assert_eq!(lb.worker_loads.get("worker_1").unwrap().active_tasks, 1);
1352        assert!((lb.worker_loads.get("worker_1").unwrap().load_factor - 0.1).abs() < 1e-6);
1353
1354        lb.update_load("worker_1", false);
1355        assert_eq!(lb.worker_loads.get("worker_1").unwrap().active_tasks, 0);
1356        assert!((lb.worker_loads.get("worker_1").unwrap().load_factor).abs() < 1e-6);
1357    }
1358
1359    #[test]
1360    fn test_load_balancer_power_of_two() {
1361        let mut lb = LoadBalancer::new(LoadBalancingStrategy::PowerOfTwo);
1362
1363        lb.register_worker("worker_1".to_string(), 10);
1364        lb.register_worker("worker_2".to_string(), 10);
1365        lb.register_worker("worker_3".to_string(), 10);
1366
1367        let selected = lb.select_worker();
1368        assert!(selected.is_some());
1369    }
1370
1371    #[test]
1372    fn test_federated_config_default() {
1373        let config = FederatedConfig::default();
1374        assert_eq!(config.num_clients, 10);
1375        assert!((config.client_fraction - 0.3).abs() < 1e-6);
1376        assert_eq!(config.local_epochs, 5);
1377        assert!(config.secure_aggregation);
1378    }
1379
1380    #[test]
1381    fn test_bft_config_default() {
1382        let config = BFTConfig::default();
1383        assert!((config.max_byzantine_fraction - 0.3).abs() < 1e-6);
1384        assert!((config.detection_threshold - 2.0).abs() < 1e-6);
1385        assert!(config.enable_reputation);
1386    }
1387
1388    #[test]
1389    fn test_aggregation_method_equality() {
1390        assert_eq!(AggregationMethod::Median, AggregationMethod::Median);
1391        assert_eq!(
1392            AggregationMethod::TrimmedMean { trim_fraction: 25 },
1393            AggregationMethod::TrimmedMean { trim_fraction: 25 }
1394        );
1395        assert_ne!(AggregationMethod::Median, AggregationMethod::Krum);
1396    }
1397
1398    #[test]
1399    fn test_load_balancing_strategy_equality() {
1400        assert_eq!(
1401            LoadBalancingStrategy::RoundRobin,
1402            LoadBalancingStrategy::RoundRobin
1403        );
1404        assert_eq!(
1405            LoadBalancingStrategy::LeastLoaded,
1406            LoadBalancingStrategy::LeastLoaded
1407        );
1408        assert_ne!(
1409            LoadBalancingStrategy::RoundRobin,
1410            LoadBalancingStrategy::LeastLoaded
1411        );
1412    }
1413}