1use scirs2_core::ndarray::Array2;
7use sklears_core::prelude::SklearsError;
8use std::collections::HashMap;
9
10#[derive(Debug, Clone)]
12pub struct DataQualityReport {
13 pub n_samples: usize,
15 pub n_features: usize,
17 pub missing_stats: Vec<MissingStats>,
19 pub outlier_stats: Vec<OutlierStats>,
21 pub distribution_stats: Vec<DistributionStats>,
23 pub correlation_warnings: Vec<CorrelationWarning>,
25 pub quality_score: f64,
27 pub issues: Vec<QualityIssue>,
29}
30
31#[derive(Debug, Clone)]
33pub struct MissingStats {
34 pub feature_idx: usize,
35 pub missing_count: usize,
36 pub missing_percentage: f64,
37}
38
39#[derive(Debug, Clone)]
41pub struct OutlierStats {
42 pub feature_idx: usize,
43 pub outlier_count: usize,
44 pub outlier_percentage: f64,
45 pub outlier_indices: Vec<usize>,
46}
47
48#[derive(Debug, Clone)]
50pub struct DistributionStats {
51 pub feature_idx: usize,
52 pub mean: f64,
53 pub std: f64,
54 pub min: f64,
55 pub max: f64,
56 pub median: f64,
57 pub q25: f64,
58 pub q75: f64,
59 pub skewness: f64,
60 pub kurtosis: f64,
61 pub unique_count: usize,
62 pub constant: bool,
63}
64
65#[derive(Debug, Clone)]
67pub struct CorrelationWarning {
68 pub feature_i: usize,
69 pub feature_j: usize,
70 pub correlation: f64,
71}
72
73#[derive(Debug, Clone)]
75pub struct QualityIssue {
76 pub severity: IssueSeverity,
77 pub category: IssueCategory,
78 pub description: String,
79 pub affected_features: Vec<usize>,
80}
81
82#[derive(Debug, Clone, PartialEq, Eq)]
84pub enum IssueSeverity {
85 Critical,
86 Warning,
87 Info,
88}
89
90#[derive(Debug, Clone, PartialEq, Eq)]
92pub enum IssueCategory {
93 MissingValues,
94 Outliers,
95 ConstantFeatures,
96 HighCorrelation,
97 Duplicates,
98 DataType,
99 Range,
100 Distribution,
101}
102
103#[derive(Debug, Clone)]
105pub struct DataQualityConfig {
106 pub missing_threshold: f64,
108 pub outlier_method: OutlierMethod,
110 pub outlier_threshold: f64,
112 pub correlation_threshold: f64,
114 pub check_duplicates: bool,
116 pub check_constant_features: bool,
118 pub near_constant_threshold: f64,
120}
121
122impl Default for DataQualityConfig {
123 fn default() -> Self {
124 Self {
125 missing_threshold: 10.0,
126 outlier_method: OutlierMethod::ZScore,
127 outlier_threshold: 3.0,
128 correlation_threshold: 0.95,
129 check_duplicates: true,
130 check_constant_features: true,
131 near_constant_threshold: 1e-8,
132 }
133 }
134}
135
136#[derive(Debug, Clone, PartialEq, Eq)]
138pub enum OutlierMethod {
139 ZScore,
140 IQR,
141 ModifiedZScore,
142}
143
144pub struct DataQualityValidator {
146 config: DataQualityConfig,
147}
148
149impl DataQualityValidator {
150 pub fn new() -> Self {
152 Self {
153 config: DataQualityConfig::default(),
154 }
155 }
156
157 pub fn with_config(config: DataQualityConfig) -> Self {
159 Self { config }
160 }
161
162 pub fn validate(&self, x: &Array2<f64>) -> Result<DataQualityReport, SklearsError> {
164 let n_samples = x.nrows();
165 let n_features = x.ncols();
166
167 if n_samples == 0 || n_features == 0 {
168 return Err(SklearsError::InvalidInput("Empty dataset".to_string()));
169 }
170
171 let mut issues = Vec::new();
172
173 let missing_stats = self.compute_missing_stats(x);
175 for stat in &missing_stats {
176 if stat.missing_percentage > self.config.missing_threshold {
177 issues.push(QualityIssue {
178 severity: if stat.missing_percentage > 50.0 {
179 IssueSeverity::Critical
180 } else {
181 IssueSeverity::Warning
182 },
183 category: IssueCategory::MissingValues,
184 description: format!(
185 "Feature {} has {:.2}% missing values",
186 stat.feature_idx, stat.missing_percentage
187 ),
188 affected_features: vec![stat.feature_idx],
189 });
190 }
191 }
192
193 let outlier_stats = self.compute_outlier_stats(x);
195 for stat in &outlier_stats {
196 if stat.outlier_percentage > 5.0 {
197 issues.push(QualityIssue {
198 severity: IssueSeverity::Warning,
199 category: IssueCategory::Outliers,
200 description: format!(
201 "Feature {} has {:.2}% outliers",
202 stat.feature_idx, stat.outlier_percentage
203 ),
204 affected_features: vec![stat.feature_idx],
205 });
206 }
207 }
208
209 let distribution_stats = self.compute_distribution_stats(x);
211 if self.config.check_constant_features {
212 for stat in &distribution_stats {
213 if stat.constant {
214 issues.push(QualityIssue {
215 severity: IssueSeverity::Warning,
216 category: IssueCategory::ConstantFeatures,
217 description: format!("Feature {} is constant", stat.feature_idx),
218 affected_features: vec![stat.feature_idx],
219 });
220 } else if stat.std < self.config.near_constant_threshold {
221 issues.push(QualityIssue {
222 severity: IssueSeverity::Info,
223 category: IssueCategory::ConstantFeatures,
224 description: format!(
225 "Feature {} has very low variance: {}",
226 stat.feature_idx, stat.std
227 ),
228 affected_features: vec![stat.feature_idx],
229 });
230 }
231 }
232 }
233
234 let correlation_warnings = self.compute_correlation_warnings(x);
236 for warning in &correlation_warnings {
237 issues.push(QualityIssue {
238 severity: IssueSeverity::Info,
239 category: IssueCategory::HighCorrelation,
240 description: format!(
241 "Features {} and {} are highly correlated: {:.3}",
242 warning.feature_i, warning.feature_j, warning.correlation
243 ),
244 affected_features: vec![warning.feature_i, warning.feature_j],
245 });
246 }
247
248 if self.config.check_duplicates {
250 let duplicate_count = self.count_duplicate_samples(x);
251 if duplicate_count > 0 {
252 issues.push(QualityIssue {
253 severity: IssueSeverity::Info,
254 category: IssueCategory::Duplicates,
255 description: format!(
256 "Found {} duplicate samples ({:.2}%)",
257 duplicate_count,
258 (duplicate_count as f64 / n_samples as f64) * 100.0
259 ),
260 affected_features: vec![],
261 });
262 }
263 }
264
265 let quality_score = self.calculate_quality_score(&issues, n_samples, n_features);
267
268 Ok(DataQualityReport {
269 n_samples,
270 n_features,
271 missing_stats,
272 outlier_stats,
273 distribution_stats,
274 correlation_warnings,
275 quality_score,
276 issues,
277 })
278 }
279
280 fn compute_missing_stats(&self, x: &Array2<f64>) -> Vec<MissingStats> {
282 let n_samples = x.nrows();
283
284 (0..x.ncols())
285 .map(|j| {
286 let col = x.column(j);
287 let missing_count = col.iter().filter(|v| v.is_nan()).count();
288 let missing_percentage = (missing_count as f64 / n_samples as f64) * 100.0;
289
290 MissingStats {
291 feature_idx: j,
292 missing_count,
293 missing_percentage,
294 }
295 })
296 .collect()
297 }
298
299 fn compute_outlier_stats(&self, x: &Array2<f64>) -> Vec<OutlierStats> {
301 (0..x.ncols())
302 .map(|j| {
303 let col = x.column(j);
304 let values: Vec<f64> = col.iter().copied().filter(|v| !v.is_nan()).collect();
305
306 if values.is_empty() {
307 return OutlierStats {
308 feature_idx: j,
309 outlier_count: 0,
310 outlier_percentage: 0.0,
311 outlier_indices: vec![],
312 };
313 }
314
315 let outlier_indices = match self.config.outlier_method {
316 OutlierMethod::ZScore => self.detect_outliers_zscore(&values, j, x.nrows()),
317 OutlierMethod::IQR => self.detect_outliers_iqr(&values, j, x.nrows()),
318 OutlierMethod::ModifiedZScore => {
319 self.detect_outliers_modified_zscore(&values, j, x.nrows())
320 }
321 };
322
323 let outlier_count = outlier_indices.len();
324 let outlier_percentage = (outlier_count as f64 / x.nrows() as f64) * 100.0;
325
326 OutlierStats {
327 feature_idx: j,
328 outlier_count,
329 outlier_percentage,
330 outlier_indices,
331 }
332 })
333 .collect()
334 }
335
336 fn detect_outliers_zscore(&self, values: &[f64], _col_idx: usize, n_rows: usize) -> Vec<usize> {
338 let mean = values.iter().sum::<f64>() / values.len() as f64;
339 let variance = values.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / values.len() as f64;
340 let std = variance.sqrt();
341
342 if std < 1e-10 {
343 return vec![];
344 }
345
346 let mut outliers = Vec::new();
347 let mut value_idx = 0;
348
349 for i in 0..n_rows {
350 if value_idx < values.len() {
351 let z_score = (values[value_idx] - mean).abs() / std;
352 if z_score > self.config.outlier_threshold {
353 outliers.push(i);
354 }
355 value_idx += 1;
356 }
357 }
358
359 outliers
360 }
361
362 fn detect_outliers_iqr(&self, values: &[f64], _col_idx: usize, n_rows: usize) -> Vec<usize> {
364 let mut sorted = values.to_vec();
365 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
366
367 let q1_idx = sorted.len() / 4;
368 let q3_idx = (sorted.len() * 3) / 4;
369 let q1 = sorted[q1_idx];
370 let q3 = sorted[q3_idx];
371 let iqr = q3 - q1;
372
373 if iqr < 1e-10 {
374 return vec![];
375 }
376
377 let lower_bound = q1 - self.config.outlier_threshold * iqr;
378 let upper_bound = q3 + self.config.outlier_threshold * iqr;
379
380 let mut outliers = Vec::new();
381 let mut value_idx = 0;
382
383 for i in 0..n_rows {
384 if value_idx < values.len() {
385 let val = values[value_idx];
386 if val < lower_bound || val > upper_bound {
387 outliers.push(i);
388 }
389 value_idx += 1;
390 }
391 }
392
393 outliers
394 }
395
396 fn detect_outliers_modified_zscore(
398 &self,
399 values: &[f64],
400 _col_idx: usize,
401 n_rows: usize,
402 ) -> Vec<usize> {
403 let mut sorted = values.to_vec();
404 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
405
406 let median = sorted[sorted.len() / 2];
407 let mad = {
408 let deviations: Vec<f64> = sorted.iter().map(|v| (v - median).abs()).collect();
409 let mut dev_sorted = deviations.clone();
410 dev_sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
411 dev_sorted[dev_sorted.len() / 2]
412 };
413
414 if mad < 1e-10 {
415 return vec![];
416 }
417
418 let mut outliers = Vec::new();
419 let mut value_idx = 0;
420
421 for i in 0..n_rows {
422 if value_idx < values.len() {
423 let modified_z = 0.6745 * (values[value_idx] - median).abs() / mad;
424 if modified_z > self.config.outlier_threshold {
425 outliers.push(i);
426 }
427 value_idx += 1;
428 }
429 }
430
431 outliers
432 }
433
434 fn compute_distribution_stats(&self, x: &Array2<f64>) -> Vec<DistributionStats> {
436 (0..x.ncols())
437 .map(|j| {
438 let col = x.column(j);
439 let values: Vec<f64> = col.iter().copied().filter(|v| !v.is_nan()).collect();
440
441 if values.is_empty() {
442 return DistributionStats {
443 feature_idx: j,
444 mean: f64::NAN,
445 std: f64::NAN,
446 min: f64::NAN,
447 max: f64::NAN,
448 median: f64::NAN,
449 q25: f64::NAN,
450 q75: f64::NAN,
451 skewness: f64::NAN,
452 kurtosis: f64::NAN,
453 unique_count: 0,
454 constant: true,
455 };
456 }
457
458 let mean = values.iter().sum::<f64>() / values.len() as f64;
459 let variance =
460 values.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / values.len() as f64;
461 let std = variance.sqrt();
462
463 let mut sorted = values.clone();
464 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
465
466 let min = sorted.first().copied().unwrap_or(f64::NAN);
467 let max = sorted.last().copied().unwrap_or(f64::NAN);
468 let median = sorted[sorted.len() / 2];
469 let q25 = sorted[sorted.len() / 4];
470 let q75 = sorted[(sorted.len() * 3) / 4];
471
472 let skewness = if std > 1e-10 {
473 values
474 .iter()
475 .map(|v| ((v - mean) / std).powi(3))
476 .sum::<f64>()
477 / values.len() as f64
478 } else {
479 0.0
480 };
481
482 let kurtosis = if std > 1e-10 {
483 values
484 .iter()
485 .map(|v| ((v - mean) / std).powi(4))
486 .sum::<f64>()
487 / values.len() as f64
488 - 3.0
489 } else {
490 0.0
491 };
492
493 let mut unique_values = Vec::new();
495 for &v in &values {
496 if !unique_values.iter().any(|&uv: &f64| (uv - v).abs() < 1e-10) {
497 unique_values.push(v);
498 }
499 }
500 let unique_count = unique_values.len();
501
502 let constant = std < 1e-10;
503
504 DistributionStats {
505 feature_idx: j,
506 mean,
507 std,
508 min,
509 max,
510 median,
511 q25,
512 q75,
513 skewness,
514 kurtosis,
515 unique_count,
516 constant,
517 }
518 })
519 .collect()
520 }
521
522 fn compute_correlation_warnings(&self, x: &Array2<f64>) -> Vec<CorrelationWarning> {
524 let mut warnings = Vec::new();
525
526 for i in 0..x.ncols() {
527 for j in (i + 1)..x.ncols() {
528 let corr = self.compute_correlation(x, i, j);
529 if corr.abs() > self.config.correlation_threshold {
530 warnings.push(CorrelationWarning {
531 feature_i: i,
532 feature_j: j,
533 correlation: corr,
534 });
535 }
536 }
537 }
538
539 warnings
540 }
541
542 fn compute_correlation(&self, x: &Array2<f64>, i: usize, j: usize) -> f64 {
544 let col_i = x.column(i);
545 let col_j = x.column(j);
546
547 let pairs: Vec<(f64, f64)> = col_i
548 .iter()
549 .zip(col_j.iter())
550 .filter(|(a, b)| !a.is_nan() && !b.is_nan())
551 .map(|(&a, &b)| (a, b))
552 .collect();
553
554 if pairs.len() < 2 {
555 return 0.0;
556 }
557
558 let mean_i = pairs.iter().map(|(a, _)| a).sum::<f64>() / pairs.len() as f64;
559 let mean_j = pairs.iter().map(|(_, b)| b).sum::<f64>() / pairs.len() as f64;
560
561 let mut cov = 0.0;
562 let mut var_i = 0.0;
563 let mut var_j = 0.0;
564
565 for (a, b) in &pairs {
566 let di = a - mean_i;
567 let dj = b - mean_j;
568 cov += di * dj;
569 var_i += di * di;
570 var_j += dj * dj;
571 }
572
573 if var_i < 1e-10 || var_j < 1e-10 {
574 return 0.0;
575 }
576
577 cov / (var_i * var_j).sqrt()
578 }
579
580 fn count_duplicate_samples(&self, x: &Array2<f64>) -> usize {
582 let mut seen = HashMap::new();
583 let mut duplicates = 0;
584
585 for i in 0..x.nrows() {
586 let row: Vec<_> = x.row(i).iter().copied().collect();
587 *seen.entry(format!("{:?}", row)).or_insert(0) += 1;
588 }
589
590 for count in seen.values() {
591 if *count > 1 {
592 duplicates += count - 1;
593 }
594 }
595
596 duplicates
597 }
598
599 fn calculate_quality_score(
601 &self,
602 issues: &[QualityIssue],
603 _n_samples: usize,
604 _n_features: usize,
605 ) -> f64 {
606 let mut score: f64 = 100.0;
607
608 for issue in issues {
609 let penalty: f64 = match issue.severity {
610 IssueSeverity::Critical => 20.0,
611 IssueSeverity::Warning => 10.0,
612 IssueSeverity::Info => 2.0,
613 };
614 score -= penalty;
615 }
616
617 score.max(0.0)
618 }
619}
620
621impl Default for DataQualityValidator {
622 fn default() -> Self {
623 Self::new()
624 }
625}
626
627impl DataQualityReport {
628 pub fn print_summary(&self) {
630 println!("Data Quality Report");
631 println!("==================");
632 println!("Samples: {}, Features: {}", self.n_samples, self.n_features);
633 println!("Quality Score: {:.1}/100", self.quality_score);
634 println!();
635
636 if !self.issues.is_empty() {
637 println!("Issues Found: {}", self.issues.len());
638 println!();
639
640 for issue in &self.issues {
641 let severity_str = match issue.severity {
642 IssueSeverity::Critical => "CRITICAL",
643 IssueSeverity::Warning => "WARNING",
644 IssueSeverity::Info => "INFO",
645 };
646 println!("[{}] {}", severity_str, issue.description);
647 }
648 } else {
649 println!("No issues detected!");
650 }
651 }
652
653 pub fn issues_by_severity(&self, severity: IssueSeverity) -> Vec<&QualityIssue> {
655 self.issues
656 .iter()
657 .filter(|issue| issue.severity == severity)
658 .collect()
659 }
660
661 pub fn issues_by_category(&self, category: IssueCategory) -> Vec<&QualityIssue> {
663 self.issues
664 .iter()
665 .filter(|issue| issue.category == category)
666 .collect()
667 }
668}
669
670#[cfg(test)]
671mod tests {
672 use super::*;
673 use scirs2_core::random::essentials::Normal;
674 use scirs2_core::random::{seeded_rng, Distribution};
675
676 fn generate_test_data(nrows: usize, ncols: usize, seed: u64) -> Array2<f64> {
677 let mut rng = seeded_rng(seed);
678 let normal = Normal::new(0.0, 1.0).unwrap();
679
680 let data: Vec<f64> = (0..nrows * ncols)
681 .map(|_| normal.sample(&mut rng))
682 .collect();
683
684 Array2::from_shape_vec((nrows, ncols), data).unwrap()
685 }
686
687 #[test]
688 fn test_data_quality_validator_basic() {
689 let x = generate_test_data(100, 5, 42);
690 let validator = DataQualityValidator::new();
691 let report = validator.validate(&x).unwrap();
692
693 assert_eq!(report.n_samples, 100);
694 assert_eq!(report.n_features, 5);
695 assert!(report.quality_score > 0.0);
696 }
697
698 #[test]
699 fn test_missing_value_detection() {
700 let mut x = generate_test_data(100, 3, 123);
701
702 for i in 0..20 {
704 x[[i, 0]] = f64::NAN;
705 }
706
707 let validator = DataQualityValidator::new();
708 let report = validator.validate(&x).unwrap();
709
710 let missing_in_col0 = &report.missing_stats[0];
711 assert_eq!(missing_in_col0.missing_count, 20);
712 assert!((missing_in_col0.missing_percentage - 20.0).abs() < 0.1);
713 }
714
715 #[test]
716 fn test_constant_feature_detection() {
717 let mut x = generate_test_data(50, 3, 456);
718
719 for i in 0..x.nrows() {
721 x[[i, 1]] = 5.0;
722 }
723
724 let validator = DataQualityValidator::new();
725 let report = validator.validate(&x).unwrap();
726
727 let constant_issues: Vec<_> = report.issues_by_category(IssueCategory::ConstantFeatures);
728
729 assert!(!constant_issues.is_empty());
730 }
731
732 #[test]
733 fn test_outlier_detection() {
734 let mut x = generate_test_data(100, 2, 789);
735
736 x[[0, 0]] = 100.0;
738 x[[1, 0]] = -100.0;
739
740 let validator = DataQualityValidator::new();
741 let report = validator.validate(&x).unwrap();
742
743 let outliers_in_col0 = &report.outlier_stats[0];
744 assert!(outliers_in_col0.outlier_count > 0);
745 }
746
747 #[test]
748 fn test_quality_score_calculation() {
749 let x = generate_test_data(100, 5, 321);
750 let validator = DataQualityValidator::new();
751 let report = validator.validate(&x).unwrap();
752
753 assert!(report.quality_score > 80.0);
755 }
756}