1use crate::Dataset;
10use std::collections::HashMap;
11use tenflowers_core::{Result, Tensor, TensorError};
12
13#[derive(Debug, Clone)]
15pub struct DataQualityMetrics {
16 pub dataset_name: String,
18 pub total_samples: usize,
20 pub completeness: HashMap<String, f64>,
22 pub validity: HashMap<String, f64>,
24 pub consistency_score: f64,
26 pub timeliness_score: Option<f64>,
28 pub accuracy_estimates: HashMap<String, f64>,
30 pub uniqueness_score: f64,
32 pub overall_quality_score: f64,
34 pub issues: Vec<DataQualityIssue>,
36}
37
38#[derive(Debug, Clone)]
40pub struct DataQualityIssue {
41 pub severity: IssueSeverity,
43 pub category: IssueCategory,
45 pub description: String,
47 pub affected_fields: Vec<String>,
49 pub affected_count: usize,
51 pub remediation: Option<String>,
53}
54
55#[derive(Debug, Clone, Copy, PartialEq, Eq)]
57pub enum IssueSeverity {
58 Critical,
60 High,
62 Medium,
64 Low,
66 Info,
68}
69
70#[derive(Debug, Clone, Copy, PartialEq, Eq)]
72pub enum IssueCategory {
73 Completeness,
75 Validity,
77 Consistency,
79 Uniqueness,
81 Accuracy,
83 Timeliness,
85 StatisticalAnomaly,
87}
88
89#[derive(Debug, Clone)]
91pub struct DriftDetectionConfig {
92 pub reference_window_size: usize,
94 pub detection_threshold: f64,
96 pub statistical_test: StatisticalTest,
98 pub min_samples: usize,
100 pub enable_visualization: bool,
102}
103
104impl Default for DriftDetectionConfig {
105 fn default() -> Self {
106 Self {
107 reference_window_size: 1000,
108 detection_threshold: 0.05, statistical_test: StatisticalTest::KolmogorovSmirnov,
110 min_samples: 100,
111 enable_visualization: false,
112 }
113 }
114}
115
116#[derive(Debug, Clone, Copy, PartialEq, Eq)]
118pub enum StatisticalTest {
119 KolmogorovSmirnov,
121 ChiSquared,
123 PopulationStabilityIndex,
125 KullbackLeibler,
127 JensenShannon,
129}
130
131#[derive(Debug, Clone)]
133pub struct DriftDetectionResult {
134 pub drift_detected: bool,
136 pub drift_score: f64,
138 pub p_value: Option<f64>,
140 pub distance_metric: f64,
142 pub drift_type: DriftType,
144 pub analysis: String,
146 pub affected_features: Vec<String>,
148}
149
150#[derive(Debug, Clone, Copy, PartialEq, Eq)]
152pub enum DriftType {
153 NoDrift,
155 CovariateShift,
157 ConceptDrift,
159 LabelDrift,
161 CombinedDrift,
163}
164
165pub struct DataQualityAnalyzer {
167 config: QualityAnalysisConfig,
169}
170
171#[derive(Debug, Clone)]
173pub struct QualityAnalysisConfig {
174 pub check_completeness: bool,
176 pub check_validity: bool,
178 pub check_duplicates: bool,
180 pub check_outliers: bool,
182 pub outlier_method: OutlierDetectionMethod,
184 pub outlier_threshold: f64,
186 pub max_unique_values: usize,
188}
189
190impl Default for QualityAnalysisConfig {
191 fn default() -> Self {
192 Self {
193 check_completeness: true,
194 check_validity: true,
195 check_duplicates: true,
196 check_outliers: true,
197 outlier_method: OutlierDetectionMethod::IQR,
198 outlier_threshold: 1.5,
199 max_unique_values: 10000,
200 }
201 }
202}
203
204#[derive(Debug, Clone, Copy, PartialEq, Eq)]
206pub enum OutlierDetectionMethod {
207 IQR,
209 ZScore,
211 ModifiedZScore,
213 IsolationForest,
215}
216
217impl DataQualityAnalyzer {
218 pub fn new(config: QualityAnalysisConfig) -> Self {
220 Self { config }
221 }
222
223 pub fn default() -> Self {
225 Self::new(QualityAnalysisConfig::default())
226 }
227
228 pub fn analyze<T>(
230 &self,
231 dataset: &dyn Dataset<T>,
232 dataset_name: impl Into<String>,
233 ) -> Result<DataQualityMetrics>
234 where
235 T: Clone
236 + Default
237 + scirs2_core::numeric::Zero
238 + scirs2_core::numeric::Float
239 + Send
240 + Sync
241 + 'static,
242 {
243 let dataset_name = dataset_name.into();
244 let total_samples = dataset.len();
245
246 if total_samples == 0 {
247 return Ok(DataQualityMetrics {
248 dataset_name,
249 total_samples: 0,
250 completeness: HashMap::new(),
251 validity: HashMap::new(),
252 consistency_score: 0.0,
253 timeliness_score: None,
254 accuracy_estimates: HashMap::new(),
255 uniqueness_score: 0.0,
256 overall_quality_score: 0.0,
257 issues: vec![DataQualityIssue {
258 severity: IssueSeverity::Critical,
259 category: IssueCategory::Completeness,
260 description: "Dataset is empty".to_string(),
261 affected_fields: vec![],
262 affected_count: 0,
263 remediation: Some("Add data to the dataset".to_string()),
264 }],
265 });
266 }
267
268 let mut metrics = DataQualityMetrics {
269 dataset_name,
270 total_samples,
271 completeness: HashMap::new(),
272 validity: HashMap::new(),
273 consistency_score: 1.0,
274 timeliness_score: None,
275 accuracy_estimates: HashMap::new(),
276 uniqueness_score: 1.0,
277 overall_quality_score: 0.0,
278 issues: Vec::new(),
279 };
280
281 if self.config.check_completeness {
283 self.check_completeness(dataset, &mut metrics)?;
284 }
285
286 if self.config.check_duplicates {
288 self.check_duplicates(dataset, &mut metrics)?;
289 }
290
291 if self.config.check_outliers {
293 self.check_outliers(dataset, &mut metrics)?;
294 }
295
296 metrics.overall_quality_score = self.calculate_overall_score(&metrics);
298
299 Ok(metrics)
300 }
301
302 fn check_completeness<T>(
303 &self,
304 dataset: &dyn Dataset<T>,
305 metrics: &mut DataQualityMetrics,
306 ) -> Result<()>
307 where
308 T: Clone + Default + scirs2_core::numeric::Zero + PartialEq + Send + Sync + 'static,
309 {
310 let mut non_zero_count = 0;
311
312 for i in 0..dataset.len().min(1000) {
313 if let Ok((features, _)) = dataset.get(i) {
315 if let Some(data) = features.as_slice() {
316 if data.iter().any(|x| *x != T::zero()) {
317 non_zero_count += 1;
318 }
319 }
320 }
321 }
322
323 let samples_checked = dataset.len().min(1000);
324 let completeness_score = if samples_checked > 0 {
325 non_zero_count as f64 / samples_checked as f64
326 } else {
327 0.0
328 };
329
330 metrics
331 .completeness
332 .insert("features".to_string(), completeness_score);
333
334 if completeness_score < 0.9 {
335 metrics.issues.push(DataQualityIssue {
336 severity: if completeness_score < 0.5 {
337 IssueSeverity::High
338 } else {
339 IssueSeverity::Medium
340 },
341 category: IssueCategory::Completeness,
342 description: format!("Low completeness score: {:.2}%", completeness_score * 100.0),
343 affected_fields: vec!["features".to_string()],
344 affected_count: ((1.0 - completeness_score) * samples_checked as f64) as usize,
345 remediation: Some(
346 "Investigate missing data and apply imputation if appropriate".to_string(),
347 ),
348 });
349 }
350
351 Ok(())
352 }
353
354 fn check_duplicates<T>(
355 &self,
356 dataset: &dyn Dataset<T>,
357 metrics: &mut DataQualityMetrics,
358 ) -> Result<()>
359 where
360 T: Clone
361 + Default
362 + scirs2_core::numeric::Zero
363 + scirs2_core::numeric::Float
364 + Send
365 + Sync
366 + 'static,
367 {
368 use std::collections::HashSet;
369
370 let samples_to_check = dataset.len().min(1000);
372 let mut seen_samples: HashSet<String> = HashSet::new();
373 let mut duplicate_count = 0usize;
374
375 for i in 0..samples_to_check {
376 if let Ok((features, _labels)) = dataset.get(i) {
377 if let Some(data_slice) = features.as_slice() {
379 let fingerprint: String = data_slice
382 .iter()
383 .map(|v| {
384 let f = v.to_f64().unwrap_or(0.0);
385 format!("{:.6}", f) })
387 .collect::<Vec<_>>()
388 .join(",");
389
390 if !seen_samples.insert(fingerprint) {
391 duplicate_count += 1;
392 }
393 }
394 }
395 }
396
397 let unique_count = samples_to_check - duplicate_count;
399 metrics.uniqueness_score = if samples_to_check > 0 {
400 unique_count as f64 / samples_to_check as f64
401 } else {
402 1.0
403 };
404
405 Ok(())
406 }
407
408 fn check_outliers<T>(
409 &self,
410 dataset: &dyn Dataset<T>,
411 metrics: &mut DataQualityMetrics,
412 ) -> Result<()>
413 where
414 T: Clone
415 + Default
416 + scirs2_core::numeric::Zero
417 + scirs2_core::numeric::Float
418 + Send
419 + Sync
420 + 'static,
421 {
422 let mut values: Vec<f64> = Vec::new();
423
424 for i in 0..dataset.len().min(1000) {
426 if let Ok((features, _)) = dataset.get(i) {
427 if let Some(data) = features.as_slice() {
428 for &val in data {
429 values.push(val.to_f64().unwrap_or(0.0));
430 }
431 }
432 }
433 }
434
435 if values.is_empty() {
436 return Ok(());
437 }
438
439 let outlier_count = match self.config.outlier_method {
441 OutlierDetectionMethod::IQR => self.detect_outliers_iqr(&values),
442 OutlierDetectionMethod::ZScore => self.detect_outliers_zscore(&values),
443 _ => 0,
444 };
445
446 if outlier_count > 0 {
447 let outlier_percentage = outlier_count as f64 / values.len() as f64;
448 if outlier_percentage > 0.05 {
449 metrics.issues.push(DataQualityIssue {
451 severity: IssueSeverity::Medium,
452 category: IssueCategory::StatisticalAnomaly,
453 description: format!(
454 "High percentage of outliers detected: {:.2}%",
455 outlier_percentage * 100.0
456 ),
457 affected_fields: vec!["features".to_string()],
458 affected_count: outlier_count,
459 remediation: Some(
460 "Review outlier values and consider outlier removal or transformation"
461 .to_string(),
462 ),
463 });
464 }
465 }
466
467 Ok(())
468 }
469
470 fn detect_outliers_iqr(&self, values: &[f64]) -> usize {
471 if values.len() < 4 {
472 return 0;
473 }
474
475 let mut sorted = values.to_vec();
476 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
477
478 let q1_idx = sorted.len() / 4;
479 let q3_idx = (3 * sorted.len()) / 4;
480 let q1 = sorted[q1_idx];
481 let q3 = sorted[q3_idx];
482 let iqr = q3 - q1;
483
484 let lower_bound = q1 - self.config.outlier_threshold * iqr;
485 let upper_bound = q3 + self.config.outlier_threshold * iqr;
486
487 values
488 .iter()
489 .filter(|&&v| v < lower_bound || v > upper_bound)
490 .count()
491 }
492
493 fn detect_outliers_zscore(&self, values: &[f64]) -> usize {
494 if values.is_empty() {
495 return 0;
496 }
497
498 let mean = values.iter().sum::<f64>() / values.len() as f64;
499 let variance =
500 values.iter().map(|&v| (v - mean).powi(2)).sum::<f64>() / values.len() as f64;
501 let std_dev = variance.sqrt();
502
503 if std_dev == 0.0 {
504 return 0;
505 }
506
507 values
508 .iter()
509 .filter(|&&v| ((v - mean) / std_dev).abs() > self.config.outlier_threshold)
510 .count()
511 }
512
513 fn calculate_overall_score(&self, metrics: &DataQualityMetrics) -> f64 {
514 let mut scores = Vec::new();
515
516 if let Some(&completeness) = metrics.completeness.get("features") {
518 scores.push(completeness);
519 }
520
521 scores.push(metrics.uniqueness_score);
523
524 scores.push(metrics.consistency_score);
526
527 let issue_penalty = metrics.issues.iter().fold(0.0, |acc, issue| {
529 acc + match issue.severity {
530 IssueSeverity::Critical => 0.3,
531 IssueSeverity::High => 0.2,
532 IssueSeverity::Medium => 0.1,
533 IssueSeverity::Low => 0.05,
534 IssueSeverity::Info => 0.0,
535 }
536 });
537
538 let base_score = if scores.is_empty() {
539 0.0
540 } else {
541 scores.iter().sum::<f64>() / scores.len() as f64
542 };
543
544 (base_score - issue_penalty).clamp(0.0, 1.0)
545 }
546
547 pub fn generate_report(&self, metrics: &DataQualityMetrics) -> String {
549 let mut report = format!(
550 "Data Quality Report: {}\n\
551 ================================\n\
552 Total Samples: {}\n\
553 Overall Quality Score: {:.2}%\n\n",
554 metrics.dataset_name,
555 metrics.total_samples,
556 metrics.overall_quality_score * 100.0
557 );
558
559 if !metrics.completeness.is_empty() {
561 report.push_str("Completeness:\n");
562 for (field, score) in &metrics.completeness {
563 report.push_str(&format!(" {}: {:.2}%\n", field, score * 100.0));
564 }
565 report.push('\n');
566 }
567
568 report.push_str(&format!(
570 "Uniqueness Score: {:.2}%\n\n",
571 metrics.uniqueness_score * 100.0
572 ));
573
574 if !metrics.issues.is_empty() {
576 report.push_str(&format!("Detected Issues ({}):\n", metrics.issues.len()));
577 for (i, issue) in metrics.issues.iter().enumerate() {
578 report.push_str(&format!(
579 " {}. [{:?}] [{:?}] {}\n",
580 i + 1,
581 issue.severity,
582 issue.category,
583 issue.description
584 ));
585 if !issue.affected_fields.is_empty() {
586 report.push_str(&format!(
587 " Affected fields: {}\n",
588 issue.affected_fields.join(", ")
589 ));
590 }
591 if let Some(remediation) = &issue.remediation {
592 report.push_str(&format!(" Remediation: {}\n", remediation));
593 }
594 }
595 } else {
596 report.push_str("No issues detected.\n");
597 }
598
599 report
600 }
601}
602
603pub trait DataQualityExt<T>: Dataset<T> + Sized {
605 fn analyze_quality(&self, name: impl Into<String>) -> Result<DataQualityMetrics>
607 where
608 T: Clone
609 + Default
610 + scirs2_core::numeric::Zero
611 + scirs2_core::numeric::Float
612 + Send
613 + Sync
614 + 'static,
615 {
616 let analyzer = DataQualityAnalyzer::default();
617 analyzer.analyze(self, name)
618 }
619
620 fn quality_report(&self, name: impl Into<String>) -> Result<String>
622 where
623 T: Clone
624 + Default
625 + scirs2_core::numeric::Zero
626 + scirs2_core::numeric::Float
627 + Send
628 + Sync
629 + 'static,
630 {
631 let metrics = self.analyze_quality(name)?;
632 let analyzer = DataQualityAnalyzer::default();
633 Ok(analyzer.generate_report(&metrics))
634 }
635}
636
637impl<T, D: Dataset<T>> DataQualityExt<T> for D {}
639
640#[derive(Debug, Clone)]
646pub struct DriftReport {
647 pub psi: f64,
650 pub ks_statistic: f64,
652 pub jsd: f64,
654 pub is_significant_drift: bool,
656}
657
658pub fn population_stability_index(
666 reference: &[f64],
667 current: &[f64],
668 n_bins: usize,
669) -> Result<f64> {
670 if reference.is_empty() {
671 return Err(TensorError::invalid_argument(
672 "reference slice is empty".to_string(),
673 ));
674 }
675 if current.is_empty() {
676 return Err(TensorError::invalid_argument(
677 "current slice is empty".to_string(),
678 ));
679 }
680 if n_bins == 0 {
681 return Err(TensorError::invalid_argument(
682 "n_bins must be > 0".to_string(),
683 ));
684 }
685
686 let min_val = reference
687 .iter()
688 .chain(current.iter())
689 .cloned()
690 .fold(f64::INFINITY, f64::min);
691 let max_val = reference
692 .iter()
693 .chain(current.iter())
694 .cloned()
695 .fold(f64::NEG_INFINITY, f64::max);
696
697 if (max_val - min_val).abs() < f64::EPSILON {
698 return Ok(0.0);
699 }
700
701 let bin_width = (max_val - min_val) / n_bins as f64;
702 let epsilon = 1e-9_f64;
703
704 let count_bins = |samples: &[f64]| -> Vec<f64> {
705 let n = samples.len() as f64;
706 let mut counts = vec![0_usize; n_bins];
707 for &v in samples {
708 let idx = ((v - min_val) / bin_width).floor() as usize;
709 counts[idx.min(n_bins - 1)] += 1;
710 }
711 counts
712 .into_iter()
713 .map(|c| (c as f64 + epsilon) / (n + n_bins as f64 * epsilon))
714 .collect()
715 };
716
717 let ref_pct = count_bins(reference);
718 let cur_pct = count_bins(current);
719
720 let psi = ref_pct
721 .iter()
722 .zip(cur_pct.iter())
723 .map(|(&r, &c)| (c - r) * (c / r).ln())
724 .sum::<f64>();
725
726 Ok(psi)
727}
728
729pub fn ks_two_sample(sample_a: &[f64], sample_b: &[f64]) -> Result<f64> {
736 if sample_a.is_empty() {
737 return Err(TensorError::invalid_argument(
738 "sample_a is empty".to_string(),
739 ));
740 }
741 if sample_b.is_empty() {
742 return Err(TensorError::invalid_argument(
743 "sample_b is empty".to_string(),
744 ));
745 }
746
747 let na = sample_a.len() as f64;
748 let nb = sample_b.len() as f64;
749
750 let mut sorted_a = sample_a.to_vec();
751 let mut sorted_b = sample_b.to_vec();
752 sorted_a.sort_by(|x, y| x.partial_cmp(y).unwrap_or(std::cmp::Ordering::Equal));
753 sorted_b.sort_by(|x, y| x.partial_cmp(y).unwrap_or(std::cmp::Ordering::Equal));
754
755 let mut ia = 0_usize;
756 let mut ib = 0_usize;
757 let mut max_diff = 0.0_f64;
758
759 while ia < sorted_a.len() || ib < sorted_b.len() {
760 let x = match (sorted_a.get(ia), sorted_b.get(ib)) {
761 (Some(&a), Some(&b)) => a.min(b),
762 (Some(&a), None) => a,
763 (None, Some(&b)) => b,
764 (None, None) => break,
765 };
766
767 while ia < sorted_a.len() && sorted_a[ia] <= x {
768 ia += 1;
769 }
770 while ib < sorted_b.len() && sorted_b[ib] <= x {
771 ib += 1;
772 }
773
774 let ecdf_a = ia as f64 / na;
775 let ecdf_b = ib as f64 / nb;
776 let diff = (ecdf_a - ecdf_b).abs();
777 if diff > max_diff {
778 max_diff = diff;
779 }
780 }
781
782 Ok(max_diff)
783}
784
785pub fn jensen_shannon_divergence(p: &[f64], q: &[f64]) -> Result<f64> {
793 if p.is_empty() || q.is_empty() {
794 return Err(TensorError::invalid_argument(
795 "p and q must be non-empty".to_string(),
796 ));
797 }
798 if p.len() != q.len() {
799 return Err(TensorError::invalid_argument(
800 "p and q must have the same length".to_string(),
801 ));
802 }
803
804 let sum_p: f64 = p.iter().sum();
805 let sum_q: f64 = q.iter().sum();
806
807 if sum_p <= 0.0 {
808 return Err(TensorError::invalid_argument("p sums to zero".to_string()));
809 }
810 if sum_q <= 0.0 {
811 return Err(TensorError::invalid_argument("q sums to zero".to_string()));
812 }
813
814 let norm_p: Vec<f64> = p.iter().map(|&v| v / sum_p).collect();
815 let norm_q: Vec<f64> = q.iter().map(|&v| v / sum_q).collect();
816
817 let m: Vec<f64> = norm_p
818 .iter()
819 .zip(norm_q.iter())
820 .map(|(&pi, &qi)| (pi + qi) * 0.5)
821 .collect();
822
823 let kl_div = |dist: &[f64], mix: &[f64]| -> f64 {
824 dist.iter()
825 .zip(mix.iter())
826 .filter(|(&pi, &mi)| pi > 0.0 && mi > 0.0)
827 .map(|(&pi, &mi)| pi * (pi / mi).log2())
828 .sum::<f64>()
829 };
830
831 let jsd = 0.5 * kl_div(&norm_p, &m) + 0.5 * kl_div(&norm_q, &m);
832 Ok(jsd.clamp(0.0, 1.0))
833}
834
835pub fn compute_drift(reference: &[f64], current: &[f64]) -> Result<DriftReport> {
842 let psi = population_stability_index(reference, current, 20)?;
843 let ks_statistic = ks_two_sample(reference, current)?;
844
845 let n_bins = 20_usize;
846 let min_val = reference
847 .iter()
848 .chain(current.iter())
849 .cloned()
850 .fold(f64::INFINITY, f64::min);
851 let max_val = reference
852 .iter()
853 .chain(current.iter())
854 .cloned()
855 .fold(f64::NEG_INFINITY, f64::max);
856
857 let jsd = if (max_val - min_val).abs() < f64::EPSILON {
858 0.0
859 } else {
860 let bin_width = (max_val - min_val) / n_bins as f64;
861 let mut hist_ref = vec![0_f64; n_bins];
862 let mut hist_cur = vec![0_f64; n_bins];
863
864 for &v in reference {
865 let idx = ((v - min_val) / bin_width).floor() as usize;
866 hist_ref[idx.min(n_bins - 1)] += 1.0;
867 }
868 for &v in current {
869 let idx = ((v - min_val) / bin_width).floor() as usize;
870 hist_cur[idx.min(n_bins - 1)] += 1.0;
871 }
872
873 jensen_shannon_divergence(&hist_ref, &hist_cur)?
874 };
875
876 let is_significant_drift = psi > 0.2 || ks_statistic > 0.1;
877
878 Ok(DriftReport {
879 psi,
880 ks_statistic,
881 jsd,
882 is_significant_drift,
883 })
884}
885
886#[cfg(test)]
887mod tests {
888 use super::*;
889 use crate::TensorDataset;
890 use tenflowers_core::Tensor;
891
892 #[test]
893 fn test_quality_analyzer_creation() {
894 let analyzer = DataQualityAnalyzer::default();
895 assert!(analyzer.config.check_completeness);
896 assert!(analyzer.config.check_validity);
897 }
898
899 #[test]
900 fn test_empty_dataset_quality() {
901 let features =
902 Tensor::<f32>::from_vec(vec![], &[0, 1]).expect("test: tensor creation should succeed");
903 let labels =
904 Tensor::<f32>::from_vec(vec![], &[0]).expect("test: tensor creation should succeed");
905 let dataset = TensorDataset::new(features, labels);
906
907 let analyzer = DataQualityAnalyzer::default();
908 let metrics = analyzer
909 .analyze(&dataset, "test_dataset")
910 .expect("test: operation should succeed");
911
912 assert_eq!(metrics.total_samples, 0);
913 assert_eq!(metrics.overall_quality_score, 0.0);
914 assert!(!metrics.issues.is_empty());
915 }
916
917 #[test]
918 fn test_quality_extension_trait() {
919 let features = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2])
920 .expect("test: tensor creation should succeed");
921 let labels = Tensor::<f32>::from_vec(vec![0.0, 1.0], &[2])
922 .expect("test: tensor creation should succeed");
923 let dataset = TensorDataset::new(features, labels);
924
925 let metrics = dataset
926 .analyze_quality("test_dataset")
927 .expect("test: operation should succeed");
928 assert_eq!(metrics.total_samples, 2);
929 assert!(metrics.overall_quality_score > 0.0);
930 }
931
932 #[test]
933 fn test_outlier_detection_iqr() {
934 let config = QualityAnalysisConfig::default();
935 let analyzer = DataQualityAnalyzer::new(config);
936
937 let values = vec![1.0, 2.0, 3.0, 4.0, 5.0, 100.0]; let outlier_count = analyzer.detect_outliers_iqr(&values);
939
940 assert!(outlier_count > 0);
941 }
942
943 #[test]
944 fn test_drift_detection_config() {
945 let config = DriftDetectionConfig::default();
946 assert_eq!(config.reference_window_size, 1000);
947 assert_eq!(config.detection_threshold, 0.05);
948 assert_eq!(config.statistical_test, StatisticalTest::KolmogorovSmirnov);
949 }
950
951 #[test]
952 fn test_psi_identical_distributions_is_zero() {
953 let data: Vec<f64> = (0..100).map(|i| i as f64).collect();
954 let psi =
955 population_stability_index(&data, &data, 10).expect("PSI should compute without error");
956 assert!(
957 psi < 1e-6,
958 "PSI of identical distributions should be < 1e-6, got {}",
959 psi
960 );
961 }
962
963 #[test]
964 fn test_ks_identical_sorted_is_zero() {
965 let data: Vec<f64> = (0..50).map(|i| i as f64).collect();
966 let ks = ks_two_sample(&data, &data).expect("KS statistic should compute without error");
967 assert!(
968 ks < 1e-10,
969 "KS of identical distributions should be 0, got {}",
970 ks
971 );
972 }
973
974 #[test]
975 fn test_jsd_identical_is_zero() {
976 let data: Vec<f64> = vec![0.1, 0.2, 0.3, 0.2, 0.1, 0.05, 0.05];
977 let jsd =
978 jensen_shannon_divergence(&data, &data).expect("JSD should compute without error");
979 assert!(
980 jsd < 1e-10,
981 "JSD of identical distributions should be 0, got {}",
982 jsd
983 );
984 }
985
986 #[test]
987 fn test_psi_shifted_distribution_positive() {
988 let reference: Vec<f64> = (0..100).map(|i| i as f64 * 0.1).collect();
989 let current: Vec<f64> = (0..100).map(|i| 50.0 + i as f64 * 0.1).collect();
990 let psi = population_stability_index(&reference, ¤t, 10)
991 .expect("PSI should compute without error");
992 assert!(
993 psi > 0.1,
994 "PSI of shifted distributions should be > 0.1, got {}",
995 psi
996 );
997 }
998}