1use scirs2_core::ndarray::{ArrayView1, ArrayView2, Axis};
8use scirs2_core::numeric::{Float, FromPrimitive};
9use std::fmt::Debug;
10
11use crate::error::{ClusteringError, Result};
12
13#[derive(Debug, Clone)]
15pub struct ValidationConfig {
16 pub min_samples: usize,
18 pub max_samples_warning: Option<usize>,
20 pub min_features: usize,
22 pub check_finite: bool,
24 pub allow_empty: bool,
26 pub error_prefix: Option<String>,
28}
29
30impl Default for ValidationConfig {
31 fn default() -> Self {
32 Self {
33 min_samples: 2,
34 max_samples_warning: Some(10000),
35 min_features: 1,
36 check_finite: true,
37 allow_empty: false,
38 error_prefix: None,
39 }
40 }
41}
42
43impl ValidationConfig {
44 pub fn for_kmeans() -> Self {
46 Self {
47 min_samples: 1,
48 max_samples_warning: Some(50000),
49 min_features: 1,
50 check_finite: true,
51 allow_empty: false,
52 error_prefix: Some("K-means".to_string()),
53 }
54 }
55
56 pub fn for_hierarchical() -> Self {
58 Self {
59 min_samples: 2,
60 max_samples_warning: Some(5000),
61 min_features: 1,
62 check_finite: true,
63 allow_empty: false,
64 error_prefix: Some("Hierarchical clustering".to_string()),
65 }
66 }
67
68 pub fn for_dbscan() -> Self {
70 Self {
71 min_samples: 2,
72 max_samples_warning: Some(20000),
73 min_features: 1,
74 check_finite: true,
75 allow_empty: false,
76 error_prefix: Some("DBSCAN".to_string()),
77 }
78 }
79
80 pub fn for_spectral() -> Self {
82 Self {
83 min_samples: 2,
84 max_samples_warning: Some(1000),
85 min_features: 1,
86 check_finite: true,
87 allow_empty: false,
88 error_prefix: Some("Spectral clustering".to_string()),
89 }
90 }
91}
92
93#[allow(dead_code)]
119pub fn validate_clustering_data<F: Float + FromPrimitive + Debug + PartialOrd>(
120 data: ArrayView2<F>,
121 config: &ValidationConfig,
122) -> Result<()> {
123 let (n_samples, n_features) = data.dim();
124 let prefix = config.error_prefix.as_deref().unwrap_or("Clustering");
125
126 if n_samples == 0 && !config.allow_empty {
128 return Err(ClusteringError::InvalidInput(format!(
129 "{}: Input data cannot be empty",
130 prefix
131 )));
132 }
133
134 if n_features == 0 && !config.allow_empty {
135 return Err(ClusteringError::InvalidInput(format!(
136 "{}: Input data must have at least one feature",
137 prefix
138 )));
139 }
140
141 if n_samples < config.min_samples {
143 return Err(ClusteringError::InvalidInput(format!(
144 "{}: Need at least {} samples, got {}",
145 prefix, config.min_samples, n_samples
146 )));
147 }
148
149 if n_features < config.min_features {
150 return Err(ClusteringError::InvalidInput(format!(
151 "{}: Need at least {} features, got {}",
152 prefix, config.min_features, n_features
153 )));
154 }
155
156 if let Some(max_warn) = config.max_samples_warning {
158 if n_samples > max_warn {
159 eprintln!(
160 "Warning: {} with {} samples may be slow. Consider using a subset or more efficient algorithm.",
161 prefix, n_samples
162 );
163 }
164 }
165
166 if config.check_finite {
168 validate_finite_values(data, prefix)?;
169 }
170
171 Ok(())
172}
173
174#[allow(dead_code)]
176fn validate_finite_values<F: Float + Debug>(data: ArrayView2<F>, prefix: &str) -> Result<()> {
177 for (i, row) in data.axis_iter(Axis(0)).enumerate() {
178 for (j, &value) in row.iter().enumerate() {
179 if !value.is_finite() {
180 return Err(ClusteringError::InvalidInput(format!(
181 "{}: Non-finite value {:?} at position ({}, {})",
182 prefix, value, i, j
183 )));
184 }
185 }
186 }
187 Ok(())
188}
189
190#[allow(dead_code)]
204pub fn validate_n_clusters(n_clusters: usize, nsamples: usize, algorithm: &str) -> Result<()> {
205 if n_clusters == 0 {
206 return Err(ClusteringError::InvalidInput(format!(
207 "{}: Number of clusters must be positive, got 0",
208 algorithm
209 )));
210 }
211
212 if n_clusters > nsamples {
213 return Err(ClusteringError::InvalidInput(format!(
214 "{}: Number of clusters ({}) cannot exceed number of samples ({})",
215 algorithm, n_clusters, nsamples
216 )));
217 }
218
219 Ok(())
220}
221
222#[allow(dead_code)]
238pub fn validate_distance_parameter<F: Float + FromPrimitive + Debug + PartialOrd>(
239 value: F,
240 param_name: &str,
241 min_value: Option<F>,
242 max_value: Option<F>,
243 algorithm: &str,
244) -> Result<()> {
245 if !value.is_finite() {
246 return Err(ClusteringError::InvalidInput(format!(
247 "{}: {} must be finite, got {:?}",
248 algorithm, param_name, value
249 )));
250 }
251
252 if let Some(min_val) = min_value {
253 if value < min_val {
254 return Err(ClusteringError::InvalidInput(format!(
255 "{}: {} must be >= {:?}, got {:?}",
256 algorithm, param_name, min_val, value
257 )));
258 }
259 }
260
261 if let Some(max_val) = max_value {
262 if value > max_val {
263 return Err(ClusteringError::InvalidInput(format!(
264 "{}: {} must be <= {:?}, got {:?}",
265 algorithm, param_name, max_val, value
266 )));
267 }
268 }
269
270 Ok(())
271}
272
273#[allow(dead_code)]
289pub fn validate_integer_parameter(
290 value: usize,
291 param_name: &str,
292 min_value: Option<usize>,
293 max_value: Option<usize>,
294 algorithm: &str,
295) -> Result<()> {
296 if let Some(min_val) = min_value {
297 if value < min_val {
298 return Err(ClusteringError::InvalidInput(format!(
299 "{}: {} must be >= {}, got {}",
300 algorithm, param_name, min_val, value
301 )));
302 }
303 }
304
305 if let Some(max_val) = max_value {
306 if value > max_val {
307 return Err(ClusteringError::InvalidInput(format!(
308 "{}: {} must be <= {}, got {}",
309 algorithm, param_name, max_val, value
310 )));
311 }
312 }
313
314 Ok(())
315}
316
317#[allow(dead_code)]
331pub fn validate_sample_weights<F: Float + FromPrimitive + Debug + PartialOrd>(
332 weights: ArrayView1<F>,
333 n_samples: usize,
334 algorithm: &str,
335) -> Result<()> {
336 if weights.len() != n_samples {
337 return Err(ClusteringError::InvalidInput(format!(
338 "{}: Sample weights length ({}) must match number of samples ({})",
339 algorithm,
340 weights.len(),
341 n_samples
342 )));
343 }
344
345 for (i, &weight) in weights.iter().enumerate() {
346 if !weight.is_finite() {
347 return Err(ClusteringError::InvalidInput(format!(
348 "{}: Sample weight at index {} is not finite: {:?}",
349 algorithm, i, weight
350 )));
351 }
352
353 if weight < F::zero() {
354 return Err(ClusteringError::InvalidInput(format!(
355 "{}: Sample weight at index {} must be non-negative, got {:?}",
356 algorithm, i, weight
357 )));
358 }
359 }
360
361 let sum_weights = weights.iter().fold(F::zero(), |acc, &w| acc + w);
363 if sum_weights <= F::zero() {
364 return Err(ClusteringError::InvalidInput(format!(
365 "{}: Sum of sample weights must be positive",
366 algorithm
367 )));
368 }
369
370 Ok(())
371}
372
373#[allow(dead_code)]
388pub fn validate_cluster_initialization<F: Float + FromPrimitive + Debug + PartialOrd>(
389 init_data: ArrayView2<F>,
390 n_clusters: usize,
391 n_features: usize,
392 algorithm: &str,
393) -> Result<()> {
394 let (init_clusters, init_features) = init_data.dim();
395
396 if init_clusters != n_clusters {
397 return Err(ClusteringError::InvalidInput(format!(
398 "{}: Initial cluster centers must have {} clusters, got {}",
399 algorithm, n_clusters, init_clusters
400 )));
401 }
402
403 if init_features != n_features {
404 return Err(ClusteringError::InvalidInput(format!(
405 "{}: Initial cluster centers must have {} features, got {}",
406 algorithm, n_features, init_features
407 )));
408 }
409
410 for (i, row) in init_data.axis_iter(Axis(0)).enumerate() {
412 for (j, &value) in row.iter().enumerate() {
413 if !value.is_finite() {
414 return Err(ClusteringError::InvalidInput(format!(
415 "{}: Non-finite value {:?} in initial cluster center at position ({}, {})",
416 algorithm, value, i, j
417 )));
418 }
419 }
420 }
421
422 Ok(())
423}
424
425#[allow(dead_code)]
439pub fn validate_convergence_parameters<F: Float + FromPrimitive + Debug + PartialOrd>(
440 tolerance: Option<F>,
441 max_iterations: Option<usize>,
442 algorithm: &str,
443) -> Result<()> {
444 if let Some(tol) = tolerance {
445 validate_distance_parameter(tol, "tolerance", Some(F::zero()), None, algorithm)?;
446 }
447
448 if let Some(max_iter) = max_iterations {
449 validate_integer_parameter(max_iter, "max_iterations", Some(1), None, algorithm)?;
450 }
451
452 Ok(())
453}
454
455#[allow(dead_code)]
469pub fn check_duplicate_points<F: Float + FromPrimitive + Debug + PartialOrd>(
470 data: ArrayView2<F>,
471 tolerance: F,
472) -> Result<Vec<(usize, usize)>> {
473 let n_samples = data.shape()[0];
474 let mut duplicates = Vec::new();
475
476 for i in 0..n_samples {
477 for j in (i + 1)..n_samples {
478 let mut distance_squared = F::zero();
479 for k in 0..data.shape()[1] {
480 let diff = data[[i, k]] - data[[j, k]];
481 distance_squared = distance_squared + diff * diff;
482 }
483
484 if distance_squared <= tolerance * tolerance {
485 duplicates.push((i, j));
486 }
487 }
488 }
489
490 Ok(duplicates)
491}
492
493#[allow(dead_code)]
507pub fn suggest_clustering_algorithm<F: Float + FromPrimitive + Debug + PartialOrd>(
508 data: ArrayView2<F>,
509 n_clusters: Option<usize>,
510) -> Result<String> {
511 let (n_samples, n_features) = data.dim();
512
513 let config = ValidationConfig::default();
515 validate_clustering_data(data, &config)?;
516
517 let mut suggestions = Vec::new();
518
519 if n_samples < 100 {
521 suggestions
522 .push("Small dataset: Consider hierarchical clustering for interpretable results");
523 } else if n_samples > 10000 {
524 suggestions.push("Large dataset: K-means or DBSCAN recommended for efficiency");
525 }
526
527 if n_features > 50 {
528 suggestions.push(
529 "High-dimensional data: Consider spectral clustering or dimensionality reduction",
530 );
531 }
532
533 let duplicates = check_duplicate_points(data, F::from_f64(1e-10).unwrap())?;
535 if !duplicates.is_empty() {
536 suggestions.push("Duplicate points detected: DBSCAN may handle noise well");
537 }
538
539 if let Some(k) = n_clusters {
541 if k <= 10 {
542 suggestions.push(
543 "Small number of clusters: K-means with k-means++ initialization recommended",
544 );
545 } else {
546 suggestions.push("Many clusters: Consider hierarchical clustering or DBSCAN");
547 }
548 } else {
549 suggestions.push(
550 "Unknown cluster count: DBSCAN or hierarchical clustering with automatic cut-off",
551 );
552 }
553
554 if n_samples > 5000 && n_features > 20 {
556 suggestions.push("Performance consideration: Use parallel implementations when available");
557 }
558
559 let recommendation = if suggestions.is_empty() {
560 "K-means with k-means++ initialization is a good general-purpose choice".to_string()
561 } else {
562 format!("Recommendations:\n{}", suggestions.join("\n- "))
563 };
564
565 Ok(recommendation)
566}
567
568#[cfg(test)]
569mod tests {
570 use super::*;
571 use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
572
573 #[test]
574 fn test_validate_clustering_data() {
575 let data = Array2::from_shape_vec((10, 3), (0..30).map(|x| x as f64).collect()).unwrap();
577 let config = ValidationConfig::default();
578 assert!(validate_clustering_data(data.view(), &config).is_ok());
579
580 let small_data = Array2::from_shape_vec((1, 3), vec![1.0, 2.0, 3.0]).unwrap();
582 assert!(validate_clustering_data(small_data.view(), &config).is_err());
583
584 let invalid_data =
586 Array2::from_shape_vec((3, 2), vec![1.0, 2.0, f64::NAN, 4.0, 5.0, 6.0]).unwrap();
587 assert!(validate_clustering_data(invalid_data.view(), &config).is_err());
588 }
589
590 #[test]
591 fn test_validate_n_clusters() {
592 assert!(validate_n_clusters(3, 10, "Test").is_ok());
593 assert!(validate_n_clusters(0, 10, "Test").is_err()); assert!(validate_n_clusters(15, 10, "Test").is_err()); }
596
597 #[test]
598 fn test_validate_distance_parameter() {
599 assert!(validate_distance_parameter(1.0, "eps", Some(0.0), Some(10.0), "Test").is_ok());
600 assert!(validate_distance_parameter(-1.0, "eps", Some(0.0), None, "Test").is_err());
601 assert!(validate_distance_parameter(f64::NAN, "eps", None, None, "Test").is_err());
602 }
603
604 #[test]
605 fn test_validate_sample_weights() {
606 let weights = Array1::from_vec(vec![1.0, 2.0, 3.0]);
607 assert!(validate_sample_weights(weights.view(), 3, "Test").is_ok());
608
609 let negative_weights = Array1::from_vec(vec![1.0, -2.0, 3.0]);
610 assert!(validate_sample_weights(negative_weights.view(), 3, "Test").is_err());
611
612 let wrong_size = Array1::from_vec(vec![1.0, 2.0]);
613 assert!(validate_sample_weights(wrong_size.view(), 3, "Test").is_err());
614 }
615
616 #[test]
617 fn test_check_duplicate_points() {
618 let data =
619 Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
620 let duplicates = check_duplicate_points(data.view(), 1e-10).unwrap();
621 assert_eq!(duplicates.len(), 1); assert_eq!(duplicates[0], (0, 1));
623 }
624
625 #[test]
626 fn test_suggest_clustering_algorithm() {
627 let data = Array2::from_shape_vec((100, 5), (0..500).map(|x| x as f64).collect()).unwrap();
628 let suggestion = suggest_clustering_algorithm(data.view(), Some(3)).unwrap();
629 assert!(!suggestion.is_empty());
630 assert!(suggestion.contains("K-means") || suggestion.contains("recommendation"));
631 }
632
633 #[test]
634 fn test_validation_configs() {
635 let kmeans_config = ValidationConfig::for_kmeans();
636 assert_eq!(kmeans_config.min_samples, 1);
637
638 let hierarchical_config = ValidationConfig::for_hierarchical();
639 assert_eq!(hierarchical_config.min_samples, 2);
640
641 let dbscan_config = ValidationConfig::for_dbscan();
642 assert_eq!(dbscan_config.min_samples, 2);
643 }
644}