scirs2_metrics/clustering/
internal_metrics.rs

1//! Internal clustering metrics
2//!
3//! This module provides functions for evaluating clustering algorithms using internal metrics,
4//! which assess clustering quality without external ground truth. These include
5//! silhouette score, Davies-Bouldin index, Calinski-Harabasz index, and Dunn index.
6
7use scirs2_core::ndarray::{Array1, Array2, ArrayBase, Data, Dimension, Ix2};
8use scirs2_core::numeric::{Float, NumCast};
9use std::collections::HashMap;
10
11use super::{calculate_distance, group_by_labels, pairwise_distances};
12use crate::error::{MetricsError, Result};
13
14/// Structure containing detailed silhouette analysis results
15#[derive(Debug, Clone)]
16pub struct SilhouetteAnalysis<F: Float> {
17    /// Sample-wise silhouette scores
18    pub sample_values: Vec<F>,
19
20    /// Mean silhouette score for all samples
21    pub mean_score: F,
22
23    /// Mean silhouette score for each cluster
24    pub cluster_scores: HashMap<usize, F>,
25
26    /// Sorted indices for visualization (samples ordered by cluster and silhouette value)
27    pub sorted_indices: Vec<usize>,
28
29    /// Original cluster labels mapped to consecutive integers (for visualization)
30    pub cluster_mapping: HashMap<usize, usize>,
31
32    /// Samples per cluster (ordered by cluster_mapping)
33    pub cluster_sizes: Vec<usize>,
34}
35
36/// Calculates the silhouette score for a clustering
37///
38/// The silhouette score measures how similar an object is to its own cluster
39/// compared to other clusters. The silhouette score ranges from -1 to 1, where
40/// a high value indicates that the object is well matched to its own cluster
41/// and poorly matched to neighboring clusters.
42///
43/// # Arguments
44///
45/// * `x` - Array of shape (n_samples, n_features) - The data
46/// * `labels` - Array of shape (n_samples,) - Predicted labels for each sample
47/// * `metric` - Distance metric to use. Currently only 'euclidean' is supported.
48///
49/// # Returns
50///
51/// * The mean silhouette coefficient for all samples
52///
53/// # Examples
54///
55/// ```
56/// use scirs2_core::ndarray::{array, Array2};
57/// use scirs2_metrics::clustering::silhouette_score;
58///
59/// // Create a small dataset with 2 clusters
60/// let x = Array2::from_shape_vec((6, 2), vec![
61///     1.0, 2.0,
62///     1.5, 1.8,
63///     1.2, 2.2,
64///     5.0, 6.0,
65///     5.2, 5.8,
66///     5.5, 6.2,
67/// ]).unwrap();
68///
69/// let labels = array![0, 0, 0, 1, 1, 1];
70///
71/// let score = silhouette_score(&x, &labels, "euclidean").unwrap();
72/// assert!(score > 0.8); // High score for well-separated clusters
73/// ```
74#[allow(dead_code)]
75pub fn silhouette_score<F, S1, S2, D>(
76    x: &ArrayBase<S1, Ix2>,
77    labels: &ArrayBase<S2, D>,
78    metric: &str,
79) -> Result<F>
80where
81    F: Float + NumCast + std::fmt::Debug + scirs2_core::ndarray::ScalarOperand + 'static,
82    S1: Data<Elem = F>,
83    S2: Data<Elem = usize>,
84    D: Dimension,
85{
86    let silhouette_results = silhouette_analysis(x, labels, metric)?;
87    Ok(silhouette_results.mean_score)
88}
89
90/// Calculate silhouette samples for a clustering
91///
92/// Returns the silhouette score for each sample in the clustering, which can be
93/// useful for more detailed analysis than just the mean silhouette score.
94///
95/// # Arguments
96///
97/// * `x` - Array of shape (n_samples, n_features) - The data
98/// * `labels` - Array of shape (n_samples,) - Predicted labels for each sample
99/// * `metric` - Distance metric to use. Currently only 'euclidean' is supported.
100///
101/// # Returns
102///
103/// * Array of shape (n_samples,) containing the silhouette score for each sample
104///
105/// # Examples
106///
107/// ```
108/// use scirs2_core::ndarray::{array, Array2};
109/// use scirs2_metrics::clustering::silhouette_samples;
110///
111/// // Create a small dataset with 2 clusters
112/// let x = Array2::from_shape_vec((6, 2), vec![
113///     1.0, 2.0, 1.5, 1.8, 1.2, 2.2, // Cluster 0
114///     5.0, 6.0, 5.2, 5.8, 5.5, 6.2, // Cluster 1
115/// ]).unwrap();
116///
117/// let labels = array![0, 0, 0, 1, 1, 1];
118///
119/// let samples = silhouette_samples(&x, &labels, "euclidean").unwrap();
120/// assert_eq!(samples.len(), 6);
121///
122/// // Calculate the mean silhouette score manually
123/// let mean_score = samples.iter().sum::<f64>() / samples.len() as f64;
124/// assert!(mean_score > 0.8); // High score for well-separated clusters
125/// ```
126#[allow(dead_code)]
127pub fn silhouette_samples<F, S1, S2, D>(
128    x: &ArrayBase<S1, Ix2>,
129    labels: &ArrayBase<S2, D>,
130    metric: &str,
131) -> Result<Vec<F>>
132where
133    F: Float + NumCast + std::fmt::Debug + scirs2_core::ndarray::ScalarOperand + 'static,
134    S1: Data<Elem = F>,
135    S2: Data<Elem = usize>,
136    D: Dimension,
137{
138    let analysis = silhouette_analysis(x, labels, metric)?;
139    Ok(analysis.sample_values)
140}
141
142/// Calculate silhouette scores per cluster
143///
144/// Returns the mean silhouette score for each cluster, allowing you to
145/// identify which clusters are more cohesive than others.
146///
147/// # Arguments
148///
149/// * `x` - Array of shape (n_samples, n_features) - The data
150/// * `labels` - Array of shape (n_samples,) - Predicted labels for each sample
151/// * `metric` - Distance metric to use. Currently only 'euclidean' is supported.
152///
153/// # Returns
154///
155/// * HashMap mapping cluster labels to their mean silhouette scores
156///
157/// # Examples
158///
159/// ```
160/// use scirs2_core::ndarray::{array, Array2};
161/// use scirs2_metrics::clustering::silhouette_scores_per_cluster;
162///
163/// // Create a small dataset with 3 clusters
164/// let x = Array2::from_shape_vec((9, 2), vec![
165///     1.0, 2.0, 1.5, 1.8, 1.2, 2.2,  // Cluster 0
166///     5.0, 6.0, 5.2, 5.8, 5.5, 6.2,  // Cluster 1
167///     9.0, 10.0, 9.2, 9.8, 9.5, 10.2, // Cluster 2
168/// ]).unwrap();
169///
170/// let labels = array![0, 0, 0, 1, 1, 1, 2, 2, 2];
171///
172/// let cluster_scores = silhouette_scores_per_cluster(&x, &labels, "euclidean").unwrap();
173/// assert_eq!(cluster_scores.len(), 3);
174/// assert!(cluster_scores[&0] > 0.5);
175/// assert!(cluster_scores[&1] > 0.5);
176/// assert!(cluster_scores[&2] > 0.5);
177/// ```
178#[allow(dead_code)]
179pub fn silhouette_scores_per_cluster<F, S1, S2, D>(
180    x: &ArrayBase<S1, Ix2>,
181    labels: &ArrayBase<S2, D>,
182    metric: &str,
183) -> Result<HashMap<usize, F>>
184where
185    F: Float + NumCast + std::fmt::Debug + scirs2_core::ndarray::ScalarOperand + 'static,
186    S1: Data<Elem = F>,
187    S2: Data<Elem = usize>,
188    D: Dimension,
189{
190    let analysis = silhouette_analysis(x, labels, metric)?;
191    Ok(analysis.cluster_scores)
192}
193
194/// Calculates detailed silhouette information for a clustering
195///
196/// This function provides sample-wise silhouette scores, cluster-wise averages,
197/// and ordering information for visualization. It's an enhanced version of
198/// silhouette_score that returns more detailed information.
199///
200/// # Arguments
201///
202/// * `x` - Array of shape (n_samples, n_features) - The data
203/// * `labels` - Array of shape (n_samples,) - Predicted labels for each sample
204/// * `metric` - Distance metric to use. Currently only 'euclidean' is supported.
205///
206/// # Returns
207///
208/// * `SilhouetteAnalysis` struct containing detailed silhouette information
209///
210/// # Examples
211///
212/// ```
213/// use scirs2_core::ndarray::{array, Array2};
214/// use scirs2_metrics::clustering::silhouette_analysis;
215///
216/// // Create a small dataset with 3 clusters
217/// let x = Array2::from_shape_vec((9, 2), vec![
218///     1.0, 2.0, 1.5, 1.8, 1.2, 2.2,  // Cluster 0
219///     5.0, 6.0, 5.2, 5.8, 5.5, 6.2,  // Cluster 1
220///     9.0, 10.0, 9.2, 9.8, 9.5, 10.2, // Cluster 2
221/// ]).unwrap();
222///
223/// let labels = array![0, 0, 0, 1, 1, 1, 2, 2, 2];
224///
225/// let analysis = silhouette_analysis(&x, &labels, "euclidean").unwrap();
226///
227/// // Get overall silhouette score
228/// let score = analysis.mean_score;
229/// assert!(score > 0.8); // High score for well-separated clusters
230///
231/// // Get cluster-wise silhouette scores
232/// for (cluster, score) in &analysis.cluster_scores {
233///     println!("Cluster {} silhouette score: {}", cluster, score);
234/// }
235///
236/// // Access individual sample silhouette values
237/// for (i, &value) in analysis.sample_values.iter().enumerate() {
238///     println!("Sample {} silhouette value: {}", i, value);
239/// }
240/// ```
241#[allow(dead_code)]
242pub fn silhouette_analysis<F, S1, S2, D>(
243    x: &ArrayBase<S1, Ix2>,
244    labels: &ArrayBase<S2, D>,
245    metric: &str,
246) -> Result<SilhouetteAnalysis<F>>
247where
248    F: Float + NumCast + std::fmt::Debug + scirs2_core::ndarray::ScalarOperand + 'static,
249    S1: Data<Elem = F>,
250    S2: Data<Elem = usize>,
251    D: Dimension,
252{
253    // Check that the metric is supported
254    if metric != "euclidean" {
255        return Err(MetricsError::InvalidInput(format!(
256            "Unsupported metric: {metric}. Only 'euclidean' is currently supported."
257        )));
258    }
259
260    // Check that x and labels have the same number of samples
261    let n_samples = x.shape()[0];
262    if n_samples != labels.len() {
263        return Err(MetricsError::InvalidInput(format!(
264            "x has {} samples, but labels has {} samples",
265            n_samples,
266            labels.len()
267        )));
268    }
269
270    // Check that there are at least 2 samples
271    if n_samples < 2 {
272        return Err(MetricsError::InvalidInput(
273            "n_samples must be at least 2".to_string(),
274        ));
275    }
276
277    // Group samples by label
278    let clusters = group_by_labels(x, labels)?;
279
280    // Check that there are at least 2 clusters
281    if clusters.len() < 2 {
282        return Err(MetricsError::InvalidInput(
283            "Number of labels is 1. Silhouette analysis is undefined for a single cluster."
284                .to_string(),
285        ));
286    }
287
288    // Check that all clusters have at least 1 sample
289    let empty_clusters: Vec<_> = clusters
290        .iter()
291        .filter(|(_, samples)| samples.is_empty())
292        .map(|(&label, _)| label)
293        .collect();
294
295    if !empty_clusters.is_empty() {
296        return Err(MetricsError::InvalidInput(format!(
297            "Empty clusters found: {empty_clusters:?}"
298        )));
299    }
300
301    // Compute distance matrix (more efficient than recomputing distances)
302    let distances = pairwise_distances(x, metric)?;
303
304    // Compute silhouette scores for each sample
305    let mut silhouette_values = Vec::with_capacity(n_samples);
306    let mut sample_clusters = Vec::with_capacity(n_samples);
307
308    for i in 0..n_samples {
309        let label_i = labels.iter().nth(i).ok_or_else(|| {
310            MetricsError::InvalidInput(format!("Could not access index {i} in labels"))
311        })?;
312        let cluster_i = &clusters[label_i];
313        sample_clusters.push(*label_i);
314
315        // Calculate the mean intra-cluster distance (a)
316        let mut a = F::zero();
317        let mut count_a = 0;
318
319        for &j in cluster_i {
320            if i == j {
321                continue;
322            }
323            a = a + distances[[i, j]];
324            count_a += 1;
325        }
326
327        // Handle single sample in cluster (set a to 0)
328        if count_a > 0 {
329            a = a / F::from(count_a).unwrap();
330        }
331
332        // Calculate the mean nearest-cluster distance (b)
333        let mut b = None;
334
335        for (label_j, cluster_j) in &clusters {
336            if *label_j == *label_i {
337                continue;
338            }
339
340            // Calculate mean distance to this cluster
341            let mut cluster_dist = F::zero();
342            for &j in cluster_j {
343                cluster_dist = cluster_dist + distances[[i, j]];
344            }
345            let cluster_dist = cluster_dist / F::from(cluster_j.len()).unwrap();
346
347            // Update b if this is the closest cluster
348            if let Some(current_b) = b {
349                if cluster_dist < current_b {
350                    b = Some(cluster_dist);
351                }
352            } else {
353                b = Some(cluster_dist);
354            }
355        }
356
357        // Calculate silhouette score
358        let s = if let Some(b) = b {
359            if a < b {
360                F::one() - a / b
361            } else if a > b {
362                b / a - F::one()
363            } else {
364                F::zero()
365            }
366        } else {
367            F::zero() // Will never happen if there are at least 2 clusters
368        };
369
370        silhouette_values.push(s);
371    }
372
373    // Calculate mean silhouette score
374    let sum = silhouette_values
375        .iter()
376        .fold(F::zero(), |acc, &val| acc + val);
377    let mean_score = sum / F::from(n_samples).unwrap();
378
379    // Calculate cluster-wise silhouette scores
380    let mut cluster_scores = HashMap::new();
381    for (label, indices) in &clusters {
382        let mut cluster_sum = F::zero();
383        for &idx in indices {
384            cluster_sum = cluster_sum + silhouette_values[idx];
385        }
386        let cluster_mean = cluster_sum / F::from(indices.len()).unwrap();
387        cluster_scores.insert(*label, cluster_mean);
388    }
389
390    // Create a mapping from original cluster labels to consecutive integers (for visualization)
391    let unique_labels: Vec<_> = clusters.keys().cloned().collect();
392    let mut cluster_mapping = HashMap::new();
393    for (i, &label) in unique_labels.iter().enumerate() {
394        cluster_mapping.insert(label, i);
395    }
396
397    // Create list of cluster sizes
398    let mut cluster_sizes = vec![0; cluster_mapping.len()];
399    for (label, indices) in &clusters {
400        let mapped_idx = cluster_mapping[label];
401        cluster_sizes[mapped_idx] = indices.len();
402    }
403
404    // Create sorted indices for visualization
405    // First, sort by cluster mapping
406    // Then, within each cluster, sort by silhouette value (descending)
407    let mut samples_with_scores: Vec<(usize, F, usize)> = silhouette_values
408        .iter()
409        .enumerate()
410        .map(|(i, &s)| (i, s, cluster_mapping[&sample_clusters[i]]))
411        .collect();
412
413    // Sort by cluster (ascending), then by silhouette value (descending)
414    samples_with_scores.sort_by(|a, b| a.2.cmp(&b.2).then(b.1.partial_cmp(&a.1).unwrap()));
415
416    let sorted_indices = samples_with_scores.iter().map(|&(i, _, _)| i).collect();
417
418    Ok(SilhouetteAnalysis {
419        sample_values: silhouette_values,
420        mean_score,
421        cluster_scores,
422        sorted_indices,
423        cluster_mapping,
424        cluster_sizes,
425    })
426}
427
428/// Calculates the Davies-Bouldin index for a clustering
429///
430/// The Davies-Bouldin index measures the average similarity between clusters,
431/// where the similarity is a ratio of within-cluster distances to between-cluster distances.
432/// The lower the value, the better the clustering.
433///
434/// # Arguments
435///
436/// * `x` - Array of shape (n_samples, n_features) - The data
437/// * `labels` - Array of shape (n_samples,) - Predicted labels for each sample
438///
439/// # Returns
440///
441/// * The Davies-Bouldin index
442///
443/// # Examples
444///
445/// ```
446/// use scirs2_core::ndarray::{array, Array2};
447/// use scirs2_metrics::clustering::davies_bouldin_score;
448///
449/// // Create a small dataset with 2 clusters
450/// let x = Array2::from_shape_vec((6, 2), vec![
451///     1.0, 2.0,
452///     1.5, 1.8,
453///     1.2, 2.2,
454///     5.0, 6.0,
455///     5.2, 5.8,
456///     5.5, 6.2,
457/// ]).unwrap();
458///
459/// let labels = array![0, 0, 0, 1, 1, 1];
460///
461/// let score = davies_bouldin_score(&x, &labels).unwrap();
462/// assert!(score < 0.5); // Low score for well-separated clusters
463/// ```
464#[allow(dead_code)]
465pub fn davies_bouldin_score<F, S1, S2, D>(
466    x: &ArrayBase<S1, Ix2>,
467    labels: &ArrayBase<S2, D>,
468) -> Result<F>
469where
470    F: Float + NumCast + std::fmt::Debug + scirs2_core::ndarray::ScalarOperand + 'static,
471    S1: Data<Elem = F>,
472    S2: Data<Elem = usize>,
473    D: Dimension,
474{
475    // Check that x and labels have the same number of samples
476    let n_samples = x.shape()[0];
477    if n_samples != labels.len() {
478        return Err(MetricsError::InvalidInput(format!(
479            "x has {} samples, but labels has {} samples",
480            n_samples,
481            labels.len()
482        )));
483    }
484
485    // Check that there are at least 2 samples
486    if n_samples < 2 {
487        return Err(MetricsError::InvalidInput(
488            "n_samples must be at least 2".to_string(),
489        ));
490    }
491
492    // Group samples by label
493    let clusters = group_by_labels(x, labels)?;
494
495    // Check that there are at least 2 clusters
496    if clusters.len() < 2 {
497        return Err(MetricsError::InvalidInput(
498            "Number of labels is 1. Davies-Bouldin index is undefined for a single cluster."
499                .to_string(),
500        ));
501    }
502
503    // Compute centroids for each cluster
504    let mut centroids = HashMap::new();
505    for (&label, indices) in &clusters {
506        let mut centroid = Array1::<F>::zeros(x.shape()[1]);
507        for &idx in indices {
508            centroid = centroid + x.row(idx).to_owned();
509        }
510        centroid = centroid / F::from(indices.len()).unwrap();
511        centroids.insert(label, centroid);
512    }
513
514    // Compute average distance to centroid for each cluster
515    let mut avg_distances = HashMap::new();
516    for (&label, indices) in &clusters {
517        let centroid = centroids.get(&label).unwrap();
518        let mut total_distance = F::zero();
519        for &idx in indices {
520            total_distance = total_distance
521                + calculate_distance(&x.row(idx).to_vec(), &centroid.to_vec(), "euclidean")?;
522        }
523        let avg_distance = total_distance / F::from(indices.len()).unwrap();
524        avg_distances.insert(label, avg_distance);
525    }
526
527    // Compute Davies-Bouldin index
528    let mut db_index = F::zero();
529    let labels_vec: Vec<_> = clusters.keys().cloned().collect();
530
531    for i in 0..labels_vec.len() {
532        let label_i = labels_vec[i];
533        let centroid_i = centroids.get(&label_i).unwrap();
534        let avg_dist_i = avg_distances.get(&label_i).unwrap();
535
536        let mut max_ratio = F::zero();
537        for (j, &label_j) in labels_vec.iter().enumerate() {
538            if i == j {
539                continue;
540            }
541            let centroid_j = centroids.get(&label_j).unwrap();
542            let avg_dist_j = avg_distances.get(&label_j).unwrap();
543
544            // Distance between centroids
545            let centroid_dist =
546                calculate_distance(&centroid_i.to_vec(), &centroid_j.to_vec(), "euclidean")?;
547
548            // Ratio of sum of intra-cluster distances to inter-cluster distance
549            let ratio = (*avg_dist_i + *avg_dist_j) / centroid_dist;
550
551            // Update max ratio
552            if ratio > max_ratio {
553                max_ratio = ratio;
554            }
555        }
556
557        db_index = db_index + max_ratio;
558    }
559
560    // Normalize by number of clusters
561    Ok(db_index / F::from(labels_vec.len()).unwrap())
562}
563
564/// Calculates the Calinski-Harabasz index (Variance Ratio Criterion)
565///
566/// The Calinski-Harabasz index is defined as the ratio of the between-clusters
567/// dispersion and the within-cluster dispersion. Higher values indicate better clustering.
568///
569/// # Arguments
570///
571/// * `x` - Array of shape (n_samples, n_features) - The data
572/// * `labels` - Array of shape (n_samples,) - Predicted labels for each sample
573///
574/// # Returns
575///
576/// * The Calinski-Harabasz index
577///
578/// # Examples
579///
580/// ```
581/// use scirs2_core::ndarray::{array, Array2};
582/// use scirs2_metrics::clustering::calinski_harabasz_score;
583///
584/// // Create a small dataset with 2 clusters
585/// let x = Array2::from_shape_vec((6, 2), vec![
586///     1.0, 2.0,
587///     1.5, 1.8,
588///     1.2, 2.2,
589///     5.0, 6.0,
590///     5.2, 5.8,
591///     5.5, 6.2,
592/// ]).unwrap();
593///
594/// let labels = array![0, 0, 0, 1, 1, 1];
595///
596/// let score = calinski_harabasz_score(&x, &labels).unwrap();
597/// assert!(score > 50.0); // High score for well-separated clusters
598/// ```
599#[allow(dead_code)]
600pub fn calinski_harabasz_score<F, S1, S2, D>(
601    x: &ArrayBase<S1, Ix2>,
602    labels: &ArrayBase<S2, D>,
603) -> Result<F>
604where
605    F: Float + NumCast + std::fmt::Debug + scirs2_core::ndarray::ScalarOperand + 'static,
606    S1: Data<Elem = F>,
607    S2: Data<Elem = usize>,
608    D: Dimension,
609{
610    // Check that x and labels have the same number of samples
611    let n_samples = x.shape()[0];
612    if n_samples != labels.len() {
613        return Err(MetricsError::InvalidInput(format!(
614            "x has {} samples, but labels has {} samples",
615            n_samples,
616            labels.len()
617        )));
618    }
619
620    // Check that there are at least 2 samples
621    if n_samples < 2 {
622        return Err(MetricsError::InvalidInput(
623            "n_samples must be at least 2".to_string(),
624        ));
625    }
626
627    // Group samples by label
628    let clusters = group_by_labels(x, labels)?;
629
630    // Check that there are at least 2 clusters
631    if clusters.len() < 2 {
632        return Err(MetricsError::InvalidInput(
633            "Number of labels is 1. Calinski-Harabasz index is undefined for a single cluster."
634                .to_string(),
635        ));
636    }
637
638    // Compute the global centroid
639    let mut global_centroid = Array1::<F>::zeros(x.shape()[1]);
640    for i in 0..n_samples {
641        global_centroid = global_centroid + x.row(i).to_owned();
642    }
643    global_centroid = global_centroid / F::from(n_samples).unwrap();
644
645    // Compute centroids for each cluster
646    let mut centroids = HashMap::new();
647    for (&label, indices) in &clusters {
648        let mut centroid = Array1::<F>::zeros(x.shape()[1]);
649        for &idx in indices {
650            centroid = centroid + x.row(idx).to_owned();
651        }
652        centroid = centroid / F::from(indices.len()).unwrap();
653        centroids.insert(label, centroid);
654    }
655
656    // Compute between-cluster dispersion
657    let mut between_disp = F::zero();
658    for (label, indices) in &clusters {
659        let cluster_size = F::from(indices.len()).unwrap();
660        let centroid = centroids.get(label).unwrap();
661
662        // Calculate squared distance between cluster centroid and global centroid
663        let mut squared_dist = F::zero();
664        for (c, g) in centroid.iter().zip(global_centroid.iter()) {
665            let diff = *c - *g;
666            squared_dist = squared_dist + diff * diff;
667        }
668
669        between_disp = between_disp + cluster_size * squared_dist;
670    }
671
672    // Compute within-cluster dispersion
673    let mut within_disp = F::zero();
674    for (label, indices) in &clusters {
675        let centroid = centroids.get(label).unwrap();
676
677        let mut cluster_disp = F::zero();
678        for &idx in indices {
679            let mut squared_dist = F::zero();
680            for (x_val, c_val) in x.row(idx).iter().zip(centroid.iter()) {
681                let diff = *x_val - *c_val;
682                squared_dist = squared_dist + diff * diff;
683            }
684            cluster_disp = cluster_disp + squared_dist;
685        }
686
687        within_disp = within_disp + cluster_disp;
688    }
689
690    // Handle edge cases
691    if within_disp <= F::epsilon() {
692        return Err(MetricsError::CalculationError(
693            "Within-cluster dispersion is zero".to_string(),
694        ));
695    }
696
697    // Calculate Calinski-Harabasz index
698    let n_clusters = F::from(clusters.len()).unwrap();
699    Ok(
700        between_disp * (F::from(n_samples - clusters.len()).unwrap())
701            / (within_disp * (n_clusters - F::one())),
702    )
703}
704
705/// Calculates the Dunn index for a clustering
706///
707/// The Dunn index is the ratio of the minimum inter-cluster distance to the maximum
708/// intra-cluster distance. Higher values indicate better clustering.
709///
710/// # Arguments
711///
712/// * `x` - Array of shape (n_samples, n_features) - The data
713/// * `labels` - Array of shape (n_samples,) - Predicted labels for each sample
714///
715/// # Returns
716///
717/// * The Dunn index
718#[allow(dead_code)]
719pub fn dunn_index<F, S1, S2, D>(x: &ArrayBase<S1, Ix2>, labels: &ArrayBase<S2, D>) -> Result<F>
720where
721    F: Float + NumCast + std::fmt::Debug + 'static,
722    S1: Data<Elem = F>,
723    S2: Data<Elem = usize>,
724    D: Dimension,
725{
726    // Check that x and labels have the same number of samples
727    let n_samples = x.shape()[0];
728    if n_samples != labels.len() {
729        return Err(MetricsError::InvalidInput(format!(
730            "x has {} samples, but labels has {} samples",
731            n_samples,
732            labels.len()
733        )));
734    }
735
736    // Check that there are at least 2 samples
737    if n_samples < 2 {
738        return Err(MetricsError::InvalidInput(
739            "n_samples must be at least 2".to_string(),
740        ));
741    }
742
743    // Group samples by label
744    let clusters = group_by_labels(x, labels)?;
745
746    // Check that there are at least 2 clusters
747    if clusters.len() < 2 {
748        return Err(MetricsError::InvalidInput(
749            "Number of labels is 1. Dunn index is undefined for a single cluster.".to_string(),
750        ));
751    }
752
753    // Compute distance matrix
754    let distances = pairwise_distances(x, "euclidean")?;
755
756    // Calculate maximum intra-cluster distance for each cluster
757    let mut max_intra_cluster_distance = F::zero();
758    for indices in clusters.values() {
759        let mut max_distance = F::zero();
760        for (i, &idx1) in indices.iter().enumerate() {
761            for &idx2 in &indices[i + 1..] {
762                let dist = distances[[idx1, idx2]];
763                if dist > max_distance {
764                    max_distance = dist;
765                }
766            }
767        }
768        if max_distance > max_intra_cluster_distance {
769            max_intra_cluster_distance = max_distance;
770        }
771    }
772
773    if max_intra_cluster_distance <= F::epsilon() {
774        return Err(MetricsError::CalculationError(
775            "Maximum intra-cluster distance is zero".to_string(),
776        ));
777    }
778
779    // Calculate minimum inter-cluster distance
780    let mut min_inter_cluster_distance = F::infinity();
781    let cluster_labels: Vec<_> = clusters.keys().collect();
782
783    for i in 0..cluster_labels.len() {
784        for j in i + 1..cluster_labels.len() {
785            let cluster_i = &clusters[cluster_labels[i]];
786            let cluster_j = &clusters[cluster_labels[j]];
787
788            let mut min_distance = F::infinity();
789            for &idx1 in cluster_i {
790                for &idx2 in cluster_j {
791                    let dist = distances[[idx1, idx2]];
792                    if dist < min_distance {
793                        min_distance = dist;
794                    }
795                }
796            }
797
798            if min_distance < min_inter_cluster_distance {
799                min_inter_cluster_distance = min_distance;
800            }
801        }
802    }
803
804    // Calculate Dunn index
805    Ok(min_inter_cluster_distance / max_intra_cluster_distance)
806}