sklears_model_selection/
drift_detection.rs

1//! Data drift detection for validation and monitoring
2//!
3//! This module provides methods to detect distribution changes in data
4//! that can affect model performance over time.
5
6use scirs2_core::ndarray::{Array1, Array2};
7use scirs2_core::SliceRandomExt;
8use sklears_core::error::{Result, SklearsError};
9use std::cmp::Ordering;
10
11/// Configuration for drift detection
12#[derive(Debug, Clone)]
13pub struct DriftDetectionConfig {
14    /// Detection method to use
15    pub detection_method: DriftDetectionMethod,
16    /// Significance level for statistical tests
17    pub alpha: f64,
18    /// Window size for windowed detection methods
19    pub window_size: usize,
20    /// Warning threshold (fraction of alpha)
21    pub warning_threshold: f64,
22    /// Minimum samples required for detection
23    pub min_samples: usize,
24    /// Whether to detect multivariate drift
25    pub multivariate: bool,
26    /// Random state for reproducible results
27    pub random_state: Option<u64>,
28}
29
30impl Default for DriftDetectionConfig {
31    fn default() -> Self {
32        Self {
33            detection_method: DriftDetectionMethod::KolmogorovSmirnov,
34            alpha: 0.05,
35            window_size: 100,
36            warning_threshold: 0.5,
37            min_samples: 30,
38            multivariate: false,
39            random_state: None,
40        }
41    }
42}
43
44/// Drift detection methods
45#[derive(Debug, Clone)]
46pub enum DriftDetectionMethod {
47    /// Kolmogorov-Smirnov test for univariate drift
48    KolmogorovSmirnov,
49    /// Anderson-Darling test
50    AndersonDarling,
51    /// Mann-Whitney U test
52    MannWhitney,
53    /// Permutation test
54    Permutation,
55    /// Population Stability Index (PSI)
56    PopulationStabilityIndex,
57    /// Maximum Mean Discrepancy (MMD)
58    MaximumMeanDiscrepancy,
59    /// ADWIN (Adaptive Windowing)
60    ADWIN,
61    /// Page-Hinkley test
62    PageHinkley,
63    /// Drift Detection Method (DDM)
64    DDM,
65    /// Early Drift Detection Method (EDDM)
66    EDDM,
67}
68
69/// Results from drift detection
70#[derive(Debug, Clone)]
71pub struct DriftDetectionResult {
72    /// Whether drift was detected
73    pub drift_detected: bool,
74    /// Whether warning threshold was exceeded
75    pub warning_detected: bool,
76    /// Test statistic value
77    pub test_statistic: f64,
78    /// P-value for statistical tests
79    pub p_value: Option<f64>,
80    /// Threshold used for detection
81    pub threshold: f64,
82    /// Drift score (higher = more drift)
83    pub drift_score: f64,
84    /// Per-feature drift scores
85    pub feature_drift_scores: Option<Vec<f64>>,
86    /// Detailed statistics
87    pub statistics: DriftStatistics,
88}
89
90/// Detailed drift statistics
91#[derive(Debug, Clone)]
92pub struct DriftStatistics {
93    /// Number of reference samples
94    pub n_reference: usize,
95    /// Number of current samples
96    pub n_current: usize,
97    /// Number of features analyzed
98    pub n_features: usize,
99    /// Drift magnitude estimate
100    pub drift_magnitude: f64,
101    /// Confidence in drift detection
102    pub confidence: f64,
103    /// Time since last drift (if applicable)
104    pub time_since_drift: Option<usize>,
105}
106
107/// Drift detector for monitoring data distribution changes
108#[derive(Debug, Clone)]
109pub struct DriftDetector {
110    config: DriftDetectionConfig,
111    reference_data: Option<Array2<f64>>,
112    current_window: Vec<Array1<f64>>,
113    drift_history: Vec<DriftDetectionResult>,
114    last_drift_time: Option<usize>,
115}
116
117impl DriftDetector {
118    pub fn new(config: DriftDetectionConfig) -> Self {
119        Self {
120            config,
121            reference_data: None,
122            current_window: Vec::new(),
123            drift_history: Vec::new(),
124            last_drift_time: None,
125        }
126    }
127
128    /// Set reference data for drift detection
129    pub fn set_reference(&mut self, reference_data: Array2<f64>) {
130        self.reference_data = Some(reference_data);
131    }
132
133    /// Detect drift in new data
134    pub fn detect_drift(&mut self, current_data: &Array2<f64>) -> Result<DriftDetectionResult> {
135        if self.reference_data.is_none() {
136            return Err(SklearsError::NotFitted {
137                operation: "drift detection".to_string(),
138            });
139        }
140
141        let reference = self.reference_data.as_ref().unwrap();
142
143        if reference.ncols() != current_data.ncols() {
144            return Err(SklearsError::InvalidInput(
145                "Reference and current data must have same number of features".to_string(),
146            ));
147        }
148
149        let result = match self.config.detection_method {
150            DriftDetectionMethod::KolmogorovSmirnov => {
151                self.kolmogorov_smirnov_test(reference, current_data)?
152            }
153            DriftDetectionMethod::AndersonDarling => {
154                self.anderson_darling_test(reference, current_data)?
155            }
156            DriftDetectionMethod::MannWhitney => self.mann_whitney_test(reference, current_data)?,
157            DriftDetectionMethod::Permutation => self.permutation_test(reference, current_data)?,
158            DriftDetectionMethod::PopulationStabilityIndex => {
159                self.population_stability_index(reference, current_data)?
160            }
161            DriftDetectionMethod::MaximumMeanDiscrepancy => {
162                self.maximum_mean_discrepancy(reference, current_data)?
163            }
164            DriftDetectionMethod::ADWIN => self.adwin_test(current_data)?,
165            DriftDetectionMethod::PageHinkley => self.page_hinkley_test(current_data)?,
166            DriftDetectionMethod::DDM => self.ddm_test(current_data)?,
167            DriftDetectionMethod::EDDM => self.eddm_test(current_data)?,
168        };
169
170        self.drift_history.push(result.clone());
171
172        if result.drift_detected {
173            self.last_drift_time = Some(self.drift_history.len());
174        }
175
176        Ok(result)
177    }
178
179    /// Kolmogorov-Smirnov test for drift detection
180    fn kolmogorov_smirnov_test(
181        &self,
182        reference: &Array2<f64>,
183        current: &Array2<f64>,
184    ) -> Result<DriftDetectionResult> {
185        let n_features = reference.ncols();
186        let mut feature_scores = Vec::new();
187        let mut max_statistic: f64 = 0.0;
188        let mut min_p_value: f64 = 1.0;
189
190        for feature_idx in 0..n_features {
191            let ref_feature: Vec<f64> = (0..reference.nrows())
192                .map(|i| reference[[i, feature_idx]])
193                .collect();
194            let cur_feature: Vec<f64> = (0..current.nrows())
195                .map(|i| current[[i, feature_idx]])
196                .collect();
197
198            let (statistic, p_value) = self.ks_test(&ref_feature, &cur_feature);
199            feature_scores.push(statistic);
200            max_statistic = max_statistic.max(statistic);
201            min_p_value = min_p_value.min(p_value);
202        }
203
204        let drift_detected = min_p_value < self.config.alpha;
205        let warning_detected =
206            min_p_value < self.config.alpha * (1.0 + self.config.warning_threshold);
207
208        let statistics = DriftStatistics {
209            n_reference: reference.nrows(),
210            n_current: current.nrows(),
211            n_features,
212            drift_magnitude: max_statistic,
213            confidence: 1.0 - min_p_value,
214            time_since_drift: self.time_since_drift(),
215        };
216
217        Ok(DriftDetectionResult {
218            drift_detected,
219            warning_detected,
220            test_statistic: max_statistic,
221            p_value: Some(min_p_value),
222            threshold: self.config.alpha,
223            drift_score: max_statistic,
224            feature_drift_scores: Some(feature_scores),
225            statistics,
226        })
227    }
228
229    /// Anderson-Darling test for drift detection
230    fn anderson_darling_test(
231        &self,
232        reference: &Array2<f64>,
233        current: &Array2<f64>,
234    ) -> Result<DriftDetectionResult> {
235        // Simplified Anderson-Darling test implementation
236        let n_features = reference.ncols();
237        let mut feature_scores = Vec::new();
238        let mut max_statistic: f64 = 0.0;
239
240        for feature_idx in 0..n_features {
241            let ref_feature: Vec<f64> = (0..reference.nrows())
242                .map(|i| reference[[i, feature_idx]])
243                .collect();
244            let cur_feature: Vec<f64> = (0..current.nrows())
245                .map(|i| current[[i, feature_idx]])
246                .collect();
247
248            let statistic = self.anderson_darling_statistic(&ref_feature, &cur_feature);
249            feature_scores.push(statistic);
250            max_statistic = max_statistic.max(statistic);
251        }
252
253        // Approximate threshold for Anderson-Darling
254        let threshold = 2.492; // Critical value for alpha = 0.05
255        let drift_detected = max_statistic > threshold;
256        let warning_detected = max_statistic > threshold * self.config.warning_threshold;
257
258        let statistics = DriftStatistics {
259            n_reference: reference.nrows(),
260            n_current: current.nrows(),
261            n_features,
262            drift_magnitude: max_statistic,
263            confidence: if drift_detected { 0.95 } else { 0.5 },
264            time_since_drift: self.time_since_drift(),
265        };
266
267        Ok(DriftDetectionResult {
268            drift_detected,
269            warning_detected,
270            test_statistic: max_statistic,
271            p_value: None,
272            threshold,
273            drift_score: max_statistic,
274            feature_drift_scores: Some(feature_scores),
275            statistics,
276        })
277    }
278
279    /// Mann-Whitney U test for drift detection
280    fn mann_whitney_test(
281        &self,
282        reference: &Array2<f64>,
283        current: &Array2<f64>,
284    ) -> Result<DriftDetectionResult> {
285        let n_features = reference.ncols();
286        let mut feature_scores = Vec::new();
287        let mut max_statistic: f64 = 0.0;
288        let mut min_p_value: f64 = 1.0;
289
290        for feature_idx in 0..n_features {
291            let ref_feature: Vec<f64> = (0..reference.nrows())
292                .map(|i| reference[[i, feature_idx]])
293                .collect();
294            let cur_feature: Vec<f64> = (0..current.nrows())
295                .map(|i| current[[i, feature_idx]])
296                .collect();
297
298            let (u_statistic, p_value) = self.mann_whitney_u_test(&ref_feature, &cur_feature);
299            let normalized_statistic = u_statistic / (ref_feature.len() * cur_feature.len()) as f64;
300
301            feature_scores.push(normalized_statistic);
302            max_statistic = max_statistic.max(normalized_statistic);
303            min_p_value = min_p_value.min(p_value);
304        }
305
306        let drift_detected = min_p_value < self.config.alpha;
307        let warning_detected =
308            min_p_value < self.config.alpha * (1.0 + self.config.warning_threshold);
309
310        let statistics = DriftStatistics {
311            n_reference: reference.nrows(),
312            n_current: current.nrows(),
313            n_features,
314            drift_magnitude: max_statistic,
315            confidence: 1.0 - min_p_value,
316            time_since_drift: self.time_since_drift(),
317        };
318
319        Ok(DriftDetectionResult {
320            drift_detected,
321            warning_detected,
322            test_statistic: max_statistic,
323            p_value: Some(min_p_value),
324            threshold: self.config.alpha,
325            drift_score: max_statistic,
326            feature_drift_scores: Some(feature_scores),
327            statistics,
328        })
329    }
330
331    /// Permutation test for drift detection
332    fn permutation_test(
333        &self,
334        reference: &Array2<f64>,
335        current: &Array2<f64>,
336    ) -> Result<DriftDetectionResult> {
337        let n_permutations = 1000;
338        let observed_statistic = self.calculate_permutation_statistic(reference, current);
339
340        let mut permutation_statistics = Vec::new();
341        let combined_data = self.combine_data(reference, current);
342        let n_ref = reference.nrows();
343
344        for _ in 0..n_permutations {
345            let (perm_ref, perm_cur) = self.random_permutation_split(&combined_data, n_ref);
346            let perm_statistic = self.calculate_permutation_statistic(&perm_ref, &perm_cur);
347            permutation_statistics.push(perm_statistic);
348        }
349
350        // Calculate p-value
351        let extreme_count = permutation_statistics
352            .iter()
353            .filter(|&&stat| stat >= observed_statistic)
354            .count();
355        let p_value = extreme_count as f64 / n_permutations as f64;
356
357        let drift_detected = p_value < self.config.alpha;
358        let warning_detected = p_value < self.config.alpha * (1.0 + self.config.warning_threshold);
359
360        let statistics = DriftStatistics {
361            n_reference: reference.nrows(),
362            n_current: current.nrows(),
363            n_features: reference.ncols(),
364            drift_magnitude: observed_statistic,
365            confidence: 1.0 - p_value,
366            time_since_drift: self.time_since_drift(),
367        };
368
369        Ok(DriftDetectionResult {
370            drift_detected,
371            warning_detected,
372            test_statistic: observed_statistic,
373            p_value: Some(p_value),
374            threshold: self.config.alpha,
375            drift_score: observed_statistic,
376            feature_drift_scores: None,
377            statistics,
378        })
379    }
380
381    /// Population Stability Index (PSI) for drift detection
382    fn population_stability_index(
383        &self,
384        reference: &Array2<f64>,
385        current: &Array2<f64>,
386    ) -> Result<DriftDetectionResult> {
387        let n_features = reference.ncols();
388        let n_bins = 10;
389        let mut feature_scores = Vec::new();
390        let mut total_psi = 0.0;
391
392        for feature_idx in 0..n_features {
393            let ref_feature: Vec<f64> = (0..reference.nrows())
394                .map(|i| reference[[i, feature_idx]])
395                .collect();
396            let cur_feature: Vec<f64> = (0..current.nrows())
397                .map(|i| current[[i, feature_idx]])
398                .collect();
399
400            let psi = self.calculate_psi(&ref_feature, &cur_feature, n_bins);
401            feature_scores.push(psi);
402            total_psi += psi;
403        }
404
405        let avg_psi = total_psi / n_features as f64;
406
407        // PSI thresholds: <0.1 (no drift), 0.1-0.2 (minor), >0.2 (major)
408        let drift_detected = avg_psi > 0.2;
409        let warning_detected = avg_psi > 0.1;
410
411        let statistics = DriftStatistics {
412            n_reference: reference.nrows(),
413            n_current: current.nrows(),
414            n_features,
415            drift_magnitude: avg_psi,
416            confidence: if drift_detected { 0.8 } else { 0.5 },
417            time_since_drift: self.time_since_drift(),
418        };
419
420        Ok(DriftDetectionResult {
421            drift_detected,
422            warning_detected,
423            test_statistic: avg_psi,
424            p_value: None,
425            threshold: 0.2,
426            drift_score: avg_psi,
427            feature_drift_scores: Some(feature_scores),
428            statistics,
429        })
430    }
431
432    /// Maximum Mean Discrepancy (MMD) test
433    fn maximum_mean_discrepancy(
434        &self,
435        reference: &Array2<f64>,
436        current: &Array2<f64>,
437    ) -> Result<DriftDetectionResult> {
438        let mmd_statistic = self.calculate_mmd(reference, current);
439
440        // Use permutation test to get p-value
441        let n_permutations = 1000;
442        let mut permutation_mmds = Vec::new();
443        let combined_data = self.combine_data(reference, current);
444        let n_ref = reference.nrows();
445
446        for _ in 0..n_permutations {
447            let (perm_ref, perm_cur) = self.random_permutation_split(&combined_data, n_ref);
448            let perm_mmd = self.calculate_mmd(&perm_ref, &perm_cur);
449            permutation_mmds.push(perm_mmd);
450        }
451
452        let extreme_count = permutation_mmds
453            .iter()
454            .filter(|&&mmd| mmd >= mmd_statistic)
455            .count();
456        let p_value = extreme_count as f64 / n_permutations as f64;
457
458        let drift_detected = p_value < self.config.alpha;
459        let warning_detected = p_value < self.config.alpha * (1.0 + self.config.warning_threshold);
460
461        let statistics = DriftStatistics {
462            n_reference: reference.nrows(),
463            n_current: current.nrows(),
464            n_features: reference.ncols(),
465            drift_magnitude: mmd_statistic,
466            confidence: 1.0 - p_value,
467            time_since_drift: self.time_since_drift(),
468        };
469
470        Ok(DriftDetectionResult {
471            drift_detected,
472            warning_detected,
473            test_statistic: mmd_statistic,
474            p_value: Some(p_value),
475            threshold: self.config.alpha,
476            drift_score: mmd_statistic,
477            feature_drift_scores: None,
478            statistics,
479        })
480    }
481
482    /// ADWIN (Adaptive Windowing) drift detection
483    fn adwin_test(&mut self, current: &Array2<f64>) -> Result<DriftDetectionResult> {
484        // Simplified ADWIN implementation
485        // In practice, this would maintain adaptive windows
486
487        let n_samples = current.nrows();
488        let avg_performance = self.calculate_average_performance(current);
489
490        // Add to current window
491        for i in 0..n_samples {
492            let sample = current.row(i).to_owned();
493            self.current_window.push(sample);
494        }
495
496        // Keep window size manageable
497        if self.current_window.len() > self.config.window_size * 2 {
498            let excess = self.current_window.len() - self.config.window_size;
499            self.current_window.drain(0..excess);
500        }
501
502        let drift_detected =
503            self.current_window.len() >= self.config.min_samples && avg_performance < 0.5; // Simplified threshold
504
505        let statistics = DriftStatistics {
506            n_reference: 0,
507            n_current: current.nrows(),
508            n_features: current.ncols(),
509            drift_magnitude: 1.0 - avg_performance,
510            confidence: if drift_detected { 0.8 } else { 0.5 },
511            time_since_drift: self.time_since_drift(),
512        };
513
514        Ok(DriftDetectionResult {
515            drift_detected,
516            warning_detected: avg_performance < 0.7,
517            test_statistic: 1.0 - avg_performance,
518            p_value: None,
519            threshold: 0.5,
520            drift_score: 1.0 - avg_performance,
521            feature_drift_scores: None,
522            statistics,
523        })
524    }
525
526    /// Page-Hinkley test for drift detection
527    fn page_hinkley_test(&self, current: &Array2<f64>) -> Result<DriftDetectionResult> {
528        // Simplified Page-Hinkley test
529        let avg_performance = self.calculate_average_performance(current);
530        let threshold = 3.0; // Typical threshold
531
532        let cumulative_sum = (0.5 - avg_performance) * current.nrows() as f64;
533        let drift_detected = cumulative_sum.abs() > threshold;
534
535        let statistics = DriftStatistics {
536            n_reference: 0,
537            n_current: current.nrows(),
538            n_features: current.ncols(),
539            drift_magnitude: cumulative_sum.abs(),
540            confidence: if drift_detected { 0.8 } else { 0.5 },
541            time_since_drift: self.time_since_drift(),
542        };
543
544        Ok(DriftDetectionResult {
545            drift_detected,
546            warning_detected: cumulative_sum.abs() > threshold * 0.7,
547            test_statistic: cumulative_sum.abs(),
548            p_value: None,
549            threshold,
550            drift_score: cumulative_sum.abs(),
551            feature_drift_scores: None,
552            statistics,
553        })
554    }
555
556    /// DDM (Drift Detection Method)
557    fn ddm_test(&self, current: &Array2<f64>) -> Result<DriftDetectionResult> {
558        // Simplified DDM implementation
559        let error_rate = 1.0 - self.calculate_average_performance(current);
560        let std_error = (error_rate * (1.0 - error_rate) / current.nrows() as f64).sqrt();
561
562        let warning_threshold = error_rate + 2.0 * std_error;
563        let drift_threshold = error_rate + 3.0 * std_error;
564
565        let drift_detected = error_rate > drift_threshold;
566        let warning_detected = error_rate > warning_threshold;
567
568        let statistics = DriftStatistics {
569            n_reference: 0,
570            n_current: current.nrows(),
571            n_features: current.ncols(),
572            drift_magnitude: error_rate,
573            confidence: if drift_detected {
574                0.99
575            } else if warning_detected {
576                0.95
577            } else {
578                0.5
579            },
580            time_since_drift: self.time_since_drift(),
581        };
582
583        Ok(DriftDetectionResult {
584            drift_detected,
585            warning_detected,
586            test_statistic: error_rate,
587            p_value: None,
588            threshold: drift_threshold,
589            drift_score: error_rate,
590            feature_drift_scores: None,
591            statistics,
592        })
593    }
594
595    /// EDDM (Early Drift Detection Method)
596    fn eddm_test(&self, current: &Array2<f64>) -> Result<DriftDetectionResult> {
597        // Simplified EDDM implementation
598        let avg_performance = self.calculate_average_performance(current);
599        let _distance_between_errors = 1.0 / (1.0 - avg_performance + 1e-8);
600
601        let threshold = 0.95;
602        let drift_detected = avg_performance < threshold;
603
604        let statistics = DriftStatistics {
605            n_reference: 0,
606            n_current: current.nrows(),
607            n_features: current.ncols(),
608            drift_magnitude: 1.0 - avg_performance,
609            confidence: if drift_detected { 0.8 } else { 0.5 },
610            time_since_drift: self.time_since_drift(),
611        };
612
613        Ok(DriftDetectionResult {
614            drift_detected,
615            warning_detected: avg_performance < 0.98,
616            test_statistic: 1.0 - avg_performance,
617            p_value: None,
618            threshold: 1.0 - threshold,
619            drift_score: 1.0 - avg_performance,
620            feature_drift_scores: None,
621            statistics,
622        })
623    }
624
625    // Helper methods
626
627    /// Kolmogorov-Smirnov test implementation
628    fn ks_test(&self, sample1: &[f64], sample2: &[f64]) -> (f64, f64) {
629        let mut combined: Vec<(f64, usize)> = sample1.iter().map(|&x| (x, 0)).collect();
630        combined.extend(sample2.iter().map(|&x| (x, 1)));
631        combined.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(Ordering::Equal));
632
633        let n1 = sample1.len() as f64;
634        let n2 = sample2.len() as f64;
635        let mut cdf1 = 0.0;
636        let mut cdf2 = 0.0;
637        let mut max_diff: f64 = 0.0;
638
639        for (_, group) in combined {
640            if group == 0 {
641                cdf1 += 1.0 / n1;
642            } else {
643                cdf2 += 1.0 / n2;
644            }
645            max_diff = max_diff.max((cdf1 - cdf2).abs());
646        }
647
648        // Approximate p-value calculation
649        let ks_statistic = max_diff;
650        let en = (n1 * n2 / (n1 + n2)).sqrt();
651        let lambda = en * ks_statistic;
652        let p_value = 2.0 * (-2.0 * lambda * lambda).exp();
653
654        (ks_statistic, p_value.clamp(0.0, 1.0))
655    }
656
657    /// Anderson-Darling statistic calculation
658    fn anderson_darling_statistic(&self, sample1: &[f64], sample2: &[f64]) -> f64 {
659        // Simplified implementation
660        let mut combined: Vec<f64> = sample1.iter().chain(sample2.iter()).cloned().collect();
661        combined.sort_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal));
662
663        let n1 = sample1.len() as f64;
664        let n2 = sample2.len() as f64;
665        let n = n1 + n2;
666
667        let mut h = 0.0;
668        let mut prev_val = f64::NEG_INFINITY;
669        let _i = 0.0;
670
671        for &val in &combined {
672            if val != prev_val {
673                let count1 = sample1.iter().filter(|&&x| x <= val).count() as f64;
674                let count2 = sample2.iter().filter(|&&x| x <= val).count() as f64;
675
676                let l = count1 + count2;
677                if l > 0.0 && l < n {
678                    h += (count1 / n1 - count2 / n2).powi(2) / (l * (n - l));
679                }
680                prev_val = val;
681            }
682        }
683
684        n1 * n2 * h / n
685    }
686
687    /// Mann-Whitney U test implementation
688    fn mann_whitney_u_test(&self, sample1: &[f64], sample2: &[f64]) -> (f64, f64) {
689        let mut combined: Vec<(f64, usize)> = sample1.iter().map(|&x| (x, 0)).collect();
690        combined.extend(sample2.iter().map(|&x| (x, 1)));
691        combined.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(Ordering::Equal));
692
693        let mut rank_sum1 = 0.0;
694        for (rank, (_, group)) in combined.iter().enumerate() {
695            if *group == 0 {
696                rank_sum1 += (rank + 1) as f64;
697            }
698        }
699
700        let n1 = sample1.len() as f64;
701        let n2 = sample2.len() as f64;
702        let u1 = rank_sum1 - n1 * (n1 + 1.0) / 2.0;
703        let u2 = n1 * n2 - u1;
704        let u_statistic = u1.min(u2);
705
706        // Approximate p-value using normal approximation
707        let mu = n1 * n2 / 2.0;
708        let sigma = (n1 * n2 * (n1 + n2 + 1.0) / 12.0).sqrt();
709        let z = (u_statistic - mu).abs() / sigma;
710        let p_value = 2.0 * (1.0 - self.normal_cdf(z));
711
712        (u_statistic, p_value.clamp(0.0, 1.0))
713    }
714
715    /// Calculate permutation test statistic
716    fn calculate_permutation_statistic(
717        &self,
718        ref_data: &Array2<f64>,
719        cur_data: &Array2<f64>,
720    ) -> f64 {
721        // Use mean difference as statistic
722        let ref_mean = self.calculate_mean(ref_data);
723        let cur_mean = self.calculate_mean(cur_data);
724        (ref_mean - cur_mean).abs()
725    }
726
727    /// Combine two datasets
728    fn combine_data(&self, data1: &Array2<f64>, data2: &Array2<f64>) -> Array2<f64> {
729        let n_rows = data1.nrows() + data2.nrows();
730        let n_cols = data1.ncols();
731        let mut combined = Array2::zeros((n_rows, n_cols));
732
733        // Copy data1
734        for i in 0..data1.nrows() {
735            for j in 0..n_cols {
736                combined[[i, j]] = data1[[i, j]];
737            }
738        }
739
740        // Copy data2
741        for i in 0..data2.nrows() {
742            for j in 0..n_cols {
743                combined[[data1.nrows() + i, j]] = data2[[i, j]];
744            }
745        }
746
747        combined
748    }
749
750    /// Random permutation split
751    fn random_permutation_split(
752        &self,
753        data: &Array2<f64>,
754        n_first: usize,
755    ) -> (Array2<f64>, Array2<f64>) {
756        use scirs2_core::random::rngs::StdRng;
757        use scirs2_core::random::SeedableRng;
758
759        let mut rng = match self.config.random_state {
760            Some(seed) => StdRng::seed_from_u64(seed),
761            None => {
762                use scirs2_core::random::thread_rng;
763                StdRng::from_rng(&mut thread_rng())
764            }
765        };
766
767        let mut indices: Vec<usize> = (0..data.nrows()).collect();
768        indices.shuffle(&mut rng);
769
770        let n_cols = data.ncols();
771        let mut first = Array2::zeros((n_first, n_cols));
772        let mut second = Array2::zeros((data.nrows() - n_first, n_cols));
773
774        for (i, &idx) in indices[..n_first].iter().enumerate() {
775            for j in 0..n_cols {
776                first[[i, j]] = data[[idx, j]];
777            }
778        }
779
780        for (i, &idx) in indices[n_first..].iter().enumerate() {
781            for j in 0..n_cols {
782                second[[i, j]] = data[[idx, j]];
783            }
784        }
785
786        (first, second)
787    }
788
789    /// Calculate Population Stability Index
790    fn calculate_psi(&self, reference: &[f64], current: &[f64], n_bins: usize) -> f64 {
791        // Calculate bin edges based on reference data
792        let mut ref_sorted = reference.to_vec();
793        ref_sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal));
794
795        let mut bin_edges = Vec::new();
796        for i in 0..=n_bins {
797            let quantile = i as f64 / n_bins as f64;
798            let idx = ((ref_sorted.len() - 1) as f64 * quantile) as usize;
799            bin_edges.push(ref_sorted[idx.min(ref_sorted.len() - 1)]);
800        }
801
802        // Calculate bin counts
803        let ref_counts = self.calculate_bin_counts(reference, &bin_edges);
804        let cur_counts = self.calculate_bin_counts(current, &bin_edges);
805
806        // Calculate PSI
807        let mut psi = 0.0;
808        for i in 0..n_bins {
809            let ref_prop = ref_counts[i] / reference.len() as f64;
810            let cur_prop = cur_counts[i] / current.len() as f64;
811
812            if ref_prop > 0.0 && cur_prop > 0.0 {
813                psi += (cur_prop - ref_prop) * (cur_prop / ref_prop).ln();
814            }
815        }
816
817        psi
818    }
819
820    /// Calculate bin counts
821    fn calculate_bin_counts(&self, data: &[f64], bin_edges: &[f64]) -> Vec<f64> {
822        let n_bins = bin_edges.len() - 1;
823        let mut counts = vec![0.0; n_bins];
824
825        for &value in data {
826            for i in 0..n_bins {
827                if (i == n_bins - 1 || value < bin_edges[i + 1]) && value >= bin_edges[i] {
828                    counts[i] += 1.0;
829                    break;
830                }
831            }
832        }
833
834        counts
835    }
836
837    /// Calculate Maximum Mean Discrepancy
838    fn calculate_mmd(&self, data1: &Array2<f64>, data2: &Array2<f64>) -> f64 {
839        // Simplified MMD with linear kernel
840        let mean1 = self.calculate_mean(data1);
841        let mean2 = self.calculate_mean(data2);
842        (mean1 - mean2).abs()
843    }
844
845    /// Calculate mean of dataset
846    fn calculate_mean(&self, data: &Array2<f64>) -> f64 {
847        let mut sum = 0.0;
848        let mut count = 0;
849
850        for i in 0..data.nrows() {
851            for j in 0..data.ncols() {
852                sum += data[[i, j]];
853                count += 1;
854            }
855        }
856
857        if count > 0 {
858            sum / count as f64
859        } else {
860            0.0
861        }
862    }
863
864    /// Calculate average performance (simplified)
865    fn calculate_average_performance(&self, data: &Array2<f64>) -> f64 {
866        // Simplified performance calculation
867        let mean = self.calculate_mean(data);
868        // Normalize to [0, 1] range (simplified)
869        (mean + 1.0) / 2.0
870    }
871
872    /// Time since last drift
873    fn time_since_drift(&self) -> Option<usize> {
874        self.last_drift_time
875            .map(|last| self.drift_history.len() - last)
876    }
877
878    /// Normal CDF approximation
879    fn normal_cdf(&self, x: f64) -> f64 {
880        0.5 * (1.0 + self.erf(x / 2.0_f64.sqrt()))
881    }
882
883    /// Error function approximation
884    fn erf(&self, x: f64) -> f64 {
885        let a1 = 0.254829592;
886        let a2 = -0.284496736;
887        let a3 = 1.421413741;
888        let a4 = -1.453152027;
889        let a5 = 1.061405429;
890        let p = 0.3275911;
891
892        let sign = if x < 0.0 { -1.0 } else { 1.0 };
893        let x = x.abs();
894
895        let t = 1.0 / (1.0 + p * x);
896        let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp();
897
898        sign * y
899    }
900}
901
902#[allow(non_snake_case)]
903#[cfg(test)]
904mod tests {
905    use super::*;
906
907    #[test]
908    fn test_ks_drift_detection() {
909        let config = DriftDetectionConfig::default();
910        let mut detector = DriftDetector::new(config);
911
912        // Create reference data
913        let mut reference = Array2::zeros((100, 2));
914        for i in 0..100 {
915            reference[[i, 0]] = i as f64 / 100.0;
916            reference[[i, 1]] = (i as f64 / 100.0).sin();
917        }
918        detector.set_reference(reference);
919
920        // Create similar current data (no drift) - subsample from same distribution
921        let mut current = Array2::zeros((50, 2));
922        for i in 0..50 {
923            let idx = i * 2; // Sample every other point from reference to avoid exact duplication
924            current[[i, 0]] = idx as f64 / 100.0;
925            current[[i, 1]] = (idx as f64 / 100.0).sin();
926        }
927
928        let result = detector.detect_drift(&current).unwrap();
929        assert!(
930            !result.drift_detected,
931            "Should not detect drift in similar data"
932        );
933    }
934
935    #[test]
936    fn test_psi_drift_detection() {
937        let config = DriftDetectionConfig {
938            detection_method: DriftDetectionMethod::PopulationStabilityIndex,
939            ..Default::default()
940        };
941        let mut detector = DriftDetector::new(config);
942
943        // Create reference data
944        let mut reference = Array2::zeros((100, 1));
945        for i in 0..100 {
946            reference[[i, 0]] = i as f64 / 100.0;
947        }
948        detector.set_reference(reference);
949
950        // Create shifted current data (drift)
951        let mut current = Array2::zeros((50, 1));
952        for i in 0..50 {
953            current[[i, 0]] = (i as f64 / 50.0) + 0.5; // Shifted distribution
954        }
955
956        let result = detector.detect_drift(&current).unwrap();
957        // PSI should detect this shift
958        assert!(result.drift_score > 0.1, "Should detect distribution shift");
959    }
960}