Skip to main content

sklears_model_selection/
drift_detection.rs

1//! Data drift detection for validation and monitoring
2//!
3//! This module provides methods to detect distribution changes in data
4//! that can affect model performance over time.
5
6use scirs2_core::ndarray::{Array1, Array2};
7use scirs2_core::SliceRandomExt;
8use sklears_core::error::{Result, SklearsError};
9use std::cmp::Ordering;
10
11/// Configuration for drift detection
12#[derive(Debug, Clone)]
13pub struct DriftDetectionConfig {
14    /// Detection method to use
15    pub detection_method: DriftDetectionMethod,
16    /// Significance level for statistical tests
17    pub alpha: f64,
18    /// Window size for windowed detection methods
19    pub window_size: usize,
20    /// Warning threshold (fraction of alpha)
21    pub warning_threshold: f64,
22    /// Minimum samples required for detection
23    pub min_samples: usize,
24    /// Whether to detect multivariate drift
25    pub multivariate: bool,
26    /// Random state for reproducible results
27    pub random_state: Option<u64>,
28}
29
30impl Default for DriftDetectionConfig {
31    fn default() -> Self {
32        Self {
33            detection_method: DriftDetectionMethod::KolmogorovSmirnov,
34            alpha: 0.05,
35            window_size: 100,
36            warning_threshold: 0.5,
37            min_samples: 30,
38            multivariate: false,
39            random_state: None,
40        }
41    }
42}
43
44/// Drift detection methods
45#[derive(Debug, Clone)]
46pub enum DriftDetectionMethod {
47    /// Kolmogorov-Smirnov test for univariate drift
48    KolmogorovSmirnov,
49    /// Anderson-Darling test
50    AndersonDarling,
51    /// Mann-Whitney U test
52    MannWhitney,
53    /// Permutation test
54    Permutation,
55    /// Population Stability Index (PSI)
56    PopulationStabilityIndex,
57    /// Maximum Mean Discrepancy (MMD)
58    MaximumMeanDiscrepancy,
59    /// ADWIN (Adaptive Windowing)
60    ADWIN,
61    /// Page-Hinkley test
62    PageHinkley,
63    /// Drift Detection Method (DDM)
64    DDM,
65    /// Early Drift Detection Method (EDDM)
66    EDDM,
67}
68
69/// Results from drift detection
70#[derive(Debug, Clone)]
71pub struct DriftDetectionResult {
72    /// Whether drift was detected
73    pub drift_detected: bool,
74    /// Whether warning threshold was exceeded
75    pub warning_detected: bool,
76    /// Test statistic value
77    pub test_statistic: f64,
78    /// P-value for statistical tests
79    pub p_value: Option<f64>,
80    /// Threshold used for detection
81    pub threshold: f64,
82    /// Drift score (higher = more drift)
83    pub drift_score: f64,
84    /// Per-feature drift scores
85    pub feature_drift_scores: Option<Vec<f64>>,
86    /// Detailed statistics
87    pub statistics: DriftStatistics,
88}
89
90/// Detailed drift statistics
91#[derive(Debug, Clone)]
92pub struct DriftStatistics {
93    /// Number of reference samples
94    pub n_reference: usize,
95    /// Number of current samples
96    pub n_current: usize,
97    /// Number of features analyzed
98    pub n_features: usize,
99    /// Drift magnitude estimate
100    pub drift_magnitude: f64,
101    /// Confidence in drift detection
102    pub confidence: f64,
103    /// Time since last drift (if applicable)
104    pub time_since_drift: Option<usize>,
105}
106
107/// Drift detector for monitoring data distribution changes
108#[derive(Debug, Clone)]
109pub struct DriftDetector {
110    config: DriftDetectionConfig,
111    reference_data: Option<Array2<f64>>,
112    current_window: Vec<Array1<f64>>,
113    drift_history: Vec<DriftDetectionResult>,
114    last_drift_time: Option<usize>,
115}
116
117impl DriftDetector {
118    pub fn new(config: DriftDetectionConfig) -> Self {
119        Self {
120            config,
121            reference_data: None,
122            current_window: Vec::new(),
123            drift_history: Vec::new(),
124            last_drift_time: None,
125        }
126    }
127
128    /// Set reference data for drift detection
129    pub fn set_reference(&mut self, reference_data: Array2<f64>) {
130        self.reference_data = Some(reference_data);
131    }
132
133    /// Detect drift in new data
134    pub fn detect_drift(&mut self, current_data: &Array2<f64>) -> Result<DriftDetectionResult> {
135        if self.reference_data.is_none() {
136            return Err(SklearsError::NotFitted {
137                operation: "drift detection".to_string(),
138            });
139        }
140
141        let reference = self
142            .reference_data
143            .as_ref()
144            .expect("operation should succeed");
145
146        if reference.ncols() != current_data.ncols() {
147            return Err(SklearsError::InvalidInput(
148                "Reference and current data must have same number of features".to_string(),
149            ));
150        }
151
152        let result = match self.config.detection_method {
153            DriftDetectionMethod::KolmogorovSmirnov => {
154                self.kolmogorov_smirnov_test(reference, current_data)?
155            }
156            DriftDetectionMethod::AndersonDarling => {
157                self.anderson_darling_test(reference, current_data)?
158            }
159            DriftDetectionMethod::MannWhitney => self.mann_whitney_test(reference, current_data)?,
160            DriftDetectionMethod::Permutation => self.permutation_test(reference, current_data)?,
161            DriftDetectionMethod::PopulationStabilityIndex => {
162                self.population_stability_index(reference, current_data)?
163            }
164            DriftDetectionMethod::MaximumMeanDiscrepancy => {
165                self.maximum_mean_discrepancy(reference, current_data)?
166            }
167            DriftDetectionMethod::ADWIN => self.adwin_test(current_data)?,
168            DriftDetectionMethod::PageHinkley => self.page_hinkley_test(current_data)?,
169            DriftDetectionMethod::DDM => self.ddm_test(current_data)?,
170            DriftDetectionMethod::EDDM => self.eddm_test(current_data)?,
171        };
172
173        self.drift_history.push(result.clone());
174
175        if result.drift_detected {
176            self.last_drift_time = Some(self.drift_history.len());
177        }
178
179        Ok(result)
180    }
181
182    /// Kolmogorov-Smirnov test for drift detection
183    fn kolmogorov_smirnov_test(
184        &self,
185        reference: &Array2<f64>,
186        current: &Array2<f64>,
187    ) -> Result<DriftDetectionResult> {
188        let n_features = reference.ncols();
189        let mut feature_scores = Vec::new();
190        let mut max_statistic: f64 = 0.0;
191        let mut min_p_value: f64 = 1.0;
192
193        for feature_idx in 0..n_features {
194            let ref_feature: Vec<f64> = (0..reference.nrows())
195                .map(|i| reference[[i, feature_idx]])
196                .collect();
197            let cur_feature: Vec<f64> = (0..current.nrows())
198                .map(|i| current[[i, feature_idx]])
199                .collect();
200
201            let (statistic, p_value) = self.ks_test(&ref_feature, &cur_feature);
202            feature_scores.push(statistic);
203            max_statistic = max_statistic.max(statistic);
204            min_p_value = min_p_value.min(p_value);
205        }
206
207        let drift_detected = min_p_value < self.config.alpha;
208        let warning_detected =
209            min_p_value < self.config.alpha * (1.0 + self.config.warning_threshold);
210
211        let statistics = DriftStatistics {
212            n_reference: reference.nrows(),
213            n_current: current.nrows(),
214            n_features,
215            drift_magnitude: max_statistic,
216            confidence: 1.0 - min_p_value,
217            time_since_drift: self.time_since_drift(),
218        };
219
220        Ok(DriftDetectionResult {
221            drift_detected,
222            warning_detected,
223            test_statistic: max_statistic,
224            p_value: Some(min_p_value),
225            threshold: self.config.alpha,
226            drift_score: max_statistic,
227            feature_drift_scores: Some(feature_scores),
228            statistics,
229        })
230    }
231
232    /// Anderson-Darling test for drift detection
233    fn anderson_darling_test(
234        &self,
235        reference: &Array2<f64>,
236        current: &Array2<f64>,
237    ) -> Result<DriftDetectionResult> {
238        // Simplified Anderson-Darling test implementation
239        let n_features = reference.ncols();
240        let mut feature_scores = Vec::new();
241        let mut max_statistic: f64 = 0.0;
242
243        for feature_idx in 0..n_features {
244            let ref_feature: Vec<f64> = (0..reference.nrows())
245                .map(|i| reference[[i, feature_idx]])
246                .collect();
247            let cur_feature: Vec<f64> = (0..current.nrows())
248                .map(|i| current[[i, feature_idx]])
249                .collect();
250
251            let statistic = self.anderson_darling_statistic(&ref_feature, &cur_feature);
252            feature_scores.push(statistic);
253            max_statistic = max_statistic.max(statistic);
254        }
255
256        // Approximate threshold for Anderson-Darling
257        let threshold = 2.492; // Critical value for alpha = 0.05
258        let drift_detected = max_statistic > threshold;
259        let warning_detected = max_statistic > threshold * self.config.warning_threshold;
260
261        let statistics = DriftStatistics {
262            n_reference: reference.nrows(),
263            n_current: current.nrows(),
264            n_features,
265            drift_magnitude: max_statistic,
266            confidence: if drift_detected { 0.95 } else { 0.5 },
267            time_since_drift: self.time_since_drift(),
268        };
269
270        Ok(DriftDetectionResult {
271            drift_detected,
272            warning_detected,
273            test_statistic: max_statistic,
274            p_value: None,
275            threshold,
276            drift_score: max_statistic,
277            feature_drift_scores: Some(feature_scores),
278            statistics,
279        })
280    }
281
282    /// Mann-Whitney U test for drift detection
283    fn mann_whitney_test(
284        &self,
285        reference: &Array2<f64>,
286        current: &Array2<f64>,
287    ) -> Result<DriftDetectionResult> {
288        let n_features = reference.ncols();
289        let mut feature_scores = Vec::new();
290        let mut max_statistic: f64 = 0.0;
291        let mut min_p_value: f64 = 1.0;
292
293        for feature_idx in 0..n_features {
294            let ref_feature: Vec<f64> = (0..reference.nrows())
295                .map(|i| reference[[i, feature_idx]])
296                .collect();
297            let cur_feature: Vec<f64> = (0..current.nrows())
298                .map(|i| current[[i, feature_idx]])
299                .collect();
300
301            let (u_statistic, p_value) = self.mann_whitney_u_test(&ref_feature, &cur_feature);
302            let normalized_statistic = u_statistic / (ref_feature.len() * cur_feature.len()) as f64;
303
304            feature_scores.push(normalized_statistic);
305            max_statistic = max_statistic.max(normalized_statistic);
306            min_p_value = min_p_value.min(p_value);
307        }
308
309        let drift_detected = min_p_value < self.config.alpha;
310        let warning_detected =
311            min_p_value < self.config.alpha * (1.0 + self.config.warning_threshold);
312
313        let statistics = DriftStatistics {
314            n_reference: reference.nrows(),
315            n_current: current.nrows(),
316            n_features,
317            drift_magnitude: max_statistic,
318            confidence: 1.0 - min_p_value,
319            time_since_drift: self.time_since_drift(),
320        };
321
322        Ok(DriftDetectionResult {
323            drift_detected,
324            warning_detected,
325            test_statistic: max_statistic,
326            p_value: Some(min_p_value),
327            threshold: self.config.alpha,
328            drift_score: max_statistic,
329            feature_drift_scores: Some(feature_scores),
330            statistics,
331        })
332    }
333
334    /// Permutation test for drift detection
335    fn permutation_test(
336        &self,
337        reference: &Array2<f64>,
338        current: &Array2<f64>,
339    ) -> Result<DriftDetectionResult> {
340        let n_permutations = 1000;
341        let observed_statistic = self.calculate_permutation_statistic(reference, current);
342
343        let mut permutation_statistics = Vec::new();
344        let combined_data = self.combine_data(reference, current);
345        let n_ref = reference.nrows();
346
347        for _ in 0..n_permutations {
348            let (perm_ref, perm_cur) = self.random_permutation_split(&combined_data, n_ref);
349            let perm_statistic = self.calculate_permutation_statistic(&perm_ref, &perm_cur);
350            permutation_statistics.push(perm_statistic);
351        }
352
353        // Calculate p-value
354        let extreme_count = permutation_statistics
355            .iter()
356            .filter(|&&stat| stat >= observed_statistic)
357            .count();
358        let p_value = extreme_count as f64 / n_permutations as f64;
359
360        let drift_detected = p_value < self.config.alpha;
361        let warning_detected = p_value < self.config.alpha * (1.0 + self.config.warning_threshold);
362
363        let statistics = DriftStatistics {
364            n_reference: reference.nrows(),
365            n_current: current.nrows(),
366            n_features: reference.ncols(),
367            drift_magnitude: observed_statistic,
368            confidence: 1.0 - p_value,
369            time_since_drift: self.time_since_drift(),
370        };
371
372        Ok(DriftDetectionResult {
373            drift_detected,
374            warning_detected,
375            test_statistic: observed_statistic,
376            p_value: Some(p_value),
377            threshold: self.config.alpha,
378            drift_score: observed_statistic,
379            feature_drift_scores: None,
380            statistics,
381        })
382    }
383
384    /// Population Stability Index (PSI) for drift detection
385    fn population_stability_index(
386        &self,
387        reference: &Array2<f64>,
388        current: &Array2<f64>,
389    ) -> Result<DriftDetectionResult> {
390        let n_features = reference.ncols();
391        let n_bins = 10;
392        let mut feature_scores = Vec::new();
393        let mut total_psi = 0.0;
394
395        for feature_idx in 0..n_features {
396            let ref_feature: Vec<f64> = (0..reference.nrows())
397                .map(|i| reference[[i, feature_idx]])
398                .collect();
399            let cur_feature: Vec<f64> = (0..current.nrows())
400                .map(|i| current[[i, feature_idx]])
401                .collect();
402
403            let psi = self.calculate_psi(&ref_feature, &cur_feature, n_bins);
404            feature_scores.push(psi);
405            total_psi += psi;
406        }
407
408        let avg_psi = total_psi / n_features as f64;
409
410        // PSI thresholds: <0.1 (no drift), 0.1-0.2 (minor), >0.2 (major)
411        let drift_detected = avg_psi > 0.2;
412        let warning_detected = avg_psi > 0.1;
413
414        let statistics = DriftStatistics {
415            n_reference: reference.nrows(),
416            n_current: current.nrows(),
417            n_features,
418            drift_magnitude: avg_psi,
419            confidence: if drift_detected { 0.8 } else { 0.5 },
420            time_since_drift: self.time_since_drift(),
421        };
422
423        Ok(DriftDetectionResult {
424            drift_detected,
425            warning_detected,
426            test_statistic: avg_psi,
427            p_value: None,
428            threshold: 0.2,
429            drift_score: avg_psi,
430            feature_drift_scores: Some(feature_scores),
431            statistics,
432        })
433    }
434
435    /// Maximum Mean Discrepancy (MMD) test
436    fn maximum_mean_discrepancy(
437        &self,
438        reference: &Array2<f64>,
439        current: &Array2<f64>,
440    ) -> Result<DriftDetectionResult> {
441        let mmd_statistic = self.calculate_mmd(reference, current);
442
443        // Use permutation test to get p-value
444        let n_permutations = 1000;
445        let mut permutation_mmds = Vec::new();
446        let combined_data = self.combine_data(reference, current);
447        let n_ref = reference.nrows();
448
449        for _ in 0..n_permutations {
450            let (perm_ref, perm_cur) = self.random_permutation_split(&combined_data, n_ref);
451            let perm_mmd = self.calculate_mmd(&perm_ref, &perm_cur);
452            permutation_mmds.push(perm_mmd);
453        }
454
455        let extreme_count = permutation_mmds
456            .iter()
457            .filter(|&&mmd| mmd >= mmd_statistic)
458            .count();
459        let p_value = extreme_count as f64 / n_permutations as f64;
460
461        let drift_detected = p_value < self.config.alpha;
462        let warning_detected = p_value < self.config.alpha * (1.0 + self.config.warning_threshold);
463
464        let statistics = DriftStatistics {
465            n_reference: reference.nrows(),
466            n_current: current.nrows(),
467            n_features: reference.ncols(),
468            drift_magnitude: mmd_statistic,
469            confidence: 1.0 - p_value,
470            time_since_drift: self.time_since_drift(),
471        };
472
473        Ok(DriftDetectionResult {
474            drift_detected,
475            warning_detected,
476            test_statistic: mmd_statistic,
477            p_value: Some(p_value),
478            threshold: self.config.alpha,
479            drift_score: mmd_statistic,
480            feature_drift_scores: None,
481            statistics,
482        })
483    }
484
485    /// ADWIN (Adaptive Windowing) drift detection
486    fn adwin_test(&mut self, current: &Array2<f64>) -> Result<DriftDetectionResult> {
487        // Simplified ADWIN implementation
488        // In practice, this would maintain adaptive windows
489
490        let n_samples = current.nrows();
491        let avg_performance = self.calculate_average_performance(current);
492
493        // Add to current window
494        for i in 0..n_samples {
495            let sample = current.row(i).to_owned();
496            self.current_window.push(sample);
497        }
498
499        // Keep window size manageable
500        if self.current_window.len() > self.config.window_size * 2 {
501            let excess = self.current_window.len() - self.config.window_size;
502            self.current_window.drain(0..excess);
503        }
504
505        let drift_detected =
506            self.current_window.len() >= self.config.min_samples && avg_performance < 0.5; // Simplified threshold
507
508        let statistics = DriftStatistics {
509            n_reference: 0,
510            n_current: current.nrows(),
511            n_features: current.ncols(),
512            drift_magnitude: 1.0 - avg_performance,
513            confidence: if drift_detected { 0.8 } else { 0.5 },
514            time_since_drift: self.time_since_drift(),
515        };
516
517        Ok(DriftDetectionResult {
518            drift_detected,
519            warning_detected: avg_performance < 0.7,
520            test_statistic: 1.0 - avg_performance,
521            p_value: None,
522            threshold: 0.5,
523            drift_score: 1.0 - avg_performance,
524            feature_drift_scores: None,
525            statistics,
526        })
527    }
528
529    /// Page-Hinkley test for drift detection
530    fn page_hinkley_test(&self, current: &Array2<f64>) -> Result<DriftDetectionResult> {
531        // Simplified Page-Hinkley test
532        let avg_performance = self.calculate_average_performance(current);
533        let threshold = 3.0; // Typical threshold
534
535        let cumulative_sum = (0.5 - avg_performance) * current.nrows() as f64;
536        let drift_detected = cumulative_sum.abs() > threshold;
537
538        let statistics = DriftStatistics {
539            n_reference: 0,
540            n_current: current.nrows(),
541            n_features: current.ncols(),
542            drift_magnitude: cumulative_sum.abs(),
543            confidence: if drift_detected { 0.8 } else { 0.5 },
544            time_since_drift: self.time_since_drift(),
545        };
546
547        Ok(DriftDetectionResult {
548            drift_detected,
549            warning_detected: cumulative_sum.abs() > threshold * 0.7,
550            test_statistic: cumulative_sum.abs(),
551            p_value: None,
552            threshold,
553            drift_score: cumulative_sum.abs(),
554            feature_drift_scores: None,
555            statistics,
556        })
557    }
558
559    /// DDM (Drift Detection Method)
560    fn ddm_test(&self, current: &Array2<f64>) -> Result<DriftDetectionResult> {
561        // Simplified DDM implementation
562        let error_rate = 1.0 - self.calculate_average_performance(current);
563        let std_error = (error_rate * (1.0 - error_rate) / current.nrows() as f64).sqrt();
564
565        let warning_threshold = error_rate + 2.0 * std_error;
566        let drift_threshold = error_rate + 3.0 * std_error;
567
568        let drift_detected = error_rate > drift_threshold;
569        let warning_detected = error_rate > warning_threshold;
570
571        let statistics = DriftStatistics {
572            n_reference: 0,
573            n_current: current.nrows(),
574            n_features: current.ncols(),
575            drift_magnitude: error_rate,
576            confidence: if drift_detected {
577                0.99
578            } else if warning_detected {
579                0.95
580            } else {
581                0.5
582            },
583            time_since_drift: self.time_since_drift(),
584        };
585
586        Ok(DriftDetectionResult {
587            drift_detected,
588            warning_detected,
589            test_statistic: error_rate,
590            p_value: None,
591            threshold: drift_threshold,
592            drift_score: error_rate,
593            feature_drift_scores: None,
594            statistics,
595        })
596    }
597
598    /// EDDM (Early Drift Detection Method)
599    fn eddm_test(&self, current: &Array2<f64>) -> Result<DriftDetectionResult> {
600        // Simplified EDDM implementation
601        let avg_performance = self.calculate_average_performance(current);
602        let _distance_between_errors = 1.0 / (1.0 - avg_performance + 1e-8);
603
604        let threshold = 0.95;
605        let drift_detected = avg_performance < threshold;
606
607        let statistics = DriftStatistics {
608            n_reference: 0,
609            n_current: current.nrows(),
610            n_features: current.ncols(),
611            drift_magnitude: 1.0 - avg_performance,
612            confidence: if drift_detected { 0.8 } else { 0.5 },
613            time_since_drift: self.time_since_drift(),
614        };
615
616        Ok(DriftDetectionResult {
617            drift_detected,
618            warning_detected: avg_performance < 0.98,
619            test_statistic: 1.0 - avg_performance,
620            p_value: None,
621            threshold: 1.0 - threshold,
622            drift_score: 1.0 - avg_performance,
623            feature_drift_scores: None,
624            statistics,
625        })
626    }
627
628    // Helper methods
629
630    /// Kolmogorov-Smirnov test implementation
631    fn ks_test(&self, sample1: &[f64], sample2: &[f64]) -> (f64, f64) {
632        let mut combined: Vec<(f64, usize)> = sample1.iter().map(|&x| (x, 0)).collect();
633        combined.extend(sample2.iter().map(|&x| (x, 1)));
634        combined.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(Ordering::Equal));
635
636        let n1 = sample1.len() as f64;
637        let n2 = sample2.len() as f64;
638        let mut cdf1 = 0.0;
639        let mut cdf2 = 0.0;
640        let mut max_diff: f64 = 0.0;
641
642        for (_, group) in combined {
643            if group == 0 {
644                cdf1 += 1.0 / n1;
645            } else {
646                cdf2 += 1.0 / n2;
647            }
648            max_diff = max_diff.max((cdf1 - cdf2).abs());
649        }
650
651        // Approximate p-value calculation
652        let ks_statistic = max_diff;
653        let en = (n1 * n2 / (n1 + n2)).sqrt();
654        let lambda = en * ks_statistic;
655        let p_value = 2.0 * (-2.0 * lambda * lambda).exp();
656
657        (ks_statistic, p_value.clamp(0.0, 1.0))
658    }
659
660    /// Anderson-Darling statistic calculation
661    fn anderson_darling_statistic(&self, sample1: &[f64], sample2: &[f64]) -> f64 {
662        // Simplified implementation
663        let mut combined: Vec<f64> = sample1.iter().chain(sample2.iter()).cloned().collect();
664        combined.sort_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal));
665
666        let n1 = sample1.len() as f64;
667        let n2 = sample2.len() as f64;
668        let n = n1 + n2;
669
670        let mut h = 0.0;
671        let mut prev_val = f64::NEG_INFINITY;
672        let _i = 0.0;
673
674        for &val in &combined {
675            if val != prev_val {
676                let count1 = sample1.iter().filter(|&&x| x <= val).count() as f64;
677                let count2 = sample2.iter().filter(|&&x| x <= val).count() as f64;
678
679                let l = count1 + count2;
680                if l > 0.0 && l < n {
681                    h += (count1 / n1 - count2 / n2).powi(2) / (l * (n - l));
682                }
683                prev_val = val;
684            }
685        }
686
687        n1 * n2 * h / n
688    }
689
690    /// Mann-Whitney U test implementation
691    fn mann_whitney_u_test(&self, sample1: &[f64], sample2: &[f64]) -> (f64, f64) {
692        let mut combined: Vec<(f64, usize)> = sample1.iter().map(|&x| (x, 0)).collect();
693        combined.extend(sample2.iter().map(|&x| (x, 1)));
694        combined.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(Ordering::Equal));
695
696        let mut rank_sum1 = 0.0;
697        for (rank, (_, group)) in combined.iter().enumerate() {
698            if *group == 0 {
699                rank_sum1 += (rank + 1) as f64;
700            }
701        }
702
703        let n1 = sample1.len() as f64;
704        let n2 = sample2.len() as f64;
705        let u1 = rank_sum1 - n1 * (n1 + 1.0) / 2.0;
706        let u2 = n1 * n2 - u1;
707        let u_statistic = u1.min(u2);
708
709        // Approximate p-value using normal approximation
710        let mu = n1 * n2 / 2.0;
711        let sigma = (n1 * n2 * (n1 + n2 + 1.0) / 12.0).sqrt();
712        let z = (u_statistic - mu).abs() / sigma;
713        let p_value = 2.0 * (1.0 - self.normal_cdf(z));
714
715        (u_statistic, p_value.clamp(0.0, 1.0))
716    }
717
718    /// Calculate permutation test statistic
719    fn calculate_permutation_statistic(
720        &self,
721        ref_data: &Array2<f64>,
722        cur_data: &Array2<f64>,
723    ) -> f64 {
724        // Use mean difference as statistic
725        let ref_mean = self.calculate_mean(ref_data);
726        let cur_mean = self.calculate_mean(cur_data);
727        (ref_mean - cur_mean).abs()
728    }
729
730    /// Combine two datasets
731    fn combine_data(&self, data1: &Array2<f64>, data2: &Array2<f64>) -> Array2<f64> {
732        let n_rows = data1.nrows() + data2.nrows();
733        let n_cols = data1.ncols();
734        let mut combined = Array2::zeros((n_rows, n_cols));
735
736        // Copy data1
737        for i in 0..data1.nrows() {
738            for j in 0..n_cols {
739                combined[[i, j]] = data1[[i, j]];
740            }
741        }
742
743        // Copy data2
744        for i in 0..data2.nrows() {
745            for j in 0..n_cols {
746                combined[[data1.nrows() + i, j]] = data2[[i, j]];
747            }
748        }
749
750        combined
751    }
752
753    /// Random permutation split
754    fn random_permutation_split(
755        &self,
756        data: &Array2<f64>,
757        n_first: usize,
758    ) -> (Array2<f64>, Array2<f64>) {
759        use scirs2_core::random::rngs::StdRng;
760        use scirs2_core::random::SeedableRng;
761
762        let mut rng = match self.config.random_state {
763            Some(seed) => StdRng::seed_from_u64(seed),
764            None => {
765                use scirs2_core::random::thread_rng;
766                StdRng::from_rng(&mut thread_rng())
767            }
768        };
769
770        let mut indices: Vec<usize> = (0..data.nrows()).collect();
771        indices.shuffle(&mut rng);
772
773        let n_cols = data.ncols();
774        let mut first = Array2::zeros((n_first, n_cols));
775        let mut second = Array2::zeros((data.nrows() - n_first, n_cols));
776
777        for (i, &idx) in indices[..n_first].iter().enumerate() {
778            for j in 0..n_cols {
779                first[[i, j]] = data[[idx, j]];
780            }
781        }
782
783        for (i, &idx) in indices[n_first..].iter().enumerate() {
784            for j in 0..n_cols {
785                second[[i, j]] = data[[idx, j]];
786            }
787        }
788
789        (first, second)
790    }
791
792    /// Calculate Population Stability Index
793    fn calculate_psi(&self, reference: &[f64], current: &[f64], n_bins: usize) -> f64 {
794        // Calculate bin edges based on reference data
795        let mut ref_sorted = reference.to_vec();
796        ref_sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal));
797
798        let mut bin_edges = Vec::new();
799        for i in 0..=n_bins {
800            let quantile = i as f64 / n_bins as f64;
801            let idx = ((ref_sorted.len() - 1) as f64 * quantile) as usize;
802            bin_edges.push(ref_sorted[idx.min(ref_sorted.len() - 1)]);
803        }
804
805        // Calculate bin counts
806        let ref_counts = self.calculate_bin_counts(reference, &bin_edges);
807        let cur_counts = self.calculate_bin_counts(current, &bin_edges);
808
809        // Calculate PSI
810        let mut psi = 0.0;
811        for i in 0..n_bins {
812            let ref_prop = ref_counts[i] / reference.len() as f64;
813            let cur_prop = cur_counts[i] / current.len() as f64;
814
815            if ref_prop > 0.0 && cur_prop > 0.0 {
816                psi += (cur_prop - ref_prop) * (cur_prop / ref_prop).ln();
817            }
818        }
819
820        psi
821    }
822
823    /// Calculate bin counts
824    fn calculate_bin_counts(&self, data: &[f64], bin_edges: &[f64]) -> Vec<f64> {
825        let n_bins = bin_edges.len() - 1;
826        let mut counts = vec![0.0; n_bins];
827
828        for &value in data {
829            for i in 0..n_bins {
830                if (i == n_bins - 1 || value < bin_edges[i + 1]) && value >= bin_edges[i] {
831                    counts[i] += 1.0;
832                    break;
833                }
834            }
835        }
836
837        counts
838    }
839
840    /// Calculate Maximum Mean Discrepancy
841    fn calculate_mmd(&self, data1: &Array2<f64>, data2: &Array2<f64>) -> f64 {
842        // Simplified MMD with linear kernel
843        let mean1 = self.calculate_mean(data1);
844        let mean2 = self.calculate_mean(data2);
845        (mean1 - mean2).abs()
846    }
847
848    /// Calculate mean of dataset
849    fn calculate_mean(&self, data: &Array2<f64>) -> f64 {
850        let mut sum = 0.0;
851        let mut count = 0;
852
853        for i in 0..data.nrows() {
854            for j in 0..data.ncols() {
855                sum += data[[i, j]];
856                count += 1;
857            }
858        }
859
860        if count > 0 {
861            sum / count as f64
862        } else {
863            0.0
864        }
865    }
866
867    /// Calculate average performance (simplified)
868    fn calculate_average_performance(&self, data: &Array2<f64>) -> f64 {
869        // Simplified performance calculation
870        let mean = self.calculate_mean(data);
871        // Normalize to [0, 1] range (simplified)
872        (mean + 1.0) / 2.0
873    }
874
875    /// Time since last drift
876    fn time_since_drift(&self) -> Option<usize> {
877        self.last_drift_time
878            .map(|last| self.drift_history.len() - last)
879    }
880
881    /// Normal CDF approximation
882    fn normal_cdf(&self, x: f64) -> f64 {
883        0.5 * (1.0 + self.erf(x / 2.0_f64.sqrt()))
884    }
885
886    /// Error function approximation
887    fn erf(&self, x: f64) -> f64 {
888        let a1 = 0.254829592;
889        let a2 = -0.284496736;
890        let a3 = 1.421413741;
891        let a4 = -1.453152027;
892        let a5 = 1.061405429;
893        let p = 0.3275911;
894
895        let sign = if x < 0.0 { -1.0 } else { 1.0 };
896        let x = x.abs();
897
898        let t = 1.0 / (1.0 + p * x);
899        let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp();
900
901        sign * y
902    }
903}
904
905#[allow(non_snake_case)]
906#[cfg(test)]
907mod tests {
908    use super::*;
909
910    #[test]
911    fn test_ks_drift_detection() {
912        let config = DriftDetectionConfig::default();
913        let mut detector = DriftDetector::new(config);
914
915        // Create reference data
916        let mut reference = Array2::zeros((100, 2));
917        for i in 0..100 {
918            reference[[i, 0]] = i as f64 / 100.0;
919            reference[[i, 1]] = (i as f64 / 100.0).sin();
920        }
921        detector.set_reference(reference);
922
923        // Create similar current data (no drift) - subsample from same distribution
924        let mut current = Array2::zeros((50, 2));
925        for i in 0..50 {
926            let idx = i * 2; // Sample every other point from reference to avoid exact duplication
927            current[[i, 0]] = idx as f64 / 100.0;
928            current[[i, 1]] = (idx as f64 / 100.0).sin();
929        }
930
931        let result = detector
932            .detect_drift(&current)
933            .expect("operation should succeed");
934        assert!(
935            !result.drift_detected,
936            "Should not detect drift in similar data"
937        );
938    }
939
940    #[test]
941    fn test_psi_drift_detection() {
942        let config = DriftDetectionConfig {
943            detection_method: DriftDetectionMethod::PopulationStabilityIndex,
944            ..Default::default()
945        };
946        let mut detector = DriftDetector::new(config);
947
948        // Create reference data
949        let mut reference = Array2::zeros((100, 1));
950        for i in 0..100 {
951            reference[[i, 0]] = i as f64 / 100.0;
952        }
953        detector.set_reference(reference);
954
955        // Create shifted current data (drift)
956        let mut current = Array2::zeros((50, 1));
957        for i in 0..50 {
958            current[[i, 0]] = (i as f64 / 50.0) + 0.5; // Shifted distribution
959        }
960
961        let result = detector
962            .detect_drift(&current)
963            .expect("operation should succeed");
964        // PSI should detect this shift
965        assert!(result.drift_score > 0.1, "Should detect distribution shift");
966    }
967}