scirs2_cluster/
input_validation.rs

1//! Enhanced input validation utilities
2//!
3//! This module provides comprehensive input validation functions that are
4//! compatible with SciPy's validation patterns and provide consistent
5//! error messages across all clustering algorithms.
6
7use scirs2_core::ndarray::{ArrayView1, ArrayView2, Axis};
8use scirs2_core::numeric::{Float, FromPrimitive};
9use std::fmt::Debug;
10
11use crate::error::{ClusteringError, Result};
12
13/// Validation configuration for different algorithms
14#[derive(Debug, Clone)]
15pub struct ValidationConfig {
16    /// Minimum number of samples required
17    pub min_samples: usize,
18    /// Maximum number of samples before warning
19    pub max_samples_warning: Option<usize>,
20    /// Minimum number of features required
21    pub min_features: usize,
22    /// Whether to check for finite values
23    pub check_finite: bool,
24    /// Whether to allow empty data
25    pub allow_empty: bool,
26    /// Custom error message prefix
27    pub error_prefix: Option<String>,
28}
29
30impl Default for ValidationConfig {
31    fn default() -> Self {
32        Self {
33            min_samples: 2,
34            max_samples_warning: Some(10000),
35            min_features: 1,
36            check_finite: true,
37            allow_empty: false,
38            error_prefix: None,
39        }
40    }
41}
42
43impl ValidationConfig {
44    /// Create validation config for K-means
45    pub fn for_kmeans() -> Self {
46        Self {
47            min_samples: 1,
48            max_samples_warning: Some(50000),
49            min_features: 1,
50            check_finite: true,
51            allow_empty: false,
52            error_prefix: Some("K-means".to_string()),
53        }
54    }
55
56    /// Create validation config for hierarchical clustering
57    pub fn for_hierarchical() -> Self {
58        Self {
59            min_samples: 2,
60            max_samples_warning: Some(5000),
61            min_features: 1,
62            check_finite: true,
63            allow_empty: false,
64            error_prefix: Some("Hierarchical clustering".to_string()),
65        }
66    }
67
68    /// Create validation config for DBSCAN
69    pub fn for_dbscan() -> Self {
70        Self {
71            min_samples: 2,
72            max_samples_warning: Some(20000),
73            min_features: 1,
74            check_finite: true,
75            allow_empty: false,
76            error_prefix: Some("DBSCAN".to_string()),
77        }
78    }
79
80    /// Create validation config for spectral clustering
81    pub fn for_spectral() -> Self {
82        Self {
83            min_samples: 2,
84            max_samples_warning: Some(1000),
85            min_features: 1,
86            check_finite: true,
87            allow_empty: false,
88            error_prefix: Some("Spectral clustering".to_string()),
89        }
90    }
91}
92
93/// Comprehensive data validation for clustering algorithms
94///
95/// Validates input data according to the specified configuration and provides
96/// SciPy-compatible error messages.
97///
98/// # Arguments
99///
100/// * `data` - Input data matrix (n_samples × n_features)
101/// * `config` - Validation configuration
102///
103/// # Returns
104///
105/// * `Result<()>` - Ok if valid, detailed error if invalid
106///
107/// # Examples
108///
109/// ```
110/// use scirs2_core::ndarray::Array2;
111/// use scirs2_cluster::input_validation::{validate_clustering_data, ValidationConfig};
112///
113/// let data = Array2::from_shape_vec((10, 3), (0..30).map(|x| x as f64).collect()).unwrap();
114/// let config = ValidationConfig::for_kmeans();
115///
116/// assert!(validate_clustering_data(data.view(), &config).is_ok());
117/// ```
118#[allow(dead_code)]
119pub fn validate_clustering_data<F: Float + FromPrimitive + Debug + PartialOrd>(
120    data: ArrayView2<F>,
121    config: &ValidationConfig,
122) -> Result<()> {
123    let (n_samples, n_features) = data.dim();
124    let prefix = config.error_prefix.as_deref().unwrap_or("Clustering");
125
126    // Check empty data
127    if n_samples == 0 && !config.allow_empty {
128        return Err(ClusteringError::InvalidInput(format!(
129            "{}: Input data cannot be empty",
130            prefix
131        )));
132    }
133
134    if n_features == 0 && !config.allow_empty {
135        return Err(ClusteringError::InvalidInput(format!(
136            "{}: Input data must have at least one feature",
137            prefix
138        )));
139    }
140
141    // Check minimum requirements
142    if n_samples < config.min_samples {
143        return Err(ClusteringError::InvalidInput(format!(
144            "{}: Need at least {} samples, got {}",
145            prefix, config.min_samples, n_samples
146        )));
147    }
148
149    if n_features < config.min_features {
150        return Err(ClusteringError::InvalidInput(format!(
151            "{}: Need at least {} features, got {}",
152            prefix, config.min_features, n_features
153        )));
154    }
155
156    // Check for size warnings
157    if let Some(max_warn) = config.max_samples_warning {
158        if n_samples > max_warn {
159            eprintln!(
160                "Warning: {} with {} samples may be slow. Consider using a subset or more efficient algorithm.",
161                prefix, n_samples
162            );
163        }
164    }
165
166    // Check for finite values
167    if config.check_finite {
168        validate_finite_values(data, prefix)?;
169    }
170
171    Ok(())
172}
173
174/// Validate that all values in the data are finite
175#[allow(dead_code)]
176fn validate_finite_values<F: Float + Debug>(data: ArrayView2<F>, prefix: &str) -> Result<()> {
177    for (i, row) in data.axis_iter(Axis(0)).enumerate() {
178        for (j, &value) in row.iter().enumerate() {
179            if !value.is_finite() {
180                return Err(ClusteringError::InvalidInput(format!(
181                    "{}: Non-finite value {:?} at position ({}, {})",
182                    prefix, value, i, j
183                )));
184            }
185        }
186    }
187    Ok(())
188}
189
190/// Validate cluster count parameter
191///
192/// Ensures the number of clusters is valid for the given dataset.
193///
194/// # Arguments
195///
196/// * `n_clusters` - Number of clusters requested
197/// * `n_samples` - Number of samples in dataset
198/// * `algorithm` - Algorithm name for error messages
199///
200/// # Returns
201///
202/// * `Result<()>` - Ok if valid, error otherwise
203#[allow(dead_code)]
204pub fn validate_n_clusters(n_clusters: usize, nsamples: usize, algorithm: &str) -> Result<()> {
205    if n_clusters == 0 {
206        return Err(ClusteringError::InvalidInput(format!(
207            "{}: Number of clusters must be positive, got 0",
208            algorithm
209        )));
210    }
211
212    if n_clusters > nsamples {
213        return Err(ClusteringError::InvalidInput(format!(
214            "{}: Number of clusters ({}) cannot exceed number of samples ({})",
215            algorithm, n_clusters, nsamples
216        )));
217    }
218
219    Ok(())
220}
221
222/// Validate distance/similarity parameters
223///
224/// Checks that distance thresholds and similarity parameters are valid.
225///
226/// # Arguments
227///
228/// * `value` - Parameter value to validate
229/// * `param_name` - Parameter name for error messages
230/// * `min_value` - Minimum allowed value (inclusive)
231/// * `max_value` - Maximum allowed value (inclusive), None for no limit
232/// * `algorithm` - Algorithm name for error messages
233///
234/// # Returns
235///
236/// * `Result<()>` - Ok if valid, error otherwise
237#[allow(dead_code)]
238pub fn validate_distance_parameter<F: Float + FromPrimitive + Debug + PartialOrd>(
239    value: F,
240    param_name: &str,
241    min_value: Option<F>,
242    max_value: Option<F>,
243    algorithm: &str,
244) -> Result<()> {
245    if !value.is_finite() {
246        return Err(ClusteringError::InvalidInput(format!(
247            "{}: {} must be finite, got {:?}",
248            algorithm, param_name, value
249        )));
250    }
251
252    if let Some(min_val) = min_value {
253        if value < min_val {
254            return Err(ClusteringError::InvalidInput(format!(
255                "{}: {} must be >= {:?}, got {:?}",
256                algorithm, param_name, min_val, value
257            )));
258        }
259    }
260
261    if let Some(max_val) = max_value {
262        if value > max_val {
263            return Err(ClusteringError::InvalidInput(format!(
264                "{}: {} must be <= {:?}, got {:?}",
265                algorithm, param_name, max_val, value
266            )));
267        }
268    }
269
270    Ok(())
271}
272
273/// Validate integer parameters with bounds
274///
275/// Ensures integer parameters are within valid ranges.
276///
277/// # Arguments
278///
279/// * `value` - Parameter value to validate
280/// * `param_name` - Parameter name for error messages
281/// * `min_value` - Minimum allowed value (inclusive)
282/// * `max_value` - Maximum allowed value (inclusive), None for no limit
283/// * `algorithm` - Algorithm name for error messages
284///
285/// # Returns
286///
287/// * `Result<()>` - Ok if valid, error otherwise
288#[allow(dead_code)]
289pub fn validate_integer_parameter(
290    value: usize,
291    param_name: &str,
292    min_value: Option<usize>,
293    max_value: Option<usize>,
294    algorithm: &str,
295) -> Result<()> {
296    if let Some(min_val) = min_value {
297        if value < min_val {
298            return Err(ClusteringError::InvalidInput(format!(
299                "{}: {} must be >= {}, got {}",
300                algorithm, param_name, min_val, value
301            )));
302        }
303    }
304
305    if let Some(max_val) = max_value {
306        if value > max_val {
307            return Err(ClusteringError::InvalidInput(format!(
308                "{}: {} must be <= {}, got {}",
309                algorithm, param_name, max_val, value
310            )));
311        }
312    }
313
314    Ok(())
315}
316
317/// Validate sample weights
318///
319/// Ensures sample weights are valid (non-negative, finite, and consistent with data size).
320///
321/// # Arguments
322///
323/// * `weights` - Sample weights array
324/// * `n_samples` - Expected number of samples
325/// * `algorithm` - Algorithm name for error messages
326///
327/// # Returns
328///
329/// * `Result<()>` - Ok if valid, error otherwise
330#[allow(dead_code)]
331pub fn validate_sample_weights<F: Float + FromPrimitive + Debug + PartialOrd>(
332    weights: ArrayView1<F>,
333    n_samples: usize,
334    algorithm: &str,
335) -> Result<()> {
336    if weights.len() != n_samples {
337        return Err(ClusteringError::InvalidInput(format!(
338            "{}: Sample weights length ({}) must match number of samples ({})",
339            algorithm,
340            weights.len(),
341            n_samples
342        )));
343    }
344
345    for (i, &weight) in weights.iter().enumerate() {
346        if !weight.is_finite() {
347            return Err(ClusteringError::InvalidInput(format!(
348                "{}: Sample weight at index {} is not finite: {:?}",
349                algorithm, i, weight
350            )));
351        }
352
353        if weight < F::zero() {
354            return Err(ClusteringError::InvalidInput(format!(
355                "{}: Sample weight at index {} must be non-negative, got {:?}",
356                algorithm, i, weight
357            )));
358        }
359    }
360
361    // Check that not all weights are zero
362    let sum_weights = weights.iter().fold(F::zero(), |acc, &w| acc + w);
363    if sum_weights <= F::zero() {
364        return Err(ClusteringError::InvalidInput(format!(
365            "{}: Sum of sample weights must be positive",
366            algorithm
367        )));
368    }
369
370    Ok(())
371}
372
373/// Validate cluster initialization data
374///
375/// Validates initial cluster centers or assignments for clustering algorithms.
376///
377/// # Arguments
378///
379/// * `init_data` - Initial cluster centers (k × n_features)
380/// * `n_clusters` - Expected number of clusters
381/// * `n_features` - Expected number of features
382/// * `algorithm` - Algorithm name for error messages
383///
384/// # Returns
385///
386/// * `Result<()>` - Ok if valid, error otherwise
387#[allow(dead_code)]
388pub fn validate_cluster_initialization<F: Float + FromPrimitive + Debug + PartialOrd>(
389    init_data: ArrayView2<F>,
390    n_clusters: usize,
391    n_features: usize,
392    algorithm: &str,
393) -> Result<()> {
394    let (init_clusters, init_features) = init_data.dim();
395
396    if init_clusters != n_clusters {
397        return Err(ClusteringError::InvalidInput(format!(
398            "{}: Initial cluster centers must have {} clusters, got {}",
399            algorithm, n_clusters, init_clusters
400        )));
401    }
402
403    if init_features != n_features {
404        return Err(ClusteringError::InvalidInput(format!(
405            "{}: Initial cluster centers must have {} features, got {}",
406            algorithm, n_features, init_features
407        )));
408    }
409
410    // Check for finite values
411    for (i, row) in init_data.axis_iter(Axis(0)).enumerate() {
412        for (j, &value) in row.iter().enumerate() {
413            if !value.is_finite() {
414                return Err(ClusteringError::InvalidInput(format!(
415                    "{}: Non-finite value {:?} in initial cluster center at position ({}, {})",
416                    algorithm, value, i, j
417                )));
418            }
419        }
420    }
421
422    Ok(())
423}
424
425/// Validate convergence parameters
426///
427/// Validates convergence threshold and maximum iterations for iterative algorithms.
428///
429/// # Arguments
430///
431/// * `tolerance` - Convergence tolerance
432/// * `max_iterations` - Maximum number of iterations
433/// * `algorithm` - Algorithm name for error messages
434///
435/// # Returns
436///
437/// * `Result<()>` - Ok if valid, error otherwise
438#[allow(dead_code)]
439pub fn validate_convergence_parameters<F: Float + FromPrimitive + Debug + PartialOrd>(
440    tolerance: Option<F>,
441    max_iterations: Option<usize>,
442    algorithm: &str,
443) -> Result<()> {
444    if let Some(tol) = tolerance {
445        validate_distance_parameter(tol, "tolerance", Some(F::zero()), None, algorithm)?;
446    }
447
448    if let Some(max_iter) = max_iterations {
449        validate_integer_parameter(max_iter, "max_iterations", Some(1), None, algorithm)?;
450    }
451
452    Ok(())
453}
454
455/// Check for duplicate data points
456///
457/// Identifies if the dataset contains duplicate points, which can cause issues
458/// for some clustering algorithms.
459///
460/// # Arguments
461///
462/// * `data` - Input data matrix
463/// * `tolerance` - Tolerance for considering points as duplicates
464///
465/// # Returns
466///
467/// * `Result<Vec<(usize, usize)>>` - List of duplicate point pairs
468#[allow(dead_code)]
469pub fn check_duplicate_points<F: Float + FromPrimitive + Debug + PartialOrd>(
470    data: ArrayView2<F>,
471    tolerance: F,
472) -> Result<Vec<(usize, usize)>> {
473    let n_samples = data.shape()[0];
474    let mut duplicates = Vec::new();
475
476    for i in 0..n_samples {
477        for j in (i + 1)..n_samples {
478            let mut distance_squared = F::zero();
479            for k in 0..data.shape()[1] {
480                let diff = data[[i, k]] - data[[j, k]];
481                distance_squared = distance_squared + diff * diff;
482            }
483
484            if distance_squared <= tolerance * tolerance {
485                duplicates.push((i, j));
486            }
487        }
488    }
489
490    Ok(duplicates)
491}
492
493/// Validate and suggest appropriate clustering algorithm
494///
495/// Analyzes the dataset characteristics and suggests the most appropriate
496/// clustering algorithm with explanations.
497///
498/// # Arguments
499///
500/// * `data` - Input data matrix
501/// * `n_clusters` - Desired number of clusters (if known)
502///
503/// # Returns
504///
505/// * `Result<String>` - Recommendation message
506#[allow(dead_code)]
507pub fn suggest_clustering_algorithm<F: Float + FromPrimitive + Debug + PartialOrd>(
508    data: ArrayView2<F>,
509    n_clusters: Option<usize>,
510) -> Result<String> {
511    let (n_samples, n_features) = data.dim();
512
513    // Validate data first
514    let config = ValidationConfig::default();
515    validate_clustering_data(data, &config)?;
516
517    let mut suggestions = Vec::new();
518
519    // Analyze dataset characteristics
520    if n_samples < 100 {
521        suggestions
522            .push("Small dataset: Consider hierarchical clustering for interpretable results");
523    } else if n_samples > 10000 {
524        suggestions.push("Large dataset: K-means or DBSCAN recommended for efficiency");
525    }
526
527    if n_features > 50 {
528        suggestions.push(
529            "High-dimensional data: Consider spectral clustering or dimensionality reduction",
530        );
531    }
532
533    // Check for duplicates
534    let duplicates = check_duplicate_points(data, F::from_f64(1e-10).unwrap())?;
535    if !duplicates.is_empty() {
536        suggestions.push("Duplicate points detected: DBSCAN may handle noise well");
537    }
538
539    // Algorithm-specific recommendations
540    if let Some(k) = n_clusters {
541        if k <= 10 {
542            suggestions.push(
543                "Small number of clusters: K-means with k-means++ initialization recommended",
544            );
545        } else {
546            suggestions.push("Many clusters: Consider hierarchical clustering or DBSCAN");
547        }
548    } else {
549        suggestions.push(
550            "Unknown cluster count: DBSCAN or hierarchical clustering with automatic cut-off",
551        );
552    }
553
554    // Performance considerations
555    if n_samples > 5000 && n_features > 20 {
556        suggestions.push("Performance consideration: Use parallel implementations when available");
557    }
558
559    let recommendation = if suggestions.is_empty() {
560        "K-means with k-means++ initialization is a good general-purpose choice".to_string()
561    } else {
562        format!("Recommendations:\n{}", suggestions.join("\n- "))
563    };
564
565    Ok(recommendation)
566}
567
568#[cfg(test)]
569mod tests {
570    use super::*;
571    use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
572
573    #[test]
574    fn test_validate_clustering_data() {
575        // Valid data
576        let data = Array2::from_shape_vec((10, 3), (0..30).map(|x| x as f64).collect()).unwrap();
577        let config = ValidationConfig::default();
578        assert!(validate_clustering_data(data.view(), &config).is_ok());
579
580        // Too few samples
581        let small_data = Array2::from_shape_vec((1, 3), vec![1.0, 2.0, 3.0]).unwrap();
582        assert!(validate_clustering_data(small_data.view(), &config).is_err());
583
584        // Non-finite values
585        let invalid_data =
586            Array2::from_shape_vec((3, 2), vec![1.0, 2.0, f64::NAN, 4.0, 5.0, 6.0]).unwrap();
587        assert!(validate_clustering_data(invalid_data.view(), &config).is_err());
588    }
589
590    #[test]
591    fn test_validate_n_clusters() {
592        assert!(validate_n_clusters(3, 10, "Test").is_ok());
593        assert!(validate_n_clusters(0, 10, "Test").is_err()); // Zero clusters
594        assert!(validate_n_clusters(15, 10, "Test").is_err()); // More clusters than samples
595    }
596
597    #[test]
598    fn test_validate_distance_parameter() {
599        assert!(validate_distance_parameter(1.0, "eps", Some(0.0), Some(10.0), "Test").is_ok());
600        assert!(validate_distance_parameter(-1.0, "eps", Some(0.0), None, "Test").is_err());
601        assert!(validate_distance_parameter(f64::NAN, "eps", None, None, "Test").is_err());
602    }
603
604    #[test]
605    fn test_validate_sample_weights() {
606        let weights = Array1::from_vec(vec![1.0, 2.0, 3.0]);
607        assert!(validate_sample_weights(weights.view(), 3, "Test").is_ok());
608
609        let negative_weights = Array1::from_vec(vec![1.0, -2.0, 3.0]);
610        assert!(validate_sample_weights(negative_weights.view(), 3, "Test").is_err());
611
612        let wrong_size = Array1::from_vec(vec![1.0, 2.0]);
613        assert!(validate_sample_weights(wrong_size.view(), 3, "Test").is_err());
614    }
615
616    #[test]
617    fn test_check_duplicate_points() {
618        let data =
619            Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
620        let duplicates = check_duplicate_points(data.view(), 1e-10).unwrap();
621        assert_eq!(duplicates.len(), 1); // Points 0 and 1 are identical
622        assert_eq!(duplicates[0], (0, 1));
623    }
624
625    #[test]
626    fn test_suggest_clustering_algorithm() {
627        let data = Array2::from_shape_vec((100, 5), (0..500).map(|x| x as f64).collect()).unwrap();
628        let suggestion = suggest_clustering_algorithm(data.view(), Some(3)).unwrap();
629        assert!(!suggestion.is_empty());
630        assert!(suggestion.contains("K-means") || suggestion.contains("recommendation"));
631    }
632
633    #[test]
634    fn test_validation_configs() {
635        let kmeans_config = ValidationConfig::for_kmeans();
636        assert_eq!(kmeans_config.min_samples, 1);
637
638        let hierarchical_config = ValidationConfig::for_hierarchical();
639        assert_eq!(hierarchical_config.min_samples, 2);
640
641        let dbscan_config = ValidationConfig::for_dbscan();
642        assert_eq!(dbscan_config.min_samples, 2);
643    }
644}