scirs2_metrics/clustering/internal_metrics.rs
1//! Internal clustering metrics
2//!
3//! This module provides functions for evaluating clustering algorithms using internal metrics,
4//! which assess clustering quality without external ground truth. These include
5//! silhouette score, Davies-Bouldin index, Calinski-Harabasz index, and Dunn index.
6
7use scirs2_core::ndarray::{Array1, Array2, ArrayBase, Data, Dimension, Ix2};
8use scirs2_core::numeric::{Float, NumCast};
9use std::collections::HashMap;
10
11use super::{calculate_distance, group_by_labels, pairwise_distances};
12use crate::error::{MetricsError, Result};
13
14/// Structure containing detailed silhouette analysis results
15#[derive(Debug, Clone)]
16pub struct SilhouetteAnalysis<F: Float> {
17 /// Sample-wise silhouette scores
18 pub sample_values: Vec<F>,
19
20 /// Mean silhouette score for all samples
21 pub mean_score: F,
22
23 /// Mean silhouette score for each cluster
24 pub cluster_scores: HashMap<usize, F>,
25
26 /// Sorted indices for visualization (samples ordered by cluster and silhouette value)
27 pub sorted_indices: Vec<usize>,
28
29 /// Original cluster labels mapped to consecutive integers (for visualization)
30 pub cluster_mapping: HashMap<usize, usize>,
31
32 /// Samples per cluster (ordered by cluster_mapping)
33 pub cluster_sizes: Vec<usize>,
34}
35
36/// Calculates the silhouette score for a clustering
37///
38/// The silhouette score measures how similar an object is to its own cluster
39/// compared to other clusters. The silhouette score ranges from -1 to 1, where
40/// a high value indicates that the object is well matched to its own cluster
41/// and poorly matched to neighboring clusters.
42///
43/// # Arguments
44///
45/// * `x` - Array of shape (n_samples, n_features) - The data
46/// * `labels` - Array of shape (n_samples,) - Predicted labels for each sample
47/// * `metric` - Distance metric to use. Currently only 'euclidean' is supported.
48///
49/// # Returns
50///
51/// * The mean silhouette coefficient for all samples
52///
53/// # Examples
54///
55/// ```
56/// use scirs2_core::ndarray::{array, Array2};
57/// use scirs2_metrics::clustering::silhouette_score;
58///
59/// // Create a small dataset with 2 clusters
60/// let x = Array2::from_shape_vec((6, 2), vec![
61/// 1.0, 2.0,
62/// 1.5, 1.8,
63/// 1.2, 2.2,
64/// 5.0, 6.0,
65/// 5.2, 5.8,
66/// 5.5, 6.2,
67/// ]).unwrap();
68///
69/// let labels = array![0, 0, 0, 1, 1, 1];
70///
71/// let score = silhouette_score(&x, &labels, "euclidean").unwrap();
72/// assert!(score > 0.8); // High score for well-separated clusters
73/// ```
74#[allow(dead_code)]
75pub fn silhouette_score<F, S1, S2, D>(
76 x: &ArrayBase<S1, Ix2>,
77 labels: &ArrayBase<S2, D>,
78 metric: &str,
79) -> Result<F>
80where
81 F: Float + NumCast + std::fmt::Debug + scirs2_core::ndarray::ScalarOperand + 'static,
82 S1: Data<Elem = F>,
83 S2: Data<Elem = usize>,
84 D: Dimension,
85{
86 let silhouette_results = silhouette_analysis(x, labels, metric)?;
87 Ok(silhouette_results.mean_score)
88}
89
90/// Calculate silhouette samples for a clustering
91///
92/// Returns the silhouette score for each sample in the clustering, which can be
93/// useful for more detailed analysis than just the mean silhouette score.
94///
95/// # Arguments
96///
97/// * `x` - Array of shape (n_samples, n_features) - The data
98/// * `labels` - Array of shape (n_samples,) - Predicted labels for each sample
99/// * `metric` - Distance metric to use. Currently only 'euclidean' is supported.
100///
101/// # Returns
102///
103/// * Array of shape (n_samples,) containing the silhouette score for each sample
104///
105/// # Examples
106///
107/// ```
108/// use scirs2_core::ndarray::{array, Array2};
109/// use scirs2_metrics::clustering::silhouette_samples;
110///
111/// // Create a small dataset with 2 clusters
112/// let x = Array2::from_shape_vec((6, 2), vec![
113/// 1.0, 2.0, 1.5, 1.8, 1.2, 2.2, // Cluster 0
114/// 5.0, 6.0, 5.2, 5.8, 5.5, 6.2, // Cluster 1
115/// ]).unwrap();
116///
117/// let labels = array![0, 0, 0, 1, 1, 1];
118///
119/// let samples = silhouette_samples(&x, &labels, "euclidean").unwrap();
120/// assert_eq!(samples.len(), 6);
121///
122/// // Calculate the mean silhouette score manually
123/// let mean_score = samples.iter().sum::<f64>() / samples.len() as f64;
124/// assert!(mean_score > 0.8); // High score for well-separated clusters
125/// ```
126#[allow(dead_code)]
127pub fn silhouette_samples<F, S1, S2, D>(
128 x: &ArrayBase<S1, Ix2>,
129 labels: &ArrayBase<S2, D>,
130 metric: &str,
131) -> Result<Vec<F>>
132where
133 F: Float + NumCast + std::fmt::Debug + scirs2_core::ndarray::ScalarOperand + 'static,
134 S1: Data<Elem = F>,
135 S2: Data<Elem = usize>,
136 D: Dimension,
137{
138 let analysis = silhouette_analysis(x, labels, metric)?;
139 Ok(analysis.sample_values)
140}
141
142/// Calculate silhouette scores per cluster
143///
144/// Returns the mean silhouette score for each cluster, allowing you to
145/// identify which clusters are more cohesive than others.
146///
147/// # Arguments
148///
149/// * `x` - Array of shape (n_samples, n_features) - The data
150/// * `labels` - Array of shape (n_samples,) - Predicted labels for each sample
151/// * `metric` - Distance metric to use. Currently only 'euclidean' is supported.
152///
153/// # Returns
154///
155/// * HashMap mapping cluster labels to their mean silhouette scores
156///
157/// # Examples
158///
159/// ```
160/// use scirs2_core::ndarray::{array, Array2};
161/// use scirs2_metrics::clustering::silhouette_scores_per_cluster;
162///
163/// // Create a small dataset with 3 clusters
164/// let x = Array2::from_shape_vec((9, 2), vec![
165/// 1.0, 2.0, 1.5, 1.8, 1.2, 2.2, // Cluster 0
166/// 5.0, 6.0, 5.2, 5.8, 5.5, 6.2, // Cluster 1
167/// 9.0, 10.0, 9.2, 9.8, 9.5, 10.2, // Cluster 2
168/// ]).unwrap();
169///
170/// let labels = array![0, 0, 0, 1, 1, 1, 2, 2, 2];
171///
172/// let cluster_scores = silhouette_scores_per_cluster(&x, &labels, "euclidean").unwrap();
173/// assert_eq!(cluster_scores.len(), 3);
174/// assert!(cluster_scores[&0] > 0.5);
175/// assert!(cluster_scores[&1] > 0.5);
176/// assert!(cluster_scores[&2] > 0.5);
177/// ```
178#[allow(dead_code)]
179pub fn silhouette_scores_per_cluster<F, S1, S2, D>(
180 x: &ArrayBase<S1, Ix2>,
181 labels: &ArrayBase<S2, D>,
182 metric: &str,
183) -> Result<HashMap<usize, F>>
184where
185 F: Float + NumCast + std::fmt::Debug + scirs2_core::ndarray::ScalarOperand + 'static,
186 S1: Data<Elem = F>,
187 S2: Data<Elem = usize>,
188 D: Dimension,
189{
190 let analysis = silhouette_analysis(x, labels, metric)?;
191 Ok(analysis.cluster_scores)
192}
193
194/// Calculates detailed silhouette information for a clustering
195///
196/// This function provides sample-wise silhouette scores, cluster-wise averages,
197/// and ordering information for visualization. It's an enhanced version of
198/// silhouette_score that returns more detailed information.
199///
200/// # Arguments
201///
202/// * `x` - Array of shape (n_samples, n_features) - The data
203/// * `labels` - Array of shape (n_samples,) - Predicted labels for each sample
204/// * `metric` - Distance metric to use. Currently only 'euclidean' is supported.
205///
206/// # Returns
207///
208/// * `SilhouetteAnalysis` struct containing detailed silhouette information
209///
210/// # Examples
211///
212/// ```
213/// use scirs2_core::ndarray::{array, Array2};
214/// use scirs2_metrics::clustering::silhouette_analysis;
215///
216/// // Create a small dataset with 3 clusters
217/// let x = Array2::from_shape_vec((9, 2), vec![
218/// 1.0, 2.0, 1.5, 1.8, 1.2, 2.2, // Cluster 0
219/// 5.0, 6.0, 5.2, 5.8, 5.5, 6.2, // Cluster 1
220/// 9.0, 10.0, 9.2, 9.8, 9.5, 10.2, // Cluster 2
221/// ]).unwrap();
222///
223/// let labels = array![0, 0, 0, 1, 1, 1, 2, 2, 2];
224///
225/// let analysis = silhouette_analysis(&x, &labels, "euclidean").unwrap();
226///
227/// // Get overall silhouette score
228/// let score = analysis.mean_score;
229/// assert!(score > 0.8); // High score for well-separated clusters
230///
231/// // Get cluster-wise silhouette scores
232/// for (cluster, score) in &analysis.cluster_scores {
233/// println!("Cluster {} silhouette score: {}", cluster, score);
234/// }
235///
236/// // Access individual sample silhouette values
237/// for (i, &value) in analysis.sample_values.iter().enumerate() {
238/// println!("Sample {} silhouette value: {}", i, value);
239/// }
240/// ```
241#[allow(dead_code)]
242pub fn silhouette_analysis<F, S1, S2, D>(
243 x: &ArrayBase<S1, Ix2>,
244 labels: &ArrayBase<S2, D>,
245 metric: &str,
246) -> Result<SilhouetteAnalysis<F>>
247where
248 F: Float + NumCast + std::fmt::Debug + scirs2_core::ndarray::ScalarOperand + 'static,
249 S1: Data<Elem = F>,
250 S2: Data<Elem = usize>,
251 D: Dimension,
252{
253 // Check that the metric is supported
254 if metric != "euclidean" {
255 return Err(MetricsError::InvalidInput(format!(
256 "Unsupported metric: {metric}. Only 'euclidean' is currently supported."
257 )));
258 }
259
260 // Check that x and labels have the same number of samples
261 let n_samples = x.shape()[0];
262 if n_samples != labels.len() {
263 return Err(MetricsError::InvalidInput(format!(
264 "x has {} samples, but labels has {} samples",
265 n_samples,
266 labels.len()
267 )));
268 }
269
270 // Check that there are at least 2 samples
271 if n_samples < 2 {
272 return Err(MetricsError::InvalidInput(
273 "n_samples must be at least 2".to_string(),
274 ));
275 }
276
277 // Group samples by label
278 let clusters = group_by_labels(x, labels)?;
279
280 // Check that there are at least 2 clusters
281 if clusters.len() < 2 {
282 return Err(MetricsError::InvalidInput(
283 "Number of labels is 1. Silhouette analysis is undefined for a single cluster."
284 .to_string(),
285 ));
286 }
287
288 // Check that all clusters have at least 1 sample
289 let empty_clusters: Vec<_> = clusters
290 .iter()
291 .filter(|(_, samples)| samples.is_empty())
292 .map(|(&label, _)| label)
293 .collect();
294
295 if !empty_clusters.is_empty() {
296 return Err(MetricsError::InvalidInput(format!(
297 "Empty clusters found: {empty_clusters:?}"
298 )));
299 }
300
301 // Compute distance matrix (more efficient than recomputing distances)
302 let distances = pairwise_distances(x, metric)?;
303
304 // Compute silhouette scores for each sample
305 let mut silhouette_values = Vec::with_capacity(n_samples);
306 let mut sample_clusters = Vec::with_capacity(n_samples);
307
308 for i in 0..n_samples {
309 let label_i = labels.iter().nth(i).ok_or_else(|| {
310 MetricsError::InvalidInput(format!("Could not access index {i} in labels"))
311 })?;
312 let cluster_i = &clusters[label_i];
313 sample_clusters.push(*label_i);
314
315 // Calculate the mean intra-cluster distance (a)
316 let mut a = F::zero();
317 let mut count_a = 0;
318
319 for &j in cluster_i {
320 if i == j {
321 continue;
322 }
323 a = a + distances[[i, j]];
324 count_a += 1;
325 }
326
327 // Handle single sample in cluster (set a to 0)
328 if count_a > 0 {
329 a = a / F::from(count_a).unwrap();
330 }
331
332 // Calculate the mean nearest-cluster distance (b)
333 let mut b = None;
334
335 for (label_j, cluster_j) in &clusters {
336 if *label_j == *label_i {
337 continue;
338 }
339
340 // Calculate mean distance to this cluster
341 let mut cluster_dist = F::zero();
342 for &j in cluster_j {
343 cluster_dist = cluster_dist + distances[[i, j]];
344 }
345 let cluster_dist = cluster_dist / F::from(cluster_j.len()).unwrap();
346
347 // Update b if this is the closest cluster
348 if let Some(current_b) = b {
349 if cluster_dist < current_b {
350 b = Some(cluster_dist);
351 }
352 } else {
353 b = Some(cluster_dist);
354 }
355 }
356
357 // Calculate silhouette score
358 let s = if let Some(b) = b {
359 if a < b {
360 F::one() - a / b
361 } else if a > b {
362 b / a - F::one()
363 } else {
364 F::zero()
365 }
366 } else {
367 F::zero() // Will never happen if there are at least 2 clusters
368 };
369
370 silhouette_values.push(s);
371 }
372
373 // Calculate mean silhouette score
374 let sum = silhouette_values
375 .iter()
376 .fold(F::zero(), |acc, &val| acc + val);
377 let mean_score = sum / F::from(n_samples).unwrap();
378
379 // Calculate cluster-wise silhouette scores
380 let mut cluster_scores = HashMap::new();
381 for (label, indices) in &clusters {
382 let mut cluster_sum = F::zero();
383 for &idx in indices {
384 cluster_sum = cluster_sum + silhouette_values[idx];
385 }
386 let cluster_mean = cluster_sum / F::from(indices.len()).unwrap();
387 cluster_scores.insert(*label, cluster_mean);
388 }
389
390 // Create a mapping from original cluster labels to consecutive integers (for visualization)
391 let unique_labels: Vec<_> = clusters.keys().cloned().collect();
392 let mut cluster_mapping = HashMap::new();
393 for (i, &label) in unique_labels.iter().enumerate() {
394 cluster_mapping.insert(label, i);
395 }
396
397 // Create list of cluster sizes
398 let mut cluster_sizes = vec![0; cluster_mapping.len()];
399 for (label, indices) in &clusters {
400 let mapped_idx = cluster_mapping[label];
401 cluster_sizes[mapped_idx] = indices.len();
402 }
403
404 // Create sorted indices for visualization
405 // First, sort by cluster mapping
406 // Then, within each cluster, sort by silhouette value (descending)
407 let mut samples_with_scores: Vec<(usize, F, usize)> = silhouette_values
408 .iter()
409 .enumerate()
410 .map(|(i, &s)| (i, s, cluster_mapping[&sample_clusters[i]]))
411 .collect();
412
413 // Sort by cluster (ascending), then by silhouette value (descending)
414 samples_with_scores.sort_by(|a, b| a.2.cmp(&b.2).then(b.1.partial_cmp(&a.1).unwrap()));
415
416 let sorted_indices = samples_with_scores.iter().map(|&(i, _, _)| i).collect();
417
418 Ok(SilhouetteAnalysis {
419 sample_values: silhouette_values,
420 mean_score,
421 cluster_scores,
422 sorted_indices,
423 cluster_mapping,
424 cluster_sizes,
425 })
426}
427
428/// Calculates the Davies-Bouldin index for a clustering
429///
430/// The Davies-Bouldin index measures the average similarity between clusters,
431/// where the similarity is a ratio of within-cluster distances to between-cluster distances.
432/// The lower the value, the better the clustering.
433///
434/// # Arguments
435///
436/// * `x` - Array of shape (n_samples, n_features) - The data
437/// * `labels` - Array of shape (n_samples,) - Predicted labels for each sample
438///
439/// # Returns
440///
441/// * The Davies-Bouldin index
442///
443/// # Examples
444///
445/// ```
446/// use scirs2_core::ndarray::{array, Array2};
447/// use scirs2_metrics::clustering::davies_bouldin_score;
448///
449/// // Create a small dataset with 2 clusters
450/// let x = Array2::from_shape_vec((6, 2), vec![
451/// 1.0, 2.0,
452/// 1.5, 1.8,
453/// 1.2, 2.2,
454/// 5.0, 6.0,
455/// 5.2, 5.8,
456/// 5.5, 6.2,
457/// ]).unwrap();
458///
459/// let labels = array![0, 0, 0, 1, 1, 1];
460///
461/// let score = davies_bouldin_score(&x, &labels).unwrap();
462/// assert!(score < 0.5); // Low score for well-separated clusters
463/// ```
464#[allow(dead_code)]
465pub fn davies_bouldin_score<F, S1, S2, D>(
466 x: &ArrayBase<S1, Ix2>,
467 labels: &ArrayBase<S2, D>,
468) -> Result<F>
469where
470 F: Float + NumCast + std::fmt::Debug + scirs2_core::ndarray::ScalarOperand + 'static,
471 S1: Data<Elem = F>,
472 S2: Data<Elem = usize>,
473 D: Dimension,
474{
475 // Check that x and labels have the same number of samples
476 let n_samples = x.shape()[0];
477 if n_samples != labels.len() {
478 return Err(MetricsError::InvalidInput(format!(
479 "x has {} samples, but labels has {} samples",
480 n_samples,
481 labels.len()
482 )));
483 }
484
485 // Check that there are at least 2 samples
486 if n_samples < 2 {
487 return Err(MetricsError::InvalidInput(
488 "n_samples must be at least 2".to_string(),
489 ));
490 }
491
492 // Group samples by label
493 let clusters = group_by_labels(x, labels)?;
494
495 // Check that there are at least 2 clusters
496 if clusters.len() < 2 {
497 return Err(MetricsError::InvalidInput(
498 "Number of labels is 1. Davies-Bouldin index is undefined for a single cluster."
499 .to_string(),
500 ));
501 }
502
503 // Compute centroids for each cluster
504 let mut centroids = HashMap::new();
505 for (&label, indices) in &clusters {
506 let mut centroid = Array1::<F>::zeros(x.shape()[1]);
507 for &idx in indices {
508 centroid = centroid + x.row(idx).to_owned();
509 }
510 centroid = centroid / F::from(indices.len()).unwrap();
511 centroids.insert(label, centroid);
512 }
513
514 // Compute average distance to centroid for each cluster
515 let mut avg_distances = HashMap::new();
516 for (&label, indices) in &clusters {
517 let centroid = centroids.get(&label).unwrap();
518 let mut total_distance = F::zero();
519 for &idx in indices {
520 total_distance = total_distance
521 + calculate_distance(&x.row(idx).to_vec(), ¢roid.to_vec(), "euclidean")?;
522 }
523 let avg_distance = total_distance / F::from(indices.len()).unwrap();
524 avg_distances.insert(label, avg_distance);
525 }
526
527 // Compute Davies-Bouldin index
528 let mut db_index = F::zero();
529 let labels_vec: Vec<_> = clusters.keys().cloned().collect();
530
531 for i in 0..labels_vec.len() {
532 let label_i = labels_vec[i];
533 let centroid_i = centroids.get(&label_i).unwrap();
534 let avg_dist_i = avg_distances.get(&label_i).unwrap();
535
536 let mut max_ratio = F::zero();
537 for (j, &label_j) in labels_vec.iter().enumerate() {
538 if i == j {
539 continue;
540 }
541 let centroid_j = centroids.get(&label_j).unwrap();
542 let avg_dist_j = avg_distances.get(&label_j).unwrap();
543
544 // Distance between centroids
545 let centroid_dist =
546 calculate_distance(¢roid_i.to_vec(), ¢roid_j.to_vec(), "euclidean")?;
547
548 // Ratio of sum of intra-cluster distances to inter-cluster distance
549 let ratio = (*avg_dist_i + *avg_dist_j) / centroid_dist;
550
551 // Update max ratio
552 if ratio > max_ratio {
553 max_ratio = ratio;
554 }
555 }
556
557 db_index = db_index + max_ratio;
558 }
559
560 // Normalize by number of clusters
561 Ok(db_index / F::from(labels_vec.len()).unwrap())
562}
563
564/// Calculates the Calinski-Harabasz index (Variance Ratio Criterion)
565///
566/// The Calinski-Harabasz index is defined as the ratio of the between-clusters
567/// dispersion and the within-cluster dispersion. Higher values indicate better clustering.
568///
569/// # Arguments
570///
571/// * `x` - Array of shape (n_samples, n_features) - The data
572/// * `labels` - Array of shape (n_samples,) - Predicted labels for each sample
573///
574/// # Returns
575///
576/// * The Calinski-Harabasz index
577///
578/// # Examples
579///
580/// ```
581/// use scirs2_core::ndarray::{array, Array2};
582/// use scirs2_metrics::clustering::calinski_harabasz_score;
583///
584/// // Create a small dataset with 2 clusters
585/// let x = Array2::from_shape_vec((6, 2), vec![
586/// 1.0, 2.0,
587/// 1.5, 1.8,
588/// 1.2, 2.2,
589/// 5.0, 6.0,
590/// 5.2, 5.8,
591/// 5.5, 6.2,
592/// ]).unwrap();
593///
594/// let labels = array![0, 0, 0, 1, 1, 1];
595///
596/// let score = calinski_harabasz_score(&x, &labels).unwrap();
597/// assert!(score > 50.0); // High score for well-separated clusters
598/// ```
599#[allow(dead_code)]
600pub fn calinski_harabasz_score<F, S1, S2, D>(
601 x: &ArrayBase<S1, Ix2>,
602 labels: &ArrayBase<S2, D>,
603) -> Result<F>
604where
605 F: Float + NumCast + std::fmt::Debug + scirs2_core::ndarray::ScalarOperand + 'static,
606 S1: Data<Elem = F>,
607 S2: Data<Elem = usize>,
608 D: Dimension,
609{
610 // Check that x and labels have the same number of samples
611 let n_samples = x.shape()[0];
612 if n_samples != labels.len() {
613 return Err(MetricsError::InvalidInput(format!(
614 "x has {} samples, but labels has {} samples",
615 n_samples,
616 labels.len()
617 )));
618 }
619
620 // Check that there are at least 2 samples
621 if n_samples < 2 {
622 return Err(MetricsError::InvalidInput(
623 "n_samples must be at least 2".to_string(),
624 ));
625 }
626
627 // Group samples by label
628 let clusters = group_by_labels(x, labels)?;
629
630 // Check that there are at least 2 clusters
631 if clusters.len() < 2 {
632 return Err(MetricsError::InvalidInput(
633 "Number of labels is 1. Calinski-Harabasz index is undefined for a single cluster."
634 .to_string(),
635 ));
636 }
637
638 // Compute the global centroid
639 let mut global_centroid = Array1::<F>::zeros(x.shape()[1]);
640 for i in 0..n_samples {
641 global_centroid = global_centroid + x.row(i).to_owned();
642 }
643 global_centroid = global_centroid / F::from(n_samples).unwrap();
644
645 // Compute centroids for each cluster
646 let mut centroids = HashMap::new();
647 for (&label, indices) in &clusters {
648 let mut centroid = Array1::<F>::zeros(x.shape()[1]);
649 for &idx in indices {
650 centroid = centroid + x.row(idx).to_owned();
651 }
652 centroid = centroid / F::from(indices.len()).unwrap();
653 centroids.insert(label, centroid);
654 }
655
656 // Compute between-cluster dispersion
657 let mut between_disp = F::zero();
658 for (label, indices) in &clusters {
659 let cluster_size = F::from(indices.len()).unwrap();
660 let centroid = centroids.get(label).unwrap();
661
662 // Calculate squared distance between cluster centroid and global centroid
663 let mut squared_dist = F::zero();
664 for (c, g) in centroid.iter().zip(global_centroid.iter()) {
665 let diff = *c - *g;
666 squared_dist = squared_dist + diff * diff;
667 }
668
669 between_disp = between_disp + cluster_size * squared_dist;
670 }
671
672 // Compute within-cluster dispersion
673 let mut within_disp = F::zero();
674 for (label, indices) in &clusters {
675 let centroid = centroids.get(label).unwrap();
676
677 let mut cluster_disp = F::zero();
678 for &idx in indices {
679 let mut squared_dist = F::zero();
680 for (x_val, c_val) in x.row(idx).iter().zip(centroid.iter()) {
681 let diff = *x_val - *c_val;
682 squared_dist = squared_dist + diff * diff;
683 }
684 cluster_disp = cluster_disp + squared_dist;
685 }
686
687 within_disp = within_disp + cluster_disp;
688 }
689
690 // Handle edge cases
691 if within_disp <= F::epsilon() {
692 return Err(MetricsError::CalculationError(
693 "Within-cluster dispersion is zero".to_string(),
694 ));
695 }
696
697 // Calculate Calinski-Harabasz index
698 let n_clusters = F::from(clusters.len()).unwrap();
699 Ok(
700 between_disp * (F::from(n_samples - clusters.len()).unwrap())
701 / (within_disp * (n_clusters - F::one())),
702 )
703}
704
705/// Calculates the Dunn index for a clustering
706///
707/// The Dunn index is the ratio of the minimum inter-cluster distance to the maximum
708/// intra-cluster distance. Higher values indicate better clustering.
709///
710/// # Arguments
711///
712/// * `x` - Array of shape (n_samples, n_features) - The data
713/// * `labels` - Array of shape (n_samples,) - Predicted labels for each sample
714///
715/// # Returns
716///
717/// * The Dunn index
718#[allow(dead_code)]
719pub fn dunn_index<F, S1, S2, D>(x: &ArrayBase<S1, Ix2>, labels: &ArrayBase<S2, D>) -> Result<F>
720where
721 F: Float + NumCast + std::fmt::Debug + 'static,
722 S1: Data<Elem = F>,
723 S2: Data<Elem = usize>,
724 D: Dimension,
725{
726 // Check that x and labels have the same number of samples
727 let n_samples = x.shape()[0];
728 if n_samples != labels.len() {
729 return Err(MetricsError::InvalidInput(format!(
730 "x has {} samples, but labels has {} samples",
731 n_samples,
732 labels.len()
733 )));
734 }
735
736 // Check that there are at least 2 samples
737 if n_samples < 2 {
738 return Err(MetricsError::InvalidInput(
739 "n_samples must be at least 2".to_string(),
740 ));
741 }
742
743 // Group samples by label
744 let clusters = group_by_labels(x, labels)?;
745
746 // Check that there are at least 2 clusters
747 if clusters.len() < 2 {
748 return Err(MetricsError::InvalidInput(
749 "Number of labels is 1. Dunn index is undefined for a single cluster.".to_string(),
750 ));
751 }
752
753 // Compute distance matrix
754 let distances = pairwise_distances(x, "euclidean")?;
755
756 // Calculate maximum intra-cluster distance for each cluster
757 let mut max_intra_cluster_distance = F::zero();
758 for indices in clusters.values() {
759 let mut max_distance = F::zero();
760 for (i, &idx1) in indices.iter().enumerate() {
761 for &idx2 in &indices[i + 1..] {
762 let dist = distances[[idx1, idx2]];
763 if dist > max_distance {
764 max_distance = dist;
765 }
766 }
767 }
768 if max_distance > max_intra_cluster_distance {
769 max_intra_cluster_distance = max_distance;
770 }
771 }
772
773 if max_intra_cluster_distance <= F::epsilon() {
774 return Err(MetricsError::CalculationError(
775 "Maximum intra-cluster distance is zero".to_string(),
776 ));
777 }
778
779 // Calculate minimum inter-cluster distance
780 let mut min_inter_cluster_distance = F::infinity();
781 let cluster_labels: Vec<_> = clusters.keys().collect();
782
783 for i in 0..cluster_labels.len() {
784 for j in i + 1..cluster_labels.len() {
785 let cluster_i = &clusters[cluster_labels[i]];
786 let cluster_j = &clusters[cluster_labels[j]];
787
788 let mut min_distance = F::infinity();
789 for &idx1 in cluster_i {
790 for &idx2 in cluster_j {
791 let dist = distances[[idx1, idx2]];
792 if dist < min_distance {
793 min_distance = dist;
794 }
795 }
796 }
797
798 if min_distance < min_inter_cluster_distance {
799 min_inter_cluster_distance = min_distance;
800 }
801 }
802 }
803
804 // Calculate Dunn index
805 Ok(min_inter_cluster_distance / max_intra_cluster_distance)
806}