1use scirs2_core::ndarray::{Array1, Array2};
7use scirs2_core::SliceRandomExt;
8use sklears_core::error::{Result, SklearsError};
9use std::cmp::Ordering;
10
11#[derive(Debug, Clone)]
13pub struct DriftDetectionConfig {
14 pub detection_method: DriftDetectionMethod,
16 pub alpha: f64,
18 pub window_size: usize,
20 pub warning_threshold: f64,
22 pub min_samples: usize,
24 pub multivariate: bool,
26 pub random_state: Option<u64>,
28}
29
30impl Default for DriftDetectionConfig {
31 fn default() -> Self {
32 Self {
33 detection_method: DriftDetectionMethod::KolmogorovSmirnov,
34 alpha: 0.05,
35 window_size: 100,
36 warning_threshold: 0.5,
37 min_samples: 30,
38 multivariate: false,
39 random_state: None,
40 }
41 }
42}
43
44#[derive(Debug, Clone)]
46pub enum DriftDetectionMethod {
47 KolmogorovSmirnov,
49 AndersonDarling,
51 MannWhitney,
53 Permutation,
55 PopulationStabilityIndex,
57 MaximumMeanDiscrepancy,
59 ADWIN,
61 PageHinkley,
63 DDM,
65 EDDM,
67}
68
69#[derive(Debug, Clone)]
71pub struct DriftDetectionResult {
72 pub drift_detected: bool,
74 pub warning_detected: bool,
76 pub test_statistic: f64,
78 pub p_value: Option<f64>,
80 pub threshold: f64,
82 pub drift_score: f64,
84 pub feature_drift_scores: Option<Vec<f64>>,
86 pub statistics: DriftStatistics,
88}
89
90#[derive(Debug, Clone)]
92pub struct DriftStatistics {
93 pub n_reference: usize,
95 pub n_current: usize,
97 pub n_features: usize,
99 pub drift_magnitude: f64,
101 pub confidence: f64,
103 pub time_since_drift: Option<usize>,
105}
106
107#[derive(Debug, Clone)]
109pub struct DriftDetector {
110 config: DriftDetectionConfig,
111 reference_data: Option<Array2<f64>>,
112 current_window: Vec<Array1<f64>>,
113 drift_history: Vec<DriftDetectionResult>,
114 last_drift_time: Option<usize>,
115}
116
117impl DriftDetector {
118 pub fn new(config: DriftDetectionConfig) -> Self {
119 Self {
120 config,
121 reference_data: None,
122 current_window: Vec::new(),
123 drift_history: Vec::new(),
124 last_drift_time: None,
125 }
126 }
127
128 pub fn set_reference(&mut self, reference_data: Array2<f64>) {
130 self.reference_data = Some(reference_data);
131 }
132
133 pub fn detect_drift(&mut self, current_data: &Array2<f64>) -> Result<DriftDetectionResult> {
135 if self.reference_data.is_none() {
136 return Err(SklearsError::NotFitted {
137 operation: "drift detection".to_string(),
138 });
139 }
140
141 let reference = self.reference_data.as_ref().unwrap();
142
143 if reference.ncols() != current_data.ncols() {
144 return Err(SklearsError::InvalidInput(
145 "Reference and current data must have same number of features".to_string(),
146 ));
147 }
148
149 let result = match self.config.detection_method {
150 DriftDetectionMethod::KolmogorovSmirnov => {
151 self.kolmogorov_smirnov_test(reference, current_data)?
152 }
153 DriftDetectionMethod::AndersonDarling => {
154 self.anderson_darling_test(reference, current_data)?
155 }
156 DriftDetectionMethod::MannWhitney => self.mann_whitney_test(reference, current_data)?,
157 DriftDetectionMethod::Permutation => self.permutation_test(reference, current_data)?,
158 DriftDetectionMethod::PopulationStabilityIndex => {
159 self.population_stability_index(reference, current_data)?
160 }
161 DriftDetectionMethod::MaximumMeanDiscrepancy => {
162 self.maximum_mean_discrepancy(reference, current_data)?
163 }
164 DriftDetectionMethod::ADWIN => self.adwin_test(current_data)?,
165 DriftDetectionMethod::PageHinkley => self.page_hinkley_test(current_data)?,
166 DriftDetectionMethod::DDM => self.ddm_test(current_data)?,
167 DriftDetectionMethod::EDDM => self.eddm_test(current_data)?,
168 };
169
170 self.drift_history.push(result.clone());
171
172 if result.drift_detected {
173 self.last_drift_time = Some(self.drift_history.len());
174 }
175
176 Ok(result)
177 }
178
179 fn kolmogorov_smirnov_test(
181 &self,
182 reference: &Array2<f64>,
183 current: &Array2<f64>,
184 ) -> Result<DriftDetectionResult> {
185 let n_features = reference.ncols();
186 let mut feature_scores = Vec::new();
187 let mut max_statistic: f64 = 0.0;
188 let mut min_p_value: f64 = 1.0;
189
190 for feature_idx in 0..n_features {
191 let ref_feature: Vec<f64> = (0..reference.nrows())
192 .map(|i| reference[[i, feature_idx]])
193 .collect();
194 let cur_feature: Vec<f64> = (0..current.nrows())
195 .map(|i| current[[i, feature_idx]])
196 .collect();
197
198 let (statistic, p_value) = self.ks_test(&ref_feature, &cur_feature);
199 feature_scores.push(statistic);
200 max_statistic = max_statistic.max(statistic);
201 min_p_value = min_p_value.min(p_value);
202 }
203
204 let drift_detected = min_p_value < self.config.alpha;
205 let warning_detected =
206 min_p_value < self.config.alpha * (1.0 + self.config.warning_threshold);
207
208 let statistics = DriftStatistics {
209 n_reference: reference.nrows(),
210 n_current: current.nrows(),
211 n_features,
212 drift_magnitude: max_statistic,
213 confidence: 1.0 - min_p_value,
214 time_since_drift: self.time_since_drift(),
215 };
216
217 Ok(DriftDetectionResult {
218 drift_detected,
219 warning_detected,
220 test_statistic: max_statistic,
221 p_value: Some(min_p_value),
222 threshold: self.config.alpha,
223 drift_score: max_statistic,
224 feature_drift_scores: Some(feature_scores),
225 statistics,
226 })
227 }
228
229 fn anderson_darling_test(
231 &self,
232 reference: &Array2<f64>,
233 current: &Array2<f64>,
234 ) -> Result<DriftDetectionResult> {
235 let n_features = reference.ncols();
237 let mut feature_scores = Vec::new();
238 let mut max_statistic: f64 = 0.0;
239
240 for feature_idx in 0..n_features {
241 let ref_feature: Vec<f64> = (0..reference.nrows())
242 .map(|i| reference[[i, feature_idx]])
243 .collect();
244 let cur_feature: Vec<f64> = (0..current.nrows())
245 .map(|i| current[[i, feature_idx]])
246 .collect();
247
248 let statistic = self.anderson_darling_statistic(&ref_feature, &cur_feature);
249 feature_scores.push(statistic);
250 max_statistic = max_statistic.max(statistic);
251 }
252
253 let threshold = 2.492; let drift_detected = max_statistic > threshold;
256 let warning_detected = max_statistic > threshold * self.config.warning_threshold;
257
258 let statistics = DriftStatistics {
259 n_reference: reference.nrows(),
260 n_current: current.nrows(),
261 n_features,
262 drift_magnitude: max_statistic,
263 confidence: if drift_detected { 0.95 } else { 0.5 },
264 time_since_drift: self.time_since_drift(),
265 };
266
267 Ok(DriftDetectionResult {
268 drift_detected,
269 warning_detected,
270 test_statistic: max_statistic,
271 p_value: None,
272 threshold,
273 drift_score: max_statistic,
274 feature_drift_scores: Some(feature_scores),
275 statistics,
276 })
277 }
278
279 fn mann_whitney_test(
281 &self,
282 reference: &Array2<f64>,
283 current: &Array2<f64>,
284 ) -> Result<DriftDetectionResult> {
285 let n_features = reference.ncols();
286 let mut feature_scores = Vec::new();
287 let mut max_statistic: f64 = 0.0;
288 let mut min_p_value: f64 = 1.0;
289
290 for feature_idx in 0..n_features {
291 let ref_feature: Vec<f64> = (0..reference.nrows())
292 .map(|i| reference[[i, feature_idx]])
293 .collect();
294 let cur_feature: Vec<f64> = (0..current.nrows())
295 .map(|i| current[[i, feature_idx]])
296 .collect();
297
298 let (u_statistic, p_value) = self.mann_whitney_u_test(&ref_feature, &cur_feature);
299 let normalized_statistic = u_statistic / (ref_feature.len() * cur_feature.len()) as f64;
300
301 feature_scores.push(normalized_statistic);
302 max_statistic = max_statistic.max(normalized_statistic);
303 min_p_value = min_p_value.min(p_value);
304 }
305
306 let drift_detected = min_p_value < self.config.alpha;
307 let warning_detected =
308 min_p_value < self.config.alpha * (1.0 + self.config.warning_threshold);
309
310 let statistics = DriftStatistics {
311 n_reference: reference.nrows(),
312 n_current: current.nrows(),
313 n_features,
314 drift_magnitude: max_statistic,
315 confidence: 1.0 - min_p_value,
316 time_since_drift: self.time_since_drift(),
317 };
318
319 Ok(DriftDetectionResult {
320 drift_detected,
321 warning_detected,
322 test_statistic: max_statistic,
323 p_value: Some(min_p_value),
324 threshold: self.config.alpha,
325 drift_score: max_statistic,
326 feature_drift_scores: Some(feature_scores),
327 statistics,
328 })
329 }
330
331 fn permutation_test(
333 &self,
334 reference: &Array2<f64>,
335 current: &Array2<f64>,
336 ) -> Result<DriftDetectionResult> {
337 let n_permutations = 1000;
338 let observed_statistic = self.calculate_permutation_statistic(reference, current);
339
340 let mut permutation_statistics = Vec::new();
341 let combined_data = self.combine_data(reference, current);
342 let n_ref = reference.nrows();
343
344 for _ in 0..n_permutations {
345 let (perm_ref, perm_cur) = self.random_permutation_split(&combined_data, n_ref);
346 let perm_statistic = self.calculate_permutation_statistic(&perm_ref, &perm_cur);
347 permutation_statistics.push(perm_statistic);
348 }
349
350 let extreme_count = permutation_statistics
352 .iter()
353 .filter(|&&stat| stat >= observed_statistic)
354 .count();
355 let p_value = extreme_count as f64 / n_permutations as f64;
356
357 let drift_detected = p_value < self.config.alpha;
358 let warning_detected = p_value < self.config.alpha * (1.0 + self.config.warning_threshold);
359
360 let statistics = DriftStatistics {
361 n_reference: reference.nrows(),
362 n_current: current.nrows(),
363 n_features: reference.ncols(),
364 drift_magnitude: observed_statistic,
365 confidence: 1.0 - p_value,
366 time_since_drift: self.time_since_drift(),
367 };
368
369 Ok(DriftDetectionResult {
370 drift_detected,
371 warning_detected,
372 test_statistic: observed_statistic,
373 p_value: Some(p_value),
374 threshold: self.config.alpha,
375 drift_score: observed_statistic,
376 feature_drift_scores: None,
377 statistics,
378 })
379 }
380
381 fn population_stability_index(
383 &self,
384 reference: &Array2<f64>,
385 current: &Array2<f64>,
386 ) -> Result<DriftDetectionResult> {
387 let n_features = reference.ncols();
388 let n_bins = 10;
389 let mut feature_scores = Vec::new();
390 let mut total_psi = 0.0;
391
392 for feature_idx in 0..n_features {
393 let ref_feature: Vec<f64> = (0..reference.nrows())
394 .map(|i| reference[[i, feature_idx]])
395 .collect();
396 let cur_feature: Vec<f64> = (0..current.nrows())
397 .map(|i| current[[i, feature_idx]])
398 .collect();
399
400 let psi = self.calculate_psi(&ref_feature, &cur_feature, n_bins);
401 feature_scores.push(psi);
402 total_psi += psi;
403 }
404
405 let avg_psi = total_psi / n_features as f64;
406
407 let drift_detected = avg_psi > 0.2;
409 let warning_detected = avg_psi > 0.1;
410
411 let statistics = DriftStatistics {
412 n_reference: reference.nrows(),
413 n_current: current.nrows(),
414 n_features,
415 drift_magnitude: avg_psi,
416 confidence: if drift_detected { 0.8 } else { 0.5 },
417 time_since_drift: self.time_since_drift(),
418 };
419
420 Ok(DriftDetectionResult {
421 drift_detected,
422 warning_detected,
423 test_statistic: avg_psi,
424 p_value: None,
425 threshold: 0.2,
426 drift_score: avg_psi,
427 feature_drift_scores: Some(feature_scores),
428 statistics,
429 })
430 }
431
432 fn maximum_mean_discrepancy(
434 &self,
435 reference: &Array2<f64>,
436 current: &Array2<f64>,
437 ) -> Result<DriftDetectionResult> {
438 let mmd_statistic = self.calculate_mmd(reference, current);
439
440 let n_permutations = 1000;
442 let mut permutation_mmds = Vec::new();
443 let combined_data = self.combine_data(reference, current);
444 let n_ref = reference.nrows();
445
446 for _ in 0..n_permutations {
447 let (perm_ref, perm_cur) = self.random_permutation_split(&combined_data, n_ref);
448 let perm_mmd = self.calculate_mmd(&perm_ref, &perm_cur);
449 permutation_mmds.push(perm_mmd);
450 }
451
452 let extreme_count = permutation_mmds
453 .iter()
454 .filter(|&&mmd| mmd >= mmd_statistic)
455 .count();
456 let p_value = extreme_count as f64 / n_permutations as f64;
457
458 let drift_detected = p_value < self.config.alpha;
459 let warning_detected = p_value < self.config.alpha * (1.0 + self.config.warning_threshold);
460
461 let statistics = DriftStatistics {
462 n_reference: reference.nrows(),
463 n_current: current.nrows(),
464 n_features: reference.ncols(),
465 drift_magnitude: mmd_statistic,
466 confidence: 1.0 - p_value,
467 time_since_drift: self.time_since_drift(),
468 };
469
470 Ok(DriftDetectionResult {
471 drift_detected,
472 warning_detected,
473 test_statistic: mmd_statistic,
474 p_value: Some(p_value),
475 threshold: self.config.alpha,
476 drift_score: mmd_statistic,
477 feature_drift_scores: None,
478 statistics,
479 })
480 }
481
482 fn adwin_test(&mut self, current: &Array2<f64>) -> Result<DriftDetectionResult> {
484 let n_samples = current.nrows();
488 let avg_performance = self.calculate_average_performance(current);
489
490 for i in 0..n_samples {
492 let sample = current.row(i).to_owned();
493 self.current_window.push(sample);
494 }
495
496 if self.current_window.len() > self.config.window_size * 2 {
498 let excess = self.current_window.len() - self.config.window_size;
499 self.current_window.drain(0..excess);
500 }
501
502 let drift_detected =
503 self.current_window.len() >= self.config.min_samples && avg_performance < 0.5; let statistics = DriftStatistics {
506 n_reference: 0,
507 n_current: current.nrows(),
508 n_features: current.ncols(),
509 drift_magnitude: 1.0 - avg_performance,
510 confidence: if drift_detected { 0.8 } else { 0.5 },
511 time_since_drift: self.time_since_drift(),
512 };
513
514 Ok(DriftDetectionResult {
515 drift_detected,
516 warning_detected: avg_performance < 0.7,
517 test_statistic: 1.0 - avg_performance,
518 p_value: None,
519 threshold: 0.5,
520 drift_score: 1.0 - avg_performance,
521 feature_drift_scores: None,
522 statistics,
523 })
524 }
525
526 fn page_hinkley_test(&self, current: &Array2<f64>) -> Result<DriftDetectionResult> {
528 let avg_performance = self.calculate_average_performance(current);
530 let threshold = 3.0; let cumulative_sum = (0.5 - avg_performance) * current.nrows() as f64;
533 let drift_detected = cumulative_sum.abs() > threshold;
534
535 let statistics = DriftStatistics {
536 n_reference: 0,
537 n_current: current.nrows(),
538 n_features: current.ncols(),
539 drift_magnitude: cumulative_sum.abs(),
540 confidence: if drift_detected { 0.8 } else { 0.5 },
541 time_since_drift: self.time_since_drift(),
542 };
543
544 Ok(DriftDetectionResult {
545 drift_detected,
546 warning_detected: cumulative_sum.abs() > threshold * 0.7,
547 test_statistic: cumulative_sum.abs(),
548 p_value: None,
549 threshold,
550 drift_score: cumulative_sum.abs(),
551 feature_drift_scores: None,
552 statistics,
553 })
554 }
555
556 fn ddm_test(&self, current: &Array2<f64>) -> Result<DriftDetectionResult> {
558 let error_rate = 1.0 - self.calculate_average_performance(current);
560 let std_error = (error_rate * (1.0 - error_rate) / current.nrows() as f64).sqrt();
561
562 let warning_threshold = error_rate + 2.0 * std_error;
563 let drift_threshold = error_rate + 3.0 * std_error;
564
565 let drift_detected = error_rate > drift_threshold;
566 let warning_detected = error_rate > warning_threshold;
567
568 let statistics = DriftStatistics {
569 n_reference: 0,
570 n_current: current.nrows(),
571 n_features: current.ncols(),
572 drift_magnitude: error_rate,
573 confidence: if drift_detected {
574 0.99
575 } else if warning_detected {
576 0.95
577 } else {
578 0.5
579 },
580 time_since_drift: self.time_since_drift(),
581 };
582
583 Ok(DriftDetectionResult {
584 drift_detected,
585 warning_detected,
586 test_statistic: error_rate,
587 p_value: None,
588 threshold: drift_threshold,
589 drift_score: error_rate,
590 feature_drift_scores: None,
591 statistics,
592 })
593 }
594
595 fn eddm_test(&self, current: &Array2<f64>) -> Result<DriftDetectionResult> {
597 let avg_performance = self.calculate_average_performance(current);
599 let _distance_between_errors = 1.0 / (1.0 - avg_performance + 1e-8);
600
601 let threshold = 0.95;
602 let drift_detected = avg_performance < threshold;
603
604 let statistics = DriftStatistics {
605 n_reference: 0,
606 n_current: current.nrows(),
607 n_features: current.ncols(),
608 drift_magnitude: 1.0 - avg_performance,
609 confidence: if drift_detected { 0.8 } else { 0.5 },
610 time_since_drift: self.time_since_drift(),
611 };
612
613 Ok(DriftDetectionResult {
614 drift_detected,
615 warning_detected: avg_performance < 0.98,
616 test_statistic: 1.0 - avg_performance,
617 p_value: None,
618 threshold: 1.0 - threshold,
619 drift_score: 1.0 - avg_performance,
620 feature_drift_scores: None,
621 statistics,
622 })
623 }
624
625 fn ks_test(&self, sample1: &[f64], sample2: &[f64]) -> (f64, f64) {
629 let mut combined: Vec<(f64, usize)> = sample1.iter().map(|&x| (x, 0)).collect();
630 combined.extend(sample2.iter().map(|&x| (x, 1)));
631 combined.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(Ordering::Equal));
632
633 let n1 = sample1.len() as f64;
634 let n2 = sample2.len() as f64;
635 let mut cdf1 = 0.0;
636 let mut cdf2 = 0.0;
637 let mut max_diff: f64 = 0.0;
638
639 for (_, group) in combined {
640 if group == 0 {
641 cdf1 += 1.0 / n1;
642 } else {
643 cdf2 += 1.0 / n2;
644 }
645 max_diff = max_diff.max((cdf1 - cdf2).abs());
646 }
647
648 let ks_statistic = max_diff;
650 let en = (n1 * n2 / (n1 + n2)).sqrt();
651 let lambda = en * ks_statistic;
652 let p_value = 2.0 * (-2.0 * lambda * lambda).exp();
653
654 (ks_statistic, p_value.clamp(0.0, 1.0))
655 }
656
657 fn anderson_darling_statistic(&self, sample1: &[f64], sample2: &[f64]) -> f64 {
659 let mut combined: Vec<f64> = sample1.iter().chain(sample2.iter()).cloned().collect();
661 combined.sort_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal));
662
663 let n1 = sample1.len() as f64;
664 let n2 = sample2.len() as f64;
665 let n = n1 + n2;
666
667 let mut h = 0.0;
668 let mut prev_val = f64::NEG_INFINITY;
669 let _i = 0.0;
670
671 for &val in &combined {
672 if val != prev_val {
673 let count1 = sample1.iter().filter(|&&x| x <= val).count() as f64;
674 let count2 = sample2.iter().filter(|&&x| x <= val).count() as f64;
675
676 let l = count1 + count2;
677 if l > 0.0 && l < n {
678 h += (count1 / n1 - count2 / n2).powi(2) / (l * (n - l));
679 }
680 prev_val = val;
681 }
682 }
683
684 n1 * n2 * h / n
685 }
686
687 fn mann_whitney_u_test(&self, sample1: &[f64], sample2: &[f64]) -> (f64, f64) {
689 let mut combined: Vec<(f64, usize)> = sample1.iter().map(|&x| (x, 0)).collect();
690 combined.extend(sample2.iter().map(|&x| (x, 1)));
691 combined.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(Ordering::Equal));
692
693 let mut rank_sum1 = 0.0;
694 for (rank, (_, group)) in combined.iter().enumerate() {
695 if *group == 0 {
696 rank_sum1 += (rank + 1) as f64;
697 }
698 }
699
700 let n1 = sample1.len() as f64;
701 let n2 = sample2.len() as f64;
702 let u1 = rank_sum1 - n1 * (n1 + 1.0) / 2.0;
703 let u2 = n1 * n2 - u1;
704 let u_statistic = u1.min(u2);
705
706 let mu = n1 * n2 / 2.0;
708 let sigma = (n1 * n2 * (n1 + n2 + 1.0) / 12.0).sqrt();
709 let z = (u_statistic - mu).abs() / sigma;
710 let p_value = 2.0 * (1.0 - self.normal_cdf(z));
711
712 (u_statistic, p_value.clamp(0.0, 1.0))
713 }
714
715 fn calculate_permutation_statistic(
717 &self,
718 ref_data: &Array2<f64>,
719 cur_data: &Array2<f64>,
720 ) -> f64 {
721 let ref_mean = self.calculate_mean(ref_data);
723 let cur_mean = self.calculate_mean(cur_data);
724 (ref_mean - cur_mean).abs()
725 }
726
727 fn combine_data(&self, data1: &Array2<f64>, data2: &Array2<f64>) -> Array2<f64> {
729 let n_rows = data1.nrows() + data2.nrows();
730 let n_cols = data1.ncols();
731 let mut combined = Array2::zeros((n_rows, n_cols));
732
733 for i in 0..data1.nrows() {
735 for j in 0..n_cols {
736 combined[[i, j]] = data1[[i, j]];
737 }
738 }
739
740 for i in 0..data2.nrows() {
742 for j in 0..n_cols {
743 combined[[data1.nrows() + i, j]] = data2[[i, j]];
744 }
745 }
746
747 combined
748 }
749
750 fn random_permutation_split(
752 &self,
753 data: &Array2<f64>,
754 n_first: usize,
755 ) -> (Array2<f64>, Array2<f64>) {
756 use scirs2_core::random::rngs::StdRng;
757 use scirs2_core::random::SeedableRng;
758
759 let mut rng = match self.config.random_state {
760 Some(seed) => StdRng::seed_from_u64(seed),
761 None => {
762 use scirs2_core::random::thread_rng;
763 StdRng::from_rng(&mut thread_rng())
764 }
765 };
766
767 let mut indices: Vec<usize> = (0..data.nrows()).collect();
768 indices.shuffle(&mut rng);
769
770 let n_cols = data.ncols();
771 let mut first = Array2::zeros((n_first, n_cols));
772 let mut second = Array2::zeros((data.nrows() - n_first, n_cols));
773
774 for (i, &idx) in indices[..n_first].iter().enumerate() {
775 for j in 0..n_cols {
776 first[[i, j]] = data[[idx, j]];
777 }
778 }
779
780 for (i, &idx) in indices[n_first..].iter().enumerate() {
781 for j in 0..n_cols {
782 second[[i, j]] = data[[idx, j]];
783 }
784 }
785
786 (first, second)
787 }
788
789 fn calculate_psi(&self, reference: &[f64], current: &[f64], n_bins: usize) -> f64 {
791 let mut ref_sorted = reference.to_vec();
793 ref_sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal));
794
795 let mut bin_edges = Vec::new();
796 for i in 0..=n_bins {
797 let quantile = i as f64 / n_bins as f64;
798 let idx = ((ref_sorted.len() - 1) as f64 * quantile) as usize;
799 bin_edges.push(ref_sorted[idx.min(ref_sorted.len() - 1)]);
800 }
801
802 let ref_counts = self.calculate_bin_counts(reference, &bin_edges);
804 let cur_counts = self.calculate_bin_counts(current, &bin_edges);
805
806 let mut psi = 0.0;
808 for i in 0..n_bins {
809 let ref_prop = ref_counts[i] / reference.len() as f64;
810 let cur_prop = cur_counts[i] / current.len() as f64;
811
812 if ref_prop > 0.0 && cur_prop > 0.0 {
813 psi += (cur_prop - ref_prop) * (cur_prop / ref_prop).ln();
814 }
815 }
816
817 psi
818 }
819
820 fn calculate_bin_counts(&self, data: &[f64], bin_edges: &[f64]) -> Vec<f64> {
822 let n_bins = bin_edges.len() - 1;
823 let mut counts = vec![0.0; n_bins];
824
825 for &value in data {
826 for i in 0..n_bins {
827 if (i == n_bins - 1 || value < bin_edges[i + 1]) && value >= bin_edges[i] {
828 counts[i] += 1.0;
829 break;
830 }
831 }
832 }
833
834 counts
835 }
836
837 fn calculate_mmd(&self, data1: &Array2<f64>, data2: &Array2<f64>) -> f64 {
839 let mean1 = self.calculate_mean(data1);
841 let mean2 = self.calculate_mean(data2);
842 (mean1 - mean2).abs()
843 }
844
845 fn calculate_mean(&self, data: &Array2<f64>) -> f64 {
847 let mut sum = 0.0;
848 let mut count = 0;
849
850 for i in 0..data.nrows() {
851 for j in 0..data.ncols() {
852 sum += data[[i, j]];
853 count += 1;
854 }
855 }
856
857 if count > 0 {
858 sum / count as f64
859 } else {
860 0.0
861 }
862 }
863
864 fn calculate_average_performance(&self, data: &Array2<f64>) -> f64 {
866 let mean = self.calculate_mean(data);
868 (mean + 1.0) / 2.0
870 }
871
872 fn time_since_drift(&self) -> Option<usize> {
874 self.last_drift_time
875 .map(|last| self.drift_history.len() - last)
876 }
877
878 fn normal_cdf(&self, x: f64) -> f64 {
880 0.5 * (1.0 + self.erf(x / 2.0_f64.sqrt()))
881 }
882
883 fn erf(&self, x: f64) -> f64 {
885 let a1 = 0.254829592;
886 let a2 = -0.284496736;
887 let a3 = 1.421413741;
888 let a4 = -1.453152027;
889 let a5 = 1.061405429;
890 let p = 0.3275911;
891
892 let sign = if x < 0.0 { -1.0 } else { 1.0 };
893 let x = x.abs();
894
895 let t = 1.0 / (1.0 + p * x);
896 let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp();
897
898 sign * y
899 }
900}
901
902#[allow(non_snake_case)]
903#[cfg(test)]
904mod tests {
905 use super::*;
906
907 #[test]
908 fn test_ks_drift_detection() {
909 let config = DriftDetectionConfig::default();
910 let mut detector = DriftDetector::new(config);
911
912 let mut reference = Array2::zeros((100, 2));
914 for i in 0..100 {
915 reference[[i, 0]] = i as f64 / 100.0;
916 reference[[i, 1]] = (i as f64 / 100.0).sin();
917 }
918 detector.set_reference(reference);
919
920 let mut current = Array2::zeros((50, 2));
922 for i in 0..50 {
923 let idx = i * 2; current[[i, 0]] = idx as f64 / 100.0;
925 current[[i, 1]] = (idx as f64 / 100.0).sin();
926 }
927
928 let result = detector.detect_drift(¤t).unwrap();
929 assert!(
930 !result.drift_detected,
931 "Should not detect drift in similar data"
932 );
933 }
934
935 #[test]
936 fn test_psi_drift_detection() {
937 let config = DriftDetectionConfig {
938 detection_method: DriftDetectionMethod::PopulationStabilityIndex,
939 ..Default::default()
940 };
941 let mut detector = DriftDetector::new(config);
942
943 let mut reference = Array2::zeros((100, 1));
945 for i in 0..100 {
946 reference[[i, 0]] = i as f64 / 100.0;
947 }
948 detector.set_reference(reference);
949
950 let mut current = Array2::zeros((50, 1));
952 for i in 0..50 {
953 current[[i, 0]] = (i as f64 / 50.0) + 0.5; }
955
956 let result = detector.detect_drift(¤t).unwrap();
957 assert!(result.drift_score > 0.1, "Should detect distribution shift");
959 }
960}