1use scirs2_core::ndarray::{s, Array1, Array2, ArrayView1, ArrayView2, Axis};
7use scirs2_core::numeric::{Float, FromPrimitive, Zero};
8use scirs2_core::random::prelude::*;
9use scirs2_core::random::rand_prelude::IndexedRandom;
10use std::collections::HashMap;
11use std::fmt::Debug;
12use std::sync::{Arc, Mutex};
13use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
14
15use serde::{Deserialize, Serialize};
16
17use crate::error::{ClusteringError, Result};
18use crate::vq::euclidean_distance;
19
20use super::fault_tolerance::{DataPartition, FaultToleranceCoordinator};
21use super::load_balancing::LoadBalancingCoordinator;
22use super::message_passing::{ClusteringMessage, MessagePassingCoordinator, MessagePriority};
23use super::monitoring::PerformanceMonitor;
24use super::partitioning::{DataPartitioner, PartitioningConfig};
25
26#[derive(Debug)]
28pub struct DistributedKMeans<F: Float> {
29 pub k: usize,
31 pub config: DistributedKMeansConfig,
33 pub centroids: Option<Array2<F>>,
35 pub partitions: Vec<DataPartition<F>>,
37 pub fault_coordinator: FaultToleranceCoordinator<F>,
39 pub load_balancer: LoadBalancingCoordinator,
41 pub performance_monitor: PerformanceMonitor,
43 pub message_coordinator: Option<MessagePassingCoordinator<F>>,
45 pub partitioner: DataPartitioner<F>,
47 pub current_iteration: usize,
49 pub convergence_history: Vec<ConvergenceInfo>,
51 pub global_inertia: f64,
53}
54
55#[derive(Debug, Clone)]
57pub struct DistributedKMeansConfig {
58 pub max_iterations: usize,
59 pub tolerance: f64,
60 pub n_workers: usize,
61 pub init_method: InitializationMethod,
62 pub enable_fault_tolerance: bool,
63 pub enable_load_balancing: bool,
64 pub enable_monitoring: bool,
65 pub convergence_check_interval: usize,
66 pub checkpoint_interval: usize,
67 pub verbose: bool,
68 pub random_seed: Option<u64>,
69}
70
71impl Default for DistributedKMeansConfig {
72 fn default() -> Self {
73 Self {
74 max_iterations: 100,
75 tolerance: 1e-4,
76 n_workers: 4,
77 init_method: InitializationMethod::KMeansPlusPlus,
78 enable_fault_tolerance: true,
79 enable_load_balancing: true,
80 enable_monitoring: true,
81 convergence_check_interval: 5,
82 checkpoint_interval: 10,
83 verbose: false,
84 random_seed: None,
85 }
86 }
87}
88
89#[derive(Debug, Clone)]
91pub enum InitializationMethod {
92 Random,
94 KMeansPlusPlus,
96 Forgy,
98 Custom(Array2<f64>),
100}
101
102#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct ConvergenceInfo {
105 pub iteration: usize,
106 pub inertia: f64,
107 pub centroid_movement: f64,
108 pub converged: bool,
109 pub timestamp: SystemTime,
110 pub computation_time_ms: u64,
111}
112
113#[derive(Debug, Clone)]
115pub struct ClusteringResult<F: Float> {
116 pub centroids: Array2<F>,
118 pub labels: Array1<usize>,
120 pub inertia: f64,
122 pub n_iterations: usize,
124 pub convergence_info: ConvergenceInfo,
126 pub performance_stats: PerformanceStatistics,
128}
129
130#[derive(Debug, Clone)]
132pub struct PerformanceStatistics {
133 pub total_time_ms: u64,
134 pub communication_time_ms: u64,
135 pub computation_time_ms: u64,
136 pub synchronization_time_ms: u64,
137 pub worker_efficiency: f64,
138 pub load_balance_score: f64,
139 pub fault_tolerance_events: usize,
140}
141
142#[derive(Debug, Clone)]
144pub struct WorkerResult<F: Float> {
145 pub worker_id: usize,
146 pub local_centroids: Array2<F>,
147 pub local_labels: Array1<usize>,
148 pub local_inertia: f64,
149 pub point_counts: Array1<usize>,
150 pub computation_time_ms: u64,
151}
152
153impl<F: Float + FromPrimitive + Debug + Send + Sync + 'static> DistributedKMeans<F> {
154 pub fn new(k: usize, config: DistributedKMeansConfig) -> Result<Self> {
156 if k == 0 {
157 return Err(ClusteringError::InvalidInput(
158 "Number of clusters must be greater than 0".to_string(),
159 ));
160 }
161
162 if config.n_workers == 0 {
163 return Err(ClusteringError::InvalidInput(
164 "Number of workers must be greater than 0".to_string(),
165 ));
166 }
167
168 let partitioner_config = PartitioningConfig {
169 n_workers: config.n_workers,
170 ..Default::default()
171 };
172
173 let fault_tolerance_config = super::fault_tolerance::FaultToleranceConfig {
174 enabled: config.enable_fault_tolerance,
175 ..Default::default()
176 };
177
178 let load_balancing_config = super::load_balancing::LoadBalancingConfig {
179 enable_dynamic_balancing: config.enable_load_balancing,
180 ..Default::default()
181 };
182
183 let monitoring_config = super::monitoring::MonitoringConfig {
184 enable_detailed_monitoring: config.enable_monitoring,
185 ..Default::default()
186 };
187
188 Ok(Self {
189 k,
190 config,
191 centroids: None,
192 partitions: Vec::new(),
193 fault_coordinator: FaultToleranceCoordinator::new(fault_tolerance_config),
194 load_balancer: LoadBalancingCoordinator::new(load_balancing_config),
195 performance_monitor: PerformanceMonitor::new(monitoring_config),
196 message_coordinator: None,
197 partitioner: DataPartitioner::new(partitioner_config),
198 current_iteration: 0,
199 convergence_history: Vec::new(),
200 global_inertia: f64::INFINITY,
201 })
202 }
203
204 pub fn fit(&mut self, data: ArrayView2<F>) -> Result<ClusteringResult<F>> {
206 let start_time = Instant::now();
207 let mut stats = PerformanceStatistics {
208 total_time_ms: 0,
209 communication_time_ms: 0,
210 computation_time_ms: 0,
211 synchronization_time_ms: 0,
212 worker_efficiency: 0.0,
213 load_balance_score: 0.0,
214 fault_tolerance_events: 0,
215 };
216
217 self.validate_input(data)?;
219
220 self.initialize_workers()?;
222
223 let partition_start = Instant::now();
225 self.partitions = self.partitioner.partition_data(data)?;
226 stats.communication_time_ms += partition_start.elapsed().as_millis() as u64;
227
228 if self.config.verbose {
229 println!("Data partitioned across {} workers", self.config.n_workers);
230 }
231
232 let init_start = Instant::now();
234 self.centroids = Some(self.initialize_centroids(data)?);
235 stats.computation_time_ms += init_start.elapsed().as_millis() as u64;
236
237 let mut converged = false;
239 self.current_iteration = 0;
240
241 while self.current_iteration < self.config.max_iterations && !converged {
242 let iteration_start = Instant::now();
243
244 converged = self.perform_iteration(&mut stats)?;
246
247 let iteration_time = iteration_start.elapsed().as_millis() as u64;
249 self.update_convergence_history(iteration_time)?;
250
251 if self.config.enable_load_balancing && self.current_iteration.is_multiple_of(10) {
253 self.check_and_rebalance(data, &mut stats)?;
254 }
255
256 if self.config.enable_fault_tolerance
258 && self
259 .current_iteration
260 .is_multiple_of(self.config.checkpoint_interval)
261 {
262 self.create_checkpoint()?;
263 }
264
265 self.current_iteration += 1;
266
267 if self.config.verbose && self.current_iteration.is_multiple_of(10) {
268 println!(
269 "Iteration {}: inertia = {:.6}",
270 self.current_iteration, self.global_inertia
271 );
272 }
273 }
274
275 stats.total_time_ms = start_time.elapsed().as_millis() as u64;
277 stats.worker_efficiency = self.calculate_worker_efficiency();
278 stats.load_balance_score = self.calculate_load_balance_score();
279
280 let final_labels = self.collect_final_labels()?;
281 let final_convergence =
282 self.convergence_history
283 .last()
284 .cloned()
285 .unwrap_or_else(|| ConvergenceInfo {
286 iteration: self.current_iteration,
287 inertia: self.global_inertia,
288 centroid_movement: 0.0,
289 converged,
290 timestamp: SystemTime::now(),
291 computation_time_ms: 0,
292 });
293
294 Ok(ClusteringResult {
295 centroids: self.centroids.as_ref().unwrap().clone(),
296 labels: final_labels,
297 inertia: self.global_inertia,
298 n_iterations: self.current_iteration,
299 convergence_info: final_convergence,
300 performance_stats: stats,
301 })
302 }
303
304 fn validate_input(&self, data: ArrayView2<F>) -> Result<()> {
306 if data.nrows() == 0 {
307 return Err(ClusteringError::InvalidInput(
308 "Input data is empty".to_string(),
309 ));
310 }
311
312 if data.ncols() == 0 {
313 return Err(ClusteringError::InvalidInput(
314 "Input data has no features".to_string(),
315 ));
316 }
317
318 if data.nrows() < self.k {
319 return Err(ClusteringError::InvalidInput(format!(
320 "Number of samples ({}) must be at least k ({})",
321 data.nrows(),
322 self.k
323 )));
324 }
325
326 if data.nrows() < self.config.n_workers {
327 return Err(ClusteringError::InvalidInput(format!(
328 "Number of samples ({}) must be at least number of workers ({})",
329 data.nrows(),
330 self.config.n_workers
331 )));
332 }
333
334 Ok(())
335 }
336
337 fn initialize_workers(&mut self) -> Result<()> {
339 for worker_id in 0..self.config.n_workers {
341 self.fault_coordinator.register_worker(worker_id);
342 self.performance_monitor.register_worker(worker_id);
343 }
344
345 if self.config.n_workers > 1 {
347 let message_config = super::message_passing::MessagePassingConfig::default();
348 self.message_coordinator = Some(MessagePassingCoordinator::new(0, message_config));
349 }
350
351 Ok(())
352 }
353
354 fn initialize_centroids(&self, data: ArrayView2<F>) -> Result<Array2<F>> {
356 match &self.config.init_method {
357 InitializationMethod::Random => self.random_initialization(data),
358 InitializationMethod::KMeansPlusPlus => self.kmeans_plus_plus_initialization(data),
359 InitializationMethod::Forgy => self.forgy_initialization(data),
360 InitializationMethod::Custom(centroids) => {
361 if centroids.nrows() != self.k || centroids.ncols() != data.ncols() {
362 return Err(ClusteringError::InvalidInput(
363 "Custom centroids dimensions don't match".to_string(),
364 ));
365 }
366 let converted_centroids =
367 Array2::from_shape_fn((self.k, data.ncols()), |(i, j)| {
368 F::from(centroids[[i, j]]).unwrap_or_else(F::zero)
369 });
370 Ok(converted_centroids)
371 }
372 }
373 }
374
375 fn random_initialization(&self, data: ArrayView2<F>) -> Result<Array2<F>> {
377 use scirs2_core::random::seq::SliceRandom;
378
379 let mut rng = scirs2_core::random::rng();
380 let data_indices: Vec<usize> = (0..data.nrows()).collect();
381 let selected_indices: Vec<_> = data_indices
382 .as_slice()
383 .choose_multiple(&mut rng, self.k)
384 .cloned()
385 .collect();
386
387 let mut centroids = Array2::zeros((self.k, data.ncols()));
388 for (i, &data_idx) in selected_indices.iter().enumerate() {
389 centroids.row_mut(i).assign(&data.row(data_idx));
390 }
391
392 Ok(centroids)
393 }
394
395 fn kmeans_plus_plus_initialization(&self, data: ArrayView2<F>) -> Result<Array2<F>> {
397 use scirs2_core::random::Rng;
398
399 let mut rng = scirs2_core::random::rng();
400 let mut centroids = Array2::zeros((self.k, data.ncols()));
401
402 let first_idx = rng.random_range(0..data.nrows());
404 centroids.row_mut(0).assign(&data.row(first_idx));
405
406 for k in 1..self.k {
408 let mut distances = Array1::zeros(data.nrows());
409
410 for (i, point) in data.rows().into_iter().enumerate() {
412 let mut min_dist = F::infinity();
413 for centroid in centroids.rows().into_iter().take(k) {
414 let dist = euclidean_distance(point, centroid);
415 if dist < min_dist {
416 min_dist = dist;
417 }
418 }
419 distances[i] = min_dist.to_f64().unwrap_or(f64::INFINITY);
420 }
421
422 let total_dist: f64 = distances.iter().map(|&d| d * d).sum();
424 if total_dist <= 0.0 {
425 let random_idx = rng.random_range(0..data.nrows());
427 centroids.row_mut(k).assign(&data.row(random_idx));
428 } else {
429 let mut cumulative = 0.0;
430 let threshold = rng.random::<f64>() * total_dist;
431
432 let mut selected_idx = 0;
433 for (i, &dist) in distances.iter().enumerate() {
434 cumulative += dist * dist;
435 if cumulative >= threshold {
436 selected_idx = i;
437 break;
438 }
439 }
440 centroids.row_mut(k).assign(&data.row(selected_idx));
441 }
442 }
443
444 Ok(centroids)
445 }
446
447 fn forgy_initialization(&self, data: ArrayView2<F>) -> Result<Array2<F>> {
449 self.random_initialization(data)
451 }
452
453 fn perform_iteration(&mut self, stats: &mut PerformanceStatistics) -> Result<bool> {
455 let iteration_start = Instant::now();
456
457 if self.config.n_workers > 1 {
459 let broadcast_start = Instant::now();
460 self.broadcast_centroids()?;
461 stats.communication_time_ms += broadcast_start.elapsed().as_millis() as u64;
462 }
463
464 let compute_start = Instant::now();
466 let worker_results = self.compute_worker_assignments()?;
467 stats.computation_time_ms += compute_start.elapsed().as_millis() as u64;
468
469 let sync_start = Instant::now();
471 let (new_centroids, new_inertia) = self.aggregate_worker_results(&worker_results)?;
472 stats.synchronization_time_ms += sync_start.elapsed().as_millis() as u64;
473
474 let converged = self.check_convergence(&new_centroids, new_inertia)?;
476
477 self.centroids = Some(new_centroids);
479 self.global_inertia = new_inertia;
480
481 Ok(converged)
482 }
483
484 fn broadcast_centroids(&mut self) -> Result<()> {
486 if let (Some(ref centroids), Some(ref mut coordinator)) =
487 (&self.centroids, &mut self.message_coordinator)
488 {
489 let message = ClusteringMessage::UpdateCentroids {
490 round: self.current_iteration,
491 centroids: centroids.clone(),
492 };
493
494 coordinator.broadcast_message(message, MessagePriority::Normal)?;
495 }
496
497 Ok(())
498 }
499
500 fn compute_worker_assignments(&mut self) -> Result<Vec<WorkerResult<F>>> {
502 let mut results = Vec::new();
503
504 if let Some(ref centroids) = self.centroids {
505 for partition in &self.partitions {
506 let worker_start = Instant::now();
507
508 let mut labels = Array1::zeros(partition.data.nrows());
510 let mut local_inertia = F::zero();
511
512 for (i, point) in partition.data.rows().into_iter().enumerate() {
513 let mut min_dist = F::infinity();
514 let mut best_cluster = 0;
515
516 for (j, centroid) in centroids.rows().into_iter().enumerate() {
517 let dist = euclidean_distance(point, centroid);
518 if dist < min_dist {
519 min_dist = dist;
520 best_cluster = j;
521 }
522 }
523
524 labels[i] = best_cluster;
525 local_inertia = local_inertia + min_dist * min_dist;
526 }
527
528 let mut local_centroids = Array2::zeros((self.k, partition.data.ncols()));
530 let mut point_counts = Array1::zeros(self.k);
531
532 for (i, point) in partition.data.rows().into_iter().enumerate() {
533 let cluster = labels[i];
534 point_counts[cluster] += 1;
535
536 for (j, &value) in point.iter().enumerate() {
537 local_centroids[[cluster, j]] = local_centroids[[cluster, j]] + value;
538 }
539 }
540
541 for k in 0..self.k {
543 if point_counts[k] > 0 {
544 let count = F::from(point_counts[k]).unwrap();
545 for j in 0..partition.data.ncols() {
546 local_centroids[[k, j]] = local_centroids[[k, j]] / count;
547 }
548 }
549 }
550
551 let computation_time = worker_start.elapsed().as_millis() as u64;
552
553 results.push(WorkerResult {
554 worker_id: partition.workerid,
555 local_centroids,
556 local_labels: labels,
557 local_inertia: local_inertia.to_f64().unwrap_or(f64::INFINITY),
558 point_counts,
559 computation_time_ms: computation_time,
560 });
561
562 let throughput = partition.data.nrows() as f64 / (computation_time as f64 / 1000.0);
564 let efficiency = 1.0 / (1.0 + computation_time as f64 / 10000.0); self.performance_monitor.update_worker_metrics(
566 partition.workerid,
567 0.5, 0.4, throughput,
570 computation_time as f64,
571 )?;
572 }
573 }
574
575 Ok(results)
576 }
577
578 fn aggregate_worker_results(
580 &self,
581 worker_results: &[WorkerResult<F>],
582 ) -> Result<(Array2<F>, f64)> {
583 if worker_results.is_empty() {
584 return Err(ClusteringError::InvalidInput(
585 "No worker results to aggregate".to_string(),
586 ));
587 }
588
589 let n_features = worker_results[0].local_centroids.ncols();
590 let mut global_centroids = Array2::zeros((self.k, n_features));
591 let mut global_counts: Array1<usize> = Array1::zeros(self.k);
592 let mut global_inertia = 0.0;
593
594 for result in worker_results {
596 global_inertia += result.local_inertia;
597
598 for k in 0..self.k {
599 let count = F::from(result.point_counts[k]).unwrap();
600 global_counts[k] += result.point_counts[k];
601
602 for j in 0..n_features {
603 global_centroids[[k, j]] =
604 global_centroids[[k, j]] + result.local_centroids[[k, j]] * count;
605 }
606 }
607 }
608
609 for k in 0..self.k {
611 if global_counts[k] > 0 {
612 let count = F::from(global_counts[k]).unwrap();
613 for j in 0..n_features {
614 global_centroids[[k, j]] = global_centroids[[k, j]] / count;
615 }
616 }
617 }
618
619 Ok((global_centroids, global_inertia))
620 }
621
622 fn check_convergence(&self, new_centroids: &Array2<F>, newinertia: f64) -> Result<bool> {
624 if let Some(ref old_centroids) = self.centroids {
625 let mut max_movement = F::zero();
627 for (old_row, new_row) in old_centroids.rows().into_iter().zip(new_centroids.rows()) {
628 let movement = euclidean_distance(old_row, new_row);
629 if movement > max_movement {
630 max_movement = movement;
631 }
632 }
633
634 let movement_converged =
636 max_movement.to_f64().unwrap_or(f64::INFINITY) < self.config.tolerance;
637 let inertia_change = (self.global_inertia - newinertia).abs();
638 let inertia_converged =
639 inertia_change < self.config.tolerance * self.global_inertia.abs();
640
641 Ok(movement_converged || inertia_converged)
642 } else {
643 Ok(false)
644 }
645 }
646
647 fn update_convergence_history(&mut self, iteration_timems: u64) -> Result<()> {
649 let centroid_movement = if let Some(ref centroids) = self.centroids {
650 if self.convergence_history.is_empty() {
651 0.0
652 } else {
653 self.config.tolerance * 2.0 }
656 } else {
657 0.0
658 };
659
660 let converged = self.current_iteration >= self.config.max_iterations
661 || centroid_movement < self.config.tolerance;
662
663 let convergence_info = ConvergenceInfo {
664 iteration: self.current_iteration,
665 inertia: self.global_inertia,
666 centroid_movement,
667 converged,
668 timestamp: SystemTime::now(),
669 computation_time_ms: iteration_timems,
670 };
671
672 self.convergence_history.push(convergence_info);
673
674 Ok(())
675 }
676
677 fn check_and_rebalance(
679 &mut self,
680 data: ArrayView2<F>,
681 stats: &mut PerformanceStatistics,
682 ) -> Result<()> {
683 if !self.config.enable_load_balancing {
684 return Ok(());
685 }
686
687 if self.fault_coordinator.should_rebalance() {
689 let rebalance_start = Instant::now();
690
691 self.partitions = self.partitioner.partition_data(data)?;
693
694 stats.communication_time_ms += rebalance_start.elapsed().as_millis() as u64;
695 stats.fault_tolerance_events += 1;
696
697 if self.config.verbose {
698 println!(
699 "Load rebalancing performed at iteration {}",
700 self.current_iteration
701 );
702 }
703 }
704
705 Ok(())
706 }
707
708 fn create_checkpoint(&mut self) -> Result<()> {
710 if !self.config.enable_fault_tolerance {
711 return Ok(());
712 }
713
714 let worker_assignments = self
715 .partitions
716 .iter()
717 .map(|p| (p.workerid, vec![p.partition_id]))
718 .collect();
719
720 self.fault_coordinator.create_checkpoint(
721 self.current_iteration,
722 self.centroids.as_ref(),
723 self.global_inertia,
724 &[], &worker_assignments,
726 );
727
728 Ok(())
729 }
730
731 fn calculate_worker_efficiency(&self) -> f64 {
733 let worker_metrics = self.performance_monitor.get_worker_metrics();
734 if worker_metrics.is_empty() {
735 return 0.0;
736 }
737
738 let avg_health_score = worker_metrics.values().map(|m| m.health_score).sum::<f64>()
739 / worker_metrics.len() as f64;
740
741 avg_health_score
742 }
743
744 fn calculate_load_balance_score(&self) -> f64 {
746 if self.partitions.is_empty() {
747 return 1.0;
748 }
749
750 let partition_sizes: Vec<usize> = self.partitions.iter().map(|p| p.data.nrows()).collect();
751 let avg_size = partition_sizes.iter().sum::<usize>() as f64 / partition_sizes.len() as f64;
752
753 if avg_size == 0.0 {
754 return 1.0;
755 }
756
757 let variance = partition_sizes
758 .iter()
759 .map(|&size| (size as f64 - avg_size).powi(2))
760 .sum::<f64>()
761 / partition_sizes.len() as f64;
762
763 let coefficient_of_variation = variance.sqrt() / avg_size;
764 1.0 / (1.0 + coefficient_of_variation)
765 }
766
767 fn collect_final_labels(&self) -> Result<Array1<usize>> {
769 let total_points: usize = self.partitions.iter().map(|p| p.data.nrows()).sum();
770 let mut labels = Array1::zeros(total_points);
771 let mut offset = 0;
772
773 for partition in &self.partitions {
776 if let Some(ref partition_labels) = partition.labels {
777 let end_offset = offset + partition_labels.len();
778 labels
779 .slice_mut(s![offset..end_offset])
780 .assign(&Array1::from_vec(partition_labels.clone()).view());
781 offset = end_offset;
782 }
783 }
784
785 Ok(labels)
786 }
787
788 pub fn predict(&self, data: ArrayView2<F>) -> Result<Array1<usize>> {
790 if let Some(ref centroids) = self.centroids {
791 let mut labels = Array1::zeros(data.nrows());
792
793 for (i, point) in data.rows().into_iter().enumerate() {
794 let mut min_dist = F::infinity();
795 let mut best_cluster = 0;
796
797 for (j, centroid) in centroids.rows().into_iter().enumerate() {
798 let dist = euclidean_distance(point, centroid);
799 if dist < min_dist {
800 min_dist = dist;
801 best_cluster = j;
802 }
803 }
804
805 labels[i] = best_cluster;
806 }
807
808 Ok(labels)
809 } else {
810 Err(ClusteringError::InvalidInput(
811 "Model has not been fitted yet".to_string(),
812 ))
813 }
814 }
815
816 pub fn centroids(&self) -> Option<&Array2<F>> {
818 self.centroids.as_ref()
819 }
820
821 pub fn convergence_history(&self) -> &[ConvergenceInfo] {
823 &self.convergence_history
824 }
825
826 pub fn inertia(&self) -> f64 {
828 self.global_inertia
829 }
830
831 pub fn n_iterations(&self) -> usize {
833 self.current_iteration
834 }
835
836 pub fn performance_monitor(&self) -> &PerformanceMonitor {
838 &self.performance_monitor
839 }
840
841 pub fn fault_coordinator(&self) -> &FaultToleranceCoordinator<F> {
843 &self.fault_coordinator
844 }
845}
846
847#[cfg(test)]
848mod tests {
849 use super::*;
850 use approx::assert_relative_eq;
851 use scirs2_core::ndarray::Array2;
852
853 #[test]
854 fn test_distributed_kmeans_creation() {
855 let config = DistributedKMeansConfig::default();
856 let kmeans = DistributedKMeans::<f64>::new(3, config);
857
858 assert!(kmeans.is_ok());
859 let kmeans = kmeans.unwrap();
860 assert_eq!(kmeans.k, 3);
861 assert!(kmeans.centroids.is_none());
862 }
863
864 #[test]
865 fn test_input_validation() {
866 let config = DistributedKMeansConfig::default();
867 let kmeans = DistributedKMeans::<f64>::new(3, config).unwrap();
868
869 let empty_data = Array2::<f64>::zeros((0, 2));
871 assert!(kmeans.validate_input(empty_data.view()).is_err());
872
873 let small_data = Array2::<f64>::zeros((2, 2));
875 assert!(kmeans.validate_input(small_data.view()).is_err());
876
877 let valid_data = Array2::<f64>::zeros((10, 2));
879 assert!(kmeans.validate_input(valid_data.view()).is_ok());
880 }
881
882 #[test]
883 fn test_random_initialization() {
884 let config = DistributedKMeansConfig::default();
885 let kmeans = DistributedKMeans::<f64>::new(3, config).unwrap();
886
887 let data = Array2::from_shape_vec((10, 2), (0..20).map(|x| x as f64).collect()).unwrap();
888
889 let centroids = kmeans.random_initialization(data.view()).unwrap();
890 assert_eq!(centroids.shape(), &[3, 2]);
891 }
892
893 #[test]
894 fn test_kmeans_plus_plus_initialization() {
895 let config = DistributedKMeansConfig::default();
896 let kmeans = DistributedKMeans::<f64>::new(2, config).unwrap();
897
898 let data = Array2::from_shape_vec(
899 (6, 2),
900 vec![
901 0.0, 0.0, 1.0, 1.0, 10.0, 10.0, 11.0, 11.0, 5.0, 5.0, 6.0, 6.0,
902 ],
903 )
904 .unwrap();
905
906 let centroids = kmeans.kmeans_plus_plus_initialization(data.view()).unwrap();
907 assert_eq!(centroids.shape(), &[2, 2]);
908
909 let dist = euclidean_distance(centroids.row(0), centroids.row(1));
911 assert!(dist > 0.0);
912 }
913
914 #[test]
915 fn test_predict() {
916 let config = DistributedKMeansConfig::default();
917 let mut kmeans = DistributedKMeans::<f64>::new(2, config).unwrap();
918
919 let centroids = Array2::from_shape_vec((2, 2), vec![0.0, 0.0, 10.0, 10.0]).unwrap();
921 kmeans.centroids = Some(centroids);
922
923 let test_data =
925 Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 9.0, 9.0, -1.0, -1.0, 11.0, 11.0])
926 .unwrap();
927
928 let labels = kmeans.predict(test_data.view()).unwrap();
929 assert_eq!(labels.len(), 4);
930
931 assert_eq!(labels[0], 0); assert_eq!(labels[1], 1); assert_eq!(labels[2], 0); assert_eq!(labels[3], 1); }
937
938 #[test]
939 fn test_convergence_check() {
940 let config = DistributedKMeansConfig {
941 tolerance: 0.1,
942 ..Default::default()
943 };
944 let kmeans = DistributedKMeans::<f64>::new(2, config).unwrap();
945
946 let old_centroids = Array2::from_shape_vec((2, 2), vec![0.0, 0.0, 1.0, 1.0]).unwrap();
947
948 let new_centroids_converged = Array2::from_shape_vec(
949 (2, 2),
950 vec![0.05, 0.05, 1.05, 1.05], )
952 .unwrap();
953
954 let new_centroids_not_converged = Array2::from_shape_vec(
955 (2, 2),
956 vec![0.5, 0.5, 1.5, 1.5], )
958 .unwrap();
959
960 let mut kmeans_converged = kmeans;
962 kmeans_converged.centroids = Some(old_centroids.clone());
963 kmeans_converged.global_inertia = 100.0;
964
965 assert!(kmeans_converged
967 .check_convergence(&new_centroids_converged, 99.0)
968 .unwrap());
969
970 let mut kmeans_not_converged = DistributedKMeans::<f64>::new(
971 2,
972 DistributedKMeansConfig {
973 tolerance: 0.1,
974 ..Default::default()
975 },
976 )
977 .unwrap();
978 kmeans_not_converged.centroids = Some(old_centroids);
979 kmeans_not_converged.global_inertia = 100.0;
980
981 assert!(!kmeans_not_converged
982 .check_convergence(&new_centroids_not_converged, 50.0)
983 .unwrap());
984 }
985
986 #[test]
987 fn test_load_balance_score() {
988 let config = DistributedKMeansConfig::default();
989 let mut kmeans = DistributedKMeans::<f64>::new(2, config).unwrap();
990
991 let partition1 = DataPartition::new(0, Array2::zeros((100, 2)), 0);
993 let partition2 = DataPartition::new(1, Array2::zeros((100, 2)), 1);
994 kmeans.partitions = vec![partition1, partition2];
995
996 let balanced_score = kmeans.calculate_load_balance_score();
997 assert!(balanced_score > 0.9);
998
999 let partition1 = DataPartition::new(0, Array2::zeros((10, 2)), 0);
1001 let partition2 = DataPartition::new(1, Array2::zeros((190, 2)), 1);
1002 kmeans.partitions = vec![partition1, partition2];
1003
1004 let imbalanced_score = kmeans.calculate_load_balance_score();
1005 assert!(imbalanced_score < balanced_score);
1006 }
1007}