1use scirs2_core::ndarray::{ArrayView1, ArrayView2};
8use scirs2_core::numeric::{Float, FromPrimitive};
9use std::fmt::Debug;
10
11use crate::error::{ClusteringError, Result};
12
13#[allow(dead_code)]
36pub fn validate_linkage_matrix<
37 F: Float + FromPrimitive + Debug + PartialOrd + std::fmt::Display,
38>(
39 linkage_matrix: ArrayView2<F>,
40 n_observations: usize,
41) -> Result<()> {
42 let n_merges = linkage_matrix.shape()[0];
43 let n_cols = linkage_matrix.shape()[1];
44
45 if n_merges != n_observations - 1 {
47 return Err(ClusteringError::InvalidInput(format!(
48 "Linkage _matrix should have {} rows for {} observations, got {}",
49 n_observations - 1,
50 n_observations,
51 n_merges
52 )));
53 }
54
55 if n_cols != 4 {
56 return Err(ClusteringError::InvalidInput(format!(
57 "Linkage _matrix should have 4 columns, got {}",
58 n_cols
59 )));
60 }
61
62 for i in 0..n_merges {
64 let cluster1 = linkage_matrix[[i, 0]];
65 let cluster2 = linkage_matrix[[i, 1]];
66 let distance = linkage_matrix[[i, 2]];
67 let count = linkage_matrix[[i, 3]];
68
69 if !cluster1.is_finite()
71 || !cluster2.is_finite()
72 || !distance.is_finite()
73 || !count.is_finite()
74 {
75 return Err(ClusteringError::InvalidInput(format!(
76 "Non-finite values in linkage _matrix at row {}",
77 i
78 )));
79 }
80
81 let c1 = cluster1.to_usize().unwrap_or(usize::MAX);
83 let c2 = cluster2.to_usize().unwrap_or(usize::MAX);
84
85 let max_cluster_id = n_observations + i - 1;
87 if c1 >= n_observations + i || c2 >= n_observations + i {
88 return Err(ClusteringError::InvalidInput(format!(
89 "Invalid cluster indices at merge {}: {} and {} (max allowed: {})",
90 i, c1, c2, max_cluster_id
91 )));
92 }
93
94 if c1 == c2 {
96 return Err(ClusteringError::InvalidInput(format!(
97 "Self-merge detected at row {}: cluster {} merges with itself",
98 i, c1
99 )));
100 }
101
102 if distance < F::zero() {
104 return Err(ClusteringError::InvalidInput(format!(
105 "Negative merge distance at row {}: {}",
106 i, distance
107 )));
108 }
109
110 if count < F::from(2).unwrap() {
112 return Err(ClusteringError::InvalidInput(format!(
113 "Cluster count should be >= 2 at row {}, got {}",
114 i, count
115 )));
116 }
117 }
118
119 Ok(())
120}
121
122#[allow(dead_code)]
135pub fn validate_monotonic_distances<
136 F: Float + FromPrimitive + Debug + PartialOrd + std::fmt::Display,
137>(
138 linkage_matrix: ArrayView2<F>,
139 strict: bool,
140) -> Result<()> {
141 let n_merges = linkage_matrix.shape()[0];
142
143 for i in 1..n_merges {
144 let prev_distance = linkage_matrix[[i - 1, 2]];
145 let curr_distance = linkage_matrix[[i, 2]];
146
147 if strict {
148 if curr_distance <= prev_distance {
149 return Err(ClusteringError::InvalidInput(format!(
150 "Merge distances should be strictly increasing: {} <= {} at merge {}",
151 curr_distance, prev_distance, i
152 )));
153 }
154 } else if curr_distance < prev_distance - F::from(1e-10).unwrap() {
155 return Err(ClusteringError::InvalidInput(format!(
156 "Merge distances should be non-decreasing: {} < {} at merge {}",
157 curr_distance, prev_distance, i
158 )));
159 }
160 }
161
162 Ok(())
163}
164
165#[allow(dead_code)]
179pub fn validate_cluster_extraction_params<
180 F: Float + FromPrimitive + Debug + PartialOrd + std::fmt::Display,
181>(
182 linkage_matrix: ArrayView2<F>,
183 criterion: &str,
184 threshold: F,
185) -> Result<()> {
186 let n_observations = linkage_matrix.shape()[0] + 1;
188 validate_linkage_matrix(linkage_matrix, n_observations)?;
189
190 match criterion.to_lowercase().as_str() {
192 "distance" => {
193 if threshold < F::zero() {
194 return Err(ClusteringError::InvalidInput(
195 "Distance threshold must be non-negative".to_string(),
196 ));
197 }
198 }
199 "maxclust" => {
200 let max_clusters = threshold.to_usize().unwrap_or(0);
201 if max_clusters < 1 || max_clusters > n_observations {
202 return Err(ClusteringError::InvalidInput(format!(
203 "Number of clusters must be between 1 and {}, got {}",
204 n_observations, max_clusters
205 )));
206 }
207 }
208 "inconsistent" => {
209 if threshold < F::zero() {
210 return Err(ClusteringError::InvalidInput(
211 "Inconsistency threshold must be non-negative".to_string(),
212 ));
213 }
214 }
215 _ => {
216 return Err(ClusteringError::InvalidInput(format!(
217 "Unknown criterion '{}'. Valid options: 'distance', 'maxclust', 'inconsistent'",
218 criterion
219 )));
220 }
221 }
222
223 Ok(())
224}
225
226#[allow(dead_code)]
239pub fn validate_distance_matrix<
240 F: Float + FromPrimitive + Debug + PartialOrd + std::fmt::Display,
241>(
242 distance_matrix: ArrayView1<F>,
243 condensed: bool,
244) -> Result<()> {
245 let n_elements = distance_matrix.len();
246
247 if condensed {
248 let n_float = (1.0 + (1.0 + 8.0 * n_elements as f64).sqrt()) / 2.0;
251 let n = n_float as usize;
252
253 if n * (n - 1) / 2 != n_elements {
254 return Err(ClusteringError::InvalidInput(format!(
255 "Invalid condensed distance _matrix size: {} elements doesn't correspond to n*(n-1)/2 for any integer n",
256 n_elements
257 )));
258 }
259
260 if n < 2 {
261 return Err(ClusteringError::InvalidInput(
262 "Distance _matrix must represent at least 2 observations".to_string(),
263 ));
264 }
265 }
266
267 for (i, &distance) in distance_matrix.iter().enumerate() {
269 if !distance.is_finite() {
270 return Err(ClusteringError::InvalidInput(format!(
271 "Non-finite distance at index {}",
272 i
273 )));
274 }
275
276 if distance < F::zero() {
277 return Err(ClusteringError::InvalidInput(format!(
278 "Negative distance at index {}: {}",
279 i, distance
280 )));
281 }
282 }
283
284 Ok(())
285}
286
287#[allow(dead_code)]
301pub fn validate_square_distance_matrix<
302 F: Float + FromPrimitive + Debug + PartialOrd + std::fmt::Display,
303>(
304 distance_matrix: ArrayView2<F>,
305 check_symmetry: bool,
306 check_triangle_inequality: bool,
307) -> Result<()> {
308 let n = distance_matrix.shape()[0];
309 let m = distance_matrix.shape()[1];
310
311 if n != m {
313 return Err(ClusteringError::InvalidInput(format!(
314 "Distance _matrix must be square, got {}x{}",
315 n, m
316 )));
317 }
318
319 if n < 2 {
320 return Err(ClusteringError::InvalidInput(
321 "Distance _matrix must be at least 2x2".to_string(),
322 ));
323 }
324
325 for i in 0..n {
327 let diag_val = distance_matrix[[i, i]];
328 if !diag_val.is_finite() || diag_val.abs() > F::from(1e-10).unwrap() {
329 return Err(ClusteringError::InvalidInput(format!(
330 "Diagonal element at ({}, {}) should be zero, got {}",
331 i, i, diag_val
332 )));
333 }
334 }
335
336 for i in 0..n {
338 for j in 0..n {
339 let val = distance_matrix[[i, j]];
340 if !val.is_finite() {
341 return Err(ClusteringError::InvalidInput(format!(
342 "Non-finite distance at ({}, {})",
343 i, j
344 )));
345 }
346
347 if val < F::zero() {
348 return Err(ClusteringError::InvalidInput(format!(
349 "Negative distance at ({}, {}): {}",
350 i, j, val
351 )));
352 }
353 }
354 }
355
356 if check_symmetry {
358 for i in 0..n {
359 for j in (i + 1)..n {
360 let val_ij = distance_matrix[[i, j]];
361 let val_ji = distance_matrix[[j, i]];
362 let diff = (val_ij - val_ji).abs();
363
364 if diff > F::from(1e-10).unwrap() {
365 return Err(ClusteringError::InvalidInput(format!(
366 "Distance _matrix is not symmetric: d({}, {}) = {} != d({}, {}) = {}",
367 i, j, val_ij, j, i, val_ji
368 )));
369 }
370 }
371 }
372 }
373
374 if check_triangle_inequality {
376 for i in 0..n {
377 for j in 0..n {
378 for k in 0..n {
379 if i != j && j != k && i != k {
380 let d_ij = distance_matrix[[i, j]];
381 let d_jk = distance_matrix[[j, k]];
382 let d_ik = distance_matrix[[i, k]];
383
384 if d_ik > d_ij + d_jk + F::from(1e-10).unwrap() {
385 return Err(ClusteringError::InvalidInput(format!(
386 "Triangle _inequality violated: d({}, {}) = {} > d({}, {}) + d({}, {}) = {} + {}",
387 i, k, d_ik, i, j, j, k, d_ij, d_jk
388 )));
389 }
390 }
391 }
392 }
393 }
394 }
395
396 Ok(())
397}
398
399#[allow(dead_code)]
413pub fn validate_cluster_consistency<
414 F: Float + FromPrimitive + Debug + PartialOrd + std::fmt::Display,
415>(
416 linkage_matrix: ArrayView2<F>,
417 cluster_assignments: ArrayView1<usize>,
418) -> Result<()> {
419 let n_observations = linkage_matrix.shape()[0] + 1;
420
421 if cluster_assignments.len() != n_observations {
423 return Err(ClusteringError::InvalidInput(format!(
424 "Cluster _assignments length {} doesn't match number of observations {}",
425 cluster_assignments.len(),
426 n_observations
427 )));
428 }
429
430 validate_linkage_matrix(linkage_matrix, n_observations)?;
432
433 let max_cluster_id = cluster_assignments.iter().max().copied().unwrap_or(0);
435 let unique_clusters: std::collections::HashSet<_> =
436 cluster_assignments.iter().copied().collect();
437
438 for expected_id in 0..unique_clusters.len() {
440 if !unique_clusters.contains(&expected_id) {
441 return Err(ClusteringError::InvalidInput(format!(
442 "Cluster IDs should be contiguous starting from 0, missing ID {}",
443 expected_id
444 )));
445 }
446 }
447
448 if max_cluster_id >= n_observations {
449 return Err(ClusteringError::InvalidInput(format!(
450 "Maximum cluster ID {} should be less than number of observations {}",
451 max_cluster_id, n_observations
452 )));
453 }
454
455 Ok(())
456}
457
458#[cfg(test)]
459mod tests {
460 use super::*;
461 use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
462
463 #[test]
464 fn test_validate_linkage_matrix_valid() {
465 let linkage = Array2::from_shape_vec(
467 (3, 4),
468 vec![
469 0.0, 1.0, 0.5, 2.0, 2.0, 3.0, 0.8, 2.0, 4.0, 5.0, 1.2, 4.0, ],
473 )
474 .unwrap();
475
476 let result = validate_linkage_matrix(linkage.view(), 4);
477 assert!(
478 result.is_ok(),
479 "Valid linkage matrix should pass validation"
480 );
481 }
482
483 #[test]
484 fn test_validate_linkage_matrix_wrong_dimensions() {
485 let linkage =
487 Array2::from_shape_vec((2, 4), vec![0.0, 1.0, 0.5, 2.0, 2.0, 3.0, 0.8, 2.0]).unwrap();
488
489 let result = validate_linkage_matrix(linkage.view(), 4);
490 assert!(result.is_err(), "Wrong dimensions should fail validation");
491 }
492
493 #[test]
494 fn test_validate_linkage_matrix_negative_distance() {
495 let linkage = Array2::from_shape_vec(
496 (3, 4),
497 vec![
498 0.0, 1.0, -0.5, 2.0, 2.0, 3.0, 0.8, 2.0, 4.0, 5.0, 1.2, 4.0,
500 ],
501 )
502 .unwrap();
503
504 let result = validate_linkage_matrix(linkage.view(), 4);
505 assert!(result.is_err(), "Negative distance should fail validation");
506 }
507
508 #[test]
509 fn test_validate_linkage_matrix_self_merge() {
510 let linkage = Array2::from_shape_vec(
511 (3, 4),
512 vec![
513 0.0, 0.0, 0.5, 2.0, 2.0, 3.0, 0.8, 2.0, 4.0, 5.0, 1.2, 4.0,
515 ],
516 )
517 .unwrap();
518
519 let result = validate_linkage_matrix(linkage.view(), 4);
520 assert!(result.is_err(), "Self-merge should fail validation");
521 }
522
523 #[test]
524 fn test_validate_monotonic_distances_valid() {
525 let linkage = Array2::from_shape_vec(
526 (3, 4),
527 vec![0.0, 1.0, 0.5, 2.0, 2.0, 3.0, 0.8, 2.0, 4.0, 5.0, 1.2, 4.0],
528 )
529 .unwrap();
530
531 let result = validate_monotonic_distances(linkage.view(), false);
532 assert!(result.is_ok(), "Monotonic distances should pass validation");
533 }
534
535 #[test]
536 fn test_validate_monotonic_distances_invalid() {
537 let linkage = Array2::from_shape_vec(
538 (3, 4),
539 vec![
540 0.0, 1.0, 1.2, 2.0, 2.0, 3.0, 0.8, 2.0, 4.0, 5.0, 1.5, 4.0,
543 ],
544 )
545 .unwrap();
546
547 let result = validate_monotonic_distances(linkage.view(), false);
548 assert!(
549 result.is_err(),
550 "Non-monotonic distances should fail validation"
551 );
552 }
553
554 #[test]
555 fn test_validate_condensed_distance_matrix() {
556 let distances = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
558
559 let result = validate_distance_matrix(distances.view(), true);
560 assert!(
561 result.is_ok(),
562 "Valid condensed distance matrix should pass"
563 );
564 }
565
566 #[test]
567 fn test_validate_condensed_distance_matrix_invalid_size() {
568 let distances = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
570
571 let result = validate_distance_matrix(distances.view(), true);
572 assert!(result.is_err(), "Invalid condensed matrix size should fail");
573 }
574
575 #[test]
576 fn test_validate_cluster_extraction_params() {
577 let linkage = Array2::from_shape_vec(
578 (3, 4),
579 vec![0.0, 1.0, 0.5, 2.0, 2.0, 3.0, 0.8, 2.0, 4.0, 5.0, 1.2, 4.0],
580 )
581 .unwrap();
582
583 assert!(validate_cluster_extraction_params(linkage.view(), "distance", 1.0).is_ok());
585 assert!(validate_cluster_extraction_params(linkage.view(), "maxclust", 3.0).is_ok());
586 assert!(validate_cluster_extraction_params(linkage.view(), "inconsistent", 0.5).is_ok());
587
588 assert!(validate_cluster_extraction_params(linkage.view(), "distance", -1.0).is_err());
590 assert!(validate_cluster_extraction_params(linkage.view(), "maxclust", 0.0).is_err());
591 assert!(validate_cluster_extraction_params(linkage.view(), "invalid", 1.0).is_err());
592 }
593}