1use scirs2_core::ndarray::{Array1, Array2, Axis};
31
32use crate::error::{Result, TransformError};
33
34#[non_exhaustive]
40#[derive(Debug, Clone, PartialEq)]
41pub enum DriftMethod {
42 KolmogorovSmirnov,
47 PopulationStabilityIndex,
52 Wasserstein,
56 MaximumMeanDiscrepancy,
61}
62
63#[derive(Debug, Clone)]
69pub struct DriftDetectorConfig {
70 pub method: DriftMethod,
72 pub significance_level: f64,
74 pub n_bins: usize,
76 pub mmd_bandwidth: Option<f64>,
78 pub wasserstein_threshold: f64,
80 pub psi_threshold: f64,
82}
83
84impl Default for DriftDetectorConfig {
85 fn default() -> Self {
86 Self {
87 method: DriftMethod::KolmogorovSmirnov,
88 significance_level: 0.05,
89 n_bins: 10,
90 mmd_bandwidth: None,
91 wasserstein_threshold: 0.1,
92 psi_threshold: 0.2,
93 }
94 }
95}
96
97#[derive(Debug, Clone)]
103pub struct DriftReport {
104 pub feature_scores: Vec<f64>,
106 pub feature_drifted: Vec<bool>,
108 pub overall_score: f64,
110 pub drifted: bool,
112 pub method: DriftMethod,
114}
115
116pub struct DriftDetector {
123 config: DriftDetectorConfig,
124 reference: Array2<f64>,
126}
127
128impl DriftDetector {
129 pub fn fit(reference: &Array2<f64>, config: DriftDetectorConfig) -> Self {
135 Self {
136 config,
137 reference: reference.to_owned(),
138 }
139 }
140
141 pub fn detect(&self, current: &Array2<f64>) -> Result<DriftReport> {
150 let n_ref = self.reference.nrows();
151 let n_cur = current.nrows();
152 let n_features = self.reference.ncols();
153
154 if n_ref == 0 {
155 return Err(TransformError::InvalidInput(
156 "Reference dataset is empty".to_string(),
157 ));
158 }
159 if n_cur == 0 {
160 return Err(TransformError::InvalidInput(
161 "Current dataset is empty".to_string(),
162 ));
163 }
164 if current.ncols() != n_features {
165 return Err(TransformError::InvalidInput(format!(
166 "Feature dimension mismatch: reference has {n_features} features, current has {}",
167 current.ncols()
168 )));
169 }
170
171 match &self.config.method {
172 DriftMethod::KolmogorovSmirnov => self.detect_ks(current),
173 DriftMethod::PopulationStabilityIndex => self.detect_psi(current),
174 DriftMethod::Wasserstein => self.detect_wasserstein(current),
175 DriftMethod::MaximumMeanDiscrepancy => self.detect_mmd(current),
176 }
177 }
178
179 pub fn update_reference(&mut self, new_reference: &Array2<f64>) {
181 self.reference = new_reference.to_owned();
182 }
183
184 fn detect_ks(&self, current: &Array2<f64>) -> Result<DriftReport> {
189 let n_features = self.reference.ncols();
190 let mut scores = Vec::with_capacity(n_features);
191 let mut drifted_flags = Vec::with_capacity(n_features);
192
193 for f in 0..n_features {
194 let ref_col: Vec<f64> = self.reference.column(f).iter().copied().collect();
195 let cur_col: Vec<f64> = current.column(f).iter().copied().collect();
196
197 let (ks_stat, p_value) = ks_2samp(&ref_col, &cur_col)?;
198 scores.push(ks_stat);
199 drifted_flags.push(p_value < self.config.significance_level);
200 }
201
202 let overall_score = scores.iter().copied().sum::<f64>() / scores.len() as f64;
203 let drifted = drifted_flags.iter().any(|&d| d);
204
205 Ok(DriftReport {
206 feature_scores: scores,
207 feature_drifted: drifted_flags,
208 overall_score,
209 drifted,
210 method: DriftMethod::KolmogorovSmirnov,
211 })
212 }
213
214 fn detect_psi(&self, current: &Array2<f64>) -> Result<DriftReport> {
219 let n_features = self.reference.ncols();
220 let mut scores = Vec::with_capacity(n_features);
221 let mut drifted_flags = Vec::with_capacity(n_features);
222
223 for f in 0..n_features {
224 let ref_col: Vec<f64> = self.reference.column(f).iter().copied().collect();
225 let cur_col: Vec<f64> = current.column(f).iter().copied().collect();
226
227 let psi = compute_psi(&ref_col, &cur_col, self.config.n_bins)?;
228 scores.push(psi);
229 drifted_flags.push(psi > self.config.psi_threshold);
230 }
231
232 let overall_score = scores.iter().copied().sum::<f64>() / scores.len() as f64;
233 let drifted = drifted_flags.iter().any(|&d| d);
234
235 Ok(DriftReport {
236 feature_scores: scores,
237 feature_drifted: drifted_flags,
238 overall_score,
239 drifted,
240 method: DriftMethod::PopulationStabilityIndex,
241 })
242 }
243
244 fn detect_wasserstein(&self, current: &Array2<f64>) -> Result<DriftReport> {
249 let n_features = self.reference.ncols();
250 let mut scores = Vec::with_capacity(n_features);
251 let mut drifted_flags = Vec::with_capacity(n_features);
252
253 for f in 0..n_features {
254 let ref_col: Vec<f64> = self.reference.column(f).iter().copied().collect();
255 let cur_col: Vec<f64> = current.column(f).iter().copied().collect();
256
257 let w1 = wasserstein_1d_distance(&ref_col, &cur_col)?;
258 scores.push(w1);
259 drifted_flags.push(w1 > self.config.wasserstein_threshold);
260 }
261
262 let overall_score = scores.iter().copied().sum::<f64>() / scores.len() as f64;
263 let drifted = drifted_flags.iter().any(|&d| d);
264
265 Ok(DriftReport {
266 feature_scores: scores,
267 feature_drifted: drifted_flags,
268 overall_score,
269 drifted,
270 method: DriftMethod::Wasserstein,
271 })
272 }
273
274 fn detect_mmd(&self, current: &Array2<f64>) -> Result<DriftReport> {
279 let n_features = self.reference.ncols();
280
281 let bandwidth = match self.config.mmd_bandwidth {
283 Some(bw) => {
284 if bw <= 0.0 {
285 return Err(TransformError::InvalidInput(
286 "mmd_bandwidth must be positive".to_string(),
287 ));
288 }
289 bw
290 }
291 None => median_heuristic_bandwidth(&self.reference)?,
292 };
293
294 let mut scores = Vec::with_capacity(n_features);
296 let mut drifted_flags = Vec::with_capacity(n_features);
297
298 for f in 0..n_features {
299 let ref_col: Vec<f64> = self.reference.column(f).iter().copied().collect();
300 let cur_col: Vec<f64> = current.column(f).iter().copied().collect();
301
302 let mmd2 = mmd_u_statistic_1d(&ref_col, &cur_col, bandwidth)?;
303 scores.push(mmd2);
304 let n_eff = (ref_col.len().min(cur_col.len()) as f64).max(1.0);
307 let threshold = 4.0 / n_eff.sqrt();
308 drifted_flags.push(mmd2 > threshold);
309 }
310
311 let overall_score = scores.iter().copied().sum::<f64>() / scores.len() as f64;
312 let drifted = drifted_flags.iter().any(|&d| d);
313
314 Ok(DriftReport {
315 feature_scores: scores,
316 feature_drifted: drifted_flags,
317 overall_score,
318 drifted,
319 method: DriftMethod::MaximumMeanDiscrepancy,
320 })
321 }
322}
323
324fn ks_2samp(x: &[f64], y: &[f64]) -> Result<(f64, f64)> {
333 if x.is_empty() || y.is_empty() {
334 return Err(TransformError::InvalidInput(
335 "KS samples must be non-empty".to_string(),
336 ));
337 }
338
339 let mut xs = x.to_vec();
340 let mut ys = y.to_vec();
341 xs.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
342 ys.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
343
344 let n1 = xs.len();
345 let n2 = ys.len();
346 let n1f = n1 as f64;
347 let n2f = n2 as f64;
348
349 let mut i = 0usize;
351 let mut j = 0usize;
352 let mut d_max: f64 = 0.0;
353
354 while i < n1 || j < n2 {
355 let xv = if i < n1 { xs[i] } else { f64::INFINITY };
356 let yv = if j < n2 { ys[j] } else { f64::INFINITY };
357
358 let cur = xv.min(yv);
360 while i < n1 && xs[i] <= cur {
361 i += 1;
362 }
363 while j < n2 && ys[j] <= cur {
364 j += 1;
365 }
366
367 let cdf1 = i as f64 / n1f;
368 let cdf2 = j as f64 / n2f;
369 d_max = d_max.max((cdf1 - cdf2).abs());
370 }
371
372 let lambda = d_max * ((n1f * n2f / (n1f + n2f)).sqrt());
374 let p_value = kolmogorov_p_value(lambda);
375
376 Ok((d_max, p_value))
377}
378
379fn kolmogorov_p_value(lambda: f64) -> f64 {
381 if lambda <= 0.0 {
382 return 1.0;
383 }
384 if lambda > 4.0 {
385 return 0.0;
386 }
387
388 let mut sum = 0.0;
389 for j in 1usize..=20 {
391 let jf = j as f64;
392 let term = (-2.0 * jf * jf * lambda * lambda).exp();
393 if j % 2 == 0 {
394 sum -= term;
395 } else {
396 sum += term;
397 }
398 if term < 1e-15 {
399 break;
400 }
401 }
402 (2.0 * sum).clamp(0.0, 1.0)
403}
404
405fn compute_psi(reference: &[f64], current: &[f64], n_bins: usize) -> Result<f64> {
415 if reference.is_empty() || current.is_empty() {
416 return Err(TransformError::InvalidInput(
417 "PSI samples must be non-empty".to_string(),
418 ));
419 }
420 if n_bins == 0 {
421 return Err(TransformError::InvalidInput(
422 "n_bins must be at least 1".to_string(),
423 ));
424 }
425
426 let mut ref_sorted: Vec<f64> = reference.to_vec();
427 ref_sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
428
429 let mut edges: Vec<f64> = Vec::with_capacity(n_bins + 1);
431 edges.push(f64::NEG_INFINITY);
432 for i in 1..n_bins {
433 let q = i as f64 / n_bins as f64;
434 let idx = ((q * reference.len() as f64) as usize).min(reference.len() - 1);
435 edges.push(ref_sorted[idx]);
436 }
437 edges.push(f64::INFINITY);
438
439 edges.dedup_by(|a, b| (*a - *b).abs() < f64::EPSILON);
441
442 let actual_bins = edges.len() - 1;
443 if actual_bins == 0 {
444 return Ok(0.0);
446 }
447
448 let ref_n = reference.len() as f64;
449 let cur_n = current.len() as f64;
450 let epsilon = 1e-8; let mut ref_counts = vec![0u64; actual_bins];
453 let mut cur_counts = vec![0u64; actual_bins];
454
455 for &v in reference {
456 let bin = find_bin(v, &edges);
457 ref_counts[bin] += 1;
458 }
459 for &v in current {
460 let bin = find_bin(v, &edges);
461 cur_counts[bin] += 1;
462 }
463
464 let mut psi = 0.0_f64;
465 for b in 0..actual_bins {
466 let ref_pct = (ref_counts[b] as f64 / ref_n + epsilon).min(1.0);
467 let cur_pct = (cur_counts[b] as f64 / cur_n + epsilon).min(1.0);
468 psi += (ref_pct - cur_pct) * (ref_pct / cur_pct).ln();
469 }
470
471 Ok(psi.max(0.0))
472}
473
474fn find_bin(value: f64, edges: &[f64]) -> usize {
476 let n_bins = edges.len() - 1;
477 let mut lo = 1usize; let mut hi = n_bins;
480 while lo < hi {
481 let mid = (lo + hi) / 2;
482 if edges[mid] <= value {
483 lo = mid + 1;
484 } else {
485 hi = mid;
486 }
487 }
488 (lo - 1).min(n_bins - 1)
489}
490
491fn wasserstein_1d_distance(x: &[f64], y: &[f64]) -> Result<f64> {
500 if x.is_empty() || y.is_empty() {
501 return Err(TransformError::InvalidInput(
502 "Wasserstein samples must be non-empty".to_string(),
503 ));
504 }
505
506 let mut xs: Vec<f64> = x.to_vec();
507 let mut ys: Vec<f64> = y.to_vec();
508 xs.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
509 ys.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
510
511 let n = xs.len();
512 let m = ys.len();
513 let nf = n as f64;
514 let mf = m as f64;
515
516 let mut events: Vec<f64> = xs.iter().chain(ys.iter()).copied().collect();
518 events.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
519 events.dedup_by(|a, b| (*a - *b).abs() < f64::EPSILON * a.abs().max(1.0));
520
521 let mut dist = 0.0_f64;
522 let mut ix = 0usize;
523 let mut iy = 0usize;
524
525 for w in events.windows(2) {
526 let lo = w[0];
527 let hi = w[1];
528 let dx = hi - lo;
529
530 while ix < n && xs[ix] <= lo {
532 ix += 1;
533 }
534 while iy < m && ys[iy] <= lo {
535 iy += 1;
536 }
537
538 let cdf_x = ix as f64 / nf;
539 let cdf_y = iy as f64 / mf;
540 dist += (cdf_x - cdf_y).abs() * dx;
541 }
542
543 Ok(dist)
544}
545
546fn mmd_u_statistic_1d(x: &[f64], y: &[f64], bandwidth: f64) -> Result<f64> {
558 if x.is_empty() || y.is_empty() {
559 return Err(TransformError::InvalidInput(
560 "MMD samples must be non-empty".to_string(),
561 ));
562 }
563 if bandwidth <= 0.0 {
564 return Err(TransformError::InvalidInput(
565 "MMD bandwidth must be positive".to_string(),
566 ));
567 }
568
569 let n = x.len();
570 let m = y.len();
571 let gamma = 1.0 / (2.0 * bandwidth * bandwidth);
572
573 let kxx = if n > 1 {
575 let mut sum = 0.0_f64;
576 for i in 0..n {
577 for j in (i + 1)..n {
578 let d = x[i] - x[j];
579 sum += (-gamma * d * d).exp();
580 }
581 }
582 2.0 * sum / (n * (n - 1)) as f64
583 } else {
584 0.0
585 };
586
587 let kyy = if m > 1 {
589 let mut sum = 0.0_f64;
590 for i in 0..m {
591 for j in (i + 1)..m {
592 let d = y[i] - y[j];
593 sum += (-gamma * d * d).exp();
594 }
595 }
596 2.0 * sum / (m * (m - 1)) as f64
597 } else {
598 0.0
599 };
600
601 let mut kxy_sum = 0.0_f64;
603 for &xi in x {
604 for &yi in y {
605 let d = xi - yi;
606 kxy_sum += (-gamma * d * d).exp();
607 }
608 }
609 let kxy = 2.0 * kxy_sum / (n * m) as f64;
610
611 Ok((kxx + kyy - kxy).max(0.0))
612}
613
614fn median_heuristic_bandwidth(data: &Array2<f64>) -> Result<f64> {
621 let n = data.nrows();
622 if n == 0 {
623 return Err(TransformError::InvalidInput(
624 "Cannot compute bandwidth on empty data".to_string(),
625 ));
626 }
627
628 let max_samples = 500usize;
630 let step = if n > max_samples { n / max_samples } else { 1 };
631 let indices: Vec<usize> = (0..n).step_by(step).collect();
632 let k = indices.len();
633
634 let mut dists: Vec<f64> = Vec::with_capacity(k * (k - 1) / 2);
635 for i in 0..k {
636 for j in (i + 1)..k {
637 let row_i = data.row(indices[i]);
638 let row_j = data.row(indices[j]);
639 let sq_dist: f64 = row_i
640 .iter()
641 .zip(row_j.iter())
642 .map(|(a, b)| (a - b) * (a - b))
643 .sum();
644 dists.push(sq_dist.sqrt());
645 }
646 }
647
648 if dists.is_empty() {
649 return Ok(1.0); }
651
652 dists.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
653 let median = dists[dists.len() / 2];
654 Ok(if median > 0.0 { median } else { 1.0 })
655}
656
657pub fn ks_test(x: &[f64], y: &[f64]) -> Result<(f64, f64)> {
665 ks_2samp(x, y)
666}
667
668pub fn psi(reference: &[f64], current: &[f64], n_bins: usize) -> Result<f64> {
670 compute_psi(reference, current, n_bins)
671}
672
673pub fn wasserstein_distance_1d(x: &[f64], y: &[f64]) -> Result<f64> {
675 wasserstein_1d_distance(x, y)
676}
677
678pub fn mmd_rbf(x: &[f64], y: &[f64], bandwidth: f64) -> Result<f64> {
680 mmd_u_statistic_1d(x, y, bandwidth)
681}
682
683#[cfg(test)]
688mod tests {
689 use super::*;
690 use scirs2_core::ndarray::Array2;
691
692 fn zeros_matrix(rows: usize, cols: usize) -> Array2<f64> {
693 Array2::<f64>::zeros((rows, cols))
694 }
695
696 fn linspace_col(start: f64, end: f64, n: usize) -> Vec<f64> {
697 (0..n)
698 .map(|i| start + (end - start) * (i as f64) / ((n - 1) as f64))
699 .collect()
700 }
701
702 #[test]
707 fn test_ks_no_drift() {
708 let data: Vec<f64> = linspace_col(0.0, 1.0, 100);
710 let config = DriftDetectorConfig {
711 method: DriftMethod::KolmogorovSmirnov,
712 ..Default::default()
713 };
714 let ref_mat = Array2::from_shape_vec((100, 1), data.clone()).expect("shape ok");
715 let cur_mat = Array2::from_shape_vec((100, 1), data).expect("shape ok");
716
717 let detector = DriftDetector::fit(&ref_mat, config);
718 let report = detector.detect(&cur_mat).expect("detect ok");
719 assert!(!report.drifted, "identical distributions should not drift");
720 assert!(
721 report.overall_score < 0.01,
722 "KS statistic should be near 0, got {}",
723 report.overall_score
724 );
725 }
726
727 #[test]
728 fn test_ks_drift_detected() {
729 let ref_data: Vec<f64> = linspace_col(0.0, 1.0, 100);
731 let cur_data: Vec<f64> = linspace_col(10.0, 11.0, 100);
732 let config = DriftDetectorConfig {
733 method: DriftMethod::KolmogorovSmirnov,
734 ..Default::default()
735 };
736 let ref_mat = Array2::from_shape_vec((100, 1), ref_data).expect("shape ok");
737 let cur_mat = Array2::from_shape_vec((100, 1), cur_data).expect("shape ok");
738
739 let detector = DriftDetector::fit(&ref_mat, config);
740 let report = detector.detect(&cur_mat).expect("detect ok");
741 assert!(report.drifted, "clearly shifted distributions should drift");
742 assert!(
743 report.overall_score > 0.9,
744 "KS stat should be close to 1.0, got {}",
745 report.overall_score
746 );
747 }
748
749 #[test]
750 fn test_ks_drift_multifeature() {
751 let n = 100usize;
753 let mut ref_data = vec![0.0f64; n * 3];
754 let mut cur_data = vec![0.0f64; n * 3];
755
756 for i in 0..n {
757 ref_data[i * 3] = i as f64 / n as f64;
759 cur_data[i * 3] = 5.0 + i as f64 / n as f64;
760 ref_data[i * 3 + 1] = i as f64 / n as f64;
762 cur_data[i * 3 + 1] = i as f64 / n as f64;
763 ref_data[i * 3 + 2] = i as f64 / n as f64;
764 cur_data[i * 3 + 2] = i as f64 / n as f64;
765 }
766
767 let ref_mat = Array2::from_shape_vec((n, 3), ref_data).expect("shape ok");
768 let cur_mat = Array2::from_shape_vec((n, 3), cur_data).expect("shape ok");
769 let config = DriftDetectorConfig {
770 method: DriftMethod::KolmogorovSmirnov,
771 ..Default::default()
772 };
773
774 let detector = DriftDetector::fit(&ref_mat, config);
775 let report = detector.detect(&cur_mat).expect("detect ok");
776 assert!(report.drifted, "overall should be drifted");
777 assert!(report.feature_drifted[0], "feature 0 should drift");
778 assert!(!report.feature_drifted[1], "feature 1 should not drift");
779 assert!(!report.feature_drifted[2], "feature 2 should not drift");
780 }
781
782 #[test]
787 fn test_psi_no_drift() {
788 let data: Vec<f64> = linspace_col(0.0, 1.0, 100);
790 let config = DriftDetectorConfig {
791 method: DriftMethod::PopulationStabilityIndex,
792 ..Default::default()
793 };
794 let ref_mat = Array2::from_shape_vec((100, 1), data.clone()).expect("shape ok");
795 let cur_mat = Array2::from_shape_vec((100, 1), data).expect("shape ok");
796
797 let detector = DriftDetector::fit(&ref_mat, config);
798 let report = detector.detect(&cur_mat).expect("detect ok");
799 assert!(
800 !report.drifted,
801 "same distribution should not trigger PSI drift"
802 );
803 }
804
805 #[test]
806 fn test_psi_severe_drift() {
807 let ref_data: Vec<f64> = linspace_col(0.0, 1.0, 200);
809 let cur_data: Vec<f64> = linspace_col(5.0, 6.0, 200);
810 let config = DriftDetectorConfig {
811 method: DriftMethod::PopulationStabilityIndex,
812 psi_threshold: 0.2,
813 ..Default::default()
814 };
815 let ref_mat = Array2::from_shape_vec((200, 1), ref_data).expect("shape ok");
816 let cur_mat = Array2::from_shape_vec((200, 1), cur_data).expect("shape ok");
817
818 let detector = DriftDetector::fit(&ref_mat, config);
819 let report = detector.detect(&cur_mat).expect("detect ok");
820 assert!(report.drifted, "severe shift should trigger PSI drift");
821 assert!(
822 report.overall_score > 0.2,
823 "PSI should exceed threshold, got {}",
824 report.overall_score
825 );
826 }
827
828 #[test]
833 fn test_wasserstein_drift() {
834 let ref_data: Vec<f64> = linspace_col(0.0, 1.0, 100);
837 let cur_data: Vec<f64> = linspace_col(0.5, 1.5, 100);
838 let config = DriftDetectorConfig {
839 method: DriftMethod::Wasserstein,
840 wasserstein_threshold: 0.1,
841 ..Default::default()
842 };
843 let ref_mat = Array2::from_shape_vec((100, 1), ref_data).expect("shape ok");
844 let cur_mat = Array2::from_shape_vec((100, 1), cur_data).expect("shape ok");
845
846 let detector = DriftDetector::fit(&ref_mat, config);
847 let report = detector.detect(&cur_mat).expect("detect ok");
848 assert!(
849 report.drifted,
850 "shifted distribution should trigger W1 drift"
851 );
852 assert!(
853 (report.overall_score - 0.5).abs() < 0.05,
854 "W1 distance should be ~0.5, got {}",
855 report.overall_score
856 );
857 }
858
859 #[test]
860 fn test_wasserstein_no_drift() {
861 let data: Vec<f64> = linspace_col(0.0, 1.0, 100);
863 let config = DriftDetectorConfig {
864 method: DriftMethod::Wasserstein,
865 wasserstein_threshold: 0.1,
866 ..Default::default()
867 };
868 let ref_mat = Array2::from_shape_vec((100, 1), data.clone()).expect("shape ok");
869 let cur_mat = Array2::from_shape_vec((100, 1), data).expect("shape ok");
870
871 let detector = DriftDetector::fit(&ref_mat, config);
872 let report = detector.detect(&cur_mat).expect("detect ok");
873 assert!(!report.drifted, "identical distributions should not drift");
874 assert!(
875 report.overall_score < 1e-10,
876 "W1 should be 0, got {}",
877 report.overall_score
878 );
879 }
880
881 #[test]
886 fn test_mmd_identical() {
887 let data: Vec<f64> = linspace_col(0.0, 1.0, 50);
889 let ref_mat = Array2::from_shape_vec((50, 1), data.clone()).expect("shape ok");
890 let cur_mat = Array2::from_shape_vec((50, 1), data).expect("shape ok");
891 let config = DriftDetectorConfig {
892 method: DriftMethod::MaximumMeanDiscrepancy,
893 mmd_bandwidth: Some(0.5),
894 ..Default::default()
895 };
896
897 let detector = DriftDetector::fit(&ref_mat, config);
898 let report = detector.detect(&cur_mat).expect("detect ok");
899 assert!(
900 report.overall_score < 1e-6,
901 "MMD of identical distributions should be near 0, got {}",
902 report.overall_score
903 );
904 }
905
906 #[test]
907 fn test_mmd_different() {
908 let ref_data: Vec<f64> = linspace_col(-3.0, 3.0, 60);
910 let cur_data: Vec<f64> = linspace_col(2.0, 8.0, 60);
911 let ref_mat = Array2::from_shape_vec((60, 1), ref_data).expect("shape ok");
912 let cur_mat = Array2::from_shape_vec((60, 1), cur_data).expect("shape ok");
913 let config = DriftDetectorConfig {
914 method: DriftMethod::MaximumMeanDiscrepancy,
915 mmd_bandwidth: Some(1.0),
916 ..Default::default()
917 };
918
919 let detector = DriftDetector::fit(&ref_mat, config);
920 let report = detector.detect(&cur_mat).expect("detect ok");
921 assert!(
922 report.overall_score > 0.01,
923 "MMD between N(0,1) and N(5,1) should be positive, got {}",
924 report.overall_score
925 );
926 }
927
928 #[test]
929 fn test_mmd_median_heuristic() {
930 let ref_data: Vec<f64> = linspace_col(0.0, 1.0, 40);
932 let cur_data: Vec<f64> = linspace_col(5.0, 6.0, 40);
933 let ref_mat = Array2::from_shape_vec((40, 1), ref_data).expect("shape ok");
934 let cur_mat = Array2::from_shape_vec((40, 1), cur_data).expect("shape ok");
935 let config = DriftDetectorConfig {
936 method: DriftMethod::MaximumMeanDiscrepancy,
937 mmd_bandwidth: None,
938 ..Default::default()
939 };
940
941 let detector = DriftDetector::fit(&ref_mat, config);
942 let report = detector.detect(&cur_mat).expect("detect ok");
943 assert!(report.overall_score >= 0.0, "MMD² must be non-negative");
945 }
946
947 #[test]
952 fn test_update_reference() {
953 let initial_ref = zeros_matrix(50, 2);
954 let config = DriftDetectorConfig::default();
955 let mut detector = DriftDetector::fit(&initial_ref, config);
956
957 let shifted_ref = Array2::from_elem((50, 2), 5.0);
959 detector.update_reference(&shifted_ref);
960
961 let current = Array2::from_elem((50, 2), 5.0);
963 let report = detector.detect(¤t).expect("detect ok");
964 assert!(!report.drifted, "after update, same data should not drift");
965 }
966
967 #[test]
972 fn test_dimension_mismatch_error() {
973 let ref_mat = zeros_matrix(50, 3);
974 let cur_mat = zeros_matrix(50, 2);
975 let config = DriftDetectorConfig::default();
976 let detector = DriftDetector::fit(&ref_mat, config);
977 let result = detector.detect(&cur_mat);
978 assert!(result.is_err(), "should error on dimension mismatch");
979 }
980
981 #[test]
982 fn test_ks_test_function() {
983 let x: Vec<f64> = linspace_col(0.0, 1.0, 50);
984 let y: Vec<f64> = linspace_col(0.0, 1.0, 50);
985 let (d, p) = ks_test(&x, &y).expect("ks test ok");
986 assert!(d < 0.05, "KS stat should be near 0");
987 assert!(p > 0.5, "p-value should be high");
988 }
989
990 #[test]
991 fn test_psi_function() {
992 let ref_data: Vec<f64> = linspace_col(0.0, 1.0, 100);
993 let cur_data: Vec<f64> = linspace_col(0.0, 1.0, 100);
994 let score = psi(&ref_data, &cur_data, 10).expect("psi ok");
995 assert!(score < 0.01, "PSI should be near 0 for same distribution");
996 }
997
998 #[test]
999 fn test_wasserstein_distance_1d_function() {
1000 let x = vec![0.0, 1.0, 2.0, 3.0];
1001 let y = vec![1.0, 2.0, 3.0, 4.0];
1002 let d = wasserstein_distance_1d(&x, &y).expect("w1 ok");
1003 assert!((d - 1.0).abs() < 0.05, "W1 should be ~1.0, got {d}");
1004 }
1005
1006 #[test]
1007 fn test_mmd_rbf_function() {
1008 let x = vec![0.0f64; 10];
1009 let y = vec![0.0f64; 10];
1010 let mmd2 = mmd_rbf(&x, &y, 1.0).expect("mmd ok");
1011 assert!(mmd2 < 1e-8, "MMD of identical samples should be 0");
1012 }
1013}