Skip to main content

tenflowers_dataset/
data_quality.rs

1//! Data quality metrics and drift detection
2//!
3//! This module provides comprehensive data quality assessment including:
4//! - Statistical quality metrics (completeness, validity, consistency)
5//! - Data drift detection (distribution shift, concept drift)
6//! - Outlier detection and anomaly analysis
7//! - Data profiling and summary statistics
8
9use crate::Dataset;
10use std::collections::HashMap;
11use tenflowers_core::{Result, Tensor, TensorError};
12
13/// Comprehensive data quality metrics
14#[derive(Debug, Clone)]
15pub struct DataQualityMetrics {
16    /// Dataset name or identifier
17    pub dataset_name: String,
18    /// Total number of samples
19    pub total_samples: usize,
20    /// Completeness metrics (percentage of non-null values)
21    pub completeness: HashMap<String, f64>,
22    /// Validity metrics (percentage of values passing validation rules)
23    pub validity: HashMap<String, f64>,
24    /// Consistency metrics (internal consistency checks)
25    pub consistency_score: f64,
26    /// Timeliness (data freshness indicators)
27    pub timeliness_score: Option<f64>,
28    /// Accuracy estimates
29    pub accuracy_estimates: HashMap<String, f64>,
30    /// Uniqueness metrics (duplicate detection)
31    pub uniqueness_score: f64,
32    /// Overall quality score (0.0 to 1.0)
33    pub overall_quality_score: f64,
34    /// Detected issues
35    pub issues: Vec<DataQualityIssue>,
36}
37
38/// Data quality issue description
39#[derive(Debug, Clone)]
40pub struct DataQualityIssue {
41    /// Issue severity
42    pub severity: IssueSeverity,
43    /// Issue category
44    pub category: IssueCategory,
45    /// Description of the issue
46    pub description: String,
47    /// Affected fields/columns
48    pub affected_fields: Vec<String>,
49    /// Number of affected samples
50    pub affected_count: usize,
51    /// Suggested remediation
52    pub remediation: Option<String>,
53}
54
55/// Issue severity levels
56#[derive(Debug, Clone, Copy, PartialEq, Eq)]
57pub enum IssueSeverity {
58    /// Critical issue requiring immediate attention
59    Critical,
60    /// High priority issue
61    High,
62    /// Medium priority issue
63    Medium,
64    /// Low priority issue
65    Low,
66    /// Informational only
67    Info,
68}
69
70/// Issue categories
71#[derive(Debug, Clone, Copy, PartialEq, Eq)]
72pub enum IssueCategory {
73    /// Missing data issues
74    Completeness,
75    /// Invalid values
76    Validity,
77    /// Inconsistent data
78    Consistency,
79    /// Duplicate data
80    Uniqueness,
81    /// Accuracy issues
82    Accuracy,
83    /// Timeliness issues
84    Timeliness,
85    /// Statistical anomalies
86    StatisticalAnomaly,
87}
88
89/// Data drift detection configuration
90#[derive(Debug, Clone)]
91pub struct DriftDetectionConfig {
92    /// Reference dataset size (for baseline)
93    pub reference_window_size: usize,
94    /// Detection threshold (0.0 to 1.0)
95    pub detection_threshold: f64,
96    /// Statistical test to use
97    pub statistical_test: StatisticalTest,
98    /// Minimum number of samples for detection
99    pub min_samples: usize,
100    /// Enable distribution visualization
101    pub enable_visualization: bool,
102}
103
104impl Default for DriftDetectionConfig {
105    fn default() -> Self {
106        Self {
107            reference_window_size: 1000,
108            detection_threshold: 0.05, // 5% significance level
109            statistical_test: StatisticalTest::KolmogorovSmirnov,
110            min_samples: 100,
111            enable_visualization: false,
112        }
113    }
114}
115
116/// Statistical tests for drift detection
117#[derive(Debug, Clone, Copy, PartialEq, Eq)]
118pub enum StatisticalTest {
119    /// Kolmogorov-Smirnov test
120    KolmogorovSmirnov,
121    /// Chi-squared test
122    ChiSquared,
123    /// Population Stability Index (PSI)
124    PopulationStabilityIndex,
125    /// Kullback-Leibler divergence
126    KullbackLeibler,
127    /// Jensen-Shannon divergence
128    JensenShannon,
129}
130
131/// Data drift detection result
132#[derive(Debug, Clone)]
133pub struct DriftDetectionResult {
134    /// Whether drift was detected
135    pub drift_detected: bool,
136    /// Drift score (higher means more drift)
137    pub drift_score: f64,
138    /// Statistical test p-value (if applicable)
139    pub p_value: Option<f64>,
140    /// Distribution distance metric
141    pub distance_metric: f64,
142    /// Drift type
143    pub drift_type: DriftType,
144    /// Detailed analysis
145    pub analysis: String,
146    /// Affected features
147    pub affected_features: Vec<String>,
148}
149
150/// Types of detected drift
151#[derive(Debug, Clone, Copy, PartialEq, Eq)]
152pub enum DriftType {
153    /// No drift detected
154    NoDrift,
155    /// Covariate shift (feature distribution change)
156    CovariateShift,
157    /// Concept drift (relationship change)
158    ConceptDrift,
159    /// Label drift (target distribution change)
160    LabelDrift,
161    /// Combined drift
162    CombinedDrift,
163}
164
165/// Data quality analyzer
166pub struct DataQualityAnalyzer {
167    /// Configuration for quality checks
168    config: QualityAnalysisConfig,
169}
170
171/// Configuration for quality analysis
172#[derive(Debug, Clone)]
173pub struct QualityAnalysisConfig {
174    /// Check for missing values
175    pub check_completeness: bool,
176    /// Check for invalid values
177    pub check_validity: bool,
178    /// Check for duplicates
179    pub check_duplicates: bool,
180    /// Check for outliers
181    pub check_outliers: bool,
182    /// Outlier detection method
183    pub outlier_method: OutlierDetectionMethod,
184    /// Outlier threshold (e.g., number of standard deviations)
185    pub outlier_threshold: f64,
186    /// Maximum number of unique values to track
187    pub max_unique_values: usize,
188}
189
190impl Default for QualityAnalysisConfig {
191    fn default() -> Self {
192        Self {
193            check_completeness: true,
194            check_validity: true,
195            check_duplicates: true,
196            check_outliers: true,
197            outlier_method: OutlierDetectionMethod::IQR,
198            outlier_threshold: 1.5,
199            max_unique_values: 10000,
200        }
201    }
202}
203
204/// Outlier detection methods
205#[derive(Debug, Clone, Copy, PartialEq, Eq)]
206pub enum OutlierDetectionMethod {
207    /// Interquartile Range (IQR) method
208    IQR,
209    /// Z-score method
210    ZScore,
211    /// Modified Z-score method
212    ModifiedZScore,
213    /// Isolation Forest
214    IsolationForest,
215}
216
217impl DataQualityAnalyzer {
218    /// Create a new data quality analyzer
219    pub fn new(config: QualityAnalysisConfig) -> Self {
220        Self { config }
221    }
222
223    /// Create with default configuration
224    pub fn default() -> Self {
225        Self::new(QualityAnalysisConfig::default())
226    }
227
228    /// Analyze data quality for a dataset
229    pub fn analyze<T>(
230        &self,
231        dataset: &dyn Dataset<T>,
232        dataset_name: impl Into<String>,
233    ) -> Result<DataQualityMetrics>
234    where
235        T: Clone
236            + Default
237            + scirs2_core::numeric::Zero
238            + scirs2_core::numeric::Float
239            + Send
240            + Sync
241            + 'static,
242    {
243        let dataset_name = dataset_name.into();
244        let total_samples = dataset.len();
245
246        if total_samples == 0 {
247            return Ok(DataQualityMetrics {
248                dataset_name,
249                total_samples: 0,
250                completeness: HashMap::new(),
251                validity: HashMap::new(),
252                consistency_score: 0.0,
253                timeliness_score: None,
254                accuracy_estimates: HashMap::new(),
255                uniqueness_score: 0.0,
256                overall_quality_score: 0.0,
257                issues: vec![DataQualityIssue {
258                    severity: IssueSeverity::Critical,
259                    category: IssueCategory::Completeness,
260                    description: "Dataset is empty".to_string(),
261                    affected_fields: vec![],
262                    affected_count: 0,
263                    remediation: Some("Add data to the dataset".to_string()),
264                }],
265            });
266        }
267
268        let mut metrics = DataQualityMetrics {
269            dataset_name,
270            total_samples,
271            completeness: HashMap::new(),
272            validity: HashMap::new(),
273            consistency_score: 1.0,
274            timeliness_score: None,
275            accuracy_estimates: HashMap::new(),
276            uniqueness_score: 1.0,
277            overall_quality_score: 0.0,
278            issues: Vec::new(),
279        };
280
281        // Check completeness
282        if self.config.check_completeness {
283            self.check_completeness(dataset, &mut metrics)?;
284        }
285
286        // Check for duplicates
287        if self.config.check_duplicates {
288            self.check_duplicates(dataset, &mut metrics)?;
289        }
290
291        // Check for outliers
292        if self.config.check_outliers {
293            self.check_outliers(dataset, &mut metrics)?;
294        }
295
296        // Calculate overall quality score
297        metrics.overall_quality_score = self.calculate_overall_score(&metrics);
298
299        Ok(metrics)
300    }
301
302    fn check_completeness<T>(
303        &self,
304        dataset: &dyn Dataset<T>,
305        metrics: &mut DataQualityMetrics,
306    ) -> Result<()>
307    where
308        T: Clone + Default + scirs2_core::numeric::Zero + PartialEq + Send + Sync + 'static,
309    {
310        let mut non_zero_count = 0;
311
312        for i in 0..dataset.len().min(1000) {
313            // Sample up to 1000 items
314            if let Ok((features, _)) = dataset.get(i) {
315                if let Some(data) = features.as_slice() {
316                    if data.iter().any(|x| *x != T::zero()) {
317                        non_zero_count += 1;
318                    }
319                }
320            }
321        }
322
323        let samples_checked = dataset.len().min(1000);
324        let completeness_score = if samples_checked > 0 {
325            non_zero_count as f64 / samples_checked as f64
326        } else {
327            0.0
328        };
329
330        metrics
331            .completeness
332            .insert("features".to_string(), completeness_score);
333
334        if completeness_score < 0.9 {
335            metrics.issues.push(DataQualityIssue {
336                severity: if completeness_score < 0.5 {
337                    IssueSeverity::High
338                } else {
339                    IssueSeverity::Medium
340                },
341                category: IssueCategory::Completeness,
342                description: format!("Low completeness score: {:.2}%", completeness_score * 100.0),
343                affected_fields: vec!["features".to_string()],
344                affected_count: ((1.0 - completeness_score) * samples_checked as f64) as usize,
345                remediation: Some(
346                    "Investigate missing data and apply imputation if appropriate".to_string(),
347                ),
348            });
349        }
350
351        Ok(())
352    }
353
354    fn check_duplicates<T>(
355        &self,
356        dataset: &dyn Dataset<T>,
357        metrics: &mut DataQualityMetrics,
358    ) -> Result<()>
359    where
360        T: Clone
361            + Default
362            + scirs2_core::numeric::Zero
363            + scirs2_core::numeric::Float
364            + Send
365            + Sync
366            + 'static,
367    {
368        use std::collections::HashSet;
369
370        // Check for duplicate samples using approximate comparison for floats
371        let samples_to_check = dataset.len().min(1000);
372        let mut seen_samples: HashSet<String> = HashSet::new();
373        let mut duplicate_count = 0usize;
374
375        for i in 0..samples_to_check {
376            if let Ok((features, _labels)) = dataset.get(i) {
377                // Create a fingerprint from the tensor values
378                if let Some(data_slice) = features.as_slice() {
379                    // Create a deterministic string representation
380                    // For floats, we round to a fixed precision to avoid precision issues
381                    let fingerprint: String = data_slice
382                        .iter()
383                        .map(|v| {
384                            let f = v.to_f64().unwrap_or(0.0);
385                            format!("{:.6}", f) // 6 decimal places for fingerprinting
386                        })
387                        .collect::<Vec<_>>()
388                        .join(",");
389
390                    if !seen_samples.insert(fingerprint) {
391                        duplicate_count += 1;
392                    }
393                }
394            }
395        }
396
397        // Calculate uniqueness score: (unique samples) / (total samples checked)
398        let unique_count = samples_to_check - duplicate_count;
399        metrics.uniqueness_score = if samples_to_check > 0 {
400            unique_count as f64 / samples_to_check as f64
401        } else {
402            1.0
403        };
404
405        Ok(())
406    }
407
408    fn check_outliers<T>(
409        &self,
410        dataset: &dyn Dataset<T>,
411        metrics: &mut DataQualityMetrics,
412    ) -> Result<()>
413    where
414        T: Clone
415            + Default
416            + scirs2_core::numeric::Zero
417            + scirs2_core::numeric::Float
418            + Send
419            + Sync
420            + 'static,
421    {
422        let mut values: Vec<f64> = Vec::new();
423
424        // Collect values for outlier analysis
425        for i in 0..dataset.len().min(1000) {
426            if let Ok((features, _)) = dataset.get(i) {
427                if let Some(data) = features.as_slice() {
428                    for &val in data {
429                        values.push(val.to_f64().unwrap_or(0.0));
430                    }
431                }
432            }
433        }
434
435        if values.is_empty() {
436            return Ok(());
437        }
438
439        // Calculate statistics
440        let outlier_count = match self.config.outlier_method {
441            OutlierDetectionMethod::IQR => self.detect_outliers_iqr(&values),
442            OutlierDetectionMethod::ZScore => self.detect_outliers_zscore(&values),
443            _ => 0,
444        };
445
446        if outlier_count > 0 {
447            let outlier_percentage = outlier_count as f64 / values.len() as f64;
448            if outlier_percentage > 0.05 {
449                // More than 5% outliers
450                metrics.issues.push(DataQualityIssue {
451                    severity: IssueSeverity::Medium,
452                    category: IssueCategory::StatisticalAnomaly,
453                    description: format!(
454                        "High percentage of outliers detected: {:.2}%",
455                        outlier_percentage * 100.0
456                    ),
457                    affected_fields: vec!["features".to_string()],
458                    affected_count: outlier_count,
459                    remediation: Some(
460                        "Review outlier values and consider outlier removal or transformation"
461                            .to_string(),
462                    ),
463                });
464            }
465        }
466
467        Ok(())
468    }
469
470    fn detect_outliers_iqr(&self, values: &[f64]) -> usize {
471        if values.len() < 4 {
472            return 0;
473        }
474
475        let mut sorted = values.to_vec();
476        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
477
478        let q1_idx = sorted.len() / 4;
479        let q3_idx = (3 * sorted.len()) / 4;
480        let q1 = sorted[q1_idx];
481        let q3 = sorted[q3_idx];
482        let iqr = q3 - q1;
483
484        let lower_bound = q1 - self.config.outlier_threshold * iqr;
485        let upper_bound = q3 + self.config.outlier_threshold * iqr;
486
487        values
488            .iter()
489            .filter(|&&v| v < lower_bound || v > upper_bound)
490            .count()
491    }
492
493    fn detect_outliers_zscore(&self, values: &[f64]) -> usize {
494        if values.is_empty() {
495            return 0;
496        }
497
498        let mean = values.iter().sum::<f64>() / values.len() as f64;
499        let variance =
500            values.iter().map(|&v| (v - mean).powi(2)).sum::<f64>() / values.len() as f64;
501        let std_dev = variance.sqrt();
502
503        if std_dev == 0.0 {
504            return 0;
505        }
506
507        values
508            .iter()
509            .filter(|&&v| ((v - mean) / std_dev).abs() > self.config.outlier_threshold)
510            .count()
511    }
512
513    fn calculate_overall_score(&self, metrics: &DataQualityMetrics) -> f64 {
514        let mut scores = Vec::new();
515
516        // Completeness score
517        if let Some(&completeness) = metrics.completeness.get("features") {
518            scores.push(completeness);
519        }
520
521        // Uniqueness score
522        scores.push(metrics.uniqueness_score);
523
524        // Consistency score
525        scores.push(metrics.consistency_score);
526
527        // Reduce score based on issues
528        let issue_penalty = metrics.issues.iter().fold(0.0, |acc, issue| {
529            acc + match issue.severity {
530                IssueSeverity::Critical => 0.3,
531                IssueSeverity::High => 0.2,
532                IssueSeverity::Medium => 0.1,
533                IssueSeverity::Low => 0.05,
534                IssueSeverity::Info => 0.0,
535            }
536        });
537
538        let base_score = if scores.is_empty() {
539            0.0
540        } else {
541            scores.iter().sum::<f64>() / scores.len() as f64
542        };
543
544        (base_score - issue_penalty).clamp(0.0, 1.0)
545    }
546
547    /// Generate a human-readable quality report
548    pub fn generate_report(&self, metrics: &DataQualityMetrics) -> String {
549        let mut report = format!(
550            "Data Quality Report: {}\n\
551             ================================\n\
552             Total Samples: {}\n\
553             Overall Quality Score: {:.2}%\n\n",
554            metrics.dataset_name,
555            metrics.total_samples,
556            metrics.overall_quality_score * 100.0
557        );
558
559        // Completeness
560        if !metrics.completeness.is_empty() {
561            report.push_str("Completeness:\n");
562            for (field, score) in &metrics.completeness {
563                report.push_str(&format!("  {}: {:.2}%\n", field, score * 100.0));
564            }
565            report.push('\n');
566        }
567
568        // Uniqueness
569        report.push_str(&format!(
570            "Uniqueness Score: {:.2}%\n\n",
571            metrics.uniqueness_score * 100.0
572        ));
573
574        // Issues
575        if !metrics.issues.is_empty() {
576            report.push_str(&format!("Detected Issues ({}):\n", metrics.issues.len()));
577            for (i, issue) in metrics.issues.iter().enumerate() {
578                report.push_str(&format!(
579                    "  {}. [{:?}] [{:?}] {}\n",
580                    i + 1,
581                    issue.severity,
582                    issue.category,
583                    issue.description
584                ));
585                if !issue.affected_fields.is_empty() {
586                    report.push_str(&format!(
587                        "     Affected fields: {}\n",
588                        issue.affected_fields.join(", ")
589                    ));
590                }
591                if let Some(remediation) = &issue.remediation {
592                    report.push_str(&format!("     Remediation: {}\n", remediation));
593                }
594            }
595        } else {
596            report.push_str("No issues detected.\n");
597        }
598
599        report
600    }
601}
602
603/// Extension trait to add quality analysis to any dataset
604pub trait DataQualityExt<T>: Dataset<T> + Sized {
605    /// Analyze data quality
606    fn analyze_quality(&self, name: impl Into<String>) -> Result<DataQualityMetrics>
607    where
608        T: Clone
609            + Default
610            + scirs2_core::numeric::Zero
611            + scirs2_core::numeric::Float
612            + Send
613            + Sync
614            + 'static,
615    {
616        let analyzer = DataQualityAnalyzer::default();
617        analyzer.analyze(self, name)
618    }
619
620    /// Generate a quality report
621    fn quality_report(&self, name: impl Into<String>) -> Result<String>
622    where
623        T: Clone
624            + Default
625            + scirs2_core::numeric::Zero
626            + scirs2_core::numeric::Float
627            + Send
628            + Sync
629            + 'static,
630    {
631        let metrics = self.analyze_quality(name)?;
632        let analyzer = DataQualityAnalyzer::default();
633        Ok(analyzer.generate_report(&metrics))
634    }
635}
636
637// Blanket implementation for all datasets
638impl<T, D: Dataset<T>> DataQualityExt<T> for D {}
639
640// ──────────────────────────────────────────────────────────────────────────────
641// Drift metrics: PSI, KS two-sample test, Jensen-Shannon divergence
642// ──────────────────────────────────────────────────────────────────────────────
643
644/// Drift analysis report combining multiple statistical measures.
645#[derive(Debug, Clone)]
646pub struct DriftReport {
647    /// Population Stability Index (PSI). Values < 0.1 indicate stable, 0.1–0.2 moderate
648    /// shift, > 0.2 significant shift.
649    pub psi: f64,
650    /// Kolmogorov-Smirnov two-sample test statistic (max |ECDF_a − ECDF_b|). Range [0, 1].
651    pub ks_statistic: f64,
652    /// Jensen-Shannon divergence (log base 2). Range [0, 1]; 0 = identical distributions.
653    pub jsd: f64,
654    /// `true` when PSI > 0.2 or KS > 0.1 — coarse flag for downstream alerting.
655    pub is_significant_drift: bool,
656}
657
658/// Compute the Population Stability Index (PSI) between a reference and a current distribution.
659///
660/// Both slices are binned into `n_bins` equal-width bins spanning
661/// `[min(all values), max(all values)]`.  Epsilon smoothing prevents log(0).
662///
663/// # Errors
664/// Returns an error if either slice is empty or `n_bins == 0`.
665pub fn population_stability_index(
666    reference: &[f64],
667    current: &[f64],
668    n_bins: usize,
669) -> Result<f64> {
670    if reference.is_empty() {
671        return Err(TensorError::invalid_argument(
672            "reference slice is empty".to_string(),
673        ));
674    }
675    if current.is_empty() {
676        return Err(TensorError::invalid_argument(
677            "current slice is empty".to_string(),
678        ));
679    }
680    if n_bins == 0 {
681        return Err(TensorError::invalid_argument(
682            "n_bins must be > 0".to_string(),
683        ));
684    }
685
686    let min_val = reference
687        .iter()
688        .chain(current.iter())
689        .cloned()
690        .fold(f64::INFINITY, f64::min);
691    let max_val = reference
692        .iter()
693        .chain(current.iter())
694        .cloned()
695        .fold(f64::NEG_INFINITY, f64::max);
696
697    if (max_val - min_val).abs() < f64::EPSILON {
698        return Ok(0.0);
699    }
700
701    let bin_width = (max_val - min_val) / n_bins as f64;
702    let epsilon = 1e-9_f64;
703
704    let count_bins = |samples: &[f64]| -> Vec<f64> {
705        let n = samples.len() as f64;
706        let mut counts = vec![0_usize; n_bins];
707        for &v in samples {
708            let idx = ((v - min_val) / bin_width).floor() as usize;
709            counts[idx.min(n_bins - 1)] += 1;
710        }
711        counts
712            .into_iter()
713            .map(|c| (c as f64 + epsilon) / (n + n_bins as f64 * epsilon))
714            .collect()
715    };
716
717    let ref_pct = count_bins(reference);
718    let cur_pct = count_bins(current);
719
720    let psi = ref_pct
721        .iter()
722        .zip(cur_pct.iter())
723        .map(|(&r, &c)| (c - r) * (c / r).ln())
724        .sum::<f64>();
725
726    Ok(psi)
727}
728
729/// Compute the Kolmogorov-Smirnov two-sample test statistic.
730///
731/// Returns `max |ECDF_a(x) − ECDF_b(x)|` over all observed values.
732///
733/// # Errors
734/// Returns an error if either slice is empty.
735pub fn ks_two_sample(sample_a: &[f64], sample_b: &[f64]) -> Result<f64> {
736    if sample_a.is_empty() {
737        return Err(TensorError::invalid_argument(
738            "sample_a is empty".to_string(),
739        ));
740    }
741    if sample_b.is_empty() {
742        return Err(TensorError::invalid_argument(
743            "sample_b is empty".to_string(),
744        ));
745    }
746
747    let na = sample_a.len() as f64;
748    let nb = sample_b.len() as f64;
749
750    let mut sorted_a = sample_a.to_vec();
751    let mut sorted_b = sample_b.to_vec();
752    sorted_a.sort_by(|x, y| x.partial_cmp(y).unwrap_or(std::cmp::Ordering::Equal));
753    sorted_b.sort_by(|x, y| x.partial_cmp(y).unwrap_or(std::cmp::Ordering::Equal));
754
755    let mut ia = 0_usize;
756    let mut ib = 0_usize;
757    let mut max_diff = 0.0_f64;
758
759    while ia < sorted_a.len() || ib < sorted_b.len() {
760        let x = match (sorted_a.get(ia), sorted_b.get(ib)) {
761            (Some(&a), Some(&b)) => a.min(b),
762            (Some(&a), None) => a,
763            (None, Some(&b)) => b,
764            (None, None) => break,
765        };
766
767        while ia < sorted_a.len() && sorted_a[ia] <= x {
768            ia += 1;
769        }
770        while ib < sorted_b.len() && sorted_b[ib] <= x {
771            ib += 1;
772        }
773
774        let ecdf_a = ia as f64 / na;
775        let ecdf_b = ib as f64 / nb;
776        let diff = (ecdf_a - ecdf_b).abs();
777        if diff > max_diff {
778            max_diff = diff;
779        }
780    }
781
782    Ok(max_diff)
783}
784
785/// Compute the Jensen-Shannon divergence (log base 2) between two distributions.
786///
787/// Both `p` and `q` are treated as un-normalised histograms; they are normalised
788/// inside the function.  Returns a value in [0, 1] (0 = identical).
789///
790/// # Errors
791/// Returns an error if either slice is empty or sums to zero, or if lengths differ.
792pub fn jensen_shannon_divergence(p: &[f64], q: &[f64]) -> Result<f64> {
793    if p.is_empty() || q.is_empty() {
794        return Err(TensorError::invalid_argument(
795            "p and q must be non-empty".to_string(),
796        ));
797    }
798    if p.len() != q.len() {
799        return Err(TensorError::invalid_argument(
800            "p and q must have the same length".to_string(),
801        ));
802    }
803
804    let sum_p: f64 = p.iter().sum();
805    let sum_q: f64 = q.iter().sum();
806
807    if sum_p <= 0.0 {
808        return Err(TensorError::invalid_argument("p sums to zero".to_string()));
809    }
810    if sum_q <= 0.0 {
811        return Err(TensorError::invalid_argument("q sums to zero".to_string()));
812    }
813
814    let norm_p: Vec<f64> = p.iter().map(|&v| v / sum_p).collect();
815    let norm_q: Vec<f64> = q.iter().map(|&v| v / sum_q).collect();
816
817    let m: Vec<f64> = norm_p
818        .iter()
819        .zip(norm_q.iter())
820        .map(|(&pi, &qi)| (pi + qi) * 0.5)
821        .collect();
822
823    let kl_div = |dist: &[f64], mix: &[f64]| -> f64 {
824        dist.iter()
825            .zip(mix.iter())
826            .filter(|(&pi, &mi)| pi > 0.0 && mi > 0.0)
827            .map(|(&pi, &mi)| pi * (pi / mi).log2())
828            .sum::<f64>()
829    };
830
831    let jsd = 0.5 * kl_div(&norm_p, &m) + 0.5 * kl_div(&norm_q, &m);
832    Ok(jsd.clamp(0.0, 1.0))
833}
834
835/// Run PSI, KS, and JSD on a pair of 1-D sample arrays and return a combined `DriftReport`.
836///
837/// Drift is flagged as significant when PSI > 0.2 or KS > 0.1.
838///
839/// # Errors
840/// Propagates errors from the underlying statistical functions.
841pub fn compute_drift(reference: &[f64], current: &[f64]) -> Result<DriftReport> {
842    let psi = population_stability_index(reference, current, 20)?;
843    let ks_statistic = ks_two_sample(reference, current)?;
844
845    let n_bins = 20_usize;
846    let min_val = reference
847        .iter()
848        .chain(current.iter())
849        .cloned()
850        .fold(f64::INFINITY, f64::min);
851    let max_val = reference
852        .iter()
853        .chain(current.iter())
854        .cloned()
855        .fold(f64::NEG_INFINITY, f64::max);
856
857    let jsd = if (max_val - min_val).abs() < f64::EPSILON {
858        0.0
859    } else {
860        let bin_width = (max_val - min_val) / n_bins as f64;
861        let mut hist_ref = vec![0_f64; n_bins];
862        let mut hist_cur = vec![0_f64; n_bins];
863
864        for &v in reference {
865            let idx = ((v - min_val) / bin_width).floor() as usize;
866            hist_ref[idx.min(n_bins - 1)] += 1.0;
867        }
868        for &v in current {
869            let idx = ((v - min_val) / bin_width).floor() as usize;
870            hist_cur[idx.min(n_bins - 1)] += 1.0;
871        }
872
873        jensen_shannon_divergence(&hist_ref, &hist_cur)?
874    };
875
876    let is_significant_drift = psi > 0.2 || ks_statistic > 0.1;
877
878    Ok(DriftReport {
879        psi,
880        ks_statistic,
881        jsd,
882        is_significant_drift,
883    })
884}
885
886#[cfg(test)]
887mod tests {
888    use super::*;
889    use crate::TensorDataset;
890    use tenflowers_core::Tensor;
891
892    #[test]
893    fn test_quality_analyzer_creation() {
894        let analyzer = DataQualityAnalyzer::default();
895        assert!(analyzer.config.check_completeness);
896        assert!(analyzer.config.check_validity);
897    }
898
899    #[test]
900    fn test_empty_dataset_quality() {
901        let features =
902            Tensor::<f32>::from_vec(vec![], &[0, 1]).expect("test: tensor creation should succeed");
903        let labels =
904            Tensor::<f32>::from_vec(vec![], &[0]).expect("test: tensor creation should succeed");
905        let dataset = TensorDataset::new(features, labels);
906
907        let analyzer = DataQualityAnalyzer::default();
908        let metrics = analyzer
909            .analyze(&dataset, "test_dataset")
910            .expect("test: operation should succeed");
911
912        assert_eq!(metrics.total_samples, 0);
913        assert_eq!(metrics.overall_quality_score, 0.0);
914        assert!(!metrics.issues.is_empty());
915    }
916
917    #[test]
918    fn test_quality_extension_trait() {
919        let features = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2])
920            .expect("test: tensor creation should succeed");
921        let labels = Tensor::<f32>::from_vec(vec![0.0, 1.0], &[2])
922            .expect("test: tensor creation should succeed");
923        let dataset = TensorDataset::new(features, labels);
924
925        let metrics = dataset
926            .analyze_quality("test_dataset")
927            .expect("test: operation should succeed");
928        assert_eq!(metrics.total_samples, 2);
929        assert!(metrics.overall_quality_score > 0.0);
930    }
931
932    #[test]
933    fn test_outlier_detection_iqr() {
934        let config = QualityAnalysisConfig::default();
935        let analyzer = DataQualityAnalyzer::new(config);
936
937        let values = vec![1.0, 2.0, 3.0, 4.0, 5.0, 100.0]; // 100.0 is an outlier
938        let outlier_count = analyzer.detect_outliers_iqr(&values);
939
940        assert!(outlier_count > 0);
941    }
942
943    #[test]
944    fn test_drift_detection_config() {
945        let config = DriftDetectionConfig::default();
946        assert_eq!(config.reference_window_size, 1000);
947        assert_eq!(config.detection_threshold, 0.05);
948        assert_eq!(config.statistical_test, StatisticalTest::KolmogorovSmirnov);
949    }
950
951    #[test]
952    fn test_psi_identical_distributions_is_zero() {
953        let data: Vec<f64> = (0..100).map(|i| i as f64).collect();
954        let psi =
955            population_stability_index(&data, &data, 10).expect("PSI should compute without error");
956        assert!(
957            psi < 1e-6,
958            "PSI of identical distributions should be < 1e-6, got {}",
959            psi
960        );
961    }
962
963    #[test]
964    fn test_ks_identical_sorted_is_zero() {
965        let data: Vec<f64> = (0..50).map(|i| i as f64).collect();
966        let ks = ks_two_sample(&data, &data).expect("KS statistic should compute without error");
967        assert!(
968            ks < 1e-10,
969            "KS of identical distributions should be 0, got {}",
970            ks
971        );
972    }
973
974    #[test]
975    fn test_jsd_identical_is_zero() {
976        let data: Vec<f64> = vec![0.1, 0.2, 0.3, 0.2, 0.1, 0.05, 0.05];
977        let jsd =
978            jensen_shannon_divergence(&data, &data).expect("JSD should compute without error");
979        assert!(
980            jsd < 1e-10,
981            "JSD of identical distributions should be 0, got {}",
982            jsd
983        );
984    }
985
986    #[test]
987    fn test_psi_shifted_distribution_positive() {
988        let reference: Vec<f64> = (0..100).map(|i| i as f64 * 0.1).collect();
989        let current: Vec<f64> = (0..100).map(|i| 50.0 + i as f64 * 0.1).collect();
990        let psi = population_stability_index(&reference, &current, 10)
991            .expect("PSI should compute without error");
992        assert!(
993            psi > 0.1,
994            "PSI of shifted distributions should be > 0.1, got {}",
995            psi
996        );
997    }
998}