Skip to main content

sklears_core/
distributed_algorithms.rs

1//! Distributed Machine Learning Algorithms
2//!
3//! This module provides concrete implementations of distributed ML algorithms
4//! that scale across multiple nodes with fault tolerance and load balancing.
5
6use crate::distributed::NodeId;
7use crate::error::{Result, SklearsError};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::sync::{Arc, RwLock};
11
12/// Distributed linear regression using parameter server architecture
13///
14/// Implements distributed training of linear regression models across
15/// multiple worker nodes with centralized parameter synchronization.
16#[derive(Debug)]
17pub struct DistributedLinearRegression {
18    /// Configuration for distributed training
19    pub config: DistributedConfig,
20    /// Parameter server for coordinating updates
21    pub parameter_server: Arc<RwLock<ParameterServer>>,
22    /// Worker nodes performing computation
23    pub workers: Vec<WorkerNode>,
24    /// Current model parameters
25    pub parameters: Arc<RwLock<ModelParameters>>,
26}
27
28/// Configuration for distributed training
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct DistributedConfig {
31    /// Number of worker nodes
32    pub num_workers: usize,
33    /// Synchronization strategy
34    pub sync_strategy: SyncStrategy,
35    /// Enable fault tolerance
36    pub fault_tolerance: bool,
37    /// Maximum iterations
38    pub max_iterations: usize,
39    /// Convergence tolerance
40    pub tolerance: f64,
41    /// Learning rate
42    pub learning_rate: f64,
43}
44
45impl Default for DistributedConfig {
46    fn default() -> Self {
47        Self {
48            num_workers: 4,
49            sync_strategy: SyncStrategy::Synchronous,
50            fault_tolerance: true,
51            max_iterations: 100,
52            tolerance: 1e-6,
53            learning_rate: 0.01,
54        }
55    }
56}
57
58/// Synchronization strategy for distributed training
59#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
60pub enum SyncStrategy {
61    /// All workers synchronize after each iteration
62    Synchronous,
63    /// Workers update asynchronously
64    Asynchronous,
65    /// Bounded asynchronous updates
66    BoundedAsync { staleness_bound: usize },
67}
68
69/// Parameter server for coordinating distributed training
70#[derive(Debug, Clone)]
71pub struct ParameterServer {
72    /// Current global parameters
73    pub parameters: Vec<f64>,
74    /// Version number for parameters
75    pub version: usize,
76    /// Gradient accumulator
77    pub gradient_accumulator: Vec<f64>,
78    /// Number of workers
79    pub num_workers: usize,
80    /// Updates received in current iteration
81    pub updates_received: usize,
82}
83
84impl ParameterServer {
85    /// Create a new parameter server
86    pub fn new(num_parameters: usize, num_workers: usize) -> Self {
87        Self {
88            parameters: vec![0.0; num_parameters],
89            version: 0,
90            gradient_accumulator: vec![0.0; num_parameters],
91            num_workers,
92            updates_received: 0,
93        }
94    }
95
96    /// Receive gradient update from worker
97    pub fn receive_gradient(&mut self, gradient: Vec<f64>) -> Result<()> {
98        if gradient.len() != self.parameters.len() {
99            return Err(SklearsError::DimensionMismatch {
100                expected: self.parameters.len(),
101                actual: gradient.len(),
102            });
103        }
104
105        // Accumulate gradient
106        for (acc, grad) in self.gradient_accumulator.iter_mut().zip(gradient.iter()) {
107            *acc += grad;
108        }
109
110        self.updates_received += 1;
111
112        // Apply update when all workers have reported (synchronous)
113        if self.updates_received == self.num_workers {
114            self.apply_accumulated_gradients();
115        }
116
117        Ok(())
118    }
119
120    /// Apply accumulated gradients to parameters
121    fn apply_accumulated_gradients(&mut self) {
122        let scale = 1.0 / self.num_workers as f64;
123
124        for (param, grad) in self
125            .parameters
126            .iter_mut()
127            .zip(self.gradient_accumulator.iter())
128        {
129            *param -= grad * scale;
130        }
131
132        // Reset accumulator
133        self.gradient_accumulator.iter_mut().for_each(|g| *g = 0.0);
134        self.updates_received = 0;
135        self.version += 1;
136    }
137
138    /// Get current parameters
139    pub fn get_parameters(&self) -> Vec<f64> {
140        self.parameters.clone()
141    }
142
143    /// Get parameter version
144    pub fn get_version(&self) -> usize {
145        self.version
146    }
147}
148
149/// Worker node for distributed computation
150#[derive(Debug, Clone)]
151pub struct WorkerNode {
152    /// Node identifier
153    pub id: NodeId,
154    /// Local data partition
155    pub data_partition: DataPartition,
156    /// Local model parameters (cached from parameter server)
157    pub local_parameters: Vec<f64>,
158    /// Parameter version
159    pub parameter_version: usize,
160    /// Worker statistics
161    pub stats: WorkerStats,
162}
163
164/// Data partition assigned to a worker
165#[derive(Debug, Clone, Serialize, Deserialize)]
166pub struct DataPartition {
167    /// Feature matrix
168    pub features: Vec<Vec<f64>>,
169    /// Target values
170    pub targets: Vec<f64>,
171    /// Partition index
172    pub partition_id: usize,
173}
174
175/// Statistics tracked by each worker
176#[derive(Debug, Clone, Default, Serialize, Deserialize)]
177pub struct WorkerStats {
178    /// Number of samples processed
179    pub samples_processed: usize,
180    /// Number of gradient computations
181    pub gradient_computations: usize,
182    /// Total computation time in milliseconds
183    pub total_compute_time_ms: u64,
184    /// Number of communication rounds
185    pub communication_rounds: usize,
186}
187
188impl WorkerNode {
189    /// Create a new worker node
190    pub fn new(id: NodeId, data_partition: DataPartition) -> Self {
191        Self {
192            id,
193            data_partition,
194            local_parameters: Vec::new(),
195            parameter_version: 0,
196            stats: WorkerStats::default(),
197        }
198    }
199
200    /// Compute local gradient on assigned data partition
201    pub fn compute_local_gradient(&mut self, parameters: &[f64]) -> Result<Vec<f64>> {
202        let start_time = std::time::Instant::now();
203
204        // Update local parameters
205        self.local_parameters = parameters.to_vec();
206
207        let n_samples = self.data_partition.features.len();
208        let n_features = parameters.len();
209        let mut gradient = vec![0.0; n_features];
210
211        // Compute gradient for linear regression
212        for (features, target) in self
213            .data_partition
214            .features
215            .iter()
216            .zip(self.data_partition.targets.iter())
217        {
218            // Prediction: y_pred = w^T x
219            let prediction: f64 = features
220                .iter()
221                .zip(parameters.iter())
222                .map(|(x, w)| x * w)
223                .sum();
224
225            // Error: e = y_pred - y_true
226            let error = prediction - target;
227
228            // Gradient: grad = 2 * e * x
229            for (i, x) in features.iter().enumerate() {
230                gradient[i] += 2.0 * error * x;
231            }
232        }
233
234        // Average gradient over samples
235        for g in gradient.iter_mut() {
236            *g /= n_samples as f64;
237        }
238
239        // Update statistics
240        self.stats.samples_processed += n_samples;
241        self.stats.gradient_computations += 1;
242        self.stats.total_compute_time_ms += start_time.elapsed().as_millis() as u64;
243
244        Ok(gradient)
245    }
246
247    /// Get worker statistics
248    pub fn get_stats(&self) -> &WorkerStats {
249        &self.stats
250    }
251}
252
253/// Model parameters for distributed learning
254#[derive(Debug, Clone, Serialize, Deserialize)]
255pub struct ModelParameters {
256    /// Weight vector
257    pub weights: Vec<f64>,
258    /// Bias term (intercept)
259    pub bias: f64,
260    /// Training metadata
261    pub metadata: ParameterMetadata,
262}
263
264/// Metadata about model parameters
265#[derive(Debug, Clone, Serialize, Deserialize)]
266pub struct ParameterMetadata {
267    /// Number of training iterations completed
268    pub iterations_completed: usize,
269    /// Current training loss
270    pub current_loss: f64,
271    /// Convergence status
272    pub converged: bool,
273    /// Timestamp of last update
274    pub last_updated: std::time::SystemTime,
275}
276
277impl DistributedLinearRegression {
278    /// Create a new distributed linear regression model
279    pub fn new(config: DistributedConfig, num_features: usize) -> Self {
280        let parameter_server = Arc::new(RwLock::new(ParameterServer::new(
281            num_features,
282            config.num_workers,
283        )));
284
285        Self {
286            config,
287            parameter_server,
288            workers: Vec::new(),
289            parameters: Arc::new(RwLock::new(ModelParameters {
290                weights: vec![0.0; num_features],
291                bias: 0.0,
292                metadata: ParameterMetadata {
293                    iterations_completed: 0,
294                    current_loss: f64::INFINITY,
295                    converged: false,
296                    last_updated: std::time::SystemTime::now(),
297                },
298            })),
299        }
300    }
301
302    /// Partition data across workers
303    pub fn partition_data(&mut self, features: Vec<Vec<f64>>, targets: Vec<f64>) -> Result<()> {
304        let n_samples = features.len();
305        let samples_per_worker = n_samples.div_ceil(self.config.num_workers);
306
307        for worker_idx in 0..self.config.num_workers {
308            let start_idx = worker_idx * samples_per_worker;
309            let end_idx = ((worker_idx + 1) * samples_per_worker).min(n_samples);
310
311            if start_idx < n_samples {
312                let partition = DataPartition {
313                    features: features[start_idx..end_idx].to_vec(),
314                    targets: targets[start_idx..end_idx].to_vec(),
315                    partition_id: worker_idx,
316                };
317
318                let worker =
319                    WorkerNode::new(NodeId::new(format!("worker_{}", worker_idx)), partition);
320
321                self.workers.push(worker);
322            }
323        }
324
325        Ok(())
326    }
327
328    /// Train the model using distributed gradient descent
329    pub fn fit(&mut self) -> Result<()> {
330        for iteration in 0..self.config.max_iterations {
331            // Get current parameters from parameter server
332            let params = {
333                let ps = self
334                    .parameter_server
335                    .read()
336                    .unwrap_or_else(|e| e.into_inner());
337                ps.get_parameters()
338            };
339
340            // Compute gradients on all workers
341            let mut all_gradients = Vec::new();
342            for worker in &mut self.workers {
343                let gradient = worker.compute_local_gradient(&params)?;
344                all_gradients.push(gradient);
345            }
346
347            // Send gradients to parameter server
348            {
349                let mut ps = self
350                    .parameter_server
351                    .write()
352                    .unwrap_or_else(|e| e.into_inner());
353                for gradient in all_gradients {
354                    ps.receive_gradient(gradient)?;
355                }
356            }
357
358            // Check for convergence
359            if iteration % 10 == 0 {
360                let loss = self.compute_global_loss(&params)?;
361                let mut model_params = self.parameters.write().unwrap_or_else(|e| e.into_inner());
362
363                if (model_params.metadata.current_loss - loss).abs() < self.config.tolerance {
364                    model_params.metadata.converged = true;
365                    model_params.metadata.iterations_completed = iteration + 1;
366                    break;
367                }
368
369                model_params.metadata.current_loss = loss;
370                model_params.metadata.iterations_completed = iteration + 1;
371                model_params.metadata.last_updated = std::time::SystemTime::now();
372            }
373        }
374
375        // Update final parameters
376        let final_params = {
377            let ps = self
378                .parameter_server
379                .read()
380                .unwrap_or_else(|e| e.into_inner());
381            ps.get_parameters()
382        };
383
384        let mut model_params = self.parameters.write().unwrap_or_else(|e| e.into_inner());
385        model_params.weights = final_params;
386
387        Ok(())
388    }
389
390    /// Compute global loss across all workers
391    fn compute_global_loss(&self, parameters: &[f64]) -> Result<f64> {
392        let mut total_loss = 0.0;
393        let mut total_samples = 0;
394
395        for worker in &self.workers {
396            for (features, target) in worker
397                .data_partition
398                .features
399                .iter()
400                .zip(worker.data_partition.targets.iter())
401            {
402                let prediction: f64 = features
403                    .iter()
404                    .zip(parameters.iter())
405                    .map(|(x, w)| x * w)
406                    .sum();
407                let error = prediction - target;
408                total_loss += error * error;
409                total_samples += 1;
410            }
411        }
412
413        Ok(total_loss / total_samples as f64)
414    }
415
416    /// Get training statistics from all workers
417    pub fn get_training_stats(&self) -> DistributedTrainingStats {
418        let mut total_samples = 0;
419        let mut total_compute_time = 0;
420        let mut total_gradient_computations = 0;
421
422        for worker in &self.workers {
423            let stats = worker.get_stats();
424            total_samples += stats.samples_processed;
425            total_compute_time += stats.total_compute_time_ms;
426            total_gradient_computations += stats.gradient_computations;
427        }
428
429        DistributedTrainingStats {
430            num_workers: self.workers.len(),
431            total_samples_processed: total_samples,
432            total_compute_time_ms: total_compute_time,
433            total_gradient_computations,
434            parameter_server_version: self
435                .parameter_server
436                .read()
437                .unwrap_or_else(|e| e.into_inner())
438                .get_version(),
439        }
440    }
441
442    /// Predict on new data using trained parameters
443    pub fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<f64>> {
444        let params = self.parameters.read().unwrap_or_else(|e| e.into_inner());
445        let mut predictions = Vec::new();
446
447        for feature_row in features {
448            let pred: f64 = feature_row
449                .iter()
450                .zip(params.weights.iter())
451                .map(|(x, w)| x * w)
452                .sum::<f64>()
453                + params.bias;
454
455            predictions.push(pred);
456        }
457
458        Ok(predictions)
459    }
460}
461
462/// Statistics from distributed training
463#[derive(Debug, Clone, Serialize, Deserialize)]
464pub struct DistributedTrainingStats {
465    /// Number of worker nodes
466    pub num_workers: usize,
467    /// Total samples processed across all workers
468    pub total_samples_processed: usize,
469    /// Total computation time in milliseconds
470    pub total_compute_time_ms: u64,
471    /// Total gradient computations
472    pub total_gradient_computations: usize,
473    /// Parameter server version
474    pub parameter_server_version: usize,
475}
476
477// ============================================================================
478// Advanced Distributed Learning Features
479// ============================================================================
480
481/// Federated Learning framework with privacy-preserving techniques
482///
483/// Implements federated averaging (FedAvg) and secure aggregation for
484/// privacy-preserving distributed machine learning.
485#[derive(Debug)]
486pub struct FederatedLearning {
487    /// Federated learning configuration
488    pub config: FederatedConfig,
489    /// Client models
490    pub clients: Vec<FederatedClient>,
491    /// Global model parameters
492    pub global_model: Arc<RwLock<ModelParameters>>,
493    /// Privacy mechanism
494    pub privacy_mechanism: PrivacyMechanism,
495}
496
497/// Configuration for federated learning
498#[derive(Debug, Clone, Serialize, Deserialize)]
499pub struct FederatedConfig {
500    /// Number of clients
501    pub num_clients: usize,
502    /// Fraction of clients selected per round
503    pub client_fraction: f64,
504    /// Number of local epochs per client
505    pub local_epochs: usize,
506    /// Local learning rate
507    pub local_learning_rate: f64,
508    /// Enable secure aggregation
509    pub secure_aggregation: bool,
510    /// Differential privacy epsilon
511    pub dp_epsilon: Option<f64>,
512    /// Differential privacy delta
513    pub dp_delta: Option<f64>,
514}
515
516impl Default for FederatedConfig {
517    fn default() -> Self {
518        Self {
519            num_clients: 10,
520            client_fraction: 0.3,
521            local_epochs: 5,
522            local_learning_rate: 0.01,
523            secure_aggregation: true,
524            dp_epsilon: Some(1.0),
525            dp_delta: Some(1e-5),
526        }
527    }
528}
529
530/// Federated learning client
531#[derive(Debug, Clone)]
532pub struct FederatedClient {
533    /// Client identifier
534    pub id: String,
535    /// Local dataset size
536    pub dataset_size: usize,
537    /// Local model parameters
538    pub local_parameters: Vec<f64>,
539    /// Client training statistics
540    pub stats: ClientStats,
541}
542
543/// Statistics for federated client
544#[derive(Debug, Clone, Default, Serialize, Deserialize)]
545pub struct ClientStats {
546    /// Number of training rounds participated in
547    pub rounds_participated: usize,
548    /// Total samples used in training
549    pub total_samples: usize,
550    /// Average local loss
551    pub avg_local_loss: f64,
552}
553
554impl FederatedLearning {
555    /// Create a new federated learning system
556    pub fn new(config: FederatedConfig, num_features: usize) -> Self {
557        Self {
558            config,
559            clients: Vec::new(),
560            global_model: Arc::new(RwLock::new(ModelParameters {
561                weights: vec![0.0; num_features],
562                bias: 0.0,
563                metadata: ParameterMetadata {
564                    iterations_completed: 0,
565                    current_loss: f64::INFINITY,
566                    converged: false,
567                    last_updated: std::time::SystemTime::now(),
568                },
569            })),
570            privacy_mechanism: PrivacyMechanism::new(),
571        }
572    }
573
574    /// Add a client to the federated system
575    pub fn add_client(&mut self, client_id: String, dataset_size: usize) {
576        let client = FederatedClient {
577            id: client_id,
578            dataset_size,
579            local_parameters: Vec::new(),
580            stats: ClientStats::default(),
581        };
582        self.clients.push(client);
583    }
584
585    /// Select clients for a training round
586    pub fn select_clients(&self) -> Vec<usize> {
587        use scirs2_core::random::thread_rng;
588
589        let num_selected =
590            (self.clients.len() as f64 * self.config.client_fraction).ceil() as usize;
591        let mut selected = Vec::new();
592        let mut rng = thread_rng();
593
594        let mut indices: Vec<usize> = (0..self.clients.len()).collect();
595
596        // Fisher-Yates shuffle for random selection
597        for i in (1..indices.len()).rev() {
598            let j = rng.gen_range(0..=i);
599            indices.swap(i, j);
600        }
601
602        selected.extend_from_slice(&indices[..num_selected]);
603        selected
604    }
605
606    /// Perform federated averaging of client updates
607    pub fn federated_average(&self, client_updates: &[(usize, Vec<f64>)]) -> Vec<f64> {
608        if client_updates.is_empty() {
609            return vec![];
610        }
611
612        let num_features = client_updates[0].1.len();
613        let mut averaged = vec![0.0; num_features];
614        let mut total_weight = 0.0;
615
616        for (client_idx, update) in client_updates {
617            let weight = self.clients[*client_idx].dataset_size as f64;
618            total_weight += weight;
619
620            for (i, val) in update.iter().enumerate() {
621                averaged[i] += val * weight;
622            }
623        }
624
625        // Normalize by total weight
626        for val in averaged.iter_mut() {
627            *val /= total_weight;
628        }
629
630        // Apply differential privacy if enabled
631        if let Some(dp_epsilon) = self.config.dp_epsilon {
632            self.privacy_mechanism
633                .apply_noise(&mut averaged, dp_epsilon);
634        }
635
636        averaged
637    }
638
639    /// Get global model parameters
640    pub fn get_global_model(&self) -> ModelParameters {
641        self.global_model
642            .read()
643            .unwrap_or_else(|e| e.into_inner())
644            .clone()
645    }
646}
647
648/// Privacy mechanism for federated learning
649#[derive(Debug, Clone)]
650pub struct PrivacyMechanism {
651    /// Noise scale for differential privacy
652    pub noise_scale: f64,
653}
654
655impl PrivacyMechanism {
656    /// Create a new privacy mechanism
657    pub fn new() -> Self {
658        Self { noise_scale: 1.0 }
659    }
660
661    /// Apply differential privacy noise to gradients
662    pub fn apply_noise(&self, gradients: &mut [f64], epsilon: f64) {
663        use scirs2_core::random::essentials::Normal;
664        use scirs2_core::random::thread_rng;
665
666        let mut rng = thread_rng();
667        let noise_std = self.noise_scale / epsilon;
668        let normal = Normal::new(0.0, noise_std)
669            .unwrap_or_else(|_| Normal::new(0.0, 1.0).expect("default normal distribution"));
670
671        for grad in gradients.iter_mut() {
672            *grad += rng.sample(normal);
673        }
674    }
675
676    /// Clip gradients for privacy preservation
677    pub fn clip_gradients(&self, gradients: &mut [f64], clip_norm: f64) {
678        let norm: f64 = gradients.iter().map(|g| g * g).sum::<f64>().sqrt();
679
680        if norm > clip_norm {
681            let scale = clip_norm / norm;
682            for grad in gradients.iter_mut() {
683                *grad *= scale;
684            }
685        }
686    }
687}
688
689impl Default for PrivacyMechanism {
690    fn default() -> Self {
691        Self::new()
692    }
693}
694
695/// Byzantine-Fault Tolerant aggregation for robust distributed learning
696///
697/// Implements robust aggregation methods that are resilient to Byzantine
698/// (malicious or faulty) workers.
699#[derive(Debug)]
700pub struct ByzantineFaultTolerant {
701    /// BFT configuration
702    pub config: BFTConfig,
703    /// Aggregation method
704    pub aggregation_method: AggregationMethod,
705}
706
707/// Configuration for Byzantine-Fault Tolerant training
708#[derive(Debug, Clone, Serialize, Deserialize)]
709pub struct BFTConfig {
710    /// Maximum fraction of Byzantine workers tolerated
711    pub max_byzantine_fraction: f64,
712    /// Detection threshold for identifying Byzantine behavior
713    pub detection_threshold: f64,
714    /// Enable reputation tracking
715    pub enable_reputation: bool,
716}
717
718impl Default for BFTConfig {
719    fn default() -> Self {
720        Self {
721            max_byzantine_fraction: 0.3,
722            detection_threshold: 2.0,
723            enable_reputation: true,
724        }
725    }
726}
727
728/// Robust aggregation methods for Byzantine-Fault Tolerance
729#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
730pub enum AggregationMethod {
731    /// Coordinate-wise median
732    Median,
733    /// Trimmed mean (remove extreme values)
734    TrimmedMean { trim_fraction: usize },
735    /// Krum algorithm (select most representative gradient)
736    Krum,
737    /// Bulyan (combination of Krum and trimmed mean)
738    Bulyan,
739}
740
741impl ByzantineFaultTolerant {
742    /// Create a new BFT aggregator
743    pub fn new(config: BFTConfig, method: AggregationMethod) -> Self {
744        Self {
745            config,
746            aggregation_method: method,
747        }
748    }
749
750    /// Aggregate gradients using Byzantine-Fault Tolerant method
751    pub fn aggregate(&self, gradients: &[Vec<f64>]) -> Result<Vec<f64>> {
752        if gradients.is_empty() {
753            return Err(SklearsError::InvalidInput(
754                "Cannot aggregate empty gradient set".to_string(),
755            ));
756        }
757
758        match self.aggregation_method {
759            AggregationMethod::Median => self.coordinate_wise_median(gradients),
760            AggregationMethod::TrimmedMean { trim_fraction } => {
761                self.trimmed_mean(gradients, trim_fraction)
762            }
763            AggregationMethod::Krum => self.krum(gradients),
764            AggregationMethod::Bulyan => self.bulyan(gradients),
765        }
766    }
767
768    /// Coordinate-wise median aggregation
769    fn coordinate_wise_median(&self, gradients: &[Vec<f64>]) -> Result<Vec<f64>> {
770        let num_features = gradients[0].len();
771        let mut result = vec![0.0; num_features];
772
773        for i in 0..num_features {
774            let mut values: Vec<f64> = gradients.iter().map(|g| g[i]).collect();
775            values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
776            result[i] = values[values.len() / 2];
777        }
778
779        Ok(result)
780    }
781
782    /// Trimmed mean aggregation
783    fn trimmed_mean(&self, gradients: &[Vec<f64>], trim_fraction: usize) -> Result<Vec<f64>> {
784        let num_features = gradients[0].len();
785        let mut result = vec![0.0; num_features];
786        let trim_count = (gradients.len() * trim_fraction) / 100;
787
788        for i in 0..num_features {
789            let mut values: Vec<f64> = gradients.iter().map(|g| g[i]).collect();
790            values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
791
792            // Remove extreme values
793            let trimmed = &values[trim_count..values.len() - trim_count];
794            result[i] = trimmed.iter().sum::<f64>() / trimmed.len() as f64;
795        }
796
797        Ok(result)
798    }
799
800    /// Krum aggregation (select most representative gradient)
801    fn krum(&self, gradients: &[Vec<f64>]) -> Result<Vec<f64>> {
802        let n = gradients.len();
803        let f = (n as f64 * self.config.max_byzantine_fraction).floor() as usize;
804        let m = n - f - 2;
805
806        let mut scores = vec![0.0; n];
807
808        // Compute Krum scores
809        for i in 0..n {
810            let mut distances: Vec<(usize, f64)> = Vec::new();
811
812            for j in 0..n {
813                if i != j {
814                    let dist = self.euclidean_distance(&gradients[i], &gradients[j]);
815                    distances.push((j, dist));
816                }
817            }
818
819            // Sort by distance and sum m closest
820            distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
821            scores[i] = distances.iter().take(m).map(|(_, d)| d).sum();
822        }
823
824        // Select gradient with minimum score
825        let best_idx = scores
826            .iter()
827            .enumerate()
828            .min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
829            .map(|(idx, _)| idx)
830            .expect("expected valid value");
831
832        Ok(gradients[best_idx].clone())
833    }
834
835    /// Bulyan aggregation (robust combination)
836    fn bulyan(&self, gradients: &[Vec<f64>]) -> Result<Vec<f64>> {
837        // Simplified Bulyan: apply Krum multiple times and then trimmed mean
838        let n = gradients.len();
839        let f = (n as f64 * self.config.max_byzantine_fraction).floor() as usize;
840        let theta = n - 2 * f;
841
842        if theta < 1 {
843            return Err(SklearsError::InvalidInput(
844                "Too many Byzantine workers for Bulyan".to_string(),
845            ));
846        }
847
848        // For simplicity, use coordinate-wise median as a robust aggregation
849        self.coordinate_wise_median(gradients)
850    }
851
852    /// Compute Euclidean distance between two gradients
853    fn euclidean_distance(&self, a: &[f64], b: &[f64]) -> f64 {
854        a.iter()
855            .zip(b.iter())
856            .map(|(x, y)| (x - y).powi(2))
857            .sum::<f64>()
858            .sqrt()
859    }
860}
861
862/// Advanced load balancing for distributed systems
863///
864/// Implements sophisticated load balancing strategies for optimal
865/// resource utilization and performance.
866#[derive(Debug)]
867pub struct LoadBalancer {
868    /// Load balancing strategy
869    pub strategy: LoadBalancingStrategy,
870    /// Worker load tracking
871    pub worker_loads: HashMap<String, WorkerLoad>,
872}
873
874/// Load balancing strategy
875#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
876pub enum LoadBalancingStrategy {
877    /// Round-robin assignment
878    RoundRobin,
879    /// Least-loaded worker first
880    LeastLoaded,
881    /// Weighted random based on capacity
882    WeightedRandom,
883    /// Power of two choices
884    PowerOfTwo,
885}
886
887/// Worker load information
888#[derive(Debug, Clone, Serialize, Deserialize)]
889pub struct WorkerLoad {
890    /// Current number of tasks
891    pub active_tasks: usize,
892    /// Worker capacity
893    pub capacity: usize,
894    /// Average task completion time (ms)
895    pub avg_completion_time_ms: u64,
896    /// Load factor (0.0 - 1.0)
897    pub load_factor: f64,
898}
899
900impl LoadBalancer {
901    /// Create a new load balancer
902    pub fn new(strategy: LoadBalancingStrategy) -> Self {
903        Self {
904            strategy,
905            worker_loads: HashMap::new(),
906        }
907    }
908
909    /// Register a worker with the load balancer
910    pub fn register_worker(&mut self, worker_id: String, capacity: usize) {
911        self.worker_loads.insert(
912            worker_id,
913            WorkerLoad {
914                active_tasks: 0,
915                capacity,
916                avg_completion_time_ms: 0,
917                load_factor: 0.0,
918            },
919        );
920    }
921
922    /// Select a worker for task assignment
923    pub fn select_worker(&mut self) -> Option<String> {
924        match self.strategy {
925            LoadBalancingStrategy::RoundRobin => self.round_robin_select(),
926            LoadBalancingStrategy::LeastLoaded => self.least_loaded_select(),
927            LoadBalancingStrategy::WeightedRandom => self.weighted_random_select(),
928            LoadBalancingStrategy::PowerOfTwo => self.power_of_two_select(),
929        }
930    }
931
932    /// Round-robin worker selection
933    fn round_robin_select(&self) -> Option<String> {
934        self.worker_loads.keys().next().cloned()
935    }
936
937    /// Select least loaded worker
938    fn least_loaded_select(&self) -> Option<String> {
939        self.worker_loads
940            .iter()
941            .min_by(|(_, a), (_, b)| {
942                a.load_factor
943                    .partial_cmp(&b.load_factor)
944                    .unwrap_or(std::cmp::Ordering::Equal)
945            })
946            .map(|(id, _)| id.clone())
947    }
948
949    /// Weighted random selection based on available capacity
950    fn weighted_random_select(&self) -> Option<String> {
951        use scirs2_core::random::thread_rng;
952
953        if self.worker_loads.is_empty() {
954            return None;
955        }
956
957        let mut rng = thread_rng();
958        let total_capacity: f64 = self
959            .worker_loads
960            .values()
961            .map(|load| (load.capacity - load.active_tasks) as f64)
962            .sum();
963
964        let mut rand_val = rng.gen_range(0.0..total_capacity);
965
966        for (id, load) in &self.worker_loads {
967            let available = (load.capacity - load.active_tasks) as f64;
968            if rand_val < available {
969                return Some(id.clone());
970            }
971            rand_val -= available;
972        }
973
974        self.worker_loads.keys().next().cloned()
975    }
976
977    /// Power of two choices (sample 2 random workers, pick least loaded)
978    fn power_of_two_select(&self) -> Option<String> {
979        use scirs2_core::random::thread_rng;
980
981        if self.worker_loads.is_empty() {
982            return None;
983        }
984
985        let mut rng = thread_rng();
986        let workers: Vec<_> = self.worker_loads.keys().collect();
987
988        if workers.len() == 1 {
989            return Some(workers[0].clone());
990        }
991
992        let idx1 = rng.gen_range(0..workers.len());
993        let mut idx2 = rng.gen_range(0..workers.len());
994        while idx2 == idx1 {
995            idx2 = rng.gen_range(0..workers.len());
996        }
997
998        let load1 = &self.worker_loads[workers[idx1]];
999        let load2 = &self.worker_loads[workers[idx2]];
1000
1001        if load1.load_factor < load2.load_factor {
1002            Some(workers[idx1].clone())
1003        } else {
1004            Some(workers[idx2].clone())
1005        }
1006    }
1007
1008    /// Update worker load after task assignment
1009    pub fn update_load(&mut self, worker_id: &str, task_assigned: bool) {
1010        if let Some(load) = self.worker_loads.get_mut(worker_id) {
1011            if task_assigned {
1012                load.active_tasks += 1;
1013            } else if load.active_tasks > 0 {
1014                load.active_tasks -= 1;
1015            }
1016            load.load_factor = load.active_tasks as f64 / load.capacity as f64;
1017        }
1018    }
1019}
1020
1021#[cfg(test)]
1022mod tests {
1023    use super::*;
1024
1025    #[test]
1026    fn test_parameter_server_creation() {
1027        let ps = ParameterServer::new(5, 3);
1028        assert_eq!(ps.parameters.len(), 5);
1029        assert_eq!(ps.num_workers, 3);
1030        assert_eq!(ps.version, 0);
1031    }
1032
1033    #[test]
1034    fn test_gradient_accumulation() {
1035        let mut ps = ParameterServer::new(3, 2);
1036
1037        let grad1 = vec![1.0, 2.0, 3.0];
1038        let grad2 = vec![2.0, 3.0, 4.0];
1039
1040        ps.receive_gradient(grad1)
1041            .expect("receive_gradient should succeed");
1042        ps.receive_gradient(grad2)
1043            .expect("receive_gradient should succeed");
1044
1045        // After 2 updates (all workers), parameters should be updated
1046        let params = ps.get_parameters();
1047        assert_eq!(ps.version, 1);
1048        assert!(params.iter().all(|&p| p != 0.0));
1049    }
1050
1051    #[test]
1052    fn test_worker_node_creation() {
1053        let partition = DataPartition {
1054            features: vec![vec![1.0, 2.0], vec![3.0, 4.0]],
1055            targets: vec![1.0, 2.0],
1056            partition_id: 0,
1057        };
1058
1059        let worker = WorkerNode::new(NodeId::new("worker_0"), partition);
1060        assert_eq!(worker.id.0, "worker_0");
1061        assert_eq!(worker.stats.samples_processed, 0);
1062    }
1063
1064    #[test]
1065    fn test_local_gradient_computation() {
1066        let partition = DataPartition {
1067            features: vec![vec![1.0, 2.0], vec![2.0, 3.0]],
1068            targets: vec![3.0, 5.0],
1069            partition_id: 0,
1070        };
1071
1072        let mut worker = WorkerNode::new(NodeId::new("worker_0"), partition);
1073        let params = vec![1.0, 1.0];
1074
1075        let gradient = worker
1076            .compute_local_gradient(&params)
1077            .expect("compute_local_gradient should succeed");
1078        assert_eq!(gradient.len(), 2);
1079        assert!(worker.stats.gradient_computations > 0);
1080    }
1081
1082    #[test]
1083    fn test_distributed_regression_creation() {
1084        let config = DistributedConfig::default();
1085        let model = DistributedLinearRegression::new(config, 5);
1086
1087        assert_eq!(model.workers.len(), 0);
1088        assert!(
1089            model
1090                .parameters
1091                .read()
1092                .unwrap_or_else(|e| e.into_inner())
1093                .weights
1094                .len()
1095                == 5
1096        );
1097    }
1098
1099    #[test]
1100    fn test_data_partitioning() {
1101        let config = DistributedConfig {
1102            num_workers: 2,
1103            ..Default::default()
1104        };
1105
1106        let mut model = DistributedLinearRegression::new(config, 2);
1107
1108        let features = vec![
1109            vec![1.0, 2.0],
1110            vec![3.0, 4.0],
1111            vec![5.0, 6.0],
1112            vec![7.0, 8.0],
1113        ];
1114        let targets = vec![3.0, 7.0, 11.0, 15.0];
1115
1116        model
1117            .partition_data(features, targets)
1118            .expect("partition_data should succeed");
1119
1120        assert_eq!(model.workers.len(), 2);
1121        assert!(!model.workers[0].data_partition.features.is_empty());
1122    }
1123
1124    #[test]
1125    fn test_distributed_training() {
1126        let config = DistributedConfig {
1127            num_workers: 2,
1128            max_iterations: 10,
1129            tolerance: 1e-3,
1130            learning_rate: 0.01,
1131            ..Default::default()
1132        };
1133
1134        let mut model = DistributedLinearRegression::new(config, 2);
1135
1136        // Simple linear relationship: y = 2*x1 + 3*x2
1137        let features = vec![
1138            vec![1.0, 1.0],
1139            vec![2.0, 2.0],
1140            vec![3.0, 3.0],
1141            vec![4.0, 4.0],
1142        ];
1143        let targets = vec![5.0, 10.0, 15.0, 20.0];
1144
1145        model
1146            .partition_data(features, targets)
1147            .expect("partition_data should succeed");
1148        model.fit().expect("model fitting should succeed");
1149
1150        let stats = model.get_training_stats();
1151        assert!(stats.total_samples_processed > 0);
1152        assert!(stats.parameter_server_version > 0);
1153    }
1154
1155    #[test]
1156    fn test_prediction() {
1157        let config = DistributedConfig::default();
1158        let model = DistributedLinearRegression::new(config, 2);
1159
1160        let test_features = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
1161        let predictions = model
1162            .predict(&test_features)
1163            .expect("prediction should succeed");
1164
1165        assert_eq!(predictions.len(), 2);
1166    }
1167
1168    #[test]
1169    fn test_training_stats() {
1170        let config = DistributedConfig {
1171            num_workers: 3,
1172            ..Default::default()
1173        };
1174
1175        let model = DistributedLinearRegression::new(config, 2);
1176        let stats = model.get_training_stats();
1177
1178        assert_eq!(stats.num_workers, 0); // No workers added yet
1179    }
1180
1181    // ============================================================================
1182    // Tests for Advanced Distributed Learning Features
1183    // ============================================================================
1184
1185    #[test]
1186    fn test_federated_learning_creation() {
1187        let config = FederatedConfig::default();
1188        let fed_learning = FederatedLearning::new(config, 5);
1189
1190        assert_eq!(fed_learning.clients.len(), 0);
1191        assert_eq!(fed_learning.config.num_clients, 10);
1192    }
1193
1194    #[test]
1195    fn test_federated_add_client() {
1196        let config = FederatedConfig::default();
1197        let mut fed_learning = FederatedLearning::new(config, 5);
1198
1199        fed_learning.add_client("client_1".to_string(), 100);
1200        fed_learning.add_client("client_2".to_string(), 150);
1201
1202        assert_eq!(fed_learning.clients.len(), 2);
1203        assert_eq!(fed_learning.clients[0].dataset_size, 100);
1204        assert_eq!(fed_learning.clients[1].dataset_size, 150);
1205    }
1206
1207    #[test]
1208    fn test_federated_client_selection() {
1209        let config = FederatedConfig {
1210            client_fraction: 0.5,
1211            ..Default::default()
1212        };
1213        let mut fed_learning = FederatedLearning::new(config, 5);
1214
1215        for i in 0..10 {
1216            fed_learning.add_client(format!("client_{}", i), 100);
1217        }
1218
1219        let selected = fed_learning.select_clients();
1220        assert!(selected.len() >= 4 && selected.len() <= 6); // ~50% of 10
1221    }
1222
1223    #[test]
1224    fn test_federated_averaging() {
1225        let config = FederatedConfig {
1226            dp_epsilon: None, // Disable noise for deterministic test
1227            ..Default::default()
1228        };
1229        let mut fed_learning = FederatedLearning::new(config, 3);
1230
1231        fed_learning.add_client("client_1".to_string(), 100);
1232        fed_learning.add_client("client_2".to_string(), 100);
1233
1234        let updates = vec![(0, vec![1.0, 2.0, 3.0]), (1, vec![2.0, 4.0, 6.0])];
1235
1236        let averaged = fed_learning.federated_average(&updates);
1237        assert_eq!(averaged.len(), 3);
1238        // With equal weights, average should be (1+2)/2=1.5, (2+4)/2=3, (3+6)/2=4.5
1239        assert!((averaged[0] - 1.5).abs() < 1e-6);
1240        assert!((averaged[1] - 3.0).abs() < 1e-6);
1241        assert!((averaged[2] - 4.5).abs() < 1e-6);
1242    }
1243
1244    #[test]
1245    fn test_privacy_mechanism_noise() {
1246        let privacy = PrivacyMechanism::new();
1247        let mut gradients = vec![1.0, 2.0, 3.0];
1248        let original = gradients.clone();
1249
1250        privacy.apply_noise(&mut gradients, 1.0);
1251
1252        // Gradients should be modified (with very high probability)
1253        assert_ne!(gradients, original);
1254    }
1255
1256    #[test]
1257    fn test_privacy_mechanism_clipping() {
1258        let privacy = PrivacyMechanism::new();
1259        let mut gradients = vec![3.0, 4.0]; // Norm = 5.0
1260
1261        privacy.clip_gradients(&mut gradients, 1.0);
1262
1263        let norm: f64 = gradients.iter().map(|g| g * g).sum::<f64>().sqrt();
1264        assert!((norm - 1.0).abs() < 1e-6);
1265    }
1266
1267    #[test]
1268    fn test_byzantine_fault_tolerant_creation() {
1269        let config = BFTConfig::default();
1270        let bft = ByzantineFaultTolerant::new(config, AggregationMethod::Median);
1271
1272        assert_eq!(bft.aggregation_method, AggregationMethod::Median);
1273    }
1274
1275    #[test]
1276    fn test_byzantine_median_aggregation() {
1277        let config = BFTConfig::default();
1278        let bft = ByzantineFaultTolerant::new(config, AggregationMethod::Median);
1279
1280        let gradients = vec![
1281            vec![1.0, 2.0, 3.0],
1282            vec![2.0, 3.0, 4.0],
1283            vec![3.0, 4.0, 5.0],
1284            vec![100.0, 100.0, 100.0], // Byzantine outlier
1285        ];
1286
1287        let result = bft.aggregate(&gradients).expect("aggregate should succeed");
1288        assert_eq!(result.len(), 3);
1289        // Median should filter out the outlier
1290        assert!(result[0] < 50.0);
1291        assert!(result[1] < 50.0);
1292        assert!(result[2] < 50.0);
1293    }
1294
1295    #[test]
1296    fn test_byzantine_trimmed_mean() {
1297        let config = BFTConfig::default();
1298        let bft = ByzantineFaultTolerant::new(
1299            config,
1300            AggregationMethod::TrimmedMean { trim_fraction: 25 },
1301        );
1302
1303        let gradients = vec![
1304            vec![1.0, 2.0, 3.0],
1305            vec![2.0, 3.0, 4.0],
1306            vec![3.0, 4.0, 5.0],
1307            vec![4.0, 5.0, 6.0],
1308        ];
1309
1310        let result = bft.aggregate(&gradients).expect("aggregate should succeed");
1311        assert_eq!(result.len(), 3);
1312        // Trimmed mean should produce reasonable values
1313        assert!(result[0] > 1.0 && result[0] < 4.0);
1314    }
1315
1316    #[test]
1317    fn test_byzantine_krum() {
1318        let config = BFTConfig::default();
1319        let bft = ByzantineFaultTolerant::new(config, AggregationMethod::Krum);
1320
1321        let gradients = vec![
1322            vec![1.0, 2.0, 3.0],
1323            vec![1.1, 2.1, 3.1],
1324            vec![1.2, 2.2, 3.2],
1325            vec![100.0, 100.0, 100.0], // Byzantine outlier
1326        ];
1327
1328        let result = bft.aggregate(&gradients).expect("aggregate should succeed");
1329        assert_eq!(result.len(), 3);
1330        // Krum should select one of the non-outlier gradients
1331        assert!(result[0] < 10.0);
1332    }
1333
1334    #[test]
1335    fn test_byzantine_empty_gradients() {
1336        let config = BFTConfig::default();
1337        let bft = ByzantineFaultTolerant::new(config, AggregationMethod::Median);
1338
1339        let gradients: Vec<Vec<f64>> = vec![];
1340        let result = bft.aggregate(&gradients);
1341
1342        assert!(result.is_err());
1343    }
1344
1345    #[test]
1346    fn test_load_balancer_creation() {
1347        let lb = LoadBalancer::new(LoadBalancingStrategy::RoundRobin);
1348        assert_eq!(lb.strategy, LoadBalancingStrategy::RoundRobin);
1349        assert_eq!(lb.worker_loads.len(), 0);
1350    }
1351
1352    #[test]
1353    fn test_load_balancer_register_worker() {
1354        let mut lb = LoadBalancer::new(LoadBalancingStrategy::LeastLoaded);
1355
1356        lb.register_worker("worker_1".to_string(), 10);
1357        lb.register_worker("worker_2".to_string(), 20);
1358
1359        assert_eq!(lb.worker_loads.len(), 2);
1360        assert_eq!(
1361            lb.worker_loads
1362                .get("worker_1")
1363                .expect("key should exist")
1364                .capacity,
1365            10
1366        );
1367        assert_eq!(
1368            lb.worker_loads
1369                .get("worker_2")
1370                .expect("key should exist")
1371                .capacity,
1372            20
1373        );
1374    }
1375
1376    #[test]
1377    fn test_load_balancer_least_loaded() {
1378        let mut lb = LoadBalancer::new(LoadBalancingStrategy::LeastLoaded);
1379
1380        lb.register_worker("worker_1".to_string(), 10);
1381        lb.register_worker("worker_2".to_string(), 10);
1382
1383        // Initially both have load 0, so either can be selected
1384        let selected = lb.select_worker();
1385        assert!(selected.is_some());
1386
1387        // Add load to one worker
1388        lb.update_load("worker_1", true);
1389        lb.update_load("worker_1", true);
1390
1391        // Now worker_2 should be selected (least loaded)
1392        let selected = lb.select_worker();
1393        assert!(selected.is_some());
1394    }
1395
1396    #[test]
1397    fn test_load_balancer_update_load() {
1398        let mut lb = LoadBalancer::new(LoadBalancingStrategy::LeastLoaded);
1399
1400        lb.register_worker("worker_1".to_string(), 10);
1401
1402        lb.update_load("worker_1", true);
1403        assert_eq!(
1404            lb.worker_loads
1405                .get("worker_1")
1406                .expect("key should exist")
1407                .active_tasks,
1408            1
1409        );
1410        assert!(
1411            (lb.worker_loads
1412                .get("worker_1")
1413                .expect("key should exist")
1414                .load_factor
1415                - 0.1)
1416                .abs()
1417                < 1e-6
1418        );
1419
1420        lb.update_load("worker_1", false);
1421        assert_eq!(
1422            lb.worker_loads
1423                .get("worker_1")
1424                .expect("key should exist")
1425                .active_tasks,
1426            0
1427        );
1428        assert!(
1429            (lb.worker_loads
1430                .get("worker_1")
1431                .expect("key should exist")
1432                .load_factor)
1433                .abs()
1434                < 1e-6
1435        );
1436    }
1437
1438    #[test]
1439    fn test_load_balancer_power_of_two() {
1440        let mut lb = LoadBalancer::new(LoadBalancingStrategy::PowerOfTwo);
1441
1442        lb.register_worker("worker_1".to_string(), 10);
1443        lb.register_worker("worker_2".to_string(), 10);
1444        lb.register_worker("worker_3".to_string(), 10);
1445
1446        let selected = lb.select_worker();
1447        assert!(selected.is_some());
1448    }
1449
1450    #[test]
1451    fn test_federated_config_default() {
1452        let config = FederatedConfig::default();
1453        assert_eq!(config.num_clients, 10);
1454        assert!((config.client_fraction - 0.3).abs() < 1e-6);
1455        assert_eq!(config.local_epochs, 5);
1456        assert!(config.secure_aggregation);
1457    }
1458
1459    #[test]
1460    fn test_bft_config_default() {
1461        let config = BFTConfig::default();
1462        assert!((config.max_byzantine_fraction - 0.3).abs() < 1e-6);
1463        assert!((config.detection_threshold - 2.0).abs() < 1e-6);
1464        assert!(config.enable_reputation);
1465    }
1466
1467    #[test]
1468    fn test_aggregation_method_equality() {
1469        assert_eq!(AggregationMethod::Median, AggregationMethod::Median);
1470        assert_eq!(
1471            AggregationMethod::TrimmedMean { trim_fraction: 25 },
1472            AggregationMethod::TrimmedMean { trim_fraction: 25 }
1473        );
1474        assert_ne!(AggregationMethod::Median, AggregationMethod::Krum);
1475    }
1476
1477    #[test]
1478    fn test_load_balancing_strategy_equality() {
1479        assert_eq!(
1480            LoadBalancingStrategy::RoundRobin,
1481            LoadBalancingStrategy::RoundRobin
1482        );
1483        assert_eq!(
1484            LoadBalancingStrategy::LeastLoaded,
1485            LoadBalancingStrategy::LeastLoaded
1486        );
1487        assert_ne!(
1488            LoadBalancingStrategy::RoundRobin,
1489            LoadBalancingStrategy::LeastLoaded
1490        );
1491    }
1492}