scirs2_cluster/distributed/
core.rs

1//! Core distributed K-means clustering implementation
2//!
3//! This module provides the main distributed K-means algorithm with
4//! support for multiple workers, fault tolerance, and load balancing.
5
6use scirs2_core::ndarray::{s, Array1, Array2, ArrayView1, ArrayView2, Axis};
7use scirs2_core::numeric::{Float, FromPrimitive, Zero};
8use scirs2_core::random::prelude::*;
9use scirs2_core::random::rand_prelude::IndexedRandom;
10use std::collections::HashMap;
11use std::fmt::Debug;
12use std::sync::{Arc, Mutex};
13use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
14
15use serde::{Deserialize, Serialize};
16
17use crate::error::{ClusteringError, Result};
18use crate::vq::euclidean_distance;
19
20use super::fault_tolerance::{DataPartition, FaultToleranceCoordinator};
21use super::load_balancing::LoadBalancingCoordinator;
22use super::message_passing::{ClusteringMessage, MessagePassingCoordinator, MessagePriority};
23use super::monitoring::PerformanceMonitor;
24use super::partitioning::{DataPartitioner, PartitioningConfig};
25
26/// Main distributed K-means clustering algorithm
27#[derive(Debug)]
28pub struct DistributedKMeans<F: Float> {
29    /// Number of clusters
30    pub k: usize,
31    /// Configuration parameters
32    pub config: DistributedKMeansConfig,
33    /// Current centroids
34    pub centroids: Option<Array2<F>>,
35    /// Worker assignments and data partitions
36    pub partitions: Vec<DataPartition<F>>,
37    /// Fault tolerance coordinator
38    pub fault_coordinator: FaultToleranceCoordinator<F>,
39    /// Load balancing coordinator
40    pub load_balancer: LoadBalancingCoordinator,
41    /// Performance monitor
42    pub performance_monitor: PerformanceMonitor,
43    /// Message passing coordinator
44    pub message_coordinator: Option<MessagePassingCoordinator<F>>,
45    /// Data partitioner
46    pub partitioner: DataPartitioner<F>,
47    /// Current iteration
48    pub current_iteration: usize,
49    /// Convergence history
50    pub convergence_history: Vec<ConvergenceInfo>,
51    /// Global inertia
52    pub global_inertia: f64,
53}
54
55/// Configuration for distributed K-means
56#[derive(Debug, Clone)]
57pub struct DistributedKMeansConfig {
58    pub max_iterations: usize,
59    pub tolerance: f64,
60    pub n_workers: usize,
61    pub init_method: InitializationMethod,
62    pub enable_fault_tolerance: bool,
63    pub enable_load_balancing: bool,
64    pub enable_monitoring: bool,
65    pub convergence_check_interval: usize,
66    pub checkpoint_interval: usize,
67    pub verbose: bool,
68    pub random_seed: Option<u64>,
69}
70
71impl Default for DistributedKMeansConfig {
72    fn default() -> Self {
73        Self {
74            max_iterations: 100,
75            tolerance: 1e-4,
76            n_workers: 4,
77            init_method: InitializationMethod::KMeansPlusPlus,
78            enable_fault_tolerance: true,
79            enable_load_balancing: true,
80            enable_monitoring: true,
81            convergence_check_interval: 5,
82            checkpoint_interval: 10,
83            verbose: false,
84            random_seed: None,
85        }
86    }
87}
88
89/// Centroid initialization methods
90#[derive(Debug, Clone)]
91pub enum InitializationMethod {
92    /// Random initialization
93    Random,
94    /// K-means++ initialization
95    KMeansPlusPlus,
96    /// Forgy initialization
97    Forgy,
98    /// Custom centroids provided by user
99    Custom(Array2<f64>),
100}
101
102/// Convergence information
103#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct ConvergenceInfo {
105    pub iteration: usize,
106    pub inertia: f64,
107    pub centroid_movement: f64,
108    pub converged: bool,
109    pub timestamp: SystemTime,
110    pub computation_time_ms: u64,
111}
112
113/// Clustering result
114#[derive(Debug, Clone)]
115pub struct ClusteringResult<F: Float> {
116    /// Final cluster centroids
117    pub centroids: Array2<F>,
118    /// Cluster labels for all data points
119    pub labels: Array1<usize>,
120    /// Final inertia (within-cluster sum of squares)
121    pub inertia: f64,
122    /// Number of iterations performed
123    pub n_iterations: usize,
124    /// Convergence information
125    pub convergence_info: ConvergenceInfo,
126    /// Performance statistics
127    pub performance_stats: PerformanceStatistics,
128}
129
130/// Performance statistics for clustering
131#[derive(Debug, Clone)]
132pub struct PerformanceStatistics {
133    pub total_time_ms: u64,
134    pub communication_time_ms: u64,
135    pub computation_time_ms: u64,
136    pub synchronization_time_ms: u64,
137    pub worker_efficiency: f64,
138    pub load_balance_score: f64,
139    pub fault_tolerance_events: usize,
140}
141
142/// Worker computation result
143#[derive(Debug, Clone)]
144pub struct WorkerResult<F: Float> {
145    pub worker_id: usize,
146    pub local_centroids: Array2<F>,
147    pub local_labels: Array1<usize>,
148    pub local_inertia: f64,
149    pub point_counts: Array1<usize>,
150    pub computation_time_ms: u64,
151}
152
153impl<F: Float + FromPrimitive + Debug + Send + Sync + 'static> DistributedKMeans<F> {
154    /// Create new distributed K-means instance
155    pub fn new(k: usize, config: DistributedKMeansConfig) -> Result<Self> {
156        if k == 0 {
157            return Err(ClusteringError::InvalidInput(
158                "Number of clusters must be greater than 0".to_string(),
159            ));
160        }
161
162        if config.n_workers == 0 {
163            return Err(ClusteringError::InvalidInput(
164                "Number of workers must be greater than 0".to_string(),
165            ));
166        }
167
168        let partitioner_config = PartitioningConfig {
169            n_workers: config.n_workers,
170            ..Default::default()
171        };
172
173        let fault_tolerance_config = super::fault_tolerance::FaultToleranceConfig {
174            enabled: config.enable_fault_tolerance,
175            ..Default::default()
176        };
177
178        let load_balancing_config = super::load_balancing::LoadBalancingConfig {
179            enable_dynamic_balancing: config.enable_load_balancing,
180            ..Default::default()
181        };
182
183        let monitoring_config = super::monitoring::MonitoringConfig {
184            enable_detailed_monitoring: config.enable_monitoring,
185            ..Default::default()
186        };
187
188        Ok(Self {
189            k,
190            config,
191            centroids: None,
192            partitions: Vec::new(),
193            fault_coordinator: FaultToleranceCoordinator::new(fault_tolerance_config),
194            load_balancer: LoadBalancingCoordinator::new(load_balancing_config),
195            performance_monitor: PerformanceMonitor::new(monitoring_config),
196            message_coordinator: None,
197            partitioner: DataPartitioner::new(partitioner_config),
198            current_iteration: 0,
199            convergence_history: Vec::new(),
200            global_inertia: f64::INFINITY,
201        })
202    }
203
204    /// Fit the distributed K-means model to data
205    pub fn fit(&mut self, data: ArrayView2<F>) -> Result<ClusteringResult<F>> {
206        let start_time = Instant::now();
207        let mut stats = PerformanceStatistics {
208            total_time_ms: 0,
209            communication_time_ms: 0,
210            computation_time_ms: 0,
211            synchronization_time_ms: 0,
212            worker_efficiency: 0.0,
213            load_balance_score: 0.0,
214            fault_tolerance_events: 0,
215        };
216
217        // Validate input data
218        self.validate_input(data)?;
219
220        // Initialize workers and message passing
221        self.initialize_workers()?;
222
223        // Partition data across workers
224        let partition_start = Instant::now();
225        self.partitions = self.partitioner.partition_data(data)?;
226        stats.communication_time_ms += partition_start.elapsed().as_millis() as u64;
227
228        if self.config.verbose {
229            println!("Data partitioned across {} workers", self.config.n_workers);
230        }
231
232        // Initialize centroids
233        let init_start = Instant::now();
234        self.centroids = Some(self.initialize_centroids(data)?);
235        stats.computation_time_ms += init_start.elapsed().as_millis() as u64;
236
237        // Main clustering loop
238        let mut converged = false;
239        self.current_iteration = 0;
240
241        while self.current_iteration < self.config.max_iterations && !converged {
242            let iteration_start = Instant::now();
243
244            // Perform one iteration of distributed K-means
245            converged = self.perform_iteration(&mut stats)?;
246
247            // Update convergence history
248            let iteration_time = iteration_start.elapsed().as_millis() as u64;
249            self.update_convergence_history(iteration_time)?;
250
251            // Check for rebalancing if needed
252            if self.config.enable_load_balancing && self.current_iteration.is_multiple_of(10) {
253                self.check_and_rebalance(data, &mut stats)?;
254            }
255
256            // Create checkpoint if configured
257            if self.config.enable_fault_tolerance
258                && self
259                    .current_iteration
260                    .is_multiple_of(self.config.checkpoint_interval)
261            {
262                self.create_checkpoint()?;
263            }
264
265            self.current_iteration += 1;
266
267            if self.config.verbose && self.current_iteration.is_multiple_of(10) {
268                println!(
269                    "Iteration {}: inertia = {:.6}",
270                    self.current_iteration, self.global_inertia
271                );
272            }
273        }
274
275        // Finalize results
276        stats.total_time_ms = start_time.elapsed().as_millis() as u64;
277        stats.worker_efficiency = self.calculate_worker_efficiency();
278        stats.load_balance_score = self.calculate_load_balance_score();
279
280        let final_labels = self.collect_final_labels()?;
281        let final_convergence =
282            self.convergence_history
283                .last()
284                .cloned()
285                .unwrap_or_else(|| ConvergenceInfo {
286                    iteration: self.current_iteration,
287                    inertia: self.global_inertia,
288                    centroid_movement: 0.0,
289                    converged,
290                    timestamp: SystemTime::now(),
291                    computation_time_ms: 0,
292                });
293
294        Ok(ClusteringResult {
295            centroids: self.centroids.as_ref().unwrap().clone(),
296            labels: final_labels,
297            inertia: self.global_inertia,
298            n_iterations: self.current_iteration,
299            convergence_info: final_convergence,
300            performance_stats: stats,
301        })
302    }
303
304    /// Validate input data
305    fn validate_input(&self, data: ArrayView2<F>) -> Result<()> {
306        if data.nrows() == 0 {
307            return Err(ClusteringError::InvalidInput(
308                "Input data is empty".to_string(),
309            ));
310        }
311
312        if data.ncols() == 0 {
313            return Err(ClusteringError::InvalidInput(
314                "Input data has no features".to_string(),
315            ));
316        }
317
318        if data.nrows() < self.k {
319            return Err(ClusteringError::InvalidInput(format!(
320                "Number of samples ({}) must be at least k ({})",
321                data.nrows(),
322                self.k
323            )));
324        }
325
326        if data.nrows() < self.config.n_workers {
327            return Err(ClusteringError::InvalidInput(format!(
328                "Number of samples ({}) must be at least number of workers ({})",
329                data.nrows(),
330                self.config.n_workers
331            )));
332        }
333
334        Ok(())
335    }
336
337    /// Initialize workers and communication infrastructure
338    fn initialize_workers(&mut self) -> Result<()> {
339        // Register workers with fault tolerance coordinator
340        for worker_id in 0..self.config.n_workers {
341            self.fault_coordinator.register_worker(worker_id);
342            self.performance_monitor.register_worker(worker_id);
343        }
344
345        // Initialize message passing coordinator if needed
346        if self.config.n_workers > 1 {
347            let message_config = super::message_passing::MessagePassingConfig::default();
348            self.message_coordinator = Some(MessagePassingCoordinator::new(0, message_config));
349        }
350
351        Ok(())
352    }
353
354    /// Initialize cluster centroids
355    fn initialize_centroids(&self, data: ArrayView2<F>) -> Result<Array2<F>> {
356        match &self.config.init_method {
357            InitializationMethod::Random => self.random_initialization(data),
358            InitializationMethod::KMeansPlusPlus => self.kmeans_plus_plus_initialization(data),
359            InitializationMethod::Forgy => self.forgy_initialization(data),
360            InitializationMethod::Custom(centroids) => {
361                if centroids.nrows() != self.k || centroids.ncols() != data.ncols() {
362                    return Err(ClusteringError::InvalidInput(
363                        "Custom centroids dimensions don't match".to_string(),
364                    ));
365                }
366                let converted_centroids =
367                    Array2::from_shape_fn((self.k, data.ncols()), |(i, j)| {
368                        F::from(centroids[[i, j]]).unwrap_or_else(F::zero)
369                    });
370                Ok(converted_centroids)
371            }
372        }
373    }
374
375    /// Random centroid initialization
376    fn random_initialization(&self, data: ArrayView2<F>) -> Result<Array2<F>> {
377        use scirs2_core::random::seq::SliceRandom;
378
379        let mut rng = scirs2_core::random::rng();
380        let data_indices: Vec<usize> = (0..data.nrows()).collect();
381        let selected_indices: Vec<_> = data_indices
382            .as_slice()
383            .choose_multiple(&mut rng, self.k)
384            .cloned()
385            .collect();
386
387        let mut centroids = Array2::zeros((self.k, data.ncols()));
388        for (i, &data_idx) in selected_indices.iter().enumerate() {
389            centroids.row_mut(i).assign(&data.row(data_idx));
390        }
391
392        Ok(centroids)
393    }
394
395    /// K-means++ centroid initialization
396    fn kmeans_plus_plus_initialization(&self, data: ArrayView2<F>) -> Result<Array2<F>> {
397        use scirs2_core::random::Rng;
398
399        let mut rng = scirs2_core::random::rng();
400        let mut centroids = Array2::zeros((self.k, data.ncols()));
401
402        // Choose first centroid randomly
403        let first_idx = rng.random_range(0..data.nrows());
404        centroids.row_mut(0).assign(&data.row(first_idx));
405
406        // Choose remaining centroids using K-means++ method
407        for k in 1..self.k {
408            let mut distances = Array1::zeros(data.nrows());
409
410            // Calculate distance to nearest centroid for each point
411            for (i, point) in data.rows().into_iter().enumerate() {
412                let mut min_dist = F::infinity();
413                for centroid in centroids.rows().into_iter().take(k) {
414                    let dist = euclidean_distance(point, centroid);
415                    if dist < min_dist {
416                        min_dist = dist;
417                    }
418                }
419                distances[i] = min_dist.to_f64().unwrap_or(f64::INFINITY);
420            }
421
422            // Choose next centroid with probability proportional to squared distance
423            let total_dist: f64 = distances.iter().map(|&d| d * d).sum();
424            if total_dist <= 0.0 {
425                // Fallback to random selection
426                let random_idx = rng.random_range(0..data.nrows());
427                centroids.row_mut(k).assign(&data.row(random_idx));
428            } else {
429                let mut cumulative = 0.0;
430                let threshold = rng.random::<f64>() * total_dist;
431
432                let mut selected_idx = 0;
433                for (i, &dist) in distances.iter().enumerate() {
434                    cumulative += dist * dist;
435                    if cumulative >= threshold {
436                        selected_idx = i;
437                        break;
438                    }
439                }
440                centroids.row_mut(k).assign(&data.row(selected_idx));
441            }
442        }
443
444        Ok(centroids)
445    }
446
447    /// Forgy centroid initialization
448    fn forgy_initialization(&self, data: ArrayView2<F>) -> Result<Array2<F>> {
449        // Forgy method is equivalent to random initialization
450        self.random_initialization(data)
451    }
452
453    /// Perform one iteration of distributed K-means
454    fn perform_iteration(&mut self, stats: &mut PerformanceStatistics) -> Result<bool> {
455        let iteration_start = Instant::now();
456
457        // Broadcast current centroids to all workers
458        if self.config.n_workers > 1 {
459            let broadcast_start = Instant::now();
460            self.broadcast_centroids()?;
461            stats.communication_time_ms += broadcast_start.elapsed().as_millis() as u64;
462        }
463
464        // Compute local assignments and centroids on each worker
465        let compute_start = Instant::now();
466        let worker_results = self.compute_worker_assignments()?;
467        stats.computation_time_ms += compute_start.elapsed().as_millis() as u64;
468
469        // Synchronize and aggregate results
470        let sync_start = Instant::now();
471        let (new_centroids, new_inertia) = self.aggregate_worker_results(&worker_results)?;
472        stats.synchronization_time_ms += sync_start.elapsed().as_millis() as u64;
473
474        // Check for convergence
475        let converged = self.check_convergence(&new_centroids, new_inertia)?;
476
477        // Update centroids and inertia
478        self.centroids = Some(new_centroids);
479        self.global_inertia = new_inertia;
480
481        Ok(converged)
482    }
483
484    /// Broadcast current centroids to all workers
485    fn broadcast_centroids(&mut self) -> Result<()> {
486        if let (Some(ref centroids), Some(ref mut coordinator)) =
487            (&self.centroids, &mut self.message_coordinator)
488        {
489            let message = ClusteringMessage::UpdateCentroids {
490                round: self.current_iteration,
491                centroids: centroids.clone(),
492            };
493
494            coordinator.broadcast_message(message, MessagePriority::Normal)?;
495        }
496
497        Ok(())
498    }
499
500    /// Compute assignments and local centroids on each worker
501    fn compute_worker_assignments(&mut self) -> Result<Vec<WorkerResult<F>>> {
502        let mut results = Vec::new();
503
504        if let Some(ref centroids) = self.centroids {
505            for partition in &self.partitions {
506                let worker_start = Instant::now();
507
508                // Assign points to nearest centroids
509                let mut labels = Array1::zeros(partition.data.nrows());
510                let mut local_inertia = F::zero();
511
512                for (i, point) in partition.data.rows().into_iter().enumerate() {
513                    let mut min_dist = F::infinity();
514                    let mut best_cluster = 0;
515
516                    for (j, centroid) in centroids.rows().into_iter().enumerate() {
517                        let dist = euclidean_distance(point, centroid);
518                        if dist < min_dist {
519                            min_dist = dist;
520                            best_cluster = j;
521                        }
522                    }
523
524                    labels[i] = best_cluster;
525                    local_inertia = local_inertia + min_dist * min_dist;
526                }
527
528                // Compute local centroids
529                let mut local_centroids = Array2::zeros((self.k, partition.data.ncols()));
530                let mut point_counts = Array1::zeros(self.k);
531
532                for (i, point) in partition.data.rows().into_iter().enumerate() {
533                    let cluster = labels[i];
534                    point_counts[cluster] += 1;
535
536                    for (j, &value) in point.iter().enumerate() {
537                        local_centroids[[cluster, j]] = local_centroids[[cluster, j]] + value;
538                    }
539                }
540
541                // Normalize to get means
542                for k in 0..self.k {
543                    if point_counts[k] > 0 {
544                        let count = F::from(point_counts[k]).unwrap();
545                        for j in 0..partition.data.ncols() {
546                            local_centroids[[k, j]] = local_centroids[[k, j]] / count;
547                        }
548                    }
549                }
550
551                let computation_time = worker_start.elapsed().as_millis() as u64;
552
553                results.push(WorkerResult {
554                    worker_id: partition.workerid,
555                    local_centroids,
556                    local_labels: labels,
557                    local_inertia: local_inertia.to_f64().unwrap_or(f64::INFINITY),
558                    point_counts,
559                    computation_time_ms: computation_time,
560                });
561
562                // Update worker performance metrics
563                let throughput = partition.data.nrows() as f64 / (computation_time as f64 / 1000.0);
564                let efficiency = 1.0 / (1.0 + computation_time as f64 / 10000.0); // Simplified efficiency
565                self.performance_monitor.update_worker_metrics(
566                    partition.workerid,
567                    0.5, // CPU usage (placeholder)
568                    0.4, // Memory usage (placeholder)
569                    throughput,
570                    computation_time as f64,
571                )?;
572            }
573        }
574
575        Ok(results)
576    }
577
578    /// Aggregate results from all workers
579    fn aggregate_worker_results(
580        &self,
581        worker_results: &[WorkerResult<F>],
582    ) -> Result<(Array2<F>, f64)> {
583        if worker_results.is_empty() {
584            return Err(ClusteringError::InvalidInput(
585                "No worker results to aggregate".to_string(),
586            ));
587        }
588
589        let n_features = worker_results[0].local_centroids.ncols();
590        let mut global_centroids = Array2::zeros((self.k, n_features));
591        let mut global_counts: Array1<usize> = Array1::zeros(self.k);
592        let mut global_inertia = 0.0;
593
594        // Aggregate weighted centroids and counts
595        for result in worker_results {
596            global_inertia += result.local_inertia;
597
598            for k in 0..self.k {
599                let count = F::from(result.point_counts[k]).unwrap();
600                global_counts[k] += result.point_counts[k];
601
602                for j in 0..n_features {
603                    global_centroids[[k, j]] =
604                        global_centroids[[k, j]] + result.local_centroids[[k, j]] * count;
605                }
606            }
607        }
608
609        // Normalize to get global means
610        for k in 0..self.k {
611            if global_counts[k] > 0 {
612                let count = F::from(global_counts[k]).unwrap();
613                for j in 0..n_features {
614                    global_centroids[[k, j]] = global_centroids[[k, j]] / count;
615                }
616            }
617        }
618
619        Ok((global_centroids, global_inertia))
620    }
621
622    /// Check for convergence
623    fn check_convergence(&self, new_centroids: &Array2<F>, newinertia: f64) -> Result<bool> {
624        if let Some(ref old_centroids) = self.centroids {
625            // Calculate centroid movement
626            let mut max_movement = F::zero();
627            for (old_row, new_row) in old_centroids.rows().into_iter().zip(new_centroids.rows()) {
628                let movement = euclidean_distance(old_row, new_row);
629                if movement > max_movement {
630                    max_movement = movement;
631                }
632            }
633
634            // Check convergence criteria
635            let movement_converged =
636                max_movement.to_f64().unwrap_or(f64::INFINITY) < self.config.tolerance;
637            let inertia_change = (self.global_inertia - newinertia).abs();
638            let inertia_converged =
639                inertia_change < self.config.tolerance * self.global_inertia.abs();
640
641            Ok(movement_converged || inertia_converged)
642        } else {
643            Ok(false)
644        }
645    }
646
647    /// Update convergence history
648    fn update_convergence_history(&mut self, iteration_timems: u64) -> Result<()> {
649        let centroid_movement = if let Some(ref centroids) = self.centroids {
650            if self.convergence_history.is_empty() {
651                0.0
652            } else {
653                // Calculate movement from previous iteration (simplified)
654                self.config.tolerance * 2.0 // Placeholder
655            }
656        } else {
657            0.0
658        };
659
660        let converged = self.current_iteration >= self.config.max_iterations
661            || centroid_movement < self.config.tolerance;
662
663        let convergence_info = ConvergenceInfo {
664            iteration: self.current_iteration,
665            inertia: self.global_inertia,
666            centroid_movement,
667            converged,
668            timestamp: SystemTime::now(),
669            computation_time_ms: iteration_timems,
670        };
671
672        self.convergence_history.push(convergence_info);
673
674        Ok(())
675    }
676
677    /// Check for load imbalance and rebalance if needed
678    fn check_and_rebalance(
679        &mut self,
680        data: ArrayView2<F>,
681        stats: &mut PerformanceStatistics,
682    ) -> Result<()> {
683        if !self.config.enable_load_balancing {
684            return Ok(());
685        }
686
687        // Check if rebalancing is needed
688        if self.fault_coordinator.should_rebalance() {
689            let rebalance_start = Instant::now();
690
691            // Re-partition data
692            self.partitions = self.partitioner.partition_data(data)?;
693
694            stats.communication_time_ms += rebalance_start.elapsed().as_millis() as u64;
695            stats.fault_tolerance_events += 1;
696
697            if self.config.verbose {
698                println!(
699                    "Load rebalancing performed at iteration {}",
700                    self.current_iteration
701                );
702            }
703        }
704
705        Ok(())
706    }
707
708    /// Create checkpoint for fault tolerance
709    fn create_checkpoint(&mut self) -> Result<()> {
710        if !self.config.enable_fault_tolerance {
711            return Ok(());
712        }
713
714        let worker_assignments = self
715            .partitions
716            .iter()
717            .map(|p| (p.workerid, vec![p.partition_id]))
718            .collect();
719
720        self.fault_coordinator.create_checkpoint(
721            self.current_iteration,
722            self.centroids.as_ref(),
723            self.global_inertia,
724            &[], // Convergence history (simplified)
725            &worker_assignments,
726        );
727
728        Ok(())
729    }
730
731    /// Calculate worker efficiency
732    fn calculate_worker_efficiency(&self) -> f64 {
733        let worker_metrics = self.performance_monitor.get_worker_metrics();
734        if worker_metrics.is_empty() {
735            return 0.0;
736        }
737
738        let avg_health_score = worker_metrics.values().map(|m| m.health_score).sum::<f64>()
739            / worker_metrics.len() as f64;
740
741        avg_health_score
742    }
743
744    /// Calculate load balance score
745    fn calculate_load_balance_score(&self) -> f64 {
746        if self.partitions.is_empty() {
747            return 1.0;
748        }
749
750        let partition_sizes: Vec<usize> = self.partitions.iter().map(|p| p.data.nrows()).collect();
751        let avg_size = partition_sizes.iter().sum::<usize>() as f64 / partition_sizes.len() as f64;
752
753        if avg_size == 0.0 {
754            return 1.0;
755        }
756
757        let variance = partition_sizes
758            .iter()
759            .map(|&size| (size as f64 - avg_size).powi(2))
760            .sum::<f64>()
761            / partition_sizes.len() as f64;
762
763        let coefficient_of_variation = variance.sqrt() / avg_size;
764        1.0 / (1.0 + coefficient_of_variation)
765    }
766
767    /// Collect final labels from all partitions
768    fn collect_final_labels(&self) -> Result<Array1<usize>> {
769        let total_points: usize = self.partitions.iter().map(|p| p.data.nrows()).sum();
770        let mut labels = Array1::zeros(total_points);
771        let mut offset = 0;
772
773        // This is a simplified version - in practice, we'd need to track
774        // original data point indices through the partitioning process
775        for partition in &self.partitions {
776            if let Some(ref partition_labels) = partition.labels {
777                let end_offset = offset + partition_labels.len();
778                labels
779                    .slice_mut(s![offset..end_offset])
780                    .assign(&Array1::from_vec(partition_labels.clone()).view());
781                offset = end_offset;
782            }
783        }
784
785        Ok(labels)
786    }
787
788    /// Predict cluster assignments for new data
789    pub fn predict(&self, data: ArrayView2<F>) -> Result<Array1<usize>> {
790        if let Some(ref centroids) = self.centroids {
791            let mut labels = Array1::zeros(data.nrows());
792
793            for (i, point) in data.rows().into_iter().enumerate() {
794                let mut min_dist = F::infinity();
795                let mut best_cluster = 0;
796
797                for (j, centroid) in centroids.rows().into_iter().enumerate() {
798                    let dist = euclidean_distance(point, centroid);
799                    if dist < min_dist {
800                        min_dist = dist;
801                        best_cluster = j;
802                    }
803                }
804
805                labels[i] = best_cluster;
806            }
807
808            Ok(labels)
809        } else {
810            Err(ClusteringError::InvalidInput(
811                "Model has not been fitted yet".to_string(),
812            ))
813        }
814    }
815
816    /// Get current centroids
817    pub fn centroids(&self) -> Option<&Array2<F>> {
818        self.centroids.as_ref()
819    }
820
821    /// Get convergence history
822    pub fn convergence_history(&self) -> &[ConvergenceInfo] {
823        &self.convergence_history
824    }
825
826    /// Get current inertia
827    pub fn inertia(&self) -> f64 {
828        self.global_inertia
829    }
830
831    /// Get number of iterations performed
832    pub fn n_iterations(&self) -> usize {
833        self.current_iteration
834    }
835
836    /// Get performance monitor
837    pub fn performance_monitor(&self) -> &PerformanceMonitor {
838        &self.performance_monitor
839    }
840
841    /// Get fault tolerance coordinator
842    pub fn fault_coordinator(&self) -> &FaultToleranceCoordinator<F> {
843        &self.fault_coordinator
844    }
845}
846
847#[cfg(test)]
848mod tests {
849    use super::*;
850    use approx::assert_relative_eq;
851    use scirs2_core::ndarray::Array2;
852
853    #[test]
854    fn test_distributed_kmeans_creation() {
855        let config = DistributedKMeansConfig::default();
856        let kmeans = DistributedKMeans::<f64>::new(3, config);
857
858        assert!(kmeans.is_ok());
859        let kmeans = kmeans.unwrap();
860        assert_eq!(kmeans.k, 3);
861        assert!(kmeans.centroids.is_none());
862    }
863
864    #[test]
865    fn test_input_validation() {
866        let config = DistributedKMeansConfig::default();
867        let kmeans = DistributedKMeans::<f64>::new(3, config).unwrap();
868
869        // Empty data
870        let empty_data = Array2::<f64>::zeros((0, 2));
871        assert!(kmeans.validate_input(empty_data.view()).is_err());
872
873        // Too few samples
874        let small_data = Array2::<f64>::zeros((2, 2));
875        assert!(kmeans.validate_input(small_data.view()).is_err());
876
877        // Valid data
878        let valid_data = Array2::<f64>::zeros((10, 2));
879        assert!(kmeans.validate_input(valid_data.view()).is_ok());
880    }
881
882    #[test]
883    fn test_random_initialization() {
884        let config = DistributedKMeansConfig::default();
885        let kmeans = DistributedKMeans::<f64>::new(3, config).unwrap();
886
887        let data = Array2::from_shape_vec((10, 2), (0..20).map(|x| x as f64).collect()).unwrap();
888
889        let centroids = kmeans.random_initialization(data.view()).unwrap();
890        assert_eq!(centroids.shape(), &[3, 2]);
891    }
892
893    #[test]
894    fn test_kmeans_plus_plus_initialization() {
895        let config = DistributedKMeansConfig::default();
896        let kmeans = DistributedKMeans::<f64>::new(2, config).unwrap();
897
898        let data = Array2::from_shape_vec(
899            (6, 2),
900            vec![
901                0.0, 0.0, 1.0, 1.0, 10.0, 10.0, 11.0, 11.0, 5.0, 5.0, 6.0, 6.0,
902            ],
903        )
904        .unwrap();
905
906        let centroids = kmeans.kmeans_plus_plus_initialization(data.view()).unwrap();
907        assert_eq!(centroids.shape(), &[2, 2]);
908
909        // Centroids should be different (with high probability)
910        let dist = euclidean_distance(centroids.row(0), centroids.row(1));
911        assert!(dist > 0.0);
912    }
913
914    #[test]
915    fn test_predict() {
916        let config = DistributedKMeansConfig::default();
917        let mut kmeans = DistributedKMeans::<f64>::new(2, config).unwrap();
918
919        // Set known centroids
920        let centroids = Array2::from_shape_vec((2, 2), vec![0.0, 0.0, 10.0, 10.0]).unwrap();
921        kmeans.centroids = Some(centroids);
922
923        // Test prediction
924        let test_data =
925            Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 9.0, 9.0, -1.0, -1.0, 11.0, 11.0])
926                .unwrap();
927
928        let labels = kmeans.predict(test_data.view()).unwrap();
929        assert_eq!(labels.len(), 4);
930
931        // Points should be assigned to nearest centroids
932        assert_eq!(labels[0], 0); // (1,1) closer to (0,0)
933        assert_eq!(labels[1], 1); // (9,9) closer to (10,10)
934        assert_eq!(labels[2], 0); // (-1,-1) closer to (0,0)
935        assert_eq!(labels[3], 1); // (11,11) closer to (10,10)
936    }
937
938    #[test]
939    fn test_convergence_check() {
940        let config = DistributedKMeansConfig {
941            tolerance: 0.1,
942            ..Default::default()
943        };
944        let kmeans = DistributedKMeans::<f64>::new(2, config).unwrap();
945
946        let old_centroids = Array2::from_shape_vec((2, 2), vec![0.0, 0.0, 1.0, 1.0]).unwrap();
947
948        let new_centroids_converged = Array2::from_shape_vec(
949            (2, 2),
950            vec![0.05, 0.05, 1.05, 1.05], // Small movement
951        )
952        .unwrap();
953
954        let new_centroids_not_converged = Array2::from_shape_vec(
955            (2, 2),
956            vec![0.5, 0.5, 1.5, 1.5], // Large movement
957        )
958        .unwrap();
959
960        // Set up kmeans with old centroids
961        let mut kmeans_converged = kmeans;
962        kmeans_converged.centroids = Some(old_centroids.clone());
963        kmeans_converged.global_inertia = 100.0;
964
965        // Test convergence
966        assert!(kmeans_converged
967            .check_convergence(&new_centroids_converged, 99.0)
968            .unwrap());
969
970        let mut kmeans_not_converged = DistributedKMeans::<f64>::new(
971            2,
972            DistributedKMeansConfig {
973                tolerance: 0.1,
974                ..Default::default()
975            },
976        )
977        .unwrap();
978        kmeans_not_converged.centroids = Some(old_centroids);
979        kmeans_not_converged.global_inertia = 100.0;
980
981        assert!(!kmeans_not_converged
982            .check_convergence(&new_centroids_not_converged, 50.0)
983            .unwrap());
984    }
985
986    #[test]
987    fn test_load_balance_score() {
988        let config = DistributedKMeansConfig::default();
989        let mut kmeans = DistributedKMeans::<f64>::new(2, config).unwrap();
990
991        // Balanced partitions
992        let partition1 = DataPartition::new(0, Array2::zeros((100, 2)), 0);
993        let partition2 = DataPartition::new(1, Array2::zeros((100, 2)), 1);
994        kmeans.partitions = vec![partition1, partition2];
995
996        let balanced_score = kmeans.calculate_load_balance_score();
997        assert!(balanced_score > 0.9);
998
999        // Imbalanced partitions
1000        let partition1 = DataPartition::new(0, Array2::zeros((10, 2)), 0);
1001        let partition2 = DataPartition::new(1, Array2::zeros((190, 2)), 1);
1002        kmeans.partitions = vec![partition1, partition2];
1003
1004        let imbalanced_score = kmeans.calculate_load_balance_score();
1005        assert!(imbalanced_score < balanced_score);
1006    }
1007}