1use scirs2_core::ndarray::{Array1, Array2};
7use scirs2_core::SliceRandomExt;
8use sklears_core::error::{Result, SklearsError};
9use std::cmp::Ordering;
10
11#[derive(Debug, Clone)]
13pub struct DriftDetectionConfig {
14 pub detection_method: DriftDetectionMethod,
16 pub alpha: f64,
18 pub window_size: usize,
20 pub warning_threshold: f64,
22 pub min_samples: usize,
24 pub multivariate: bool,
26 pub random_state: Option<u64>,
28}
29
30impl Default for DriftDetectionConfig {
31 fn default() -> Self {
32 Self {
33 detection_method: DriftDetectionMethod::KolmogorovSmirnov,
34 alpha: 0.05,
35 window_size: 100,
36 warning_threshold: 0.5,
37 min_samples: 30,
38 multivariate: false,
39 random_state: None,
40 }
41 }
42}
43
44#[derive(Debug, Clone)]
46pub enum DriftDetectionMethod {
47 KolmogorovSmirnov,
49 AndersonDarling,
51 MannWhitney,
53 Permutation,
55 PopulationStabilityIndex,
57 MaximumMeanDiscrepancy,
59 ADWIN,
61 PageHinkley,
63 DDM,
65 EDDM,
67}
68
69#[derive(Debug, Clone)]
71pub struct DriftDetectionResult {
72 pub drift_detected: bool,
74 pub warning_detected: bool,
76 pub test_statistic: f64,
78 pub p_value: Option<f64>,
80 pub threshold: f64,
82 pub drift_score: f64,
84 pub feature_drift_scores: Option<Vec<f64>>,
86 pub statistics: DriftStatistics,
88}
89
90#[derive(Debug, Clone)]
92pub struct DriftStatistics {
93 pub n_reference: usize,
95 pub n_current: usize,
97 pub n_features: usize,
99 pub drift_magnitude: f64,
101 pub confidence: f64,
103 pub time_since_drift: Option<usize>,
105}
106
107#[derive(Debug, Clone)]
109pub struct DriftDetector {
110 config: DriftDetectionConfig,
111 reference_data: Option<Array2<f64>>,
112 current_window: Vec<Array1<f64>>,
113 drift_history: Vec<DriftDetectionResult>,
114 last_drift_time: Option<usize>,
115}
116
117impl DriftDetector {
118 pub fn new(config: DriftDetectionConfig) -> Self {
119 Self {
120 config,
121 reference_data: None,
122 current_window: Vec::new(),
123 drift_history: Vec::new(),
124 last_drift_time: None,
125 }
126 }
127
128 pub fn set_reference(&mut self, reference_data: Array2<f64>) {
130 self.reference_data = Some(reference_data);
131 }
132
133 pub fn detect_drift(&mut self, current_data: &Array2<f64>) -> Result<DriftDetectionResult> {
135 if self.reference_data.is_none() {
136 return Err(SklearsError::NotFitted {
137 operation: "drift detection".to_string(),
138 });
139 }
140
141 let reference = self
142 .reference_data
143 .as_ref()
144 .expect("operation should succeed");
145
146 if reference.ncols() != current_data.ncols() {
147 return Err(SklearsError::InvalidInput(
148 "Reference and current data must have same number of features".to_string(),
149 ));
150 }
151
152 let result = match self.config.detection_method {
153 DriftDetectionMethod::KolmogorovSmirnov => {
154 self.kolmogorov_smirnov_test(reference, current_data)?
155 }
156 DriftDetectionMethod::AndersonDarling => {
157 self.anderson_darling_test(reference, current_data)?
158 }
159 DriftDetectionMethod::MannWhitney => self.mann_whitney_test(reference, current_data)?,
160 DriftDetectionMethod::Permutation => self.permutation_test(reference, current_data)?,
161 DriftDetectionMethod::PopulationStabilityIndex => {
162 self.population_stability_index(reference, current_data)?
163 }
164 DriftDetectionMethod::MaximumMeanDiscrepancy => {
165 self.maximum_mean_discrepancy(reference, current_data)?
166 }
167 DriftDetectionMethod::ADWIN => self.adwin_test(current_data)?,
168 DriftDetectionMethod::PageHinkley => self.page_hinkley_test(current_data)?,
169 DriftDetectionMethod::DDM => self.ddm_test(current_data)?,
170 DriftDetectionMethod::EDDM => self.eddm_test(current_data)?,
171 };
172
173 self.drift_history.push(result.clone());
174
175 if result.drift_detected {
176 self.last_drift_time = Some(self.drift_history.len());
177 }
178
179 Ok(result)
180 }
181
182 fn kolmogorov_smirnov_test(
184 &self,
185 reference: &Array2<f64>,
186 current: &Array2<f64>,
187 ) -> Result<DriftDetectionResult> {
188 let n_features = reference.ncols();
189 let mut feature_scores = Vec::new();
190 let mut max_statistic: f64 = 0.0;
191 let mut min_p_value: f64 = 1.0;
192
193 for feature_idx in 0..n_features {
194 let ref_feature: Vec<f64> = (0..reference.nrows())
195 .map(|i| reference[[i, feature_idx]])
196 .collect();
197 let cur_feature: Vec<f64> = (0..current.nrows())
198 .map(|i| current[[i, feature_idx]])
199 .collect();
200
201 let (statistic, p_value) = self.ks_test(&ref_feature, &cur_feature);
202 feature_scores.push(statistic);
203 max_statistic = max_statistic.max(statistic);
204 min_p_value = min_p_value.min(p_value);
205 }
206
207 let drift_detected = min_p_value < self.config.alpha;
208 let warning_detected =
209 min_p_value < self.config.alpha * (1.0 + self.config.warning_threshold);
210
211 let statistics = DriftStatistics {
212 n_reference: reference.nrows(),
213 n_current: current.nrows(),
214 n_features,
215 drift_magnitude: max_statistic,
216 confidence: 1.0 - min_p_value,
217 time_since_drift: self.time_since_drift(),
218 };
219
220 Ok(DriftDetectionResult {
221 drift_detected,
222 warning_detected,
223 test_statistic: max_statistic,
224 p_value: Some(min_p_value),
225 threshold: self.config.alpha,
226 drift_score: max_statistic,
227 feature_drift_scores: Some(feature_scores),
228 statistics,
229 })
230 }
231
232 fn anderson_darling_test(
234 &self,
235 reference: &Array2<f64>,
236 current: &Array2<f64>,
237 ) -> Result<DriftDetectionResult> {
238 let n_features = reference.ncols();
240 let mut feature_scores = Vec::new();
241 let mut max_statistic: f64 = 0.0;
242
243 for feature_idx in 0..n_features {
244 let ref_feature: Vec<f64> = (0..reference.nrows())
245 .map(|i| reference[[i, feature_idx]])
246 .collect();
247 let cur_feature: Vec<f64> = (0..current.nrows())
248 .map(|i| current[[i, feature_idx]])
249 .collect();
250
251 let statistic = self.anderson_darling_statistic(&ref_feature, &cur_feature);
252 feature_scores.push(statistic);
253 max_statistic = max_statistic.max(statistic);
254 }
255
256 let threshold = 2.492; let drift_detected = max_statistic > threshold;
259 let warning_detected = max_statistic > threshold * self.config.warning_threshold;
260
261 let statistics = DriftStatistics {
262 n_reference: reference.nrows(),
263 n_current: current.nrows(),
264 n_features,
265 drift_magnitude: max_statistic,
266 confidence: if drift_detected { 0.95 } else { 0.5 },
267 time_since_drift: self.time_since_drift(),
268 };
269
270 Ok(DriftDetectionResult {
271 drift_detected,
272 warning_detected,
273 test_statistic: max_statistic,
274 p_value: None,
275 threshold,
276 drift_score: max_statistic,
277 feature_drift_scores: Some(feature_scores),
278 statistics,
279 })
280 }
281
282 fn mann_whitney_test(
284 &self,
285 reference: &Array2<f64>,
286 current: &Array2<f64>,
287 ) -> Result<DriftDetectionResult> {
288 let n_features = reference.ncols();
289 let mut feature_scores = Vec::new();
290 let mut max_statistic: f64 = 0.0;
291 let mut min_p_value: f64 = 1.0;
292
293 for feature_idx in 0..n_features {
294 let ref_feature: Vec<f64> = (0..reference.nrows())
295 .map(|i| reference[[i, feature_idx]])
296 .collect();
297 let cur_feature: Vec<f64> = (0..current.nrows())
298 .map(|i| current[[i, feature_idx]])
299 .collect();
300
301 let (u_statistic, p_value) = self.mann_whitney_u_test(&ref_feature, &cur_feature);
302 let normalized_statistic = u_statistic / (ref_feature.len() * cur_feature.len()) as f64;
303
304 feature_scores.push(normalized_statistic);
305 max_statistic = max_statistic.max(normalized_statistic);
306 min_p_value = min_p_value.min(p_value);
307 }
308
309 let drift_detected = min_p_value < self.config.alpha;
310 let warning_detected =
311 min_p_value < self.config.alpha * (1.0 + self.config.warning_threshold);
312
313 let statistics = DriftStatistics {
314 n_reference: reference.nrows(),
315 n_current: current.nrows(),
316 n_features,
317 drift_magnitude: max_statistic,
318 confidence: 1.0 - min_p_value,
319 time_since_drift: self.time_since_drift(),
320 };
321
322 Ok(DriftDetectionResult {
323 drift_detected,
324 warning_detected,
325 test_statistic: max_statistic,
326 p_value: Some(min_p_value),
327 threshold: self.config.alpha,
328 drift_score: max_statistic,
329 feature_drift_scores: Some(feature_scores),
330 statistics,
331 })
332 }
333
334 fn permutation_test(
336 &self,
337 reference: &Array2<f64>,
338 current: &Array2<f64>,
339 ) -> Result<DriftDetectionResult> {
340 let n_permutations = 1000;
341 let observed_statistic = self.calculate_permutation_statistic(reference, current);
342
343 let mut permutation_statistics = Vec::new();
344 let combined_data = self.combine_data(reference, current);
345 let n_ref = reference.nrows();
346
347 for _ in 0..n_permutations {
348 let (perm_ref, perm_cur) = self.random_permutation_split(&combined_data, n_ref);
349 let perm_statistic = self.calculate_permutation_statistic(&perm_ref, &perm_cur);
350 permutation_statistics.push(perm_statistic);
351 }
352
353 let extreme_count = permutation_statistics
355 .iter()
356 .filter(|&&stat| stat >= observed_statistic)
357 .count();
358 let p_value = extreme_count as f64 / n_permutations as f64;
359
360 let drift_detected = p_value < self.config.alpha;
361 let warning_detected = p_value < self.config.alpha * (1.0 + self.config.warning_threshold);
362
363 let statistics = DriftStatistics {
364 n_reference: reference.nrows(),
365 n_current: current.nrows(),
366 n_features: reference.ncols(),
367 drift_magnitude: observed_statistic,
368 confidence: 1.0 - p_value,
369 time_since_drift: self.time_since_drift(),
370 };
371
372 Ok(DriftDetectionResult {
373 drift_detected,
374 warning_detected,
375 test_statistic: observed_statistic,
376 p_value: Some(p_value),
377 threshold: self.config.alpha,
378 drift_score: observed_statistic,
379 feature_drift_scores: None,
380 statistics,
381 })
382 }
383
384 fn population_stability_index(
386 &self,
387 reference: &Array2<f64>,
388 current: &Array2<f64>,
389 ) -> Result<DriftDetectionResult> {
390 let n_features = reference.ncols();
391 let n_bins = 10;
392 let mut feature_scores = Vec::new();
393 let mut total_psi = 0.0;
394
395 for feature_idx in 0..n_features {
396 let ref_feature: Vec<f64> = (0..reference.nrows())
397 .map(|i| reference[[i, feature_idx]])
398 .collect();
399 let cur_feature: Vec<f64> = (0..current.nrows())
400 .map(|i| current[[i, feature_idx]])
401 .collect();
402
403 let psi = self.calculate_psi(&ref_feature, &cur_feature, n_bins);
404 feature_scores.push(psi);
405 total_psi += psi;
406 }
407
408 let avg_psi = total_psi / n_features as f64;
409
410 let drift_detected = avg_psi > 0.2;
412 let warning_detected = avg_psi > 0.1;
413
414 let statistics = DriftStatistics {
415 n_reference: reference.nrows(),
416 n_current: current.nrows(),
417 n_features,
418 drift_magnitude: avg_psi,
419 confidence: if drift_detected { 0.8 } else { 0.5 },
420 time_since_drift: self.time_since_drift(),
421 };
422
423 Ok(DriftDetectionResult {
424 drift_detected,
425 warning_detected,
426 test_statistic: avg_psi,
427 p_value: None,
428 threshold: 0.2,
429 drift_score: avg_psi,
430 feature_drift_scores: Some(feature_scores),
431 statistics,
432 })
433 }
434
435 fn maximum_mean_discrepancy(
437 &self,
438 reference: &Array2<f64>,
439 current: &Array2<f64>,
440 ) -> Result<DriftDetectionResult> {
441 let mmd_statistic = self.calculate_mmd(reference, current);
442
443 let n_permutations = 1000;
445 let mut permutation_mmds = Vec::new();
446 let combined_data = self.combine_data(reference, current);
447 let n_ref = reference.nrows();
448
449 for _ in 0..n_permutations {
450 let (perm_ref, perm_cur) = self.random_permutation_split(&combined_data, n_ref);
451 let perm_mmd = self.calculate_mmd(&perm_ref, &perm_cur);
452 permutation_mmds.push(perm_mmd);
453 }
454
455 let extreme_count = permutation_mmds
456 .iter()
457 .filter(|&&mmd| mmd >= mmd_statistic)
458 .count();
459 let p_value = extreme_count as f64 / n_permutations as f64;
460
461 let drift_detected = p_value < self.config.alpha;
462 let warning_detected = p_value < self.config.alpha * (1.0 + self.config.warning_threshold);
463
464 let statistics = DriftStatistics {
465 n_reference: reference.nrows(),
466 n_current: current.nrows(),
467 n_features: reference.ncols(),
468 drift_magnitude: mmd_statistic,
469 confidence: 1.0 - p_value,
470 time_since_drift: self.time_since_drift(),
471 };
472
473 Ok(DriftDetectionResult {
474 drift_detected,
475 warning_detected,
476 test_statistic: mmd_statistic,
477 p_value: Some(p_value),
478 threshold: self.config.alpha,
479 drift_score: mmd_statistic,
480 feature_drift_scores: None,
481 statistics,
482 })
483 }
484
485 fn adwin_test(&mut self, current: &Array2<f64>) -> Result<DriftDetectionResult> {
487 let n_samples = current.nrows();
491 let avg_performance = self.calculate_average_performance(current);
492
493 for i in 0..n_samples {
495 let sample = current.row(i).to_owned();
496 self.current_window.push(sample);
497 }
498
499 if self.current_window.len() > self.config.window_size * 2 {
501 let excess = self.current_window.len() - self.config.window_size;
502 self.current_window.drain(0..excess);
503 }
504
505 let drift_detected =
506 self.current_window.len() >= self.config.min_samples && avg_performance < 0.5; let statistics = DriftStatistics {
509 n_reference: 0,
510 n_current: current.nrows(),
511 n_features: current.ncols(),
512 drift_magnitude: 1.0 - avg_performance,
513 confidence: if drift_detected { 0.8 } else { 0.5 },
514 time_since_drift: self.time_since_drift(),
515 };
516
517 Ok(DriftDetectionResult {
518 drift_detected,
519 warning_detected: avg_performance < 0.7,
520 test_statistic: 1.0 - avg_performance,
521 p_value: None,
522 threshold: 0.5,
523 drift_score: 1.0 - avg_performance,
524 feature_drift_scores: None,
525 statistics,
526 })
527 }
528
529 fn page_hinkley_test(&self, current: &Array2<f64>) -> Result<DriftDetectionResult> {
531 let avg_performance = self.calculate_average_performance(current);
533 let threshold = 3.0; let cumulative_sum = (0.5 - avg_performance) * current.nrows() as f64;
536 let drift_detected = cumulative_sum.abs() > threshold;
537
538 let statistics = DriftStatistics {
539 n_reference: 0,
540 n_current: current.nrows(),
541 n_features: current.ncols(),
542 drift_magnitude: cumulative_sum.abs(),
543 confidence: if drift_detected { 0.8 } else { 0.5 },
544 time_since_drift: self.time_since_drift(),
545 };
546
547 Ok(DriftDetectionResult {
548 drift_detected,
549 warning_detected: cumulative_sum.abs() > threshold * 0.7,
550 test_statistic: cumulative_sum.abs(),
551 p_value: None,
552 threshold,
553 drift_score: cumulative_sum.abs(),
554 feature_drift_scores: None,
555 statistics,
556 })
557 }
558
559 fn ddm_test(&self, current: &Array2<f64>) -> Result<DriftDetectionResult> {
561 let error_rate = 1.0 - self.calculate_average_performance(current);
563 let std_error = (error_rate * (1.0 - error_rate) / current.nrows() as f64).sqrt();
564
565 let warning_threshold = error_rate + 2.0 * std_error;
566 let drift_threshold = error_rate + 3.0 * std_error;
567
568 let drift_detected = error_rate > drift_threshold;
569 let warning_detected = error_rate > warning_threshold;
570
571 let statistics = DriftStatistics {
572 n_reference: 0,
573 n_current: current.nrows(),
574 n_features: current.ncols(),
575 drift_magnitude: error_rate,
576 confidence: if drift_detected {
577 0.99
578 } else if warning_detected {
579 0.95
580 } else {
581 0.5
582 },
583 time_since_drift: self.time_since_drift(),
584 };
585
586 Ok(DriftDetectionResult {
587 drift_detected,
588 warning_detected,
589 test_statistic: error_rate,
590 p_value: None,
591 threshold: drift_threshold,
592 drift_score: error_rate,
593 feature_drift_scores: None,
594 statistics,
595 })
596 }
597
598 fn eddm_test(&self, current: &Array2<f64>) -> Result<DriftDetectionResult> {
600 let avg_performance = self.calculate_average_performance(current);
602 let _distance_between_errors = 1.0 / (1.0 - avg_performance + 1e-8);
603
604 let threshold = 0.95;
605 let drift_detected = avg_performance < threshold;
606
607 let statistics = DriftStatistics {
608 n_reference: 0,
609 n_current: current.nrows(),
610 n_features: current.ncols(),
611 drift_magnitude: 1.0 - avg_performance,
612 confidence: if drift_detected { 0.8 } else { 0.5 },
613 time_since_drift: self.time_since_drift(),
614 };
615
616 Ok(DriftDetectionResult {
617 drift_detected,
618 warning_detected: avg_performance < 0.98,
619 test_statistic: 1.0 - avg_performance,
620 p_value: None,
621 threshold: 1.0 - threshold,
622 drift_score: 1.0 - avg_performance,
623 feature_drift_scores: None,
624 statistics,
625 })
626 }
627
628 fn ks_test(&self, sample1: &[f64], sample2: &[f64]) -> (f64, f64) {
632 let mut combined: Vec<(f64, usize)> = sample1.iter().map(|&x| (x, 0)).collect();
633 combined.extend(sample2.iter().map(|&x| (x, 1)));
634 combined.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(Ordering::Equal));
635
636 let n1 = sample1.len() as f64;
637 let n2 = sample2.len() as f64;
638 let mut cdf1 = 0.0;
639 let mut cdf2 = 0.0;
640 let mut max_diff: f64 = 0.0;
641
642 for (_, group) in combined {
643 if group == 0 {
644 cdf1 += 1.0 / n1;
645 } else {
646 cdf2 += 1.0 / n2;
647 }
648 max_diff = max_diff.max((cdf1 - cdf2).abs());
649 }
650
651 let ks_statistic = max_diff;
653 let en = (n1 * n2 / (n1 + n2)).sqrt();
654 let lambda = en * ks_statistic;
655 let p_value = 2.0 * (-2.0 * lambda * lambda).exp();
656
657 (ks_statistic, p_value.clamp(0.0, 1.0))
658 }
659
660 fn anderson_darling_statistic(&self, sample1: &[f64], sample2: &[f64]) -> f64 {
662 let mut combined: Vec<f64> = sample1.iter().chain(sample2.iter()).cloned().collect();
664 combined.sort_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal));
665
666 let n1 = sample1.len() as f64;
667 let n2 = sample2.len() as f64;
668 let n = n1 + n2;
669
670 let mut h = 0.0;
671 let mut prev_val = f64::NEG_INFINITY;
672 let _i = 0.0;
673
674 for &val in &combined {
675 if val != prev_val {
676 let count1 = sample1.iter().filter(|&&x| x <= val).count() as f64;
677 let count2 = sample2.iter().filter(|&&x| x <= val).count() as f64;
678
679 let l = count1 + count2;
680 if l > 0.0 && l < n {
681 h += (count1 / n1 - count2 / n2).powi(2) / (l * (n - l));
682 }
683 prev_val = val;
684 }
685 }
686
687 n1 * n2 * h / n
688 }
689
690 fn mann_whitney_u_test(&self, sample1: &[f64], sample2: &[f64]) -> (f64, f64) {
692 let mut combined: Vec<(f64, usize)> = sample1.iter().map(|&x| (x, 0)).collect();
693 combined.extend(sample2.iter().map(|&x| (x, 1)));
694 combined.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(Ordering::Equal));
695
696 let mut rank_sum1 = 0.0;
697 for (rank, (_, group)) in combined.iter().enumerate() {
698 if *group == 0 {
699 rank_sum1 += (rank + 1) as f64;
700 }
701 }
702
703 let n1 = sample1.len() as f64;
704 let n2 = sample2.len() as f64;
705 let u1 = rank_sum1 - n1 * (n1 + 1.0) / 2.0;
706 let u2 = n1 * n2 - u1;
707 let u_statistic = u1.min(u2);
708
709 let mu = n1 * n2 / 2.0;
711 let sigma = (n1 * n2 * (n1 + n2 + 1.0) / 12.0).sqrt();
712 let z = (u_statistic - mu).abs() / sigma;
713 let p_value = 2.0 * (1.0 - self.normal_cdf(z));
714
715 (u_statistic, p_value.clamp(0.0, 1.0))
716 }
717
718 fn calculate_permutation_statistic(
720 &self,
721 ref_data: &Array2<f64>,
722 cur_data: &Array2<f64>,
723 ) -> f64 {
724 let ref_mean = self.calculate_mean(ref_data);
726 let cur_mean = self.calculate_mean(cur_data);
727 (ref_mean - cur_mean).abs()
728 }
729
730 fn combine_data(&self, data1: &Array2<f64>, data2: &Array2<f64>) -> Array2<f64> {
732 let n_rows = data1.nrows() + data2.nrows();
733 let n_cols = data1.ncols();
734 let mut combined = Array2::zeros((n_rows, n_cols));
735
736 for i in 0..data1.nrows() {
738 for j in 0..n_cols {
739 combined[[i, j]] = data1[[i, j]];
740 }
741 }
742
743 for i in 0..data2.nrows() {
745 for j in 0..n_cols {
746 combined[[data1.nrows() + i, j]] = data2[[i, j]];
747 }
748 }
749
750 combined
751 }
752
753 fn random_permutation_split(
755 &self,
756 data: &Array2<f64>,
757 n_first: usize,
758 ) -> (Array2<f64>, Array2<f64>) {
759 use scirs2_core::random::rngs::StdRng;
760 use scirs2_core::random::SeedableRng;
761
762 let mut rng = match self.config.random_state {
763 Some(seed) => StdRng::seed_from_u64(seed),
764 None => {
765 use scirs2_core::random::thread_rng;
766 StdRng::from_rng(&mut thread_rng())
767 }
768 };
769
770 let mut indices: Vec<usize> = (0..data.nrows()).collect();
771 indices.shuffle(&mut rng);
772
773 let n_cols = data.ncols();
774 let mut first = Array2::zeros((n_first, n_cols));
775 let mut second = Array2::zeros((data.nrows() - n_first, n_cols));
776
777 for (i, &idx) in indices[..n_first].iter().enumerate() {
778 for j in 0..n_cols {
779 first[[i, j]] = data[[idx, j]];
780 }
781 }
782
783 for (i, &idx) in indices[n_first..].iter().enumerate() {
784 for j in 0..n_cols {
785 second[[i, j]] = data[[idx, j]];
786 }
787 }
788
789 (first, second)
790 }
791
792 fn calculate_psi(&self, reference: &[f64], current: &[f64], n_bins: usize) -> f64 {
794 let mut ref_sorted = reference.to_vec();
796 ref_sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal));
797
798 let mut bin_edges = Vec::new();
799 for i in 0..=n_bins {
800 let quantile = i as f64 / n_bins as f64;
801 let idx = ((ref_sorted.len() - 1) as f64 * quantile) as usize;
802 bin_edges.push(ref_sorted[idx.min(ref_sorted.len() - 1)]);
803 }
804
805 let ref_counts = self.calculate_bin_counts(reference, &bin_edges);
807 let cur_counts = self.calculate_bin_counts(current, &bin_edges);
808
809 let mut psi = 0.0;
811 for i in 0..n_bins {
812 let ref_prop = ref_counts[i] / reference.len() as f64;
813 let cur_prop = cur_counts[i] / current.len() as f64;
814
815 if ref_prop > 0.0 && cur_prop > 0.0 {
816 psi += (cur_prop - ref_prop) * (cur_prop / ref_prop).ln();
817 }
818 }
819
820 psi
821 }
822
823 fn calculate_bin_counts(&self, data: &[f64], bin_edges: &[f64]) -> Vec<f64> {
825 let n_bins = bin_edges.len() - 1;
826 let mut counts = vec![0.0; n_bins];
827
828 for &value in data {
829 for i in 0..n_bins {
830 if (i == n_bins - 1 || value < bin_edges[i + 1]) && value >= bin_edges[i] {
831 counts[i] += 1.0;
832 break;
833 }
834 }
835 }
836
837 counts
838 }
839
840 fn calculate_mmd(&self, data1: &Array2<f64>, data2: &Array2<f64>) -> f64 {
842 let mean1 = self.calculate_mean(data1);
844 let mean2 = self.calculate_mean(data2);
845 (mean1 - mean2).abs()
846 }
847
848 fn calculate_mean(&self, data: &Array2<f64>) -> f64 {
850 let mut sum = 0.0;
851 let mut count = 0;
852
853 for i in 0..data.nrows() {
854 for j in 0..data.ncols() {
855 sum += data[[i, j]];
856 count += 1;
857 }
858 }
859
860 if count > 0 {
861 sum / count as f64
862 } else {
863 0.0
864 }
865 }
866
867 fn calculate_average_performance(&self, data: &Array2<f64>) -> f64 {
869 let mean = self.calculate_mean(data);
871 (mean + 1.0) / 2.0
873 }
874
875 fn time_since_drift(&self) -> Option<usize> {
877 self.last_drift_time
878 .map(|last| self.drift_history.len() - last)
879 }
880
881 fn normal_cdf(&self, x: f64) -> f64 {
883 0.5 * (1.0 + self.erf(x / 2.0_f64.sqrt()))
884 }
885
886 fn erf(&self, x: f64) -> f64 {
888 let a1 = 0.254829592;
889 let a2 = -0.284496736;
890 let a3 = 1.421413741;
891 let a4 = -1.453152027;
892 let a5 = 1.061405429;
893 let p = 0.3275911;
894
895 let sign = if x < 0.0 { -1.0 } else { 1.0 };
896 let x = x.abs();
897
898 let t = 1.0 / (1.0 + p * x);
899 let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp();
900
901 sign * y
902 }
903}
904
905#[allow(non_snake_case)]
906#[cfg(test)]
907mod tests {
908 use super::*;
909
910 #[test]
911 fn test_ks_drift_detection() {
912 let config = DriftDetectionConfig::default();
913 let mut detector = DriftDetector::new(config);
914
915 let mut reference = Array2::zeros((100, 2));
917 for i in 0..100 {
918 reference[[i, 0]] = i as f64 / 100.0;
919 reference[[i, 1]] = (i as f64 / 100.0).sin();
920 }
921 detector.set_reference(reference);
922
923 let mut current = Array2::zeros((50, 2));
925 for i in 0..50 {
926 let idx = i * 2; current[[i, 0]] = idx as f64 / 100.0;
928 current[[i, 1]] = (idx as f64 / 100.0).sin();
929 }
930
931 let result = detector
932 .detect_drift(¤t)
933 .expect("operation should succeed");
934 assert!(
935 !result.drift_detected,
936 "Should not detect drift in similar data"
937 );
938 }
939
940 #[test]
941 fn test_psi_drift_detection() {
942 let config = DriftDetectionConfig {
943 detection_method: DriftDetectionMethod::PopulationStabilityIndex,
944 ..Default::default()
945 };
946 let mut detector = DriftDetector::new(config);
947
948 let mut reference = Array2::zeros((100, 1));
950 for i in 0..100 {
951 reference[[i, 0]] = i as f64 / 100.0;
952 }
953 detector.set_reference(reference);
954
955 let mut current = Array2::zeros((50, 1));
957 for i in 0..50 {
958 current[[i, 0]] = (i as f64 / 50.0) + 0.5; }
960
961 let result = detector
962 .detect_drift(¤t)
963 .expect("operation should succeed");
964 assert!(result.drift_score > 0.1, "Should detect distribution shift");
966 }
967}