scirs2_cluster/hierarchy/
validation.rs

1//! Validation utilities for hierarchical clustering
2//!
3//! This module provides functions to validate linkage matrices and other
4//! hierarchical clustering data structures to ensure they meet mathematical
5//! requirements and are suitable for downstream analysis.
6
7use scirs2_core::ndarray::{ArrayView1, ArrayView2};
8use scirs2_core::numeric::{Float, FromPrimitive};
9use std::fmt::Debug;
10
11use crate::error::{ClusteringError, Result};
12
13/// Validates a linkage matrix for correctness
14///
15/// This function performs comprehensive validation of a linkage matrix
16/// to ensure it meets the mathematical requirements for hierarchical clustering.
17///
18/// # Arguments
19///
20/// * `linkage_matrix` - The linkage matrix to validate (n-1 × 4)
21/// * `n_observations` - Expected number of original observations
22///
23/// # Returns
24///
25/// * `Result<()>` - Ok if valid, error with detailed message if invalid
26///
27/// # Validation Checks
28///
29/// 1. Matrix dimensions are correct (n-1 rows, 4 columns)
30/// 2. Cluster indices are valid and in proper range
31/// 3. Merge distances are non-negative and monotonic (for single/complete linkage)
32/// 4. Cluster sizes are consistent and >= 2
33/// 5. No self-merges (cluster merging with itself)
34/// 6. All values are finite
35#[allow(dead_code)]
36pub fn validate_linkage_matrix<
37    F: Float + FromPrimitive + Debug + PartialOrd + std::fmt::Display,
38>(
39    linkage_matrix: ArrayView2<F>,
40    n_observations: usize,
41) -> Result<()> {
42    let n_merges = linkage_matrix.shape()[0];
43    let n_cols = linkage_matrix.shape()[1];
44
45    // Check dimensions
46    if n_merges != n_observations - 1 {
47        return Err(ClusteringError::InvalidInput(format!(
48            "Linkage _matrix should have {} rows for {} observations, got {}",
49            n_observations - 1,
50            n_observations,
51            n_merges
52        )));
53    }
54
55    if n_cols != 4 {
56        return Err(ClusteringError::InvalidInput(format!(
57            "Linkage _matrix should have 4 columns, got {}",
58            n_cols
59        )));
60    }
61
62    // Validate each merge
63    for i in 0..n_merges {
64        let cluster1 = linkage_matrix[[i, 0]];
65        let cluster2 = linkage_matrix[[i, 1]];
66        let distance = linkage_matrix[[i, 2]];
67        let count = linkage_matrix[[i, 3]];
68
69        // Check that all values are finite
70        if !cluster1.is_finite()
71            || !cluster2.is_finite()
72            || !distance.is_finite()
73            || !count.is_finite()
74        {
75            return Err(ClusteringError::InvalidInput(format!(
76                "Non-finite values in linkage _matrix at row {}",
77                i
78            )));
79        }
80
81        // Convert to usize for index checking
82        let c1 = cluster1.to_usize().unwrap_or(usize::MAX);
83        let c2 = cluster2.to_usize().unwrap_or(usize::MAX);
84
85        // Check cluster indices are valid
86        let max_cluster_id = n_observations + i - 1;
87        if c1 >= n_observations + i || c2 >= n_observations + i {
88            return Err(ClusteringError::InvalidInput(format!(
89                "Invalid cluster indices at merge {}: {} and {} (max allowed: {})",
90                i, c1, c2, max_cluster_id
91            )));
92        }
93
94        // Check no self-merge
95        if c1 == c2 {
96            return Err(ClusteringError::InvalidInput(format!(
97                "Self-merge detected at row {}: cluster {} merges with itself",
98                i, c1
99            )));
100        }
101
102        // Check distance is non-negative
103        if distance < F::zero() {
104            return Err(ClusteringError::InvalidInput(format!(
105                "Negative merge distance at row {}: {}",
106                i, distance
107            )));
108        }
109
110        // Check cluster count is at least 2 (since it's a merge)
111        if count < F::from(2).unwrap() {
112            return Err(ClusteringError::InvalidInput(format!(
113                "Cluster count should be >= 2 at row {}, got {}",
114                i, count
115            )));
116        }
117    }
118
119    Ok(())
120}
121
122/// Validates that merge distances are monotonic (for certain linkage methods)
123///
124/// For single and complete linkage, merge distances should be non-decreasing.
125///
126/// # Arguments
127///
128/// * `linkage_matrix` - The linkage matrix to check
129/// * `strict` - If true, requires strictly increasing distances
130///
131/// # Returns
132///
133/// * `Result<()>` - Ok if monotonic, error otherwise
134#[allow(dead_code)]
135pub fn validate_monotonic_distances<
136    F: Float + FromPrimitive + Debug + PartialOrd + std::fmt::Display,
137>(
138    linkage_matrix: ArrayView2<F>,
139    strict: bool,
140) -> Result<()> {
141    let n_merges = linkage_matrix.shape()[0];
142
143    for i in 1..n_merges {
144        let prev_distance = linkage_matrix[[i - 1, 2]];
145        let curr_distance = linkage_matrix[[i, 2]];
146
147        if strict {
148            if curr_distance <= prev_distance {
149                return Err(ClusteringError::InvalidInput(format!(
150                    "Merge distances should be strictly increasing: {} <= {} at merge {}",
151                    curr_distance, prev_distance, i
152                )));
153            }
154        } else if curr_distance < prev_distance - F::from(1e-10).unwrap() {
155            return Err(ClusteringError::InvalidInput(format!(
156                "Merge distances should be non-decreasing: {} < {} at merge {}",
157                curr_distance, prev_distance, i
158            )));
159        }
160    }
161
162    Ok(())
163}
164
165/// Validates cluster extraction parameters
166///
167/// Ensures that parameters for flat cluster extraction are valid.
168///
169/// # Arguments
170///
171/// * `linkage_matrix` - The linkage matrix
172/// * `criterion` - Criterion type ("distance", "maxclust", "inconsistent")
173/// * `threshold` - Threshold value for the criterion
174///
175/// # Returns
176///
177/// * `Result<()>` - Ok if valid, error otherwise
178#[allow(dead_code)]
179pub fn validate_cluster_extraction_params<
180    F: Float + FromPrimitive + Debug + PartialOrd + std::fmt::Display,
181>(
182    linkage_matrix: ArrayView2<F>,
183    criterion: &str,
184    threshold: F,
185) -> Result<()> {
186    // First validate the linkage _matrix itself
187    let n_observations = linkage_matrix.shape()[0] + 1;
188    validate_linkage_matrix(linkage_matrix, n_observations)?;
189
190    // Check criterion is valid
191    match criterion.to_lowercase().as_str() {
192        "distance" => {
193            if threshold < F::zero() {
194                return Err(ClusteringError::InvalidInput(
195                    "Distance threshold must be non-negative".to_string(),
196                ));
197            }
198        }
199        "maxclust" => {
200            let max_clusters = threshold.to_usize().unwrap_or(0);
201            if max_clusters < 1 || max_clusters > n_observations {
202                return Err(ClusteringError::InvalidInput(format!(
203                    "Number of clusters must be between 1 and {}, got {}",
204                    n_observations, max_clusters
205                )));
206            }
207        }
208        "inconsistent" => {
209            if threshold < F::zero() {
210                return Err(ClusteringError::InvalidInput(
211                    "Inconsistency threshold must be non-negative".to_string(),
212                ));
213            }
214        }
215        _ => {
216            return Err(ClusteringError::InvalidInput(format!(
217                "Unknown criterion '{}'. Valid options: 'distance', 'maxclust', 'inconsistent'",
218                criterion
219            )));
220        }
221    }
222
223    Ok(())
224}
225
226/// Validates that a distance matrix is suitable for clustering
227///
228/// Checks properties required for distance matrices used in hierarchical clustering.
229///
230/// # Arguments
231///
232/// * `distance_matrix` - Distance matrix (condensed or square form)
233/// * `condensed` - Whether the matrix is in condensed form
234///
235/// # Returns
236///
237/// * `Result<()>` - Ok if valid, error otherwise
238#[allow(dead_code)]
239pub fn validate_distance_matrix<
240    F: Float + FromPrimitive + Debug + PartialOrd + std::fmt::Display,
241>(
242    distance_matrix: ArrayView1<F>,
243    condensed: bool,
244) -> Result<()> {
245    let n_elements = distance_matrix.len();
246
247    if condensed {
248        // For condensed form, we should have n*(n-1)/2 elements
249        // Solve n*(n-1)/2 = n_elements for n
250        let n_float = (1.0 + (1.0 + 8.0 * n_elements as f64).sqrt()) / 2.0;
251        let n = n_float as usize;
252
253        if n * (n - 1) / 2 != n_elements {
254            return Err(ClusteringError::InvalidInput(format!(
255                "Invalid condensed distance _matrix size: {} elements doesn't correspond to n*(n-1)/2 for any integer n",
256                n_elements
257            )));
258        }
259
260        if n < 2 {
261            return Err(ClusteringError::InvalidInput(
262                "Distance _matrix must represent at least 2 observations".to_string(),
263            ));
264        }
265    }
266
267    // Check all distances are non-negative and finite
268    for (i, &distance) in distance_matrix.iter().enumerate() {
269        if !distance.is_finite() {
270            return Err(ClusteringError::InvalidInput(format!(
271                "Non-finite distance at index {}",
272                i
273            )));
274        }
275
276        if distance < F::zero() {
277            return Err(ClusteringError::InvalidInput(format!(
278                "Negative distance at index {}: {}",
279                i, distance
280            )));
281        }
282    }
283
284    Ok(())
285}
286
287/// Validates that a square distance matrix has required properties
288///
289/// Checks symmetry, zero diagonal, and metric properties.
290///
291/// # Arguments
292///
293/// * `distance_matrix` - Square distance matrix
294/// * `check_symmetry` - Whether to check matrix symmetry
295/// * `check_triangle_inequality` - Whether to check triangle inequality
296///
297/// # Returns
298///
299/// * `Result<()>` - Ok if valid, error otherwise
300#[allow(dead_code)]
301pub fn validate_square_distance_matrix<
302    F: Float + FromPrimitive + Debug + PartialOrd + std::fmt::Display,
303>(
304    distance_matrix: ArrayView2<F>,
305    check_symmetry: bool,
306    check_triangle_inequality: bool,
307) -> Result<()> {
308    let n = distance_matrix.shape()[0];
309    let m = distance_matrix.shape()[1];
310
311    // Check square _matrix
312    if n != m {
313        return Err(ClusteringError::InvalidInput(format!(
314            "Distance _matrix must be square, got {}x{}",
315            n, m
316        )));
317    }
318
319    if n < 2 {
320        return Err(ClusteringError::InvalidInput(
321            "Distance _matrix must be at least 2x2".to_string(),
322        ));
323    }
324
325    // Check diagonal is zero
326    for i in 0..n {
327        let diag_val = distance_matrix[[i, i]];
328        if !diag_val.is_finite() || diag_val.abs() > F::from(1e-10).unwrap() {
329            return Err(ClusteringError::InvalidInput(format!(
330                "Diagonal element at ({}, {}) should be zero, got {}",
331                i, i, diag_val
332            )));
333        }
334    }
335
336    // Check all elements are finite and non-negative
337    for i in 0..n {
338        for j in 0..n {
339            let val = distance_matrix[[i, j]];
340            if !val.is_finite() {
341                return Err(ClusteringError::InvalidInput(format!(
342                    "Non-finite distance at ({}, {})",
343                    i, j
344                )));
345            }
346
347            if val < F::zero() {
348                return Err(ClusteringError::InvalidInput(format!(
349                    "Negative distance at ({}, {}): {}",
350                    i, j, val
351                )));
352            }
353        }
354    }
355
356    // Check _symmetry
357    if check_symmetry {
358        for i in 0..n {
359            for j in (i + 1)..n {
360                let val_ij = distance_matrix[[i, j]];
361                let val_ji = distance_matrix[[j, i]];
362                let diff = (val_ij - val_ji).abs();
363
364                if diff > F::from(1e-10).unwrap() {
365                    return Err(ClusteringError::InvalidInput(format!(
366                        "Distance _matrix is not symmetric: d({}, {}) = {} != d({}, {}) = {}",
367                        i, j, val_ij, j, i, val_ji
368                    )));
369                }
370            }
371        }
372    }
373
374    // Check triangle _inequality
375    if check_triangle_inequality {
376        for i in 0..n {
377            for j in 0..n {
378                for k in 0..n {
379                    if i != j && j != k && i != k {
380                        let d_ij = distance_matrix[[i, j]];
381                        let d_jk = distance_matrix[[j, k]];
382                        let d_ik = distance_matrix[[i, k]];
383
384                        if d_ik > d_ij + d_jk + F::from(1e-10).unwrap() {
385                            return Err(ClusteringError::InvalidInput(format!(
386                                "Triangle _inequality violated: d({}, {}) = {} > d({}, {}) + d({}, {}) = {} + {}",
387                                i, k, d_ik, i, j, j, k, d_ij, d_jk
388                            )));
389                        }
390                    }
391                }
392            }
393        }
394    }
395
396    Ok(())
397}
398
399/// Checks consistency of cluster assignments with a linkage matrix
400///
401/// Validates that flat cluster assignments are consistent with the
402/// hierarchical structure defined by the linkage matrix.
403///
404/// # Arguments
405///
406/// * `linkage_matrix` - The linkage matrix
407/// * `cluster_assignments` - Flat cluster assignments for each observation
408///
409/// # Returns
410///
411/// * `Result<()>` - Ok if consistent, error otherwise
412#[allow(dead_code)]
413pub fn validate_cluster_consistency<
414    F: Float + FromPrimitive + Debug + PartialOrd + std::fmt::Display,
415>(
416    linkage_matrix: ArrayView2<F>,
417    cluster_assignments: ArrayView1<usize>,
418) -> Result<()> {
419    let n_observations = linkage_matrix.shape()[0] + 1;
420
421    // Check dimensions
422    if cluster_assignments.len() != n_observations {
423        return Err(ClusteringError::InvalidInput(format!(
424            "Cluster _assignments length {} doesn't match number of observations {}",
425            cluster_assignments.len(),
426            n_observations
427        )));
428    }
429
430    // First validate the linkage _matrix
431    validate_linkage_matrix(linkage_matrix, n_observations)?;
432
433    // Check that cluster IDs are in valid range
434    let max_cluster_id = cluster_assignments.iter().max().copied().unwrap_or(0);
435    let unique_clusters: std::collections::HashSet<_> =
436        cluster_assignments.iter().copied().collect();
437
438    // Cluster IDs should be contiguous starting from 0
439    for expected_id in 0..unique_clusters.len() {
440        if !unique_clusters.contains(&expected_id) {
441            return Err(ClusteringError::InvalidInput(format!(
442                "Cluster IDs should be contiguous starting from 0, missing ID {}",
443                expected_id
444            )));
445        }
446    }
447
448    if max_cluster_id >= n_observations {
449        return Err(ClusteringError::InvalidInput(format!(
450            "Maximum cluster ID {} should be less than number of observations {}",
451            max_cluster_id, n_observations
452        )));
453    }
454
455    Ok(())
456}
457
458#[cfg(test)]
459mod tests {
460    use super::*;
461    use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
462
463    #[test]
464    fn test_validate_linkage_matrix_valid() {
465        // Create a valid linkage matrix for 4 points
466        let linkage = Array2::from_shape_vec(
467            (3, 4),
468            vec![
469                0.0, 1.0, 0.5, 2.0, // Merge clusters 0 and 1 at distance 0.5
470                2.0, 3.0, 0.8, 2.0, // Merge clusters 2 and 3 at distance 0.8
471                4.0, 5.0, 1.2, 4.0, // Merge clusters 4 and 5 at distance 1.2
472            ],
473        )
474        .unwrap();
475
476        let result = validate_linkage_matrix(linkage.view(), 4);
477        assert!(
478            result.is_ok(),
479            "Valid linkage matrix should pass validation"
480        );
481    }
482
483    #[test]
484    fn test_validate_linkage_matrix_wrong_dimensions() {
485        // Wrong number of rows
486        let linkage =
487            Array2::from_shape_vec((2, 4), vec![0.0, 1.0, 0.5, 2.0, 2.0, 3.0, 0.8, 2.0]).unwrap();
488
489        let result = validate_linkage_matrix(linkage.view(), 4);
490        assert!(result.is_err(), "Wrong dimensions should fail validation");
491    }
492
493    #[test]
494    fn test_validate_linkage_matrix_negative_distance() {
495        let linkage = Array2::from_shape_vec(
496            (3, 4),
497            vec![
498                0.0, 1.0, -0.5, 2.0, // Negative distance
499                2.0, 3.0, 0.8, 2.0, 4.0, 5.0, 1.2, 4.0,
500            ],
501        )
502        .unwrap();
503
504        let result = validate_linkage_matrix(linkage.view(), 4);
505        assert!(result.is_err(), "Negative distance should fail validation");
506    }
507
508    #[test]
509    fn test_validate_linkage_matrix_self_merge() {
510        let linkage = Array2::from_shape_vec(
511            (3, 4),
512            vec![
513                0.0, 0.0, 0.5, 2.0, // Self-merge
514                2.0, 3.0, 0.8, 2.0, 4.0, 5.0, 1.2, 4.0,
515            ],
516        )
517        .unwrap();
518
519        let result = validate_linkage_matrix(linkage.view(), 4);
520        assert!(result.is_err(), "Self-merge should fail validation");
521    }
522
523    #[test]
524    fn test_validate_monotonic_distances_valid() {
525        let linkage = Array2::from_shape_vec(
526            (3, 4),
527            vec![0.0, 1.0, 0.5, 2.0, 2.0, 3.0, 0.8, 2.0, 4.0, 5.0, 1.2, 4.0],
528        )
529        .unwrap();
530
531        let result = validate_monotonic_distances(linkage.view(), false);
532        assert!(result.is_ok(), "Monotonic distances should pass validation");
533    }
534
535    #[test]
536    fn test_validate_monotonic_distances_invalid() {
537        let linkage = Array2::from_shape_vec(
538            (3, 4),
539            vec![
540                0.0, 1.0, 1.2, 2.0, // Higher distance first
541                2.0, 3.0, 0.8, 2.0, // Lower distance second
542                4.0, 5.0, 1.5, 4.0,
543            ],
544        )
545        .unwrap();
546
547        let result = validate_monotonic_distances(linkage.view(), false);
548        assert!(
549            result.is_err(),
550            "Non-monotonic distances should fail validation"
551        );
552    }
553
554    #[test]
555    fn test_validate_condensed_distance_matrix() {
556        // Valid condensed matrix for 4 points: 4*3/2 = 6 elements
557        let distances = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
558
559        let result = validate_distance_matrix(distances.view(), true);
560        assert!(
561            result.is_ok(),
562            "Valid condensed distance matrix should pass"
563        );
564    }
565
566    #[test]
567    fn test_validate_condensed_distance_matrix_invalid_size() {
568        // Invalid size: 5 elements doesn't correspond to n*(n-1)/2 for any n
569        let distances = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
570
571        let result = validate_distance_matrix(distances.view(), true);
572        assert!(result.is_err(), "Invalid condensed matrix size should fail");
573    }
574
575    #[test]
576    fn test_validate_cluster_extraction_params() {
577        let linkage = Array2::from_shape_vec(
578            (3, 4),
579            vec![0.0, 1.0, 0.5, 2.0, 2.0, 3.0, 0.8, 2.0, 4.0, 5.0, 1.2, 4.0],
580        )
581        .unwrap();
582
583        // Valid parameters
584        assert!(validate_cluster_extraction_params(linkage.view(), "distance", 1.0).is_ok());
585        assert!(validate_cluster_extraction_params(linkage.view(), "maxclust", 3.0).is_ok());
586        assert!(validate_cluster_extraction_params(linkage.view(), "inconsistent", 0.5).is_ok());
587
588        // Invalid parameters
589        assert!(validate_cluster_extraction_params(linkage.view(), "distance", -1.0).is_err());
590        assert!(validate_cluster_extraction_params(linkage.view(), "maxclust", 0.0).is_err());
591        assert!(validate_cluster_extraction_params(linkage.view(), "invalid", 1.0).is_err());
592    }
593}