scirs2_transform/monitoring/
drift_detection.rs1use crate::error::{Result, TransformError};
16
17#[derive(Debug, Clone)]
19pub struct DriftResult {
20 pub detected: bool,
22 pub statistic: f64,
24 pub p_value: Option<f64>,
26 pub threshold: f64,
28}
29
30pub trait DriftDetector: Send + Sync {
32 fn detect(&self, reference: &[f64], test: &[f64]) -> Result<DriftResult>;
34}
35
36#[derive(Debug, Clone)]
46pub struct KolmogorovSmirnovDetector {
47 significance_level: f64,
49}
50
51impl KolmogorovSmirnovDetector {
52 pub fn new(significance_level: f64) -> Result<Self> {
54 if significance_level <= 0.0 || significance_level >= 1.0 {
55 return Err(TransformError::InvalidInput(
56 "significance_level must be in (0, 1)".to_string(),
57 ));
58 }
59 Ok(Self { significance_level })
60 }
61
62 pub fn default_config() -> Self {
64 Self {
65 significance_level: 0.05,
66 }
67 }
68}
69
70impl DriftDetector for KolmogorovSmirnovDetector {
71 fn detect(&self, reference: &[f64], test: &[f64]) -> Result<DriftResult> {
72 if reference.is_empty() || test.is_empty() {
73 return Err(TransformError::InvalidInput(
74 "Reference and test samples must be non-empty".to_string(),
75 ));
76 }
77
78 let mut ref_sorted: Vec<f64> = reference.to_vec();
79 ref_sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
80 let mut test_sorted: Vec<f64> = test.to_vec();
81 test_sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
82
83 let n1 = reference.len() as f64;
84 let n2 = test.len() as f64;
85
86 let mut i = 0usize;
88 let mut j = 0usize;
89 let mut d_max: f64 = 0.0;
90
91 while i < ref_sorted.len() || j < test_sorted.len() {
92 let ref_val = if i < ref_sorted.len() {
93 ref_sorted[i]
94 } else {
95 f64::INFINITY
96 };
97 let test_val = if j < test_sorted.len() {
98 test_sorted[j]
99 } else {
100 f64::INFINITY
101 };
102
103 if ref_val <= test_val {
104 i += 1;
105 }
106 if test_val <= ref_val {
107 j += 1;
108 }
109
110 let cdf_ref = (i as f64) / n1;
111 let cdf_test = (j as f64) / n2;
112 let diff = (cdf_ref - cdf_test).abs();
113 if diff > d_max {
114 d_max = diff;
115 }
116 }
117
118 let en = (n1 * n2 / (n1 + n2)).sqrt();
120 let lambda = (en + 0.12 + 0.11 / en) * d_max;
121 let p_value = ks_p_value(lambda);
122
123 let threshold = ks_critical_value(n1 as usize, n2 as usize, self.significance_level);
124 let detected = d_max > threshold;
125
126 Ok(DriftResult {
127 detected,
128 statistic: d_max,
129 p_value: Some(p_value),
130 threshold,
131 })
132 }
133}
134
135fn ks_critical_value(n1: usize, n2: usize, alpha: f64) -> f64 {
137 let n = ((n1 * n2) as f64 / (n1 + n2) as f64).sqrt();
138 let c = (-0.5 * (alpha / 2.0).ln()).sqrt();
140 c / n
141}
142
143fn ks_p_value(lambda: f64) -> f64 {
147 if lambda <= 0.0 {
148 return 1.0;
149 }
150 if lambda > 4.0 {
151 return 0.0; }
153
154 let mut p = 0.0;
155 for k in 1..=100 {
156 let sign = if k % 2 == 1 { 1.0 } else { -1.0 };
157 let term = sign * (-2.0 * (k as f64).powi(2) * lambda * lambda).exp();
158 p += term;
159 if term.abs() < 1e-15 {
160 break;
161 }
162 }
163 (2.0 * p).clamp(0.0, 1.0)
164}
165
166#[derive(Debug, Clone)]
181pub struct PopulationStabilityIndexDetector {
182 n_bins: usize,
184 threshold: f64,
186}
187
188impl PopulationStabilityIndexDetector {
189 pub fn new(n_bins: usize, threshold: f64) -> Result<Self> {
194 if n_bins < 2 {
195 return Err(TransformError::InvalidInput(
196 "n_bins must be >= 2".to_string(),
197 ));
198 }
199 if threshold <= 0.0 {
200 return Err(TransformError::InvalidInput(
201 "threshold must be positive".to_string(),
202 ));
203 }
204 Ok(Self { n_bins, threshold })
205 }
206
207 pub fn default_config() -> Self {
209 Self {
210 n_bins: 10,
211 threshold: 0.25,
212 }
213 }
214}
215
216impl DriftDetector for PopulationStabilityIndexDetector {
217 fn detect(&self, reference: &[f64], test: &[f64]) -> Result<DriftResult> {
218 if reference.is_empty() || test.is_empty() {
219 return Err(TransformError::InvalidInput(
220 "Reference and test samples must be non-empty".to_string(),
221 ));
222 }
223
224 let mut global_min = f64::INFINITY;
226 let mut global_max = f64::NEG_INFINITY;
227 for &v in reference.iter().chain(test.iter()) {
228 if v < global_min {
229 global_min = v;
230 }
231 if v > global_max {
232 global_max = v;
233 }
234 }
235
236 if (global_max - global_min).abs() < 1e-15 {
237 return Ok(DriftResult {
239 detected: false,
240 statistic: 0.0,
241 p_value: None,
242 threshold: self.threshold,
243 });
244 }
245
246 let bin_width = (global_max - global_min) / self.n_bins as f64;
247 let eps = 1e-10; let ref_counts = bin_counts(reference, global_min, bin_width, self.n_bins);
251 let test_counts = bin_counts(test, global_min, bin_width, self.n_bins);
252
253 let n_ref = reference.len() as f64;
254 let n_test = test.len() as f64;
255
256 let mut psi = 0.0;
257 for i in 0..self.n_bins {
258 let p = (ref_counts[i] as f64 / n_ref) + eps;
259 let q = (test_counts[i] as f64 / n_test) + eps;
260 psi += (p - q) * (p / q).ln();
261 }
262
263 Ok(DriftResult {
264 detected: psi > self.threshold,
265 statistic: psi,
266 p_value: None,
267 threshold: self.threshold,
268 })
269 }
270}
271
272fn bin_counts(data: &[f64], min_val: f64, bin_width: f64, n_bins: usize) -> Vec<usize> {
274 let mut counts = vec![0usize; n_bins];
275 for &v in data {
276 let idx = ((v - min_val) / bin_width).floor() as usize;
277 let idx = idx.min(n_bins - 1);
278 counts[idx] += 1;
279 }
280 counts
281}
282
283#[derive(Debug, Clone)]
292pub struct WassersteinDetector {
293 threshold: f64,
295}
296
297impl WassersteinDetector {
298 pub fn new(threshold: f64) -> Result<Self> {
300 if threshold <= 0.0 {
301 return Err(TransformError::InvalidInput(
302 "threshold must be positive".to_string(),
303 ));
304 }
305 Ok(Self { threshold })
306 }
307}
308
309impl DriftDetector for WassersteinDetector {
310 fn detect(&self, reference: &[f64], test: &[f64]) -> Result<DriftResult> {
311 if reference.is_empty() || test.is_empty() {
312 return Err(TransformError::InvalidInput(
313 "Reference and test samples must be non-empty".to_string(),
314 ));
315 }
316
317 let mut ref_sorted: Vec<f64> = reference.to_vec();
318 ref_sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
319 let mut test_sorted: Vec<f64> = test.to_vec();
320 test_sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
321
322 let n1 = reference.len() as f64;
323 let n2 = test.len() as f64;
324
325 let mut all_vals: Vec<f64> = Vec::with_capacity(reference.len() + test.len());
327 all_vals.extend_from_slice(&ref_sorted);
328 all_vals.extend_from_slice(&test_sorted);
329 all_vals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
330 all_vals.dedup();
331
332 let mut distance = 0.0;
333 let mut prev_val = all_vals[0];
334
335 for &val in all_vals.iter().skip(1) {
336 let cdf_ref = count_le(&ref_sorted, prev_val) as f64 / n1;
338 let cdf_test = count_le(&test_sorted, prev_val) as f64 / n2;
339 distance += (cdf_ref - cdf_test).abs() * (val - prev_val);
340 prev_val = val;
341 }
342
343 Ok(DriftResult {
344 detected: distance > self.threshold,
345 statistic: distance,
346 p_value: None,
347 threshold: self.threshold,
348 })
349 }
350}
351
352fn count_le(sorted: &[f64], val: f64) -> usize {
354 match sorted.binary_search_by(|x| x.partial_cmp(&val).unwrap_or(std::cmp::Ordering::Equal)) {
355 Ok(pos) => {
356 let mut p = pos;
358 while p + 1 < sorted.len()
359 && sorted[p + 1]
360 .partial_cmp(&val)
361 .unwrap_or(std::cmp::Ordering::Equal)
362 == std::cmp::Ordering::Equal
363 {
364 p += 1;
365 }
366 p + 1
367 }
368 Err(pos) => pos,
369 }
370}
371
372#[derive(Debug, Clone)]
389pub struct MaximumMeanDiscrepancyDetector {
390 sigma: f64,
392 dim: usize,
394 threshold: f64,
396}
397
398impl MaximumMeanDiscrepancyDetector {
399 pub fn new(dim: usize, sigma: f64, threshold: f64) -> Result<Self> {
405 if dim == 0 {
406 return Err(TransformError::InvalidInput(
407 "dim must be positive".to_string(),
408 ));
409 }
410 if sigma <= 0.0 {
411 return Err(TransformError::InvalidInput(
412 "sigma must be positive".to_string(),
413 ));
414 }
415 if threshold <= 0.0 {
416 return Err(TransformError::InvalidInput(
417 "threshold must be positive".to_string(),
418 ));
419 }
420 Ok(Self {
421 sigma,
422 dim,
423 threshold,
424 })
425 }
426
427 pub fn detect_multivariate(&self, reference: &[f64], test: &[f64]) -> Result<DriftResult> {
432 if reference.len() % self.dim != 0 || test.len() % self.dim != 0 {
433 return Err(TransformError::InvalidInput(format!(
434 "Data length must be a multiple of dim ({})",
435 self.dim
436 )));
437 }
438
439 let n_ref = reference.len() / self.dim;
440 let n_test = test.len() / self.dim;
441
442 if n_ref < 2 || n_test < 2 {
443 return Err(TransformError::InvalidInput(
444 "Need at least 2 samples in each set".to_string(),
445 ));
446 }
447
448 let gamma = 1.0 / (2.0 * self.sigma * self.sigma);
449
450 let mut kxx = 0.0;
452 for i in 0..n_ref {
453 for j in (i + 1)..n_ref {
454 kxx += rbf_kernel(
455 &reference[i * self.dim..(i + 1) * self.dim],
456 &reference[j * self.dim..(j + 1) * self.dim],
457 gamma,
458 );
459 }
460 }
461 kxx *= 2.0 / (n_ref * (n_ref - 1)) as f64;
462
463 let mut kyy = 0.0;
465 for i in 0..n_test {
466 for j in (i + 1)..n_test {
467 kyy += rbf_kernel(
468 &test[i * self.dim..(i + 1) * self.dim],
469 &test[j * self.dim..(j + 1) * self.dim],
470 gamma,
471 );
472 }
473 }
474 kyy *= 2.0 / (n_test * (n_test - 1)) as f64;
475
476 let mut kxy = 0.0;
478 for i in 0..n_ref {
479 for j in 0..n_test {
480 kxy += rbf_kernel(
481 &reference[i * self.dim..(i + 1) * self.dim],
482 &test[j * self.dim..(j + 1) * self.dim],
483 gamma,
484 );
485 }
486 }
487 kxy /= (n_ref * n_test) as f64;
488
489 let mmd2 = kxx - 2.0 * kxy + kyy;
490 let mmd2 = mmd2.max(0.0); Ok(DriftResult {
493 detected: mmd2 > self.threshold,
494 statistic: mmd2,
495 p_value: None,
496 threshold: self.threshold,
497 })
498 }
499}
500
501impl DriftDetector for MaximumMeanDiscrepancyDetector {
502 fn detect(&self, reference: &[f64], test: &[f64]) -> Result<DriftResult> {
504 if self.dim != 1 {
505 return Err(TransformError::InvalidInput(
506 "Use detect_multivariate() for dim > 1".to_string(),
507 ));
508 }
509 self.detect_multivariate(reference, test)
510 }
511}
512
513fn rbf_kernel(x: &[f64], y: &[f64], gamma: f64) -> f64 {
515 let sq_dist: f64 = x.iter().zip(y.iter()).map(|(a, b)| (a - b).powi(2)).sum();
516 (-gamma * sq_dist).exp()
517}
518
519#[cfg(test)]
524mod tests {
525 use super::*;
526
527 #[test]
528 fn test_ks_no_drift_same_distribution() {
529 let reference: Vec<f64> = (0..200).map(|i| (i as f64) * 0.01).collect();
530 let test: Vec<f64> = (0..200).map(|i| (i as f64) * 0.01 + 0.001).collect();
531
532 let ks = KolmogorovSmirnovDetector::default_config();
533 let result = ks.detect(&reference, &test).expect("detect");
534
535 assert!(
536 !result.detected,
537 "Should NOT detect drift on nearly identical distributions: stat={}",
538 result.statistic
539 );
540 assert!(result.p_value.is_some());
541 }
542
543 #[test]
544 fn test_ks_detect_mean_shift() {
545 let reference: Vec<f64> = (0..300).map(|i| (i as f64) * 0.01).collect();
546 let test: Vec<f64> = (0..300).map(|i| (i as f64) * 0.01 + 5.0).collect();
547
548 let ks = KolmogorovSmirnovDetector::default_config();
549 let result = ks.detect(&reference, &test).expect("detect");
550
551 assert!(
552 result.detected,
553 "Should detect drift after mean shift of 5.0: stat={}",
554 result.statistic
555 );
556 }
557
558 #[test]
559 fn test_ks_empty_input() {
560 let ks = KolmogorovSmirnovDetector::default_config();
561 assert!(ks.detect(&[], &[1.0]).is_err());
562 assert!(ks.detect(&[1.0], &[]).is_err());
563 }
564
565 #[test]
566 fn test_ks_invalid_significance() {
567 assert!(KolmogorovSmirnovDetector::new(0.0).is_err());
568 assert!(KolmogorovSmirnovDetector::new(1.0).is_err());
569 assert!(KolmogorovSmirnovDetector::new(-0.1).is_err());
570 }
571
572 #[test]
573 fn test_psi_identical_distributions() {
574 let data: Vec<f64> = (0..500).map(|i| (i as f64) * 0.01).collect();
575
576 let psi = PopulationStabilityIndexDetector::default_config();
577 let result = psi.detect(&data, &data).expect("detect");
578
579 assert!(
580 result.statistic < 0.01,
581 "PSI for identical distributions should be ~0, got {}",
582 result.statistic
583 );
584 assert!(!result.detected);
585 }
586
587 #[test]
588 fn test_psi_detect_shift() {
589 let reference: Vec<f64> = (0..500).map(|i| (i as f64) * 0.01).collect();
590 let test: Vec<f64> = (0..500).map(|i| (i as f64) * 0.01 + 10.0).collect();
591
592 let psi = PopulationStabilityIndexDetector::default_config();
593 let result = psi.detect(&reference, &test).expect("detect");
594
595 assert!(
596 result.detected,
597 "PSI should detect large distribution shift: psi={}",
598 result.statistic
599 );
600 }
601
602 #[test]
603 fn test_psi_constant_values() {
604 let data = vec![1.0; 100];
605 let psi = PopulationStabilityIndexDetector::default_config();
606 let result = psi.detect(&data, &data).expect("detect");
607 assert!(!result.detected);
608 assert!(result.statistic.abs() < 1e-10);
609 }
610
611 #[test]
612 fn test_wasserstein_no_drift() {
613 let reference: Vec<f64> = (0..200).map(|i| (i as f64) * 0.01).collect();
614 let test: Vec<f64> = (0..200).map(|i| (i as f64) * 0.01 + 0.001).collect();
615
616 let w = WassersteinDetector::new(1.0).expect("create");
617 let result = w.detect(&reference, &test).expect("detect");
618
619 assert!(
620 !result.detected,
621 "Should not detect drift: distance={}",
622 result.statistic
623 );
624 }
625
626 #[test]
627 fn test_wasserstein_detect_shift() {
628 let reference: Vec<f64> = (0..200).map(|i| (i as f64) * 0.01).collect();
629 let test: Vec<f64> = (0..200).map(|i| (i as f64) * 0.01 + 10.0).collect();
630
631 let w = WassersteinDetector::new(1.0).expect("create");
632 let result = w.detect(&reference, &test).expect("detect");
633
634 assert!(
635 result.detected,
636 "Should detect shift of 10.0: distance={}",
637 result.statistic
638 );
639 }
640
641 #[test]
642 fn test_mmd_no_drift() {
643 let reference: Vec<f64> = (0..100).map(|i| (i as f64) * 0.1).collect();
645 let test: Vec<f64> = (0..100).map(|i| (i as f64) * 0.1 + 0.01).collect();
646
647 let mmd = MaximumMeanDiscrepancyDetector::new(1, 1.0, 0.1).expect("create");
648 let result = mmd.detect(&reference, &test).expect("detect");
649
650 assert!(
651 !result.detected,
652 "Should not detect drift on similar distributions: mmd2={}",
653 result.statistic
654 );
655 }
656
657 #[test]
658 fn test_mmd_detect_shift() {
659 let reference: Vec<f64> = (0..50).map(|i| (i as f64) * 0.1).collect();
660 let test: Vec<f64> = (0..50).map(|i| (i as f64) * 0.1 + 100.0).collect();
661
662 let mmd = MaximumMeanDiscrepancyDetector::new(1, 1.0, 0.01).expect("create");
663 let result = mmd.detect(&reference, &test).expect("detect");
664
665 assert!(
666 result.detected,
667 "Should detect large shift: mmd2={}",
668 result.statistic
669 );
670 }
671
672 #[test]
673 fn test_mmd_multivariate() {
674 let dim = 3;
675 let reference: Vec<f64> = (0..60).map(|i| (i as f64) * 0.01).collect();
677 let test: Vec<f64> = (0..60).map(|i| (i as f64) * 0.01 + 50.0).collect();
679
680 let mmd = MaximumMeanDiscrepancyDetector::new(dim, 1.0, 0.01).expect("create");
681 let result = mmd.detect_multivariate(&reference, &test).expect("detect");
682
683 assert!(
684 result.detected,
685 "Should detect multivariate drift: mmd2={}",
686 result.statistic
687 );
688 }
689
690 #[test]
691 fn test_mmd_error_wrong_dim() {
692 let mmd = MaximumMeanDiscrepancyDetector::new(3, 1.0, 0.1).expect("create");
693 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
695 assert!(mmd.detect_multivariate(&data, &data).is_err());
696 }
697}