scirs2_cluster/metrics/core/
basic.rs1use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
8use scirs2_core::numeric::{Float, FromPrimitive};
9use std::fmt::Debug;
10
11use crate::error::{ClusteringError, Result};
12use crate::metrics::silhouette_score;
13
14pub fn davies_bouldin_score<F>(data: ArrayView2<F>, labels: ArrayView1<i32>) -> Result<F>
47where
48 F: Float + FromPrimitive + Debug + PartialOrd + 'static,
49{
50 if data.shape()[0] != labels.shape()[0] {
51 return Err(ClusteringError::InvalidInput(
52 "Data and labels must have the same number of samples".to_string(),
53 ));
54 }
55
56 let mut unique_labels = Vec::new();
58 for &label in labels.iter() {
59 if label >= 0 && !unique_labels.contains(&label) {
60 unique_labels.push(label);
61 }
62 }
63
64 let n_clusters = unique_labels.len();
65
66 if n_clusters < 2 {
67 return Err(ClusteringError::InvalidInput(
68 "Davies-Bouldin score requires at least 2 clusters".to_string(),
69 ));
70 }
71
72 let mut centers = Array2::<F>::zeros((n_clusters, data.shape()[1]));
74 let mut cluster_sizes = vec![0; n_clusters];
75
76 for (i, &label) in labels.iter().enumerate() {
77 if label >= 0 {
78 let cluster_idx = unique_labels.iter().position(|&l| l == label).unwrap();
79 centers
80 .row_mut(cluster_idx)
81 .scaled_add(F::one(), &data.row(i));
82 cluster_sizes[cluster_idx] += 1;
83 }
84 }
85
86 for (i, &size) in cluster_sizes.iter().enumerate() {
88 if size > 0 {
89 centers
90 .row_mut(i)
91 .mapv_inplace(|x| x / F::from(size).unwrap());
92 }
93 }
94
95 let mut scatter = vec![F::zero(); n_clusters];
97 for (i, &label) in labels.iter().enumerate() {
98 if label >= 0 {
99 let cluster_idx = unique_labels.iter().position(|&l| l == label).unwrap();
100 let center = centers.row(cluster_idx);
101 let diff = &data.row(i) - ¢er;
102 let distance = diff.dot(&diff).sqrt();
103 scatter[cluster_idx] = scatter[cluster_idx] + distance;
104 }
105 }
106
107 for (i, &size) in cluster_sizes.iter().enumerate() {
109 if size > 0 {
110 scatter[i] = scatter[i] / F::from(size).unwrap();
111 }
112 }
113
114 let mut db_index = F::zero();
116
117 for i in 0..n_clusters {
118 let mut max_ratio = F::zero();
119
120 for j in 0..n_clusters {
121 if i != j {
122 let between_distance = (¢ers.row(i) - ¢ers.row(j))
123 .mapv(|x| x * x)
124 .sum()
125 .sqrt();
126
127 if between_distance > F::zero() {
128 let ratio = (scatter[i] + scatter[j]) / between_distance;
129 if ratio > max_ratio {
130 max_ratio = ratio;
131 }
132 }
133 }
134 }
135
136 db_index = db_index + max_ratio;
137 }
138
139 db_index = db_index / F::from(n_clusters).unwrap();
140 Ok(db_index)
141}
142
143pub fn calinski_harabasz_score<F>(data: ArrayView2<F>, labels: ArrayView1<i32>) -> Result<F>
176where
177 F: Float + FromPrimitive + Debug + PartialOrd + 'static,
178{
179 if data.shape()[0] != labels.shape()[0] {
180 return Err(ClusteringError::InvalidInput(
181 "Data and labels must have the same number of samples".to_string(),
182 ));
183 }
184
185 let n_samples = data.shape()[0];
186 let n_features = data.shape()[1];
187
188 let mut unique_labels = Vec::new();
190 for &label in labels.iter() {
191 if label >= 0 && !unique_labels.contains(&label) {
192 unique_labels.push(label);
193 }
194 }
195
196 let n_clusters = unique_labels.len();
197
198 if n_clusters < 2 {
199 return Err(ClusteringError::InvalidInput(
200 "Calinski-Harabasz score requires at least 2 clusters".to_string(),
201 ));
202 }
203
204 if n_clusters >= n_samples {
205 return Err(ClusteringError::InvalidInput(
206 "Number of clusters must be less than number of samples".to_string(),
207 ));
208 }
209
210 let mut overall_mean = Array1::<F>::zeros(n_features);
212 let mut valid_samples = 0;
213
214 for (i, &label) in labels.iter().enumerate() {
215 if label >= 0 {
216 overall_mean.scaled_add(F::one(), &data.row(i));
217 valid_samples += 1;
218 }
219 }
220
221 overall_mean.mapv_inplace(|x| x / F::from(valid_samples).unwrap());
222
223 let mut centers = Array2::<F>::zeros((n_clusters, n_features));
225 let mut cluster_sizes = vec![0; n_clusters];
226
227 for (i, &label) in labels.iter().enumerate() {
228 if label >= 0 {
229 let cluster_idx = unique_labels.iter().position(|&l| l == label).unwrap();
230 centers
231 .row_mut(cluster_idx)
232 .scaled_add(F::one(), &data.row(i));
233 cluster_sizes[cluster_idx] += 1;
234 }
235 }
236
237 for (i, &size) in cluster_sizes.iter().enumerate() {
239 if size > 0 {
240 centers
241 .row_mut(i)
242 .mapv_inplace(|x| x / F::from(size).unwrap());
243 }
244 }
245
246 let mut ssb = F::zero();
248 for (i, &size) in cluster_sizes.iter().enumerate() {
249 if size > 0 {
250 let diff = ¢ers.row(i) - &overall_mean;
251 ssb = ssb + F::from(size).unwrap() * diff.dot(&diff);
252 }
253 }
254
255 let mut ssw = F::zero();
257 for (i, &label) in labels.iter().enumerate() {
258 if label >= 0 {
259 let cluster_idx = unique_labels.iter().position(|&l| l == label).unwrap();
260 let diff = &data.row(i) - ¢ers.row(cluster_idx);
261 ssw = ssw + diff.dot(&diff);
262 }
263 }
264
265 if ssw == F::zero() {
267 return Ok(F::infinity());
268 }
269
270 let score = (ssb / ssw) * F::from(valid_samples - n_clusters).unwrap()
271 / F::from(n_clusters - 1).unwrap();
272
273 Ok(score)
274}
275
276pub fn mean_silhouette_score<F>(data: ArrayView2<F>, labels: ArrayView1<i32>) -> Result<F>
298where
299 F: Float + FromPrimitive + 'static,
300{
301 silhouette_score(data, labels)
302}
303
304pub fn adjusted_rand_index<F>(
334 labels_true: ArrayView1<i32>,
335 labels_pred: ArrayView1<i32>,
336) -> Result<F>
337where
338 F: Float + FromPrimitive + Debug + 'static,
339{
340 if labels_true.len() != labels_pred.len() {
341 return Err(ClusteringError::InvalidInput(
342 "Labels arrays must have the same length".to_string(),
343 ));
344 }
345
346 let n = labels_true.len();
347 if n == 0 {
348 return Err(ClusteringError::InvalidInput(
349 "Empty labels arrays".to_string(),
350 ));
351 }
352
353 let mut true_labels = std::collections::HashSet::new();
355 let mut pred_labels = std::collections::HashSet::new();
356
357 for &label in labels_true.iter() {
358 true_labels.insert(label);
359 }
360 for &label in labels_pred.iter() {
361 pred_labels.insert(label);
362 }
363
364 let n_true = true_labels.len();
365 let n_pred = pred_labels.len();
366
367 let true_label_map: std::collections::HashMap<i32, usize> = true_labels
369 .iter()
370 .enumerate()
371 .map(|(i, &label)| (label, i))
372 .collect();
373 let pred_label_map: std::collections::HashMap<i32, usize> = pred_labels
374 .iter()
375 .enumerate()
376 .map(|(i, &label)| (label, i))
377 .collect();
378
379 let mut contingency = Array2::<usize>::zeros((n_true, n_pred));
381 for i in 0..n {
382 let true_idx = true_label_map[&labels_true[i]];
383 let pred_idx = pred_label_map[&labels_pred[i]];
384 contingency[[true_idx, pred_idx]] += 1;
385 }
386
387 let sum_comb_c = contingency
389 .iter()
390 .map(|&n_ij| {
391 if n_ij >= 2 {
392 (n_ij * (n_ij - 1)) / 2
393 } else {
394 0
395 }
396 })
397 .sum::<usize>();
398
399 let sum_a = contingency
400 .sum_axis(Axis(1))
401 .iter()
402 .map(|&n_i| if n_i >= 2 { (n_i * (n_i - 1)) / 2 } else { 0 })
403 .sum::<usize>();
404
405 let sum_b = contingency
406 .sum_axis(Axis(0))
407 .iter()
408 .map(|&n_j| if n_j >= 2 { (n_j * (n_j - 1)) / 2 } else { 0 })
409 .sum::<usize>();
410
411 let n_choose_2 = if n >= 2 { (n * (n - 1)) / 2 } else { 0 };
412
413 let expected_index =
415 F::from(sum_a).unwrap() * F::from(sum_b).unwrap() / F::from(n_choose_2).unwrap();
416 let max_index = (F::from(sum_a).unwrap() + F::from(sum_b).unwrap()) / F::from(2.0).unwrap();
417 let index = F::from(sum_comb_c).unwrap();
418
419 if max_index == expected_index {
421 return Ok(F::zero());
422 }
423
424 let ari = (index - expected_index) / (max_index - expected_index);
426 Ok(ari)
427}
428
429pub fn normalized_mutual_info<F>(
457 labels_true: ArrayView1<i32>,
458 labels_pred: ArrayView1<i32>,
459 average_method: &str,
460) -> Result<F>
461where
462 F: Float + FromPrimitive + Debug + 'static,
463{
464 if labels_true.len() != labels_pred.len() {
465 return Err(ClusteringError::InvalidInput(
466 "Labels arrays must have the same length".to_string(),
467 ));
468 }
469
470 let n = labels_true.len();
471 if n == 0 {
472 return Ok(F::one());
473 }
474
475 let mi = mutual_info::<F>(labels_true, labels_pred)?;
477
478 let h_true = entropy::<F>(labels_true)?;
480 let h_pred = entropy::<F>(labels_pred)?;
481
482 if h_true == F::zero() && h_pred == F::zero() {
484 return Ok(F::one());
485 }
486
487 let normalizer = match average_method {
489 "arithmetic" => (h_true + h_pred) / F::from(2.0).unwrap(),
490 "geometric" => (h_true * h_pred).sqrt(),
491 "min" => h_true.min(h_pred),
492 "max" => h_true.max(h_pred),
493 _ => {
494 return Err(ClusteringError::InvalidInput(
495 "Invalid average method. Use 'arithmetic', 'geometric', 'min', or 'max'"
496 .to_string(),
497 ))
498 }
499 };
500
501 if normalizer == F::zero() {
502 return Ok(F::zero());
503 }
504
505 Ok(mi / normalizer)
506}
507
508pub fn homogeneity_completeness_v_measure<F>(
540 labels_true: ArrayView1<i32>,
541 labels_pred: ArrayView1<i32>,
542) -> Result<(F, F, F)>
543where
544 F: Float + FromPrimitive + Debug + 'static,
545{
546 if labels_true.len() != labels_pred.len() {
547 return Err(ClusteringError::InvalidInput(
548 "Labels arrays must have the same length".to_string(),
549 ));
550 }
551
552 let n = labels_true.len();
553 if n == 0 {
554 return Ok((F::one(), F::one(), F::one()));
555 }
556
557 let h_true = entropy::<F>(labels_true)?;
559 let h_pred = entropy::<F>(labels_pred)?;
560
561 if h_true == F::zero() {
563 return Ok((F::one(), F::one(), F::one()));
564 }
565 if h_pred == F::zero() {
566 return Ok((F::one(), F::one(), F::one()));
567 }
568
569 let h_true_given_pred = conditional_entropy::<F>(labels_true, labels_pred)?;
571 let h_pred_given_true = conditional_entropy::<F>(labels_pred, labels_true)?;
572
573 let homogeneity = if h_pred == F::zero() {
575 F::one()
576 } else {
577 F::one() - h_true_given_pred / h_true
578 };
579
580 let completeness = if h_true == F::zero() {
582 F::one()
583 } else {
584 F::one() - h_pred_given_true / h_pred
585 };
586
587 let v_measure = if homogeneity + completeness == F::zero() {
589 F::zero()
590 } else {
591 F::from(2.0).unwrap() * homogeneity * completeness / (homogeneity + completeness)
592 };
593
594 Ok((homogeneity, completeness, v_measure))
595}
596
597fn mutual_info<F>(labels_true: ArrayView1<i32>, labels_pred: ArrayView1<i32>) -> Result<F>
601where
602 F: Float + FromPrimitive + Debug + 'static,
603{
604 let n = labels_true.len() as f64;
605 let contingency = build_contingency_matrix(labels_true, labels_pred)?;
606
607 let mut mi = F::zero();
608 let n_rows = contingency.shape()[0];
609 let n_cols = contingency.shape()[1];
610
611 let row_sums = contingency.sum_axis(Axis(1));
613 let col_sums = contingency.sum_axis(Axis(0));
614
615 for i in 0..n_rows {
616 for j in 0..n_cols {
617 let n_ij = contingency[[i, j]] as f64;
618 if n_ij > 0.0 {
619 let n_i = row_sums[i] as f64;
620 let n_j = col_sums[j] as f64;
621 let term = n_ij / n * (n_ij / (n_i * n_j / n)).ln();
622 mi = mi + F::from(term).unwrap();
623 }
624 }
625 }
626
627 Ok(mi)
628}
629
630fn entropy<F>(labels: ArrayView1<i32>) -> Result<F>
632where
633 F: Float + FromPrimitive + Debug + 'static,
634{
635 let n = labels.len() as f64;
636 let mut label_counts = std::collections::HashMap::new();
637
638 for &label in labels.iter() {
639 *label_counts.entry(label).or_insert(0) += 1;
640 }
641
642 let mut h = F::zero();
643 for &count in label_counts.values() {
644 if count > 0 {
645 let p = count as f64 / n;
646 h = h - F::from(p * p.ln()).unwrap();
647 }
648 }
649
650 Ok(h)
651}
652
653fn build_contingency_matrix(
655 labels_true: ArrayView1<i32>,
656 labels_pred: ArrayView1<i32>,
657) -> Result<Array2<usize>> {
658 let mut true_labels = std::collections::BTreeSet::new();
659 let mut pred_labels = std::collections::BTreeSet::new();
660
661 for &label in labels_true.iter() {
662 true_labels.insert(label);
663 }
664 for &label in labels_pred.iter() {
665 pred_labels.insert(label);
666 }
667
668 let true_label_map: std::collections::HashMap<i32, usize> = true_labels
669 .iter()
670 .enumerate()
671 .map(|(i, &label)| (label, i))
672 .collect();
673 let pred_label_map: std::collections::HashMap<i32, usize> = pred_labels
674 .iter()
675 .enumerate()
676 .map(|(i, &label)| (label, i))
677 .collect();
678
679 let mut contingency = Array2::<usize>::zeros((true_labels.len(), pred_labels.len()));
680 for i in 0..labels_true.len() {
681 let true_idx = true_label_map[&labels_true[i]];
682 let pred_idx = pred_label_map[&labels_pred[i]];
683 contingency[[true_idx, pred_idx]] += 1;
684 }
685
686 Ok(contingency)
687}
688
689fn conditional_entropy<F>(labels_x: ArrayView1<i32>, labels_y: ArrayView1<i32>) -> Result<F>
691where
692 F: Float + FromPrimitive + Debug + 'static,
693{
694 let n = labels_x.len() as f64;
695 let contingency = build_contingency_matrix(labels_x, labels_y)?;
696
697 let mut h_xy = F::zero();
698 let col_sums = contingency.sum_axis(Axis(0));
699
700 for j in 0..contingency.shape()[1] {
701 let n_j = col_sums[j] as f64;
702 if n_j > 0.0 {
703 for i in 0..contingency.shape()[0] {
704 let n_ij = contingency[[i, j]] as f64;
705 if n_ij > 0.0 {
706 let term = n_ij / n * (n_ij / n_j).ln();
707 h_xy = h_xy - F::from(term).unwrap();
708 }
709 }
710 }
711 }
712
713 Ok(h_xy)
714}
715
716#[cfg(test)]
717mod tests {
718 use super::*;
719 use scirs2_core::ndarray::Array2;
720
721 #[test]
722 fn test_davies_bouldin_score() {
723 let data =
724 Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 0.1, 0.1, 5.0, 5.0, 5.1, 5.1]).unwrap();
725 let labels = Array1::from_vec(vec![0, 0, 1, 1]);
726
727 let score = davies_bouldin_score(data.view(), labels.view()).unwrap();
728 assert!(score >= 0.0);
729 }
730
731 #[test]
732 fn test_calinski_harabasz_score() {
733 let data =
734 Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 0.1, 0.1, 5.0, 5.0, 5.1, 5.1]).unwrap();
735 let labels = Array1::from_vec(vec![0, 0, 1, 1]);
736
737 let score = calinski_harabasz_score(data.view(), labels.view()).unwrap();
738 assert!(score > 0.0);
739 }
740
741 #[test]
742 fn test_adjusted_rand_index() {
743 let labels_true = Array1::from_vec(vec![0, 0, 1, 1, 2, 2]);
744 let labels_pred = Array1::from_vec(vec![0, 0, 2, 2, 1, 1]);
745
746 let ari: f64 = adjusted_rand_index(labels_true.view(), labels_pred.view()).unwrap();
747 assert!(ari >= -1.0 && ari <= 1.0);
748 }
749
750 #[test]
751 fn test_normalized_mutual_info() {
752 let labels_true = Array1::from_vec(vec![0, 0, 1, 1]);
753 let labels_pred = Array1::from_vec(vec![0, 0, 1, 1]);
754
755 let nmi: f64 =
756 normalized_mutual_info(labels_true.view(), labels_pred.view(), "arithmetic").unwrap();
757 assert!((nmi - 1.0).abs() < 1e-6);
758 }
759
760 #[test]
761 fn test_homogeneity_completeness_v_measure() {
762 let labels_true = Array1::from_vec(vec![0, 0, 1, 1, 2, 2]);
763 let labels_pred = Array1::from_vec(vec![0, 0, 1, 1, 1, 1]);
764
765 let (h, c, v): (f64, f64, f64) =
766 homogeneity_completeness_v_measure(labels_true.view(), labels_pred.view()).unwrap();
767
768 assert!(h >= 0.0 && h <= 1.0);
769 assert!(c >= 0.0 && c <= 1.0);
770 assert!(v >= 0.0 && v <= 1.0);
771 }
772}