scirs2_cluster/metrics/core/
basic.rs

1//! Basic clustering evaluation metrics
2//!
3//! This module provides fundamental clustering evaluation metrics including
4//! Davies-Bouldin score, Calinski-Harabasz score, Adjusted Rand Index,
5//! and Normalized Mutual Information.
6
7use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
8use scirs2_core::numeric::{Float, FromPrimitive};
9use std::fmt::Debug;
10
11use crate::error::{ClusteringError, Result};
12use crate::metrics::silhouette_score;
13
14/// Davies-Bouldin score for clustering evaluation.
15///
16/// The Davies-Bouldin index measures the average similarity between clusters,
17/// where similarity is the ratio of within-cluster and between-cluster distances.
18/// A lower score indicates better clustering.
19///
20/// # Arguments
21///
22/// * `data` - Input data (n_samples x n_features)
23/// * `labels` - Cluster labels for each sample
24///
25/// # Returns
26///
27/// The Davies-Bouldin score (lower is better)
28///
29/// # Example
30///
31/// ```
32/// use scirs2_core::ndarray::{ArrayView1, Array1, Array2};
33/// use scirs2_cluster::metrics::davies_bouldin_score;
34///
35/// let data = Array2::from_shape_vec((4, 2), vec![
36///     0.0, 0.0,
37///     0.1, 0.1,
38///     5.0, 5.0,
39///     5.1, 5.1,
40/// ]).unwrap();
41/// let labels = Array1::from_vec(vec![0, 0, 1, 1]);
42///
43/// let score = davies_bouldin_score(data.view(), labels.view()).unwrap();
44/// assert!(score < 0.5);  // Should be low for well-separated clusters
45/// ```
46pub fn davies_bouldin_score<F>(data: ArrayView2<F>, labels: ArrayView1<i32>) -> Result<F>
47where
48    F: Float + FromPrimitive + Debug + PartialOrd + 'static,
49{
50    if data.shape()[0] != labels.shape()[0] {
51        return Err(ClusteringError::InvalidInput(
52            "Data and labels must have the same number of samples".to_string(),
53        ));
54    }
55
56    // Find unique cluster labels
57    let mut unique_labels = Vec::new();
58    for &label in labels.iter() {
59        if label >= 0 && !unique_labels.contains(&label) {
60            unique_labels.push(label);
61        }
62    }
63
64    let n_clusters = unique_labels.len();
65
66    if n_clusters < 2 {
67        return Err(ClusteringError::InvalidInput(
68            "Davies-Bouldin score requires at least 2 clusters".to_string(),
69        ));
70    }
71
72    // Compute cluster centers
73    let mut centers = Array2::<F>::zeros((n_clusters, data.shape()[1]));
74    let mut cluster_sizes = vec![0; n_clusters];
75
76    for (i, &label) in labels.iter().enumerate() {
77        if label >= 0 {
78            let cluster_idx = unique_labels.iter().position(|&l| l == label).unwrap();
79            centers
80                .row_mut(cluster_idx)
81                .scaled_add(F::one(), &data.row(i));
82            cluster_sizes[cluster_idx] += 1;
83        }
84    }
85
86    // Normalize to get averages
87    for (i, &size) in cluster_sizes.iter().enumerate() {
88        if size > 0 {
89            centers
90                .row_mut(i)
91                .mapv_inplace(|x| x / F::from(size).unwrap());
92        }
93    }
94
95    // Compute within-cluster scatter
96    let mut scatter = vec![F::zero(); n_clusters];
97    for (i, &label) in labels.iter().enumerate() {
98        if label >= 0 {
99            let cluster_idx = unique_labels.iter().position(|&l| l == label).unwrap();
100            let center = centers.row(cluster_idx);
101            let diff = &data.row(i) - &center;
102            let distance = diff.dot(&diff).sqrt();
103            scatter[cluster_idx] = scatter[cluster_idx] + distance;
104        }
105    }
106
107    // Compute average within-cluster scatter
108    for (i, &size) in cluster_sizes.iter().enumerate() {
109        if size > 0 {
110            scatter[i] = scatter[i] / F::from(size).unwrap();
111        }
112    }
113
114    // Compute Davies-Bouldin index
115    let mut db_index = F::zero();
116
117    for i in 0..n_clusters {
118        let mut max_ratio = F::zero();
119
120        for j in 0..n_clusters {
121            if i != j {
122                let between_distance = (&centers.row(i) - &centers.row(j))
123                    .mapv(|x| x * x)
124                    .sum()
125                    .sqrt();
126
127                if between_distance > F::zero() {
128                    let ratio = (scatter[i] + scatter[j]) / between_distance;
129                    if ratio > max_ratio {
130                        max_ratio = ratio;
131                    }
132                }
133            }
134        }
135
136        db_index = db_index + max_ratio;
137    }
138
139    db_index = db_index / F::from(n_clusters).unwrap();
140    Ok(db_index)
141}
142
143/// Calinski-Harabasz score for clustering evaluation.
144///
145/// Also known as the Variance Ratio Criterion, this score computes the ratio of
146/// the sum of between-clusters dispersion to the within-cluster dispersion.
147/// A higher score indicates better defined clusters.
148///
149/// # Arguments
150///
151/// * `data` - Input data (n_samples x n_features)
152/// * `labels` - Cluster labels for each sample
153///
154/// # Returns
155///
156/// The Calinski-Harabasz score (higher is better)
157///
158/// # Example
159///
160/// ```
161/// use scirs2_core::ndarray::{ArrayView1, Array1, Array2};
162/// use scirs2_cluster::metrics::calinski_harabasz_score;
163///
164/// let data = Array2::from_shape_vec((4, 2), vec![
165///     0.0, 0.0,
166///     0.1, 0.1,
167///     5.0, 5.0,
168///     5.1, 5.1,
169/// ]).unwrap();
170/// let labels = Array1::from_vec(vec![0, 0, 1, 1]);
171///
172/// let score = calinski_harabasz_score(data.view(), labels.view()).unwrap();
173/// assert!(score > 50.0);  // Should be high for well-separated clusters
174/// ```
175pub fn calinski_harabasz_score<F>(data: ArrayView2<F>, labels: ArrayView1<i32>) -> Result<F>
176where
177    F: Float + FromPrimitive + Debug + PartialOrd + 'static,
178{
179    if data.shape()[0] != labels.shape()[0] {
180        return Err(ClusteringError::InvalidInput(
181            "Data and labels must have the same number of samples".to_string(),
182        ));
183    }
184
185    let n_samples = data.shape()[0];
186    let n_features = data.shape()[1];
187
188    // Find unique cluster labels
189    let mut unique_labels = Vec::new();
190    for &label in labels.iter() {
191        if label >= 0 && !unique_labels.contains(&label) {
192            unique_labels.push(label);
193        }
194    }
195
196    let n_clusters = unique_labels.len();
197
198    if n_clusters < 2 {
199        return Err(ClusteringError::InvalidInput(
200            "Calinski-Harabasz score requires at least 2 clusters".to_string(),
201        ));
202    }
203
204    if n_clusters >= n_samples {
205        return Err(ClusteringError::InvalidInput(
206            "Number of clusters must be less than number of samples".to_string(),
207        ));
208    }
209
210    // Compute overall mean
211    let mut overall_mean = Array1::<F>::zeros(n_features);
212    let mut valid_samples = 0;
213
214    for (i, &label) in labels.iter().enumerate() {
215        if label >= 0 {
216            overall_mean.scaled_add(F::one(), &data.row(i));
217            valid_samples += 1;
218        }
219    }
220
221    overall_mean.mapv_inplace(|x| x / F::from(valid_samples).unwrap());
222
223    // Compute cluster centers and sizes
224    let mut centers = Array2::<F>::zeros((n_clusters, n_features));
225    let mut cluster_sizes = vec![0; n_clusters];
226
227    for (i, &label) in labels.iter().enumerate() {
228        if label >= 0 {
229            let cluster_idx = unique_labels.iter().position(|&l| l == label).unwrap();
230            centers
231                .row_mut(cluster_idx)
232                .scaled_add(F::one(), &data.row(i));
233            cluster_sizes[cluster_idx] += 1;
234        }
235    }
236
237    // Normalize to get averages
238    for (i, &size) in cluster_sizes.iter().enumerate() {
239        if size > 0 {
240            centers
241                .row_mut(i)
242                .mapv_inplace(|x| x / F::from(size).unwrap());
243        }
244    }
245
246    // Compute between-group sum of squares (SSB)
247    let mut ssb = F::zero();
248    for (i, &size) in cluster_sizes.iter().enumerate() {
249        if size > 0 {
250            let diff = &centers.row(i) - &overall_mean;
251            ssb = ssb + F::from(size).unwrap() * diff.dot(&diff);
252        }
253    }
254
255    // Compute within-group sum of squares (SSW)
256    let mut ssw = F::zero();
257    for (i, &label) in labels.iter().enumerate() {
258        if label >= 0 {
259            let cluster_idx = unique_labels.iter().position(|&l| l == label).unwrap();
260            let diff = &data.row(i) - &centers.row(cluster_idx);
261            ssw = ssw + diff.dot(&diff);
262        }
263    }
264
265    // Calculate score
266    if ssw == F::zero() {
267        return Ok(F::infinity());
268    }
269
270    let score = (ssb / ssw) * F::from(valid_samples - n_clusters).unwrap()
271        / F::from(n_clusters - 1).unwrap();
272
273    Ok(score)
274}
275
276/// Mean silhouette coefficient over all samples.
277///
278/// Convenience wrapper that computes the mean silhouette coefficient using
279/// Euclidean distance metric by default.
280///
281/// # Example
282///
283/// ```
284/// use scirs2_core::ndarray::{ArrayView1, Array1, Array2};
285/// use scirs2_cluster::metrics::mean_silhouette_score;
286///
287/// let data = Array2::from_shape_vec((4, 2), vec![
288///     0.0, 0.0,
289///     0.1, 0.1,
290///     5.0, 5.0,
291///     5.1, 5.1,
292/// ]).unwrap();
293/// let labels = Array1::from_vec(vec![0, 0, 1, 1]);
294///
295/// let score = mean_silhouette_score(data.view(), labels.view()).unwrap();
296/// ```
297pub fn mean_silhouette_score<F>(data: ArrayView2<F>, labels: ArrayView1<i32>) -> Result<F>
298where
299    F: Float + FromPrimitive + 'static,
300{
301    silhouette_score(data, labels)
302}
303
304/// Adjusted Rand Index for comparing two clusterings.
305///
306/// The Adjusted Rand Index (ARI) is a measure of the similarity between two data clusterings,
307/// adjusted for chance. It has a value between -1 and 1, where:
308/// - 1 indicates perfect agreement
309/// - 0 indicates agreement no better than random chance
310/// - Negative values indicate agreement worse than random chance
311///
312/// # Arguments
313///
314/// * `labels_true` - Ground truth cluster labels
315/// * `labels_pred` - Predicted cluster labels
316///
317/// # Returns
318///
319/// The Adjusted Rand Index score
320///
321/// # Example
322///
323/// ```
324/// use scirs2_core::ndarray::Array1;
325/// use scirs2_cluster::metrics::adjusted_rand_index;
326///
327/// let labels_true = Array1::from_vec(vec![0, 0, 1, 1, 2, 2]);
328/// let labels_pred = Array1::from_vec(vec![0, 0, 2, 2, 1, 1]);
329///
330/// let ari: f64 = adjusted_rand_index(labels_true.view(), labels_pred.view()).unwrap();
331/// assert!(ari > 0.0);  // Should be positive for similar clusterings
332/// ```
333pub fn adjusted_rand_index<F>(
334    labels_true: ArrayView1<i32>,
335    labels_pred: ArrayView1<i32>,
336) -> Result<F>
337where
338    F: Float + FromPrimitive + Debug + 'static,
339{
340    if labels_true.len() != labels_pred.len() {
341        return Err(ClusteringError::InvalidInput(
342            "Labels arrays must have the same length".to_string(),
343        ));
344    }
345
346    let n = labels_true.len();
347    if n == 0 {
348        return Err(ClusteringError::InvalidInput(
349            "Empty labels arrays".to_string(),
350        ));
351    }
352
353    // Build contingency table
354    let mut true_labels = std::collections::HashSet::new();
355    let mut pred_labels = std::collections::HashSet::new();
356
357    for &label in labels_true.iter() {
358        true_labels.insert(label);
359    }
360    for &label in labels_pred.iter() {
361        pred_labels.insert(label);
362    }
363
364    let n_true = true_labels.len();
365    let n_pred = pred_labels.len();
366
367    // Create mapping from labels to indices
368    let true_label_map: std::collections::HashMap<i32, usize> = true_labels
369        .iter()
370        .enumerate()
371        .map(|(i, &label)| (label, i))
372        .collect();
373    let pred_label_map: std::collections::HashMap<i32, usize> = pred_labels
374        .iter()
375        .enumerate()
376        .map(|(i, &label)| (label, i))
377        .collect();
378
379    // Build contingency table
380    let mut contingency = Array2::<usize>::zeros((n_true, n_pred));
381    for i in 0..n {
382        let true_idx = true_label_map[&labels_true[i]];
383        let pred_idx = pred_label_map[&labels_pred[i]];
384        contingency[[true_idx, pred_idx]] += 1;
385    }
386
387    // Calculate sums
388    let sum_comb_c = contingency
389        .iter()
390        .map(|&n_ij| {
391            if n_ij >= 2 {
392                (n_ij * (n_ij - 1)) / 2
393            } else {
394                0
395            }
396        })
397        .sum::<usize>();
398
399    let sum_a = contingency
400        .sum_axis(Axis(1))
401        .iter()
402        .map(|&n_i| if n_i >= 2 { (n_i * (n_i - 1)) / 2 } else { 0 })
403        .sum::<usize>();
404
405    let sum_b = contingency
406        .sum_axis(Axis(0))
407        .iter()
408        .map(|&n_j| if n_j >= 2 { (n_j * (n_j - 1)) / 2 } else { 0 })
409        .sum::<usize>();
410
411    let n_choose_2 = if n >= 2 { (n * (n - 1)) / 2 } else { 0 };
412
413    // Calculate expected index
414    let expected_index =
415        F::from(sum_a).unwrap() * F::from(sum_b).unwrap() / F::from(n_choose_2).unwrap();
416    let max_index = (F::from(sum_a).unwrap() + F::from(sum_b).unwrap()) / F::from(2.0).unwrap();
417    let index = F::from(sum_comb_c).unwrap();
418
419    // Handle edge cases
420    if max_index == expected_index {
421        return Ok(F::zero());
422    }
423
424    // Calculate ARI
425    let ari = (index - expected_index) / (max_index - expected_index);
426    Ok(ari)
427}
428
429/// Normalized Mutual Information (NMI) for comparing two clusterings.
430///
431/// Normalized Mutual Information is a normalization of the Mutual Information (MI) score
432/// to scale the results between 0 (no mutual information) and 1 (perfect correlation).
433///
434/// # Arguments
435///
436/// * `labels_true` - Ground truth cluster labels
437/// * `labels_pred` - Predicted cluster labels
438/// * `average_method` - Method to compute the normalizer ('geometric', 'arithmetic', 'min', 'max')
439///
440/// # Returns
441///
442/// The Normalized Mutual Information score
443///
444/// # Example
445///
446/// ```
447/// use scirs2_core::ndarray::Array1;
448/// use scirs2_cluster::metrics::normalized_mutual_info;
449///
450/// let labels_true = Array1::from_vec(vec![0, 0, 1, 1]);
451/// let labels_pred = Array1::from_vec(vec![0, 0, 1, 1]);
452///
453/// let nmi: f64 = normalized_mutual_info(labels_true.view(), labels_pred.view(), "arithmetic").unwrap();
454/// assert!((nmi - 1.0).abs() < 1e-6);  // Perfect agreement
455/// ```
456pub fn normalized_mutual_info<F>(
457    labels_true: ArrayView1<i32>,
458    labels_pred: ArrayView1<i32>,
459    average_method: &str,
460) -> Result<F>
461where
462    F: Float + FromPrimitive + Debug + 'static,
463{
464    if labels_true.len() != labels_pred.len() {
465        return Err(ClusteringError::InvalidInput(
466            "Labels arrays must have the same length".to_string(),
467        ));
468    }
469
470    let n = labels_true.len();
471    if n == 0 {
472        return Ok(F::one());
473    }
474
475    // Compute mutual information
476    let mi = mutual_info::<F>(labels_true, labels_pred)?;
477
478    // Compute entropies
479    let h_true = entropy::<F>(labels_true)?;
480    let h_pred = entropy::<F>(labels_pred)?;
481
482    // Handle edge cases
483    if h_true == F::zero() && h_pred == F::zero() {
484        return Ok(F::one());
485    }
486
487    // Compute normalization
488    let normalizer = match average_method {
489        "arithmetic" => (h_true + h_pred) / F::from(2.0).unwrap(),
490        "geometric" => (h_true * h_pred).sqrt(),
491        "min" => h_true.min(h_pred),
492        "max" => h_true.max(h_pred),
493        _ => {
494            return Err(ClusteringError::InvalidInput(
495                "Invalid average method. Use 'arithmetic', 'geometric', 'min', or 'max'"
496                    .to_string(),
497            ))
498        }
499    };
500
501    if normalizer == F::zero() {
502        return Ok(F::zero());
503    }
504
505    Ok(mi / normalizer)
506}
507
508/// Homogeneity, completeness and V-measure metrics for clustering evaluation.
509///
510/// These metrics are useful to evaluate the quality of clustering when ground truth is available.
511/// - Homogeneity: each cluster contains only members of a single class.
512/// - Completeness: all members of a given class are assigned to the same cluster.
513/// - V-measure: harmonic mean of homogeneity and completeness.
514///
515/// All scores are between 0 and 1, where 1 indicates perfect clustering.
516///
517/// # Arguments
518///
519/// * `labels_true` - Ground truth cluster labels
520/// * `labels_pred` - Predicted cluster labels
521///
522/// # Returns
523///
524/// Tuple of (homogeneity, completeness, v_measure)
525///
526/// # Example
527///
528/// ```
529/// use scirs2_core::ndarray::Array1;
530/// use scirs2_cluster::metrics::homogeneity_completeness_v_measure;
531///
532/// let labels_true = Array1::from_vec(vec![0, 0, 1, 1, 2, 2]);
533/// let labels_pred = Array1::from_vec(vec![0, 0, 1, 1, 1, 1]);
534///
535/// let (h, c, v): (f64, f64, f64) = homogeneity_completeness_v_measure(labels_true.view(), labels_pred.view()).unwrap();
536/// assert!(h > 0.5);  // Good homogeneity
537/// assert!(c > 0.9);  // High completeness (all members of each class in single clusters)
538/// ```
539pub fn homogeneity_completeness_v_measure<F>(
540    labels_true: ArrayView1<i32>,
541    labels_pred: ArrayView1<i32>,
542) -> Result<(F, F, F)>
543where
544    F: Float + FromPrimitive + Debug + 'static,
545{
546    if labels_true.len() != labels_pred.len() {
547        return Err(ClusteringError::InvalidInput(
548            "Labels arrays must have the same length".to_string(),
549        ));
550    }
551
552    let n = labels_true.len();
553    if n == 0 {
554        return Ok((F::one(), F::one(), F::one()));
555    }
556
557    // Compute entropies
558    let h_true = entropy::<F>(labels_true)?;
559    let h_pred = entropy::<F>(labels_pred)?;
560
561    // Edge cases
562    if h_true == F::zero() {
563        return Ok((F::one(), F::one(), F::one()));
564    }
565    if h_pred == F::zero() {
566        return Ok((F::one(), F::one(), F::one()));
567    }
568
569    // Compute conditional entropies
570    let h_true_given_pred = conditional_entropy::<F>(labels_true, labels_pred)?;
571    let h_pred_given_true = conditional_entropy::<F>(labels_pred, labels_true)?;
572
573    // Compute homogeneity
574    let homogeneity = if h_pred == F::zero() {
575        F::one()
576    } else {
577        F::one() - h_true_given_pred / h_true
578    };
579
580    // Compute completeness
581    let completeness = if h_true == F::zero() {
582        F::one()
583    } else {
584        F::one() - h_pred_given_true / h_pred
585    };
586
587    // Compute V-measure
588    let v_measure = if homogeneity + completeness == F::zero() {
589        F::zero()
590    } else {
591        F::from(2.0).unwrap() * homogeneity * completeness / (homogeneity + completeness)
592    };
593
594    Ok((homogeneity, completeness, v_measure))
595}
596
597// Helper functions
598
599/// Compute mutual information between two label assignments.
600fn mutual_info<F>(labels_true: ArrayView1<i32>, labels_pred: ArrayView1<i32>) -> Result<F>
601where
602    F: Float + FromPrimitive + Debug + 'static,
603{
604    let n = labels_true.len() as f64;
605    let contingency = build_contingency_matrix(labels_true, labels_pred)?;
606
607    let mut mi = F::zero();
608    let n_rows = contingency.shape()[0];
609    let n_cols = contingency.shape()[1];
610
611    // Compute marginal sums
612    let row_sums = contingency.sum_axis(Axis(1));
613    let col_sums = contingency.sum_axis(Axis(0));
614
615    for i in 0..n_rows {
616        for j in 0..n_cols {
617            let n_ij = contingency[[i, j]] as f64;
618            if n_ij > 0.0 {
619                let n_i = row_sums[i] as f64;
620                let n_j = col_sums[j] as f64;
621                let term = n_ij / n * (n_ij / (n_i * n_j / n)).ln();
622                mi = mi + F::from(term).unwrap();
623            }
624        }
625    }
626
627    Ok(mi)
628}
629
630/// Compute entropy of a label assignment.
631fn entropy<F>(labels: ArrayView1<i32>) -> Result<F>
632where
633    F: Float + FromPrimitive + Debug + 'static,
634{
635    let n = labels.len() as f64;
636    let mut label_counts = std::collections::HashMap::new();
637
638    for &label in labels.iter() {
639        *label_counts.entry(label).or_insert(0) += 1;
640    }
641
642    let mut h = F::zero();
643    for &count in label_counts.values() {
644        if count > 0 {
645            let p = count as f64 / n;
646            h = h - F::from(p * p.ln()).unwrap();
647        }
648    }
649
650    Ok(h)
651}
652
653/// Build contingency matrix for two label assignments.
654fn build_contingency_matrix(
655    labels_true: ArrayView1<i32>,
656    labels_pred: ArrayView1<i32>,
657) -> Result<Array2<usize>> {
658    let mut true_labels = std::collections::BTreeSet::new();
659    let mut pred_labels = std::collections::BTreeSet::new();
660
661    for &label in labels_true.iter() {
662        true_labels.insert(label);
663    }
664    for &label in labels_pred.iter() {
665        pred_labels.insert(label);
666    }
667
668    let true_label_map: std::collections::HashMap<i32, usize> = true_labels
669        .iter()
670        .enumerate()
671        .map(|(i, &label)| (label, i))
672        .collect();
673    let pred_label_map: std::collections::HashMap<i32, usize> = pred_labels
674        .iter()
675        .enumerate()
676        .map(|(i, &label)| (label, i))
677        .collect();
678
679    let mut contingency = Array2::<usize>::zeros((true_labels.len(), pred_labels.len()));
680    for i in 0..labels_true.len() {
681        let true_idx = true_label_map[&labels_true[i]];
682        let pred_idx = pred_label_map[&labels_pred[i]];
683        contingency[[true_idx, pred_idx]] += 1;
684    }
685
686    Ok(contingency)
687}
688
689/// Compute conditional entropy H(X|Y).
690fn conditional_entropy<F>(labels_x: ArrayView1<i32>, labels_y: ArrayView1<i32>) -> Result<F>
691where
692    F: Float + FromPrimitive + Debug + 'static,
693{
694    let n = labels_x.len() as f64;
695    let contingency = build_contingency_matrix(labels_x, labels_y)?;
696
697    let mut h_xy = F::zero();
698    let col_sums = contingency.sum_axis(Axis(0));
699
700    for j in 0..contingency.shape()[1] {
701        let n_j = col_sums[j] as f64;
702        if n_j > 0.0 {
703            for i in 0..contingency.shape()[0] {
704                let n_ij = contingency[[i, j]] as f64;
705                if n_ij > 0.0 {
706                    let term = n_ij / n * (n_ij / n_j).ln();
707                    h_xy = h_xy - F::from(term).unwrap();
708                }
709            }
710        }
711    }
712
713    Ok(h_xy)
714}
715
716#[cfg(test)]
717mod tests {
718    use super::*;
719    use scirs2_core::ndarray::Array2;
720
721    #[test]
722    fn test_davies_bouldin_score() {
723        let data =
724            Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 0.1, 0.1, 5.0, 5.0, 5.1, 5.1]).unwrap();
725        let labels = Array1::from_vec(vec![0, 0, 1, 1]);
726
727        let score = davies_bouldin_score(data.view(), labels.view()).unwrap();
728        assert!(score >= 0.0);
729    }
730
731    #[test]
732    fn test_calinski_harabasz_score() {
733        let data =
734            Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 0.1, 0.1, 5.0, 5.0, 5.1, 5.1]).unwrap();
735        let labels = Array1::from_vec(vec![0, 0, 1, 1]);
736
737        let score = calinski_harabasz_score(data.view(), labels.view()).unwrap();
738        assert!(score > 0.0);
739    }
740
741    #[test]
742    fn test_adjusted_rand_index() {
743        let labels_true = Array1::from_vec(vec![0, 0, 1, 1, 2, 2]);
744        let labels_pred = Array1::from_vec(vec![0, 0, 2, 2, 1, 1]);
745
746        let ari: f64 = adjusted_rand_index(labels_true.view(), labels_pred.view()).unwrap();
747        assert!(ari >= -1.0 && ari <= 1.0);
748    }
749
750    #[test]
751    fn test_normalized_mutual_info() {
752        let labels_true = Array1::from_vec(vec![0, 0, 1, 1]);
753        let labels_pred = Array1::from_vec(vec![0, 0, 1, 1]);
754
755        let nmi: f64 =
756            normalized_mutual_info(labels_true.view(), labels_pred.view(), "arithmetic").unwrap();
757        assert!((nmi - 1.0).abs() < 1e-6);
758    }
759
760    #[test]
761    fn test_homogeneity_completeness_v_measure() {
762        let labels_true = Array1::from_vec(vec![0, 0, 1, 1, 2, 2]);
763        let labels_pred = Array1::from_vec(vec![0, 0, 1, 1, 1, 1]);
764
765        let (h, c, v): (f64, f64, f64) =
766            homogeneity_completeness_v_measure(labels_true.view(), labels_pred.view()).unwrap();
767
768        assert!(h >= 0.0 && h <= 1.0);
769        assert!(c >= 0.0 && c <= 1.0);
770        assert!(v >= 0.0 && v <= 1.0);
771    }
772}