sklears_preprocessing/
data_quality.rs

1//! Data Quality Validation Framework
2//!
3//! Comprehensive data quality checks for preprocessing pipelines.
4//! Validates data before and after transformations to ensure correctness.
5
6use scirs2_core::ndarray::Array2;
7use sklears_core::prelude::SklearsError;
8use std::collections::HashMap;
9
10/// Data quality report with comprehensive statistics
11#[derive(Debug, Clone)]
12pub struct DataQualityReport {
13    /// Number of samples
14    pub n_samples: usize,
15    /// Number of features
16    pub n_features: usize,
17    /// Missing value statistics per feature
18    pub missing_stats: Vec<MissingStats>,
19    /// Outlier statistics per feature
20    pub outlier_stats: Vec<OutlierStats>,
21    /// Distribution statistics per feature
22    pub distribution_stats: Vec<DistributionStats>,
23    /// Correlation warnings
24    pub correlation_warnings: Vec<CorrelationWarning>,
25    /// Data quality score (0-100)
26    pub quality_score: f64,
27    /// List of detected issues
28    pub issues: Vec<QualityIssue>,
29}
30
31/// Missing value statistics for a feature
32#[derive(Debug, Clone)]
33pub struct MissingStats {
34    pub feature_idx: usize,
35    pub missing_count: usize,
36    pub missing_percentage: f64,
37}
38
39/// Outlier statistics for a feature
40#[derive(Debug, Clone)]
41pub struct OutlierStats {
42    pub feature_idx: usize,
43    pub outlier_count: usize,
44    pub outlier_percentage: f64,
45    pub outlier_indices: Vec<usize>,
46}
47
48/// Distribution statistics for a feature
49#[derive(Debug, Clone)]
50pub struct DistributionStats {
51    pub feature_idx: usize,
52    pub mean: f64,
53    pub std: f64,
54    pub min: f64,
55    pub max: f64,
56    pub median: f64,
57    pub q25: f64,
58    pub q75: f64,
59    pub skewness: f64,
60    pub kurtosis: f64,
61    pub unique_count: usize,
62    pub constant: bool,
63}
64
65/// Correlation warning between features
66#[derive(Debug, Clone)]
67pub struct CorrelationWarning {
68    pub feature_i: usize,
69    pub feature_j: usize,
70    pub correlation: f64,
71}
72
73/// Data quality issue
74#[derive(Debug, Clone)]
75pub struct QualityIssue {
76    pub severity: IssueSeverity,
77    pub category: IssueCategory,
78    pub description: String,
79    pub affected_features: Vec<usize>,
80}
81
82/// Issue severity level
83#[derive(Debug, Clone, PartialEq, Eq)]
84pub enum IssueSeverity {
85    Critical,
86    Warning,
87    Info,
88}
89
90/// Issue category
91#[derive(Debug, Clone, PartialEq, Eq)]
92pub enum IssueCategory {
93    MissingValues,
94    Outliers,
95    ConstantFeatures,
96    HighCorrelation,
97    Duplicates,
98    DataType,
99    Range,
100    Distribution,
101}
102
103/// Configuration for data quality validation
104#[derive(Debug, Clone)]
105pub struct DataQualityConfig {
106    /// Missing value threshold for warnings (percentage)
107    pub missing_threshold: f64,
108    /// Outlier detection method
109    pub outlier_method: OutlierMethod,
110    /// Outlier threshold (std deviations or IQR multiplier)
111    pub outlier_threshold: f64,
112    /// Correlation threshold for warnings
113    pub correlation_threshold: f64,
114    /// Check for duplicate samples
115    pub check_duplicates: bool,
116    /// Check for constant features
117    pub check_constant_features: bool,
118    /// Check for near-constant features (variance threshold)
119    pub near_constant_threshold: f64,
120}
121
122impl Default for DataQualityConfig {
123    fn default() -> Self {
124        Self {
125            missing_threshold: 10.0,
126            outlier_method: OutlierMethod::ZScore,
127            outlier_threshold: 3.0,
128            correlation_threshold: 0.95,
129            check_duplicates: true,
130            check_constant_features: true,
131            near_constant_threshold: 1e-8,
132        }
133    }
134}
135
136/// Outlier detection method
137#[derive(Debug, Clone, PartialEq, Eq)]
138pub enum OutlierMethod {
139    ZScore,
140    IQR,
141    ModifiedZScore,
142}
143
144/// Data quality validator
145pub struct DataQualityValidator {
146    config: DataQualityConfig,
147}
148
149impl DataQualityValidator {
150    /// Create a new validator with default configuration
151    pub fn new() -> Self {
152        Self {
153            config: DataQualityConfig::default(),
154        }
155    }
156
157    /// Create a validator with custom configuration
158    pub fn with_config(config: DataQualityConfig) -> Self {
159        Self { config }
160    }
161
162    /// Validate data and generate quality report
163    pub fn validate(&self, x: &Array2<f64>) -> Result<DataQualityReport, SklearsError> {
164        let n_samples = x.nrows();
165        let n_features = x.ncols();
166
167        if n_samples == 0 || n_features == 0 {
168            return Err(SklearsError::InvalidInput("Empty dataset".to_string()));
169        }
170
171        let mut issues = Vec::new();
172
173        // Compute missing value statistics
174        let missing_stats = self.compute_missing_stats(x);
175        for stat in &missing_stats {
176            if stat.missing_percentage > self.config.missing_threshold {
177                issues.push(QualityIssue {
178                    severity: if stat.missing_percentage > 50.0 {
179                        IssueSeverity::Critical
180                    } else {
181                        IssueSeverity::Warning
182                    },
183                    category: IssueCategory::MissingValues,
184                    description: format!(
185                        "Feature {} has {:.2}% missing values",
186                        stat.feature_idx, stat.missing_percentage
187                    ),
188                    affected_features: vec![stat.feature_idx],
189                });
190            }
191        }
192
193        // Compute outlier statistics
194        let outlier_stats = self.compute_outlier_stats(x);
195        for stat in &outlier_stats {
196            if stat.outlier_percentage > 5.0 {
197                issues.push(QualityIssue {
198                    severity: IssueSeverity::Warning,
199                    category: IssueCategory::Outliers,
200                    description: format!(
201                        "Feature {} has {:.2}% outliers",
202                        stat.feature_idx, stat.outlier_percentage
203                    ),
204                    affected_features: vec![stat.feature_idx],
205                });
206            }
207        }
208
209        // Compute distribution statistics
210        let distribution_stats = self.compute_distribution_stats(x);
211        if self.config.check_constant_features {
212            for stat in &distribution_stats {
213                if stat.constant {
214                    issues.push(QualityIssue {
215                        severity: IssueSeverity::Warning,
216                        category: IssueCategory::ConstantFeatures,
217                        description: format!("Feature {} is constant", stat.feature_idx),
218                        affected_features: vec![stat.feature_idx],
219                    });
220                } else if stat.std < self.config.near_constant_threshold {
221                    issues.push(QualityIssue {
222                        severity: IssueSeverity::Info,
223                        category: IssueCategory::ConstantFeatures,
224                        description: format!(
225                            "Feature {} has very low variance: {}",
226                            stat.feature_idx, stat.std
227                        ),
228                        affected_features: vec![stat.feature_idx],
229                    });
230                }
231            }
232        }
233
234        // Compute correlations
235        let correlation_warnings = self.compute_correlation_warnings(x);
236        for warning in &correlation_warnings {
237            issues.push(QualityIssue {
238                severity: IssueSeverity::Info,
239                category: IssueCategory::HighCorrelation,
240                description: format!(
241                    "Features {} and {} are highly correlated: {:.3}",
242                    warning.feature_i, warning.feature_j, warning.correlation
243                ),
244                affected_features: vec![warning.feature_i, warning.feature_j],
245            });
246        }
247
248        // Check for duplicates
249        if self.config.check_duplicates {
250            let duplicate_count = self.count_duplicate_samples(x);
251            if duplicate_count > 0 {
252                issues.push(QualityIssue {
253                    severity: IssueSeverity::Info,
254                    category: IssueCategory::Duplicates,
255                    description: format!(
256                        "Found {} duplicate samples ({:.2}%)",
257                        duplicate_count,
258                        (duplicate_count as f64 / n_samples as f64) * 100.0
259                    ),
260                    affected_features: vec![],
261                });
262            }
263        }
264
265        // Calculate quality score
266        let quality_score = self.calculate_quality_score(&issues, n_samples, n_features);
267
268        Ok(DataQualityReport {
269            n_samples,
270            n_features,
271            missing_stats,
272            outlier_stats,
273            distribution_stats,
274            correlation_warnings,
275            quality_score,
276            issues,
277        })
278    }
279
280    /// Compute missing value statistics
281    fn compute_missing_stats(&self, x: &Array2<f64>) -> Vec<MissingStats> {
282        let n_samples = x.nrows();
283
284        (0..x.ncols())
285            .map(|j| {
286                let col = x.column(j);
287                let missing_count = col.iter().filter(|v| v.is_nan()).count();
288                let missing_percentage = (missing_count as f64 / n_samples as f64) * 100.0;
289
290                MissingStats {
291                    feature_idx: j,
292                    missing_count,
293                    missing_percentage,
294                }
295            })
296            .collect()
297    }
298
299    /// Compute outlier statistics
300    fn compute_outlier_stats(&self, x: &Array2<f64>) -> Vec<OutlierStats> {
301        (0..x.ncols())
302            .map(|j| {
303                let col = x.column(j);
304                let values: Vec<f64> = col.iter().copied().filter(|v| !v.is_nan()).collect();
305
306                if values.is_empty() {
307                    return OutlierStats {
308                        feature_idx: j,
309                        outlier_count: 0,
310                        outlier_percentage: 0.0,
311                        outlier_indices: vec![],
312                    };
313                }
314
315                let outlier_indices = match self.config.outlier_method {
316                    OutlierMethod::ZScore => self.detect_outliers_zscore(&values, j, x.nrows()),
317                    OutlierMethod::IQR => self.detect_outliers_iqr(&values, j, x.nrows()),
318                    OutlierMethod::ModifiedZScore => {
319                        self.detect_outliers_modified_zscore(&values, j, x.nrows())
320                    }
321                };
322
323                let outlier_count = outlier_indices.len();
324                let outlier_percentage = (outlier_count as f64 / x.nrows() as f64) * 100.0;
325
326                OutlierStats {
327                    feature_idx: j,
328                    outlier_count,
329                    outlier_percentage,
330                    outlier_indices,
331                }
332            })
333            .collect()
334    }
335
336    /// Detect outliers using Z-score method
337    fn detect_outliers_zscore(&self, values: &[f64], _col_idx: usize, n_rows: usize) -> Vec<usize> {
338        let mean = values.iter().sum::<f64>() / values.len() as f64;
339        let variance = values.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / values.len() as f64;
340        let std = variance.sqrt();
341
342        if std < 1e-10 {
343            return vec![];
344        }
345
346        let mut outliers = Vec::new();
347        let mut value_idx = 0;
348
349        for i in 0..n_rows {
350            if value_idx < values.len() {
351                let z_score = (values[value_idx] - mean).abs() / std;
352                if z_score > self.config.outlier_threshold {
353                    outliers.push(i);
354                }
355                value_idx += 1;
356            }
357        }
358
359        outliers
360    }
361
362    /// Detect outliers using IQR method
363    fn detect_outliers_iqr(&self, values: &[f64], _col_idx: usize, n_rows: usize) -> Vec<usize> {
364        let mut sorted = values.to_vec();
365        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
366
367        let q1_idx = sorted.len() / 4;
368        let q3_idx = (sorted.len() * 3) / 4;
369        let q1 = sorted[q1_idx];
370        let q3 = sorted[q3_idx];
371        let iqr = q3 - q1;
372
373        if iqr < 1e-10 {
374            return vec![];
375        }
376
377        let lower_bound = q1 - self.config.outlier_threshold * iqr;
378        let upper_bound = q3 + self.config.outlier_threshold * iqr;
379
380        let mut outliers = Vec::new();
381        let mut value_idx = 0;
382
383        for i in 0..n_rows {
384            if value_idx < values.len() {
385                let val = values[value_idx];
386                if val < lower_bound || val > upper_bound {
387                    outliers.push(i);
388                }
389                value_idx += 1;
390            }
391        }
392
393        outliers
394    }
395
396    /// Detect outliers using Modified Z-score method
397    fn detect_outliers_modified_zscore(
398        &self,
399        values: &[f64],
400        _col_idx: usize,
401        n_rows: usize,
402    ) -> Vec<usize> {
403        let mut sorted = values.to_vec();
404        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
405
406        let median = sorted[sorted.len() / 2];
407        let mad = {
408            let deviations: Vec<f64> = sorted.iter().map(|v| (v - median).abs()).collect();
409            let mut dev_sorted = deviations.clone();
410            dev_sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
411            dev_sorted[dev_sorted.len() / 2]
412        };
413
414        if mad < 1e-10 {
415            return vec![];
416        }
417
418        let mut outliers = Vec::new();
419        let mut value_idx = 0;
420
421        for i in 0..n_rows {
422            if value_idx < values.len() {
423                let modified_z = 0.6745 * (values[value_idx] - median).abs() / mad;
424                if modified_z > self.config.outlier_threshold {
425                    outliers.push(i);
426                }
427                value_idx += 1;
428            }
429        }
430
431        outliers
432    }
433
434    /// Compute distribution statistics
435    fn compute_distribution_stats(&self, x: &Array2<f64>) -> Vec<DistributionStats> {
436        (0..x.ncols())
437            .map(|j| {
438                let col = x.column(j);
439                let values: Vec<f64> = col.iter().copied().filter(|v| !v.is_nan()).collect();
440
441                if values.is_empty() {
442                    return DistributionStats {
443                        feature_idx: j,
444                        mean: f64::NAN,
445                        std: f64::NAN,
446                        min: f64::NAN,
447                        max: f64::NAN,
448                        median: f64::NAN,
449                        q25: f64::NAN,
450                        q75: f64::NAN,
451                        skewness: f64::NAN,
452                        kurtosis: f64::NAN,
453                        unique_count: 0,
454                        constant: true,
455                    };
456                }
457
458                let mean = values.iter().sum::<f64>() / values.len() as f64;
459                let variance =
460                    values.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / values.len() as f64;
461                let std = variance.sqrt();
462
463                let mut sorted = values.clone();
464                sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
465
466                let min = sorted.first().copied().unwrap_or(f64::NAN);
467                let max = sorted.last().copied().unwrap_or(f64::NAN);
468                let median = sorted[sorted.len() / 2];
469                let q25 = sorted[sorted.len() / 4];
470                let q75 = sorted[(sorted.len() * 3) / 4];
471
472                let skewness = if std > 1e-10 {
473                    values
474                        .iter()
475                        .map(|v| ((v - mean) / std).powi(3))
476                        .sum::<f64>()
477                        / values.len() as f64
478                } else {
479                    0.0
480                };
481
482                let kurtosis = if std > 1e-10 {
483                    values
484                        .iter()
485                        .map(|v| ((v - mean) / std).powi(4))
486                        .sum::<f64>()
487                        / values.len() as f64
488                        - 3.0
489                } else {
490                    0.0
491                };
492
493                // Count unique values (with epsilon comparison for floats)
494                let mut unique_values = Vec::new();
495                for &v in &values {
496                    if !unique_values.iter().any(|&uv: &f64| (uv - v).abs() < 1e-10) {
497                        unique_values.push(v);
498                    }
499                }
500                let unique_count = unique_values.len();
501
502                let constant = std < 1e-10;
503
504                DistributionStats {
505                    feature_idx: j,
506                    mean,
507                    std,
508                    min,
509                    max,
510                    median,
511                    q25,
512                    q75,
513                    skewness,
514                    kurtosis,
515                    unique_count,
516                    constant,
517                }
518            })
519            .collect()
520    }
521
522    /// Compute correlation warnings
523    fn compute_correlation_warnings(&self, x: &Array2<f64>) -> Vec<CorrelationWarning> {
524        let mut warnings = Vec::new();
525
526        for i in 0..x.ncols() {
527            for j in (i + 1)..x.ncols() {
528                let corr = self.compute_correlation(x, i, j);
529                if corr.abs() > self.config.correlation_threshold {
530                    warnings.push(CorrelationWarning {
531                        feature_i: i,
532                        feature_j: j,
533                        correlation: corr,
534                    });
535                }
536            }
537        }
538
539        warnings
540    }
541
542    /// Compute Pearson correlation between two features
543    fn compute_correlation(&self, x: &Array2<f64>, i: usize, j: usize) -> f64 {
544        let col_i = x.column(i);
545        let col_j = x.column(j);
546
547        let pairs: Vec<(f64, f64)> = col_i
548            .iter()
549            .zip(col_j.iter())
550            .filter(|(a, b)| !a.is_nan() && !b.is_nan())
551            .map(|(&a, &b)| (a, b))
552            .collect();
553
554        if pairs.len() < 2 {
555            return 0.0;
556        }
557
558        let mean_i = pairs.iter().map(|(a, _)| a).sum::<f64>() / pairs.len() as f64;
559        let mean_j = pairs.iter().map(|(_, b)| b).sum::<f64>() / pairs.len() as f64;
560
561        let mut cov = 0.0;
562        let mut var_i = 0.0;
563        let mut var_j = 0.0;
564
565        for (a, b) in &pairs {
566            let di = a - mean_i;
567            let dj = b - mean_j;
568            cov += di * dj;
569            var_i += di * di;
570            var_j += dj * dj;
571        }
572
573        if var_i < 1e-10 || var_j < 1e-10 {
574            return 0.0;
575        }
576
577        cov / (var_i * var_j).sqrt()
578    }
579
580    /// Count duplicate samples
581    fn count_duplicate_samples(&self, x: &Array2<f64>) -> usize {
582        let mut seen = HashMap::new();
583        let mut duplicates = 0;
584
585        for i in 0..x.nrows() {
586            let row: Vec<_> = x.row(i).iter().copied().collect();
587            *seen.entry(format!("{:?}", row)).or_insert(0) += 1;
588        }
589
590        for count in seen.values() {
591            if *count > 1 {
592                duplicates += count - 1;
593            }
594        }
595
596        duplicates
597    }
598
599    /// Calculate overall quality score (0-100)
600    fn calculate_quality_score(
601        &self,
602        issues: &[QualityIssue],
603        _n_samples: usize,
604        _n_features: usize,
605    ) -> f64 {
606        let mut score: f64 = 100.0;
607
608        for issue in issues {
609            let penalty: f64 = match issue.severity {
610                IssueSeverity::Critical => 20.0,
611                IssueSeverity::Warning => 10.0,
612                IssueSeverity::Info => 2.0,
613            };
614            score -= penalty;
615        }
616
617        score.max(0.0)
618    }
619}
620
621impl Default for DataQualityValidator {
622    fn default() -> Self {
623        Self::new()
624    }
625}
626
627impl DataQualityReport {
628    /// Print a human-readable summary of the report
629    pub fn print_summary(&self) {
630        println!("Data Quality Report");
631        println!("==================");
632        println!("Samples: {}, Features: {}", self.n_samples, self.n_features);
633        println!("Quality Score: {:.1}/100", self.quality_score);
634        println!();
635
636        if !self.issues.is_empty() {
637            println!("Issues Found: {}", self.issues.len());
638            println!();
639
640            for issue in &self.issues {
641                let severity_str = match issue.severity {
642                    IssueSeverity::Critical => "CRITICAL",
643                    IssueSeverity::Warning => "WARNING",
644                    IssueSeverity::Info => "INFO",
645                };
646                println!("[{}] {}", severity_str, issue.description);
647            }
648        } else {
649            println!("No issues detected!");
650        }
651    }
652
653    /// Get issues by severity
654    pub fn issues_by_severity(&self, severity: IssueSeverity) -> Vec<&QualityIssue> {
655        self.issues
656            .iter()
657            .filter(|issue| issue.severity == severity)
658            .collect()
659    }
660
661    /// Get issues by category
662    pub fn issues_by_category(&self, category: IssueCategory) -> Vec<&QualityIssue> {
663        self.issues
664            .iter()
665            .filter(|issue| issue.category == category)
666            .collect()
667    }
668}
669
670#[cfg(test)]
671mod tests {
672    use super::*;
673    use scirs2_core::random::essentials::Normal;
674    use scirs2_core::random::{seeded_rng, Distribution};
675
676    fn generate_test_data(nrows: usize, ncols: usize, seed: u64) -> Array2<f64> {
677        let mut rng = seeded_rng(seed);
678        let normal = Normal::new(0.0, 1.0).unwrap();
679
680        let data: Vec<f64> = (0..nrows * ncols)
681            .map(|_| normal.sample(&mut rng))
682            .collect();
683
684        Array2::from_shape_vec((nrows, ncols), data).unwrap()
685    }
686
687    #[test]
688    fn test_data_quality_validator_basic() {
689        let x = generate_test_data(100, 5, 42);
690        let validator = DataQualityValidator::new();
691        let report = validator.validate(&x).unwrap();
692
693        assert_eq!(report.n_samples, 100);
694        assert_eq!(report.n_features, 5);
695        assert!(report.quality_score > 0.0);
696    }
697
698    #[test]
699    fn test_missing_value_detection() {
700        let mut x = generate_test_data(100, 3, 123);
701
702        // Add missing values
703        for i in 0..20 {
704            x[[i, 0]] = f64::NAN;
705        }
706
707        let validator = DataQualityValidator::new();
708        let report = validator.validate(&x).unwrap();
709
710        let missing_in_col0 = &report.missing_stats[0];
711        assert_eq!(missing_in_col0.missing_count, 20);
712        assert!((missing_in_col0.missing_percentage - 20.0).abs() < 0.1);
713    }
714
715    #[test]
716    fn test_constant_feature_detection() {
717        let mut x = generate_test_data(50, 3, 456);
718
719        // Make one feature constant
720        for i in 0..x.nrows() {
721            x[[i, 1]] = 5.0;
722        }
723
724        let validator = DataQualityValidator::new();
725        let report = validator.validate(&x).unwrap();
726
727        let constant_issues: Vec<_> = report.issues_by_category(IssueCategory::ConstantFeatures);
728
729        assert!(!constant_issues.is_empty());
730    }
731
732    #[test]
733    fn test_outlier_detection() {
734        let mut x = generate_test_data(100, 2, 789);
735
736        // Add outliers
737        x[[0, 0]] = 100.0;
738        x[[1, 0]] = -100.0;
739
740        let validator = DataQualityValidator::new();
741        let report = validator.validate(&x).unwrap();
742
743        let outliers_in_col0 = &report.outlier_stats[0];
744        assert!(outliers_in_col0.outlier_count > 0);
745    }
746
747    #[test]
748    fn test_quality_score_calculation() {
749        let x = generate_test_data(100, 5, 321);
750        let validator = DataQualityValidator::new();
751        let report = validator.validate(&x).unwrap();
752
753        // Clean data should have high quality score
754        assert!(report.quality_score > 80.0);
755    }
756}