rustkernel_ml/
clustering.rs

1//! Clustering kernels.
2//!
3//! This module provides machine learning clustering algorithms:
4//! - K-Means (Lloyd's algorithm with K-Means++ initialization)
5//! - DBSCAN (density-based clustering)
6//! - Hierarchical clustering (agglomerative)
7
8use crate::ring_messages::{
9    K2KCentroidAggregation, K2KCentroidBroadcast, K2KCentroidBroadcastAck, K2KKMeansSync,
10    K2KKMeansSyncResponse, K2KPartialCentroid, KMeansAssignResponse, KMeansAssignRing,
11    KMeansQueryResponse, KMeansQueryRing, KMeansUpdateResponse, KMeansUpdateRing, from_fixed_point,
12    to_fixed_point, unpack_coordinates,
13};
14use crate::types::{ClusteringResult, DataMatrix, DistanceMetric};
15use rand::prelude::*;
16use ringkernel_core::RingContext;
17use rustkernel_core::traits::RingKernelHandler;
18use rustkernel_core::{domain::Domain, kernel::KernelMetadata, traits::GpuKernel};
19
20// ============================================================================
21// K-Means Clustering Kernel
22// ============================================================================
23
24/// K-Means clustering state for Ring mode operations.
25#[derive(Debug, Clone, Default)]
26pub struct KMeansState {
27    /// Current centroids (k * n_features).
28    pub centroids: Vec<f64>,
29    /// Input data reference (stored for query operations).
30    pub data: Option<DataMatrix>,
31    /// Number of clusters.
32    pub k: usize,
33    /// Number of features per point.
34    pub n_features: usize,
35    /// Current iteration.
36    pub iteration: u32,
37    /// Current inertia (sum of squared distances).
38    pub inertia: f64,
39    /// Whether converged.
40    pub converged: bool,
41    /// Current cluster assignments.
42    pub labels: Vec<usize>,
43}
44
45/// K-Means clustering kernel.
46///
47/// Implements Lloyd's algorithm with K-Means++ initialization.
48#[derive(Debug)]
49pub struct KMeans {
50    metadata: KernelMetadata,
51    /// Internal state for Ring mode operations.
52    state: std::sync::RwLock<KMeansState>,
53}
54
55impl Clone for KMeans {
56    fn clone(&self) -> Self {
57        Self {
58            metadata: self.metadata.clone(),
59            state: std::sync::RwLock::new(self.state.read().unwrap().clone()),
60        }
61    }
62}
63
64impl Default for KMeans {
65    fn default() -> Self {
66        Self::new()
67    }
68}
69
70impl KMeans {
71    /// Create a new K-Means kernel.
72    #[must_use]
73    pub fn new() -> Self {
74        Self {
75            metadata: KernelMetadata::batch("ml/kmeans-cluster", Domain::StatisticalML)
76                .with_description("K-Means clustering with K-Means++ initialization")
77                .with_throughput(20_000)
78                .with_latency_us(50.0),
79            state: std::sync::RwLock::new(KMeansState::default()),
80        }
81    }
82
83    /// Initialize the kernel with data and k for Ring mode operations.
84    pub fn initialize(&self, data: DataMatrix, k: usize) {
85        let centroids = Self::kmeans_plus_plus_init(&data, k);
86        let n = data.n_samples;
87        let n_features = data.n_features;
88
89        let mut state = self.state.write().unwrap();
90        *state = KMeansState {
91            centroids,
92            data: Some(data),
93            k,
94            n_features,
95            iteration: 0,
96            inertia: 0.0,
97            converged: false,
98            labels: vec![0; n],
99        };
100    }
101
102    /// Perform one E-step (assignment) on internal state.
103    /// Returns the total inertia (sum of squared distances).
104    pub fn assign_step(&self) -> f64 {
105        let mut state = self.state.write().unwrap();
106
107        // Check if data exists
108        let data = match state.data {
109            Some(ref d) => d.clone(),
110            None => return 0.0,
111        };
112
113        let n = data.n_samples;
114        let d_features = state.n_features;
115        let mut total_inertia = 0.0;
116
117        // Clone centroids to avoid borrow conflict
118        let centroids = state.centroids.clone();
119
120        // Compute assignments
121        let mut new_labels = vec![0usize; n];
122        for i in 0..n {
123            let point = data.row(i);
124            let mut min_dist = f64::MAX;
125            let mut min_cluster = 0;
126
127            for (c, centroid) in centroids.chunks(d_features).enumerate() {
128                let dist = Self::euclidean_distance(point, centroid);
129                if dist < min_dist {
130                    min_dist = dist;
131                    min_cluster = c;
132                }
133            }
134            new_labels[i] = min_cluster;
135            total_inertia += min_dist * min_dist;
136        }
137
138        // Update state
139        state.labels = new_labels;
140        state.inertia = total_inertia;
141        total_inertia
142    }
143
144    /// Perform one M-step (centroid update) on internal state.
145    /// Returns the maximum centroid shift.
146    pub fn update_step(&self) -> f64 {
147        let mut state = self.state.write().unwrap();
148        let Some(ref data) = state.data else {
149            return 0.0;
150        };
151
152        let n = data.n_samples;
153        let d = state.n_features;
154        let k = state.k;
155
156        let mut new_centroids = vec![0.0f64; k * d];
157        let mut counts = vec![0usize; k];
158
159        for i in 0..n {
160            let cluster = state.labels[i];
161            counts[cluster] += 1;
162            let point = data.row(i);
163            for j in 0..d {
164                new_centroids[cluster * d + j] += point[j];
165            }
166        }
167
168        // Normalize centroids
169        for c in 0..k {
170            if counts[c] > 0 {
171                for j in 0..d {
172                    new_centroids[c * d + j] /= counts[c] as f64;
173                }
174            }
175        }
176
177        // Calculate maximum shift
178        let max_shift = state
179            .centroids
180            .chunks(d)
181            .zip(new_centroids.chunks(d))
182            .map(|(old, new)| Self::euclidean_distance(old, new))
183            .fold(0.0f64, f64::max);
184
185        state.centroids = new_centroids;
186        state.iteration += 1;
187        max_shift
188    }
189
190    /// Query the nearest cluster for a point.
191    pub fn query_point(&self, point: &[f64]) -> (usize, f64) {
192        let state = self.state.read().unwrap();
193        let d = state.n_features;
194
195        let mut min_dist = f64::MAX;
196        let mut min_cluster = 0;
197
198        for (c, centroid) in state.centroids.chunks(d).enumerate() {
199            let dist = Self::euclidean_distance(point, centroid);
200            if dist < min_dist {
201                min_dist = dist;
202                min_cluster = c;
203            }
204        }
205
206        (min_cluster, min_dist)
207    }
208
209    /// Get current iteration count.
210    pub fn current_iteration(&self) -> u32 {
211        self.state.read().unwrap().iteration
212    }
213
214    /// Get current inertia.
215    pub fn current_inertia(&self) -> f64 {
216        self.state.read().unwrap().inertia
217    }
218
219    /// Run K-Means clustering.
220    ///
221    /// # Arguments
222    /// * `data` - Input data matrix (n_samples x n_features)
223    /// * `k` - Number of clusters
224    /// * `max_iterations` - Maximum number of iterations
225    /// * `tolerance` - Convergence threshold for centroid movement
226    pub fn compute(
227        data: &DataMatrix,
228        k: usize,
229        max_iterations: u32,
230        tolerance: f64,
231    ) -> ClusteringResult {
232        let n = data.n_samples;
233        let d = data.n_features;
234
235        if n == 0 || k == 0 || k > n {
236            return ClusteringResult {
237                labels: Vec::new(),
238                n_clusters: 0,
239                centroids: Vec::new(),
240                inertia: 0.0,
241                iterations: 0,
242                converged: true,
243            };
244        }
245
246        // K-Means++ initialization
247        let mut centroids = Self::kmeans_plus_plus_init(data, k);
248        let mut labels = vec![0usize; n];
249        let mut converged = false;
250        let mut iterations = 0u32;
251
252        for iter in 0..max_iterations {
253            iterations = iter + 1;
254
255            // Assignment step: assign each point to nearest centroid
256            for i in 0..n {
257                let point = data.row(i);
258                let mut min_dist = f64::MAX;
259                let mut min_cluster = 0;
260
261                for (c, centroid) in centroids.chunks(d).enumerate() {
262                    let dist = Self::euclidean_distance(point, centroid);
263                    if dist < min_dist {
264                        min_dist = dist;
265                        min_cluster = c;
266                    }
267                }
268                labels[i] = min_cluster;
269            }
270
271            // Update step: recalculate centroids
272            let mut new_centroids = vec![0.0f64; k * d];
273            let mut counts = vec![0usize; k];
274
275            for i in 0..n {
276                let cluster = labels[i];
277                counts[cluster] += 1;
278                let point = data.row(i);
279                for j in 0..d {
280                    new_centroids[cluster * d + j] += point[j];
281                }
282            }
283
284            // Normalize centroids
285            for c in 0..k {
286                if counts[c] > 0 {
287                    for j in 0..d {
288                        new_centroids[c * d + j] /= counts[c] as f64;
289                    }
290                }
291            }
292
293            // Check convergence
294            let max_shift = centroids
295                .chunks(d)
296                .zip(new_centroids.chunks(d))
297                .map(|(old, new)| Self::euclidean_distance(old, new))
298                .fold(0.0f64, f64::max);
299
300            centroids = new_centroids;
301
302            if max_shift < tolerance {
303                converged = true;
304                break;
305            }
306        }
307
308        // Calculate inertia (sum of squared distances to centroids)
309        let inertia: f64 = (0..n)
310            .map(|i| {
311                let point = data.row(i);
312                let centroid_start = labels[i] * d;
313                let centroid = &centroids[centroid_start..centroid_start + d];
314                let dist = Self::euclidean_distance(point, centroid);
315                dist * dist
316            })
317            .sum();
318
319        ClusteringResult {
320            labels,
321            n_clusters: k,
322            centroids,
323            inertia,
324            iterations,
325            converged,
326        }
327    }
328
329    /// K-Means++ initialization.
330    fn kmeans_plus_plus_init(data: &DataMatrix, k: usize) -> Vec<f64> {
331        let n = data.n_samples;
332        let d = data.n_features;
333        let mut rng = rand::rng();
334        let mut centroids = Vec::with_capacity(k * d);
335
336        // Choose first centroid randomly
337        let first_idx = rng.random_range(0..n);
338        centroids.extend_from_slice(data.row(first_idx));
339
340        let mut distances = vec![f64::MAX; n];
341
342        // Choose remaining centroids
343        for _ in 1..k {
344            // Update distances to nearest centroid
345            for i in 0..n {
346                let point = data.row(i);
347                let last_centroid = &centroids[centroids.len() - d..];
348                let dist = Self::euclidean_distance(point, last_centroid);
349                distances[i] = distances[i].min(dist);
350            }
351
352            // Choose next centroid with probability proportional to D^2
353            let total: f64 = distances.iter().map(|d| d * d).sum();
354            let threshold = rng.random::<f64>() * total;
355
356            let mut cumsum = 0.0;
357            let mut next_idx = 0;
358            for (i, &dist) in distances.iter().enumerate() {
359                cumsum += dist * dist;
360                if cumsum >= threshold {
361                    next_idx = i;
362                    break;
363                }
364            }
365
366            centroids.extend_from_slice(data.row(next_idx));
367        }
368
369        centroids
370    }
371
372    /// Euclidean distance between two vectors.
373    fn euclidean_distance(a: &[f64], b: &[f64]) -> f64 {
374        a.iter()
375            .zip(b.iter())
376            .map(|(x, y)| (x - y).powi(2))
377            .sum::<f64>()
378            .sqrt()
379    }
380}
381
382impl GpuKernel for KMeans {
383    fn metadata(&self) -> &KernelMetadata {
384        &self.metadata
385    }
386}
387
388// ============================================================================
389// KMeans RingKernelHandler Implementations
390// ============================================================================
391
392/// RingKernelHandler for KMeans assignment step (E-step).
393#[async_trait::async_trait]
394impl RingKernelHandler<KMeansAssignRing, KMeansAssignResponse> for KMeans {
395    async fn handle(
396        &self,
397        _ctx: &mut RingContext,
398        msg: KMeansAssignRing,
399    ) -> Result<KMeansAssignResponse> {
400        // Perform assignment step on internal state
401        let inertia = self.assign_step();
402
403        let state = self.state.read().unwrap();
404        let points_assigned = state.labels.len() as u32;
405
406        Ok(KMeansAssignResponse {
407            request_id: msg.id.0,
408            iteration: msg.iteration,
409            inertia_fp: to_fixed_point(inertia),
410            points_assigned,
411        })
412    }
413}
414
415/// RingKernelHandler for KMeans update step (M-step).
416#[async_trait::async_trait]
417impl RingKernelHandler<KMeansUpdateRing, KMeansUpdateResponse> for KMeans {
418    async fn handle(
419        &self,
420        _ctx: &mut RingContext,
421        msg: KMeansUpdateRing,
422    ) -> Result<KMeansUpdateResponse> {
423        // Perform update step on internal state
424        let max_shift = self.update_step();
425        let converged = max_shift < 1e-6;
426
427        // Update convergence status in state
428        if converged {
429            let mut state = self.state.write().unwrap();
430            state.converged = true;
431        }
432
433        Ok(KMeansUpdateResponse {
434            request_id: msg.id.0,
435            iteration: msg.iteration,
436            max_shift_fp: to_fixed_point(max_shift),
437            converged,
438        })
439    }
440}
441
442/// RingKernelHandler for point queries.
443#[async_trait::async_trait]
444impl RingKernelHandler<KMeansQueryRing, KMeansQueryResponse> for KMeans {
445    async fn handle(
446        &self,
447        _ctx: &mut RingContext,
448        msg: KMeansQueryRing,
449    ) -> Result<KMeansQueryResponse> {
450        // Unpack the query point coordinates
451        let point = unpack_coordinates(&msg.point, msg.n_dims as usize);
452
453        // Query the nearest cluster using internal state
454        let (cluster, distance) = self.query_point(&point);
455
456        Ok(KMeansQueryResponse {
457            request_id: msg.id.0,
458            cluster: cluster as u32,
459            distance_fp: to_fixed_point(distance),
460        })
461    }
462}
463
464/// RingKernelHandler for K2K partial centroid updates.
465///
466/// Aggregates partial centroid contributions from distributed workers.
467#[async_trait::async_trait]
468impl RingKernelHandler<K2KPartialCentroid, K2KCentroidAggregation> for KMeans {
469    async fn handle(
470        &self,
471        _ctx: &mut RingContext,
472        msg: K2KPartialCentroid,
473    ) -> Result<K2KCentroidAggregation> {
474        let n_dims = msg.n_dims as usize;
475        let cluster_id = msg.cluster_id as usize;
476        let mut new_centroid = [0i64; 8];
477
478        // Compute new centroid from partial sums
479        if msg.point_count > 0 {
480            for i in 0..n_dims.min(8) {
481                new_centroid[i] = msg.coord_sum_fp[i] / msg.point_count as i64;
482            }
483        }
484
485        // Calculate shift from old centroid in internal state
486        let shift = {
487            let state = self.state.read().unwrap();
488            let d = state.n_features;
489            if cluster_id < state.k && d > 0 {
490                let old_centroid = &state.centroids[cluster_id * d..(cluster_id + 1) * d];
491                let new_coords: Vec<f64> = new_centroid[..d.min(8)]
492                    .iter()
493                    .map(|&v| from_fixed_point(v))
494                    .collect();
495                Self::euclidean_distance(old_centroid, &new_coords)
496            } else {
497                0.0
498            }
499        };
500
501        Ok(K2KCentroidAggregation {
502            request_id: msg.id.0,
503            cluster_id: msg.cluster_id,
504            iteration: msg.iteration,
505            new_centroid_fp: new_centroid,
506            total_points: msg.point_count,
507            shift_fp: to_fixed_point(shift),
508        })
509    }
510}
511
512/// RingKernelHandler for K2K iteration sync.
513///
514/// Synchronizes distributed KMeans workers after each iteration.
515/// In a single-instance setting, validates iteration state and returns convergence status.
516#[async_trait::async_trait]
517impl RingKernelHandler<K2KKMeansSync, K2KKMeansSyncResponse> for KMeans {
518    async fn handle(
519        &self,
520        _ctx: &mut RingContext,
521        msg: K2KKMeansSync,
522    ) -> Result<K2KKMeansSyncResponse> {
523        let state = self.state.read().unwrap();
524
525        // Verify iteration matches internal state
526        let current_iteration = state.iteration as u64;
527        let all_synced = msg.iteration <= current_iteration;
528
529        // Use reported values for single-worker case
530        // In distributed setting, would aggregate across workers
531        let global_shift = from_fixed_point(msg.max_shift_fp);
532        let converged = global_shift < 1e-6 || state.converged;
533
534        Ok(K2KKMeansSyncResponse {
535            request_id: msg.id.0,
536            iteration: msg.iteration,
537            all_synced,
538            global_inertia_fp: msg.local_inertia_fp,
539            global_max_shift_fp: msg.max_shift_fp,
540            converged,
541        })
542    }
543}
544
545/// RingKernelHandler for K2K centroid broadcast.
546///
547/// Receives new centroids broadcast from coordinator.
548#[async_trait::async_trait]
549impl RingKernelHandler<K2KCentroidBroadcast, K2KCentroidBroadcastAck> for KMeans {
550    async fn handle(
551        &self,
552        _ctx: &mut RingContext,
553        msg: K2KCentroidBroadcast,
554    ) -> Result<K2KCentroidBroadcastAck> {
555        // In a distributed setting, this would update local centroids
556        Ok(K2KCentroidBroadcastAck {
557            request_id: msg.id.0,
558            worker_id: 0, // Would be actual worker ID
559            iteration: msg.iteration,
560            applied: true,
561        })
562    }
563}
564
565// ============================================================================
566// DBSCAN Clustering Kernel
567// ============================================================================
568
569/// DBSCAN clustering kernel.
570///
571/// Density-based spatial clustering of applications with noise.
572#[derive(Debug, Clone)]
573pub struct DBSCAN {
574    metadata: KernelMetadata,
575}
576
577impl Default for DBSCAN {
578    fn default() -> Self {
579        Self::new()
580    }
581}
582
583impl DBSCAN {
584    /// Create a new DBSCAN kernel.
585    #[must_use]
586    pub fn new() -> Self {
587        Self {
588            metadata: KernelMetadata::batch("ml/dbscan-cluster", Domain::StatisticalML)
589                .with_description("Density-based clustering with GPU union-find")
590                .with_throughput(1_000)
591                .with_latency_us(10_000.0),
592        }
593    }
594
595    /// Run DBSCAN clustering.
596    ///
597    /// # Arguments
598    /// * `data` - Input data matrix
599    /// * `eps` - Maximum distance for neighborhood
600    /// * `min_samples` - Minimum points to form a dense region
601    /// * `metric` - Distance metric to use
602    pub fn compute(
603        data: &DataMatrix,
604        eps: f64,
605        min_samples: usize,
606        metric: DistanceMetric,
607    ) -> ClusteringResult {
608        let n = data.n_samples;
609
610        if n == 0 {
611            return ClusteringResult {
612                labels: Vec::new(),
613                n_clusters: 0,
614                centroids: Vec::new(),
615                inertia: 0.0,
616                iterations: 1,
617                converged: true,
618            };
619        }
620
621        // -1 = unvisited, -2 = noise, >= 0 = cluster label
622        let mut labels = vec![-1i64; n];
623        let mut current_cluster = 0i64;
624
625        // Precompute neighborhoods (for efficiency)
626        let neighborhoods: Vec<Vec<usize>> = (0..n)
627            .map(|i| Self::get_neighbors(data, i, eps, metric))
628            .collect();
629
630        for i in 0..n {
631            if labels[i] != -1 {
632                continue; // Already processed
633            }
634
635            let neighbors = &neighborhoods[i];
636
637            if neighbors.len() < min_samples {
638                labels[i] = -2; // Mark as noise
639                continue;
640            }
641
642            // Start new cluster
643            labels[i] = current_cluster;
644            let mut seed_set: Vec<usize> = neighbors.clone();
645            let mut j = 0;
646
647            while j < seed_set.len() {
648                let q = seed_set[j];
649                j += 1;
650
651                if labels[q] == -2 {
652                    labels[q] = current_cluster; // Change noise to border
653                }
654
655                if labels[q] != -1 {
656                    continue; // Already processed
657                }
658
659                labels[q] = current_cluster;
660
661                let q_neighbors = &neighborhoods[q];
662                if q_neighbors.len() >= min_samples {
663                    // Add new neighbors to seed set
664                    for &neighbor in q_neighbors {
665                        if !seed_set.contains(&neighbor) {
666                            seed_set.push(neighbor);
667                        }
668                    }
669                }
670            }
671
672            current_cluster += 1;
673        }
674
675        // Convert labels to usize (noise stays as max value)
676        let n_clusters = current_cluster as usize;
677        let labels: Vec<usize> = labels
678            .iter()
679            .map(|&l| if l < 0 { usize::MAX } else { l as usize })
680            .collect();
681
682        // Calculate centroids for each cluster
683        let d = data.n_features;
684        let mut centroids = vec![0.0f64; n_clusters * d];
685        let mut counts = vec![0usize; n_clusters];
686
687        for i in 0..n {
688            if labels[i] < n_clusters {
689                let cluster = labels[i];
690                counts[cluster] += 1;
691                for j in 0..d {
692                    centroids[cluster * d + j] += data.row(i)[j];
693                }
694            }
695        }
696
697        for c in 0..n_clusters {
698            if counts[c] > 0 {
699                for j in 0..d {
700                    centroids[c * d + j] /= counts[c] as f64;
701                }
702            }
703        }
704
705        ClusteringResult {
706            labels,
707            n_clusters,
708            centroids,
709            inertia: 0.0,
710            iterations: 1,
711            converged: true,
712        }
713    }
714
715    /// Get neighbors within eps distance.
716    fn get_neighbors(
717        data: &DataMatrix,
718        point_idx: usize,
719        eps: f64,
720        metric: DistanceMetric,
721    ) -> Vec<usize> {
722        let n = data.n_samples;
723        let point = data.row(point_idx);
724
725        (0..n)
726            .filter(|&i| {
727                let other = data.row(i);
728                let dist = metric.compute(point, other);
729                dist <= eps
730            })
731            .collect()
732    }
733}
734
735impl GpuKernel for DBSCAN {
736    fn metadata(&self) -> &KernelMetadata {
737        &self.metadata
738    }
739}
740
741// ============================================================================
742// Hierarchical Clustering Kernel
743// ============================================================================
744
745/// Linkage method for hierarchical clustering.
746#[derive(Debug, Clone, Copy, PartialEq)]
747pub enum LinkageMethod {
748    /// Single linkage (minimum distance)
749    Single,
750    /// Complete linkage (maximum distance)
751    Complete,
752    /// Average linkage (UPGMA)
753    Average,
754    /// Ward's method (minimize variance)
755    Ward,
756}
757
758/// Hierarchical clustering kernel.
759///
760/// Agglomerative hierarchical clustering with various linkage methods.
761#[derive(Debug, Clone)]
762pub struct HierarchicalClustering {
763    metadata: KernelMetadata,
764}
765
766impl Default for HierarchicalClustering {
767    fn default() -> Self {
768        Self::new()
769    }
770}
771
772impl HierarchicalClustering {
773    /// Create a new hierarchical clustering kernel.
774    #[must_use]
775    pub fn new() -> Self {
776        Self {
777            metadata: KernelMetadata::batch("ml/hierarchical-cluster", Domain::StatisticalML)
778                .with_description("Agglomerative hierarchical clustering")
779                .with_throughput(500)
780                .with_latency_us(50_000.0),
781        }
782    }
783
784    /// Run hierarchical clustering.
785    ///
786    /// # Arguments
787    /// * `data` - Input data matrix
788    /// * `n_clusters` - Number of clusters to form
789    /// * `linkage` - Linkage method
790    /// * `metric` - Distance metric
791    pub fn compute(
792        data: &DataMatrix,
793        n_clusters: usize,
794        linkage: LinkageMethod,
795        metric: DistanceMetric,
796    ) -> ClusteringResult {
797        let n = data.n_samples;
798
799        if n == 0 || n_clusters == 0 {
800            return ClusteringResult {
801                labels: Vec::new(),
802                n_clusters: 0,
803                centroids: Vec::new(),
804                inertia: 0.0,
805                iterations: 0,
806                converged: true,
807            };
808        }
809
810        // Initialize each point as its own cluster
811        let mut labels: Vec<usize> = (0..n).collect();
812        let mut active_clusters: Vec<bool> = vec![true; n];
813        let mut cluster_sizes: Vec<usize> = vec![1; n];
814
815        // Compute initial distance matrix
816        let mut distances = Self::compute_distance_matrix(data, metric);
817
818        // Merge clusters until we have n_clusters
819        let mut current_n_clusters = n;
820
821        while current_n_clusters > n_clusters {
822            // Find closest pair of clusters
823            let (c1, c2) = Self::find_closest_clusters(&distances, &active_clusters, n);
824
825            if c1 == c2 {
826                break;
827            }
828
829            // Merge c2 into c1
830            for label in &mut labels {
831                if *label == c2 {
832                    *label = c1;
833                }
834            }
835
836            // Update distances based on linkage
837            Self::update_distances(
838                &mut distances,
839                c1,
840                c2,
841                n,
842                linkage,
843                &cluster_sizes,
844                &active_clusters,
845            );
846
847            cluster_sizes[c1] += cluster_sizes[c2];
848            active_clusters[c2] = false;
849            current_n_clusters -= 1;
850        }
851
852        // Renumber labels to be contiguous
853        let mut label_map = std::collections::HashMap::new();
854        let mut next_label = 0usize;
855
856        for label in &mut labels {
857            let new_label = *label_map.entry(*label).or_insert_with(|| {
858                let l = next_label;
859                next_label += 1;
860                l
861            });
862            *label = new_label;
863        }
864
865        // Calculate centroids
866        let d = data.n_features;
867        let final_n_clusters = next_label;
868        let mut centroids = vec![0.0f64; final_n_clusters * d];
869        let mut counts = vec![0usize; final_n_clusters];
870
871        for i in 0..n {
872            let cluster = labels[i];
873            counts[cluster] += 1;
874            for j in 0..d {
875                centroids[cluster * d + j] += data.row(i)[j];
876            }
877        }
878
879        for c in 0..final_n_clusters {
880            if counts[c] > 0 {
881                for j in 0..d {
882                    centroids[c * d + j] /= counts[c] as f64;
883                }
884            }
885        }
886
887        ClusteringResult {
888            labels,
889            n_clusters: final_n_clusters,
890            centroids,
891            inertia: 0.0,
892            iterations: (n - n_clusters) as u32,
893            converged: true,
894        }
895    }
896
897    fn compute_distance_matrix(data: &DataMatrix, metric: DistanceMetric) -> Vec<f64> {
898        let n = data.n_samples;
899        let mut distances = vec![f64::MAX; n * n];
900
901        for i in 0..n {
902            for j in 0..n {
903                if i != j {
904                    distances[i * n + j] = metric.compute(data.row(i), data.row(j));
905                }
906            }
907        }
908
909        distances
910    }
911
912    fn find_closest_clusters(distances: &[f64], active: &[bool], n: usize) -> (usize, usize) {
913        let mut min_dist = f64::MAX;
914        let mut min_i = 0;
915        let mut min_j = 0;
916
917        for i in 0..n {
918            if !active[i] {
919                continue;
920            }
921            for j in (i + 1)..n {
922                if !active[j] {
923                    continue;
924                }
925                let dist = distances[i * n + j];
926                if dist < min_dist {
927                    min_dist = dist;
928                    min_i = i;
929                    min_j = j;
930                }
931            }
932        }
933
934        (min_i, min_j)
935    }
936
937    fn update_distances(
938        distances: &mut [f64],
939        c1: usize,
940        c2: usize,
941        n: usize,
942        linkage: LinkageMethod,
943        cluster_sizes: &[usize],
944        active: &[bool],
945    ) {
946        for k in 0..n {
947            if !active[k] || k == c1 || k == c2 {
948                continue;
949            }
950
951            let d1 = distances[c1 * n + k];
952            let d2 = distances[c2 * n + k];
953
954            let new_dist = match linkage {
955                LinkageMethod::Single => d1.min(d2),
956                LinkageMethod::Complete => d1.max(d2),
957                LinkageMethod::Average => {
958                    let n1 = cluster_sizes[c1] as f64;
959                    let n2 = cluster_sizes[c2] as f64;
960                    (n1 * d1 + n2 * d2) / (n1 + n2)
961                }
962                LinkageMethod::Ward => {
963                    let n1 = cluster_sizes[c1] as f64;
964                    let n2 = cluster_sizes[c2] as f64;
965                    let nk = cluster_sizes[k] as f64;
966                    let total = n1 + n2 + nk;
967                    ((n1 + nk) * d1 * d1 + (n2 + nk) * d2 * d2
968                        - nk * distances[c1 * n + c2].powi(2))
969                        / total
970                }
971            };
972
973            distances[c1 * n + k] = new_dist;
974            distances[k * n + c1] = new_dist;
975        }
976    }
977}
978
979impl GpuKernel for HierarchicalClustering {
980    fn metadata(&self) -> &KernelMetadata {
981        &self.metadata
982    }
983}
984
985// ============================================================================
986// BatchKernel Implementations
987// ============================================================================
988
989use crate::messages::{
990    DBSCANInput, DBSCANOutput, HierarchicalInput, HierarchicalOutput, KMeansInput, KMeansOutput,
991    Linkage,
992};
993use async_trait::async_trait;
994use rustkernel_core::error::Result;
995use rustkernel_core::traits::BatchKernel;
996use std::time::Instant;
997
998/// K-Means batch kernel implementation.
999impl KMeans {
1000    /// Execute K-Means clustering as a batch operation.
1001    ///
1002    /// Convenience method for batch clustering.
1003    pub async fn cluster_batch(&self, input: KMeansInput) -> Result<KMeansOutput> {
1004        let start = Instant::now();
1005        let result = Self::compute(&input.data, input.k, input.max_iterations, input.tolerance);
1006        let compute_time_us = start.elapsed().as_micros() as u64;
1007
1008        Ok(KMeansOutput {
1009            result,
1010            compute_time_us,
1011        })
1012    }
1013}
1014
1015#[async_trait]
1016impl BatchKernel<KMeansInput, KMeansOutput> for KMeans {
1017    async fn execute(&self, input: KMeansInput) -> Result<KMeansOutput> {
1018        self.cluster_batch(input).await
1019    }
1020}
1021
1022/// DBSCAN batch kernel implementation.
1023#[async_trait]
1024impl BatchKernel<DBSCANInput, DBSCANOutput> for DBSCAN {
1025    async fn execute(&self, input: DBSCANInput) -> Result<DBSCANOutput> {
1026        let start = Instant::now();
1027        let result = Self::compute(&input.data, input.eps, input.min_samples, input.metric);
1028        let compute_time_us = start.elapsed().as_micros() as u64;
1029
1030        Ok(DBSCANOutput {
1031            result,
1032            compute_time_us,
1033        })
1034    }
1035}
1036
1037/// Hierarchical clustering batch kernel implementation.
1038#[async_trait]
1039impl BatchKernel<HierarchicalInput, HierarchicalOutput> for HierarchicalClustering {
1040    async fn execute(&self, input: HierarchicalInput) -> Result<HierarchicalOutput> {
1041        let start = Instant::now();
1042        let linkage_method = match input.linkage {
1043            Linkage::Single => LinkageMethod::Single,
1044            Linkage::Complete => LinkageMethod::Complete,
1045            Linkage::Average => LinkageMethod::Average,
1046            Linkage::Ward => LinkageMethod::Ward,
1047        };
1048        let result = Self::compute(&input.data, input.n_clusters, linkage_method, input.metric);
1049        let compute_time_us = start.elapsed().as_micros() as u64;
1050
1051        Ok(HierarchicalOutput {
1052            result,
1053            compute_time_us,
1054        })
1055    }
1056}
1057
1058#[cfg(test)]
1059mod tests {
1060    use super::*;
1061
1062    fn create_two_clusters() -> DataMatrix {
1063        // Two clear clusters
1064        DataMatrix::from_rows(&[
1065            &[0.0, 0.0],
1066            &[0.1, 0.1],
1067            &[0.2, 0.0],
1068            &[10.0, 10.0],
1069            &[10.1, 10.1],
1070            &[10.2, 10.0],
1071        ])
1072    }
1073
1074    #[test]
1075    fn test_kmeans_metadata() {
1076        let kernel = KMeans::new();
1077        assert_eq!(kernel.metadata().id, "ml/kmeans-cluster");
1078        assert_eq!(kernel.metadata().domain, Domain::StatisticalML);
1079    }
1080
1081    #[test]
1082    fn test_kmeans_two_clusters() {
1083        let data = create_two_clusters();
1084        let result = KMeans::compute(&data, 2, 100, 1e-6);
1085
1086        assert_eq!(result.n_clusters, 2);
1087        assert!(result.converged);
1088
1089        // Points 0,1,2 should be in one cluster, 3,4,5 in another
1090        assert_eq!(result.labels[0], result.labels[1]);
1091        assert_eq!(result.labels[1], result.labels[2]);
1092        assert_eq!(result.labels[3], result.labels[4]);
1093        assert_eq!(result.labels[4], result.labels[5]);
1094        assert_ne!(result.labels[0], result.labels[3]);
1095    }
1096
1097    #[test]
1098    fn test_dbscan_two_clusters() {
1099        let data = create_two_clusters();
1100        let result = DBSCAN::compute(&data, 1.0, 2, DistanceMetric::Euclidean);
1101
1102        assert_eq!(result.n_clusters, 2);
1103
1104        // Points should be grouped correctly
1105        assert_eq!(result.labels[0], result.labels[1]);
1106        assert_eq!(result.labels[3], result.labels[4]);
1107        assert_ne!(result.labels[0], result.labels[3]);
1108    }
1109
1110    #[test]
1111    fn test_hierarchical_two_clusters() {
1112        let data = create_two_clusters();
1113        let result = HierarchicalClustering::compute(
1114            &data,
1115            2,
1116            LinkageMethod::Complete,
1117            DistanceMetric::Euclidean,
1118        );
1119
1120        assert_eq!(result.n_clusters, 2);
1121
1122        // Points should be grouped correctly
1123        assert_eq!(result.labels[0], result.labels[1]);
1124        assert_eq!(result.labels[1], result.labels[2]);
1125        assert_eq!(result.labels[3], result.labels[4]);
1126        assert_ne!(result.labels[0], result.labels[3]);
1127    }
1128}