1use scirs2_core::ndarray::{Array1, Array2, ArrayView2};
34use scirs2_core::numeric::{Float, FromPrimitive};
35use std::fmt::Debug;
36
37use crate::error::{ClusteringError, Result};
38
39pub mod agglomerative;
41pub mod cluster_extraction;
42pub mod condensed_matrix;
43pub mod dendrogram;
44pub mod disjoint_set;
45pub mod leaf_ordering;
46pub mod linkage;
47pub mod optimized_ward;
48pub mod parallel_linkage;
49pub mod validation;
50pub mod visualization;
51
52pub use self::agglomerative::{cut_tree_by_distance, cut_tree_by_inconsistency};
54pub use self::cluster_extraction::{
55 estimate_optimal_clusters, extract_clusters_multi_criteria, prune_clusters,
56};
57pub use self::condensed_matrix::{
58 condensed_size, condensed_to_square, get_distance, points_from_condensed_size,
59 square_to_condensed, validate_condensed_matrix,
60};
61pub use self::dendrogram::{cophenet, dendrogram, inconsistent, optimal_leaf_ordering};
62pub use self::disjoint_set::DisjointSet;
63pub use self::leaf_ordering::{
64 apply_leaf_ordering, optimal_leaf_ordering_exact, optimal_leaf_ordering_heuristic,
65};
66pub use self::optimized_ward::{
67 lance_williams_ward_update, memory_efficient_ward_linkage, optimized_ward_linkage,
68};
69pub use self::validation::{
70 validate_cluster_consistency, validate_cluster_extraction_params, validate_distance_matrix,
71 validate_linkage_matrix, validate_monotonic_distances, validate_square_distance_matrix,
72};
73pub use self::visualization::{
74 create_dendrogramplot, get_color_palette, Branch, ColorScheme, ColorThreshold,
75 DendrogramConfig, DendrogramOrientation, DendrogramPlot, Leaf, LegendEntry, TruncateMode,
76};
77
78#[derive(Debug, Clone, Copy, PartialEq, Eq)]
80pub enum LinkageMethod {
81 Single,
83
84 Complete,
86
87 Average,
89
90 Ward,
92
93 Centroid,
95
96 Median,
98
99 Weighted,
101}
102
103#[derive(Debug, Clone, Copy, PartialEq, Eq)]
105pub enum Metric {
106 Euclidean,
108
109 Manhattan,
111
112 Chebyshev,
114
115 Cosine,
117
118 Correlation,
120}
121
122#[derive(Debug, Clone, Copy, PartialEq, Eq)]
124pub enum ClusterCriterion {
125 MaxClust,
127
128 Distance,
130
131 Inconsistent,
133}
134
135#[allow(dead_code)]
137fn compute_distances<F: Float + FromPrimitive>(data: ArrayView2<F>, metric: Metric) -> Array1<F> {
138 let n_samples = data.shape()[0];
139 let n_features = data.shape()[1];
140
141 let num_distances = n_samples * (n_samples - 1) / 2;
143 let mut distances = Array1::zeros(num_distances);
144
145 let mut idx = 0;
146 for i in 0..n_samples {
147 for j in (i + 1)..n_samples {
148 let dist = match metric {
149 Metric::Euclidean => {
150 let mut sum = F::zero();
152 for k in 0..n_features {
153 let diff = data[[i, k]] - data[[j, k]];
154 sum = sum + diff * diff;
155 }
156 sum.sqrt()
157 }
158 Metric::Manhattan => {
159 let mut sum = F::zero();
161 for k in 0..n_features {
162 let diff = (data[[i, k]] - data[[j, k]]).abs();
163 sum = sum + diff;
164 }
165 sum
166 }
167 Metric::Chebyshev => {
168 let mut max_diff = F::zero();
170 for k in 0..n_features {
171 let diff = (data[[i, k]] - data[[j, k]]).abs();
172 if diff > max_diff {
173 max_diff = diff;
174 }
175 }
176 max_diff
177 }
178 Metric::Cosine => {
179 let mut dot_product = F::zero();
182 let mut norm_i = F::zero();
183 let mut norm_j = F::zero();
184
185 for k in 0..n_features {
186 let val_i = data[[i, k]];
187 let val_j = data[[j, k]];
188
189 dot_product = dot_product + val_i * val_j;
190 norm_i = norm_i + val_i * val_i;
191 norm_j = norm_j + val_j * val_j;
192 }
193
194 let norm_product = (norm_i * norm_j).sqrt();
195
196 if norm_product < F::from_f64(1e-10).unwrap() {
197 F::one()
199 } else {
200 F::one() - (dot_product / norm_product)
201 }
202 }
203 Metric::Correlation => {
204 let mut mean_i = F::zero();
209 let mut mean_j = F::zero();
210
211 for k in 0..n_features {
212 mean_i = mean_i + data[[i, k]];
213 mean_j = mean_j + data[[j, k]];
214 }
215
216 mean_i = mean_i / F::from_usize(n_features).unwrap();
217 mean_j = mean_j / F::from_usize(n_features).unwrap();
218
219 let mut numerator = F::zero();
221 let mut denom_i = F::zero();
222 let mut denom_j = F::zero();
223
224 for k in 0..n_features {
225 let diff_i = data[[i, k]] - mean_i;
226 let diff_j = data[[j, k]] - mean_j;
227
228 numerator = numerator + diff_i * diff_j;
229 denom_i = denom_i + diff_i * diff_i;
230 denom_j = denom_j + diff_j * diff_j;
231 }
232
233 let denom = (denom_i * denom_j).sqrt();
234
235 if denom < F::from_f64(1e-10).unwrap() {
236 F::zero()
238 } else {
239 F::one() - (numerator / denom)
240 }
241 }
242 };
243
244 distances[idx] = dist;
245 idx += 1;
246 }
247 }
248
249 distances
250}
251
252#[allow(dead_code)]
254pub fn condensed_index_to_coords(n: usize, idx: usize) -> (usize, usize) {
255 let mut i = 0;
257 let mut j = 0;
258 let mut k = 0;
259
260 for i_temp in 0..n {
261 for j_temp in (i_temp + 1)..n {
262 if k == idx {
263 i = i_temp;
264 j = j_temp;
265 break;
266 }
267 k += 1;
268 }
269
270 if k == idx {
271 break;
272 }
273 }
274
275 (i, j)
276}
277
278#[allow(dead_code)]
280pub fn coords_to_condensed_index(n: usize, i: usize, j: usize) -> Result<usize> {
281 if i == j {
282 return Err(ClusteringError::InvalidInput(
283 "Cannot compute diagonal index in condensed matrix".into(),
284 ));
285 }
286
287 if i >= n || j >= n {
288 return Err(ClusteringError::InvalidInput(format!(
289 "Indices ({}, {}) out of bounds for matrix size {}",
290 i, j, n
291 )));
292 }
293
294 let (i_min, j_min) = if i < j { (i, j) } else { (j, i) };
295 Ok((n * i_min) - ((i_min * (i_min + 1)) / 2) + (j_min - i_min - 1))
296}
297
298#[allow(dead_code)]
310pub fn linkage<
311 F: Float
312 + FromPrimitive
313 + Debug
314 + PartialOrd
315 + Send
316 + Sync
317 + scirs2_core::ndarray::ScalarOperand
318 + 'static,
319>(
320 data: ArrayView2<F>,
321 method: LinkageMethod,
322 metric: Metric,
323) -> Result<Array2<F>> {
324 let n_samples = data.shape()[0];
325
326 if n_samples < 2 {
327 return Err(ClusteringError::InvalidInput(
328 "Need at least 2 samples for hierarchical clustering".into(),
329 ));
330 }
331
332 if n_samples > 10000 {
333 eprintln!("Warning: Performing hierarchical clustering on {n_samples} samples. This may be slow and memory-intensive.");
336 }
337
338 if method == LinkageMethod::Ward {
340 return optimized_ward::optimized_ward_linkage(data, metric);
341 }
342
343 let distances = compute_distances(data, metric);
345
346 linkage::hierarchical_clustering(&distances, n_samples, method)
348}
349
350#[allow(dead_code)]
387pub fn parallel_linkage<
388 F: Float
389 + FromPrimitive
390 + Debug
391 + PartialOrd
392 + Send
393 + Sync
394 + std::iter::Sum
395 + scirs2_core::ndarray::ScalarOperand
396 + 'static,
397>(
398 data: ArrayView2<F>,
399 method: LinkageMethod,
400 metric: Metric,
401) -> Result<Array2<F>> {
402 let n_samples = data.shape()[0];
403
404 if n_samples < 2 {
405 return Err(ClusteringError::InvalidInput(
406 "Need at least 2 samples for hierarchical clustering".into(),
407 ));
408 }
409
410 if n_samples > 10000 {
411 eprintln!("Warning: Performing parallel hierarchical clustering on {n_samples} samples. This may still be slow for very large datasets.");
414 }
415
416 if method == LinkageMethod::Ward {
418 return optimized_ward::optimized_ward_linkage(data, metric);
419 }
420
421 let distances = compute_distances(data, metric);
423
424 parallel_linkage::parallel_hierarchical_clustering(&distances, n_samples, method)
426}
427
428#[allow(dead_code)]
445pub fn fcluster<F: Float + FromPrimitive + PartialOrd + Debug>(
446 z: &Array2<F>,
447 t: usize,
448 criterion: Option<ClusterCriterion>,
449) -> Result<Array1<usize>> {
450 let n_samples = z.shape()[0] + 1;
451 let crit = criterion.unwrap_or(ClusterCriterion::MaxClust);
452
453 match crit {
454 ClusterCriterion::MaxClust => {
455 if t == 0 || t > n_samples {
457 return Err(ClusteringError::InvalidInput(format!(
458 "Number of clusters must be between 1 and {}",
459 n_samples
460 )));
461 }
462
463 agglomerative::cut_tree(z, t)
464 }
465 ClusterCriterion::Distance => {
466 let t_float = F::from_usize(t).unwrap();
468 agglomerative::cut_tree_by_distance(z, t_float)
469 }
470 ClusterCriterion::Inconsistent => {
471 let t_float = F::from_usize(t).unwrap();
473
474 let inconsistency_matrix = dendrogram::inconsistent(z, None)?;
476
477 agglomerative::cut_tree_by_inconsistency(z, t_float, &inconsistency_matrix)
479 }
480 }
481}
482
483#[allow(dead_code)]
518pub fn fcluster_generic<F: Float + FromPrimitive + PartialOrd + Debug>(
519 z: &Array2<F>,
520 t: F,
521 criterion: ClusterCriterion,
522) -> Result<Array1<usize>> {
523 let n_samples = z.shape()[0] + 1;
524
525 match criterion {
526 ClusterCriterion::MaxClust => {
527 let n_clusters = t.to_usize().ok_or_else(|| {
529 ClusteringError::InvalidInput("Invalid number of clusters".into())
530 })?;
531
532 if n_clusters == 0 || n_clusters > n_samples {
533 return Err(ClusteringError::InvalidInput(format!(
534 "Number of clusters must be between 1 and {}",
535 n_samples
536 )));
537 }
538
539 agglomerative::cut_tree(z, n_clusters)
540 }
541 ClusterCriterion::Distance => {
542 agglomerative::cut_tree_by_distance(z, t)
544 }
545 ClusterCriterion::Inconsistent => {
546 let inconsistency_matrix = dendrogram::inconsistent(z, None)?;
549
550 agglomerative::cut_tree_by_inconsistency(z, t, &inconsistency_matrix)
552 }
553 }
554}
555
556#[cfg(test)]
557mod tests {
558 use super::*;
559 use approx::assert_abs_diff_eq;
560
561 #[test]
562 fn test_linkage_simple() {
563 let data = Array2::from_shape_vec(
565 (6, 2),
566 vec![1.0, 2.0, 1.2, 1.8, 0.8, 1.9, 3.7, 4.2, 3.9, 3.9, 4.2, 4.1],
567 )
568 .unwrap();
569
570 let linkage_matrix = linkage(data.view(), LinkageMethod::Ward, Metric::Euclidean).unwrap();
572
573 assert_eq!(linkage_matrix.shape(), &[5, 4]);
575
576 assert!(linkage_matrix[[0, 2]] > 0.0); assert_eq!(linkage_matrix[[0, 3]] as usize, 2); }
581
582 #[test]
583 fn test_fcluster() {
584 let data = Array2::from_shape_vec(
586 (6, 2),
587 vec![1.0, 2.0, 1.2, 1.8, 0.8, 1.9, 3.7, 4.2, 3.9, 3.9, 4.2, 4.1],
588 )
589 .unwrap();
590
591 let linkage_matrix = linkage(data.view(), LinkageMethod::Ward, Metric::Euclidean).unwrap();
592
593 let labels = fcluster(&linkage_matrix, 2, None).unwrap();
595
596 assert_eq!(labels.len(), 6);
598
599 assert_eq!(labels[0], labels[1]);
601 assert_eq!(labels[1], labels[2]);
602 assert_eq!(labels[3], labels[4]);
603 assert_eq!(labels[4], labels[5]);
604 assert_ne!(labels[0], labels[3]);
605 }
606
607 #[test]
608 fn test_distance_metrics() {
609 let data =
611 Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0]).unwrap();
612
613 let euclidean_distances = compute_distances(data.view(), Metric::Euclidean);
615 let manhattan_distances = compute_distances(data.view(), Metric::Manhattan);
616 let chebyshev_distances = compute_distances(data.view(), Metric::Chebyshev);
617
618 assert_eq!(euclidean_distances.len(), 6); assert_abs_diff_eq!(euclidean_distances[0], 1.0, epsilon = 1e-10);
624
625 assert_abs_diff_eq!(
627 euclidean_distances[2],
628 std::f64::consts::SQRT_2,
629 epsilon = 1e-10
630 );
631
632 assert_abs_diff_eq!(manhattan_distances[0], 1.0, epsilon = 1e-10);
634
635 assert_abs_diff_eq!(manhattan_distances[2], 2.0, epsilon = 1e-10);
637
638 assert_abs_diff_eq!(chebyshev_distances[2], 1.0, epsilon = 1e-10);
640 }
641
642 #[test]
643 fn test_hierarchy_with_different_linkage_methods() {
644 let data = Array2::from_shape_vec(
646 (6, 2),
647 vec![1.0, 2.0, 1.2, 1.8, 0.8, 1.9, 3.7, 4.2, 3.9, 3.9, 4.2, 4.1],
648 )
649 .unwrap();
650
651 let methods = vec![
653 LinkageMethod::Single,
654 LinkageMethod::Complete,
655 LinkageMethod::Average,
656 LinkageMethod::Ward,
657 ];
658
659 for method in methods {
660 let linkage_matrix = linkage(data.view(), method, Metric::Euclidean).unwrap();
661
662 assert_eq!(linkage_matrix.shape(), &[5, 4]);
664
665 let labels = fcluster(&linkage_matrix, 2, None).unwrap();
667
668 assert_eq!(labels.len(), 6);
670 }
671 }
672
673 #[test]
674 fn test_fcluster_inconsistent_criterion() {
675 let data = Array2::from_shape_vec(
677 (6, 2),
678 vec![1.0, 2.0, 1.2, 1.8, 0.8, 1.9, 3.7, 4.2, 3.9, 3.9, 4.2, 4.1],
679 )
680 .unwrap();
681
682 let linkage_matrix = linkage(data.view(), LinkageMethod::Ward, Metric::Euclidean).unwrap();
683
684 let labels =
686 fcluster_generic(&linkage_matrix, 1.0, ClusterCriterion::Inconsistent).unwrap();
687
688 assert_eq!(labels.len(), 6);
690
691 assert!(labels.iter().all(|&l| l < 6));
693 }
694
695 #[test]
696 fn test_fcluster_generic_all_criteria() {
697 let data = Array2::from_shape_vec(
698 (6, 2),
699 vec![1.0, 2.0, 1.2, 1.8, 0.8, 1.9, 3.7, 4.2, 3.9, 3.9, 4.2, 4.1],
700 )
701 .unwrap();
702
703 let linkage_matrix = linkage(data.view(), LinkageMethod::Ward, Metric::Euclidean).unwrap();
704
705 let labels_maxclust =
707 fcluster_generic(&linkage_matrix, 2.0, ClusterCriterion::MaxClust).unwrap();
708 assert_eq!(labels_maxclust.len(), 6);
709 let unique_maxclust: std::collections::HashSet<_> =
710 labels_maxclust.iter().cloned().collect();
711 assert_eq!(unique_maxclust.len(), 2);
712
713 let labels_distance =
715 fcluster_generic(&linkage_matrix, 2.5, ClusterCriterion::Distance).unwrap();
716 assert_eq!(labels_distance.len(), 6);
717
718 let labels_inconsistent =
720 fcluster_generic(&linkage_matrix, 0.5, ClusterCriterion::Inconsistent).unwrap();
721 assert_eq!(labels_inconsistent.len(), 6);
722 }
723}