Skip to main content

sklears_model_selection/
adversarial_validation.rs

1//! Adversarial validation for robustness testing and data leakage detection
2//!
3//! This module provides methods to detect potential issues in training/test splits
4//! by training discriminators to distinguish between training and test data.
5
6use scirs2_core::ndarray::Array2;
7use scirs2_core::RngExt;
8use scirs2_core::SliceRandomExt;
9use sklears_core::error::{Result, SklearsError};
10
11/// Configuration for adversarial validation
12#[derive(Debug, Clone)]
13pub struct AdversarialValidationConfig {
14    /// Number of cross-validation folds for discriminator training
15    pub cv_folds: usize,
16    /// Test size for discriminator evaluation
17    pub test_size: f64,
18    /// Threshold for considering distributions significantly different
19    pub significance_threshold: f64,
20    /// Number of bootstrap samples for confidence intervals
21    pub n_bootstrap: usize,
22    /// Random state for reproducible results
23    pub random_state: Option<u64>,
24    /// Whether to perform feature importance analysis
25    pub analyze_features: bool,
26    /// Maximum number of discriminator iterations
27    pub max_iterations: usize,
28}
29
30impl Default for AdversarialValidationConfig {
31    fn default() -> Self {
32        Self {
33            cv_folds: 5,
34            test_size: 0.2,
35            significance_threshold: 0.6, // AUC threshold for detecting issues
36            n_bootstrap: 1000,
37            random_state: None,
38            analyze_features: true,
39            max_iterations: 100,
40        }
41    }
42}
43
44/// Results from adversarial validation
45#[derive(Debug, Clone)]
46pub struct AdversarialValidationResult {
47    /// AUC score of the discriminator (0.5 = no difference, 1.0 = perfect discrimination)
48    pub discriminator_auc: f64,
49    /// Confidence interval for the AUC score
50    pub auc_confidence_interval: (f64, f64),
51    /// P-value for statistical significance test
52    pub p_value: f64,
53    /// Whether the distributions are significantly different
54    pub is_significantly_different: bool,
55    /// Feature importance scores (if enabled)
56    pub feature_importance: Option<Vec<f64>>,
57    /// Suspicious feature indices (features that help discriminate)
58    pub suspicious_features: Vec<usize>,
59    /// Cross-validation scores for the discriminator
60    pub cv_scores: Vec<f64>,
61    /// Detailed statistics about the validation
62    pub statistics: AdversarialStatistics,
63}
64
65/// Detailed statistics from adversarial validation
66#[derive(Debug, Clone)]
67pub struct AdversarialStatistics {
68    /// Number of training samples used
69    pub n_train_samples: usize,
70    /// Number of test samples used
71    pub n_test_samples: usize,
72    /// Number of features analyzed
73    pub n_features: usize,
74    /// Mean AUC across cross-validation folds
75    pub mean_cv_auc: f64,
76    /// Standard deviation of AUC across folds
77    pub std_cv_auc: f64,
78    /// Best single fold AUC
79    pub best_cv_auc: f64,
80    /// Worst single fold AUC
81    pub worst_cv_auc: f64,
82}
83
84/// Adversarial validator for detecting data leakage and distribution shifts
85#[derive(Debug, Clone)]
86pub struct AdversarialValidator {
87    config: AdversarialValidationConfig,
88}
89
90impl AdversarialValidator {
91    pub fn new(config: AdversarialValidationConfig) -> Self {
92        Self { config }
93    }
94
95    /// Perform adversarial validation on training and test sets
96    pub fn validate(
97        &self,
98        train_data: &Array2<f64>,
99        test_data: &Array2<f64>,
100    ) -> Result<AdversarialValidationResult> {
101        if train_data.ncols() != test_data.ncols() {
102            return Err(SklearsError::InvalidInput(
103                "Training and test data must have the same number of features".to_string(),
104            ));
105        }
106
107        // Create combined dataset with labels (0 = train, 1 = test)
108        let (combined_data, labels) = self.prepare_adversarial_data(train_data, test_data)?;
109
110        // Train discriminator using cross-validation
111        let cv_scores = self.cross_validate_discriminator(&combined_data, &labels)?;
112
113        // Calculate main discriminator performance
114        let discriminator_auc = self.train_discriminator(&combined_data, &labels)?;
115
116        // Bootstrap confidence intervals
117        let auc_ci = self.bootstrap_confidence_interval(&combined_data, &labels)?;
118
119        // Feature importance analysis
120        let (feature_importance, suspicious_features) = if self.config.analyze_features {
121            self.analyze_feature_importance(&combined_data, &labels)?
122        } else {
123            (None, Vec::new())
124        };
125
126        // Statistical significance test
127        let p_value = self.calculate_p_value(&cv_scores);
128        let is_significantly_different = discriminator_auc > self.config.significance_threshold;
129
130        // Calculate statistics
131        let statistics = AdversarialStatistics {
132            n_train_samples: train_data.nrows(),
133            n_test_samples: test_data.nrows(),
134            n_features: train_data.ncols(),
135            mean_cv_auc: cv_scores.iter().sum::<f64>() / cv_scores.len() as f64,
136            std_cv_auc: self.calculate_std(&cv_scores),
137            best_cv_auc: cv_scores.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b)),
138            worst_cv_auc: cv_scores.iter().fold(f64::INFINITY, |a, &b| a.min(b)),
139        };
140
141        Ok(AdversarialValidationResult {
142            discriminator_auc,
143            auc_confidence_interval: auc_ci,
144            p_value,
145            is_significantly_different,
146            feature_importance,
147            suspicious_features,
148            cv_scores,
149            statistics,
150        })
151    }
152
153    /// Prepare adversarial dataset by combining train and test data
154    fn prepare_adversarial_data(
155        &self,
156        train_data: &Array2<f64>,
157        test_data: &Array2<f64>,
158    ) -> Result<(Array2<f64>, Vec<usize>)> {
159        let n_train = train_data.nrows();
160        let n_test = test_data.nrows();
161        let n_features = train_data.ncols();
162
163        // Combine data
164        let mut combined_data = Array2::zeros((n_train + n_test, n_features));
165
166        // Copy training data
167        for i in 0..n_train {
168            for j in 0..n_features {
169                combined_data[[i, j]] = train_data[[i, j]];
170            }
171        }
172
173        // Copy test data
174        for i in 0..n_test {
175            for j in 0..n_features {
176                combined_data[[n_train + i, j]] = test_data[[i, j]];
177            }
178        }
179
180        // Create labels (0 = train, 1 = test)
181        let mut labels = Vec::with_capacity(n_train + n_test);
182        labels.extend(vec![0; n_train]);
183        labels.extend(vec![1; n_test]);
184
185        Ok((combined_data, labels))
186    }
187
188    /// Cross-validate discriminator performance
189    fn cross_validate_discriminator(
190        &self,
191        data: &Array2<f64>,
192        labels: &[usize],
193    ) -> Result<Vec<f64>> {
194        let n_samples = data.nrows();
195        let fold_size = n_samples / self.config.cv_folds;
196        let mut cv_scores = Vec::new();
197
198        for fold in 0..self.config.cv_folds {
199            let test_start = fold * fold_size;
200            let test_end = if fold == self.config.cv_folds - 1 {
201                n_samples
202            } else {
203                (fold + 1) * fold_size
204            };
205
206            // Create train/test splits for this fold
207            let mut train_indices = Vec::new();
208            let mut test_indices = Vec::new();
209
210            for i in 0..n_samples {
211                if i >= test_start && i < test_end {
212                    test_indices.push(i);
213                } else {
214                    train_indices.push(i);
215                }
216            }
217
218            // Extract fold data
219            let train_fold_data = self.extract_rows(data, &train_indices);
220            let test_fold_data = self.extract_rows(data, &test_indices);
221            let train_fold_labels: Vec<usize> = train_indices.iter().map(|&i| labels[i]).collect();
222            let test_fold_labels: Vec<usize> = test_indices.iter().map(|&i| labels[i]).collect();
223
224            // Train discriminator on this fold
225            let fold_auc = self.train_simple_discriminator(
226                &train_fold_data,
227                &train_fold_labels,
228                &test_fold_data,
229                &test_fold_labels,
230            )?;
231            cv_scores.push(fold_auc);
232        }
233
234        Ok(cv_scores)
235    }
236
237    /// Train a discriminator and return AUC score
238    fn train_discriminator(&self, data: &Array2<f64>, labels: &[usize]) -> Result<f64> {
239        // Split data for training discriminator
240        let n_samples = data.nrows();
241        let test_size = (n_samples as f64 * self.config.test_size) as usize;
242
243        let mut indices: Vec<usize> = (0..n_samples).collect();
244        self.shuffle_indices(&mut indices);
245
246        let train_indices = &indices[test_size..];
247        let test_indices = &indices[..test_size];
248
249        let train_data = self.extract_rows(data, train_indices);
250        let test_data = self.extract_rows(data, test_indices);
251        let train_labels: Vec<usize> = train_indices.iter().map(|&i| labels[i]).collect();
252        let test_labels: Vec<usize> = test_indices.iter().map(|&i| labels[i]).collect();
253
254        self.train_simple_discriminator(&train_data, &train_labels, &test_data, &test_labels)
255    }
256
257    /// Train a simple logistic regression discriminator
258    fn train_simple_discriminator(
259        &self,
260        train_data: &Array2<f64>,
261        train_labels: &[usize],
262        test_data: &Array2<f64>,
263        test_labels: &[usize],
264    ) -> Result<f64> {
265        let n_features = train_data.ncols();
266        let mut weights = vec![0.0; n_features + 1]; // +1 for bias
267        let learning_rate = 0.01;
268
269        // Convert labels to -1/1 for logistic regression
270        let train_y: Vec<f64> = train_labels
271            .iter()
272            .map(|&label| if label == 1 { 1.0 } else { -1.0 })
273            .collect();
274
275        // Gradient descent training
276        for _iteration in 0..self.config.max_iterations {
277            let mut gradients = vec![0.0; n_features + 1];
278
279            for (i, &y) in train_y.iter().enumerate() {
280                // Compute prediction
281                let mut prediction = weights[0]; // bias
282                for j in 0..n_features {
283                    prediction += weights[j + 1] * train_data[[i, j]];
284                }
285
286                // Sigmoid activation
287                let prob = 1.0 / (1.0 + (-prediction).exp());
288                let error = prob - (y + 1.0) / 2.0; // Convert back to 0/1
289
290                // Update gradients
291                gradients[0] += error; // bias gradient
292                for j in 0..n_features {
293                    gradients[j + 1] += error * train_data[[i, j]];
294                }
295            }
296
297            // Update weights
298            for j in 0..weights.len() {
299                weights[j] -= learning_rate * gradients[j] / train_y.len() as f64;
300            }
301        }
302
303        // Evaluate on test set and calculate AUC
304        self.calculate_auc(&weights, test_data, test_labels)
305    }
306
307    /// Calculate AUC score
308    fn calculate_auc(
309        &self,
310        weights: &[f64],
311        test_data: &Array2<f64>,
312        test_labels: &[usize],
313    ) -> Result<f64> {
314        let n_features = test_data.ncols();
315        let mut predictions = Vec::new();
316
317        for i in 0..test_data.nrows() {
318            let mut prediction = weights[0]; // bias
319            for j in 0..n_features {
320                prediction += weights[j + 1] * test_data[[i, j]];
321            }
322            let prob = 1.0 / (1.0 + (-prediction).exp());
323            predictions.push(prob);
324        }
325
326        // Calculate AUC using trapezoidal rule
327        let mut positive_scores = Vec::new();
328        let mut negative_scores = Vec::new();
329
330        for (i, &score) in predictions.iter().enumerate() {
331            if test_labels[i] == 1 {
332                positive_scores.push(score);
333            } else {
334                negative_scores.push(score);
335            }
336        }
337
338        if positive_scores.is_empty() || negative_scores.is_empty() {
339            return Ok(0.5); // No discrimination possible
340        }
341
342        // Count concordant pairs
343        let mut concordant = 0;
344        let mut total = 0;
345
346        for &pos_score in &positive_scores {
347            for &neg_score in &negative_scores {
348                total += 1;
349                if pos_score > neg_score {
350                    concordant += 1;
351                }
352            }
353        }
354
355        Ok(concordant as f64 / total as f64)
356    }
357
358    /// Bootstrap confidence intervals for AUC
359    fn bootstrap_confidence_interval(
360        &self,
361        data: &Array2<f64>,
362        labels: &[usize],
363    ) -> Result<(f64, f64)> {
364        let mut bootstrap_aucs = Vec::new();
365        let n_samples = data.nrows();
366
367        for _ in 0..self.config.n_bootstrap {
368            // Bootstrap sample
369            let mut boot_indices = Vec::new();
370            for _ in 0..n_samples {
371                boot_indices.push(self.random_index(n_samples));
372            }
373
374            let boot_data = self.extract_rows(data, &boot_indices);
375            let boot_labels: Vec<usize> = boot_indices.iter().map(|&i| labels[i]).collect();
376
377            // Train discriminator on bootstrap sample
378            if let Ok(auc) = self.train_discriminator(&boot_data, &boot_labels) {
379                bootstrap_aucs.push(auc);
380            }
381        }
382
383        if bootstrap_aucs.is_empty() {
384            return Ok((0.5, 0.5));
385        }
386
387        bootstrap_aucs.sort_by(|a, b| a.partial_cmp(b).expect("operation should succeed"));
388
389        let lower_idx = (self.config.n_bootstrap as f64 * 0.025) as usize;
390        let upper_idx = (self.config.n_bootstrap as f64 * 0.975) as usize;
391
392        let lower_bound = bootstrap_aucs[lower_idx.min(bootstrap_aucs.len() - 1)];
393        let upper_bound = bootstrap_aucs[upper_idx.min(bootstrap_aucs.len() - 1)];
394
395        Ok((lower_bound, upper_bound))
396    }
397
398    /// Analyze feature importance for discrimination
399    fn analyze_feature_importance(
400        &self,
401        data: &Array2<f64>,
402        labels: &[usize],
403    ) -> Result<(Option<Vec<f64>>, Vec<usize>)> {
404        let n_features = data.ncols();
405        let mut feature_importance = vec![0.0; n_features];
406
407        // Calculate baseline AUC
408        let baseline_auc = self.train_discriminator(data, labels)?;
409
410        // Permute each feature and measure AUC drop
411        for feature_idx in 0..n_features {
412            let mut permuted_data = data.clone();
413
414            // Permute this feature
415            let mut feature_values: Vec<f64> =
416                (0..data.nrows()).map(|i| data[[i, feature_idx]]).collect();
417            self.shuffle_f64(&mut feature_values);
418
419            for (i, &value) in feature_values.iter().enumerate() {
420                permuted_data[[i, feature_idx]] = value;
421            }
422
423            // Calculate AUC with permuted feature
424            let permuted_auc = self.train_discriminator(&permuted_data, labels)?;
425
426            // Feature importance is the drop in AUC
427            feature_importance[feature_idx] = baseline_auc - permuted_auc;
428        }
429
430        // Identify suspicious features (those that help discrimination)
431        let mut suspicious_features = Vec::new();
432        let importance_threshold = 0.01; // 1% AUC drop threshold
433
434        for (i, &importance) in feature_importance.iter().enumerate() {
435            if importance > importance_threshold {
436                suspicious_features.push(i);
437            }
438        }
439
440        Ok((Some(feature_importance), suspicious_features))
441    }
442
443    /// Calculate p-value for statistical significance
444    fn calculate_p_value(&self, cv_scores: &[f64]) -> f64 {
445        // One-sample t-test against null hypothesis (AUC = 0.5)
446        let mean_auc = cv_scores.iter().sum::<f64>() / cv_scores.len() as f64;
447        let std_auc = self.calculate_std(cv_scores);
448        let n = cv_scores.len() as f64;
449
450        if std_auc == 0.0 {
451            return if mean_auc > 0.5 { 0.0 } else { 1.0 };
452        }
453
454        let t_stat = (mean_auc - 0.5) * n.sqrt() / std_auc;
455
456        // Approximate p-value using normal distribution (for large samples)
457        let p_value = 2.0 * (1.0 - self.normal_cdf(t_stat.abs()));
458        p_value.clamp(0.0, 1.0)
459    }
460
461    /// Calculate standard deviation
462    fn calculate_std(&self, values: &[f64]) -> f64 {
463        if values.len() < 2 {
464            return 0.0;
465        }
466
467        let mean = values.iter().sum::<f64>() / values.len() as f64;
468        let variance =
469            values.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / (values.len() - 1) as f64;
470
471        variance.sqrt()
472    }
473
474    /// Approximate normal CDF
475    fn normal_cdf(&self, x: f64) -> f64 {
476        0.5 * (1.0 + self.erf(x / 2.0_f64.sqrt()))
477    }
478
479    /// Approximate error function
480    fn erf(&self, x: f64) -> f64 {
481        // Abramowitz and Stegun approximation
482        let a1 = 0.254829592;
483        let a2 = -0.284496736;
484        let a3 = 1.421413741;
485        let a4 = -1.453152027;
486        let a5 = 1.061405429;
487        let p = 0.3275911;
488
489        let sign = if x < 0.0 { -1.0 } else { 1.0 };
490        let x = x.abs();
491
492        let t = 1.0 / (1.0 + p * x);
493        let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp();
494
495        sign * y
496    }
497
498    /// Extract rows from array by indices
499    fn extract_rows(&self, data: &Array2<f64>, indices: &[usize]) -> Array2<f64> {
500        let n_rows = indices.len();
501        let n_cols = data.ncols();
502        let mut result = Array2::zeros((n_rows, n_cols));
503
504        for (i, &idx) in indices.iter().enumerate() {
505            for j in 0..n_cols {
506                result[[i, j]] = data[[idx, j]];
507            }
508        }
509
510        result
511    }
512
513    /// Shuffle indices randomly
514    fn shuffle_indices(&self, indices: &mut [usize]) {
515        use scirs2_core::random::rngs::StdRng;
516        use scirs2_core::random::SeedableRng;
517        let mut rng = match self.config.random_state {
518            Some(seed) => StdRng::seed_from_u64(seed),
519            None => {
520                use scirs2_core::random::thread_rng;
521                StdRng::from_rng(&mut thread_rng())
522            }
523        };
524        indices.shuffle(&mut rng);
525    }
526
527    /// Shuffle f64 values randomly
528    fn shuffle_f64(&self, values: &mut [f64]) {
529        use scirs2_core::random::rngs::StdRng;
530        use scirs2_core::random::SeedableRng;
531        let mut rng = match self.config.random_state {
532            Some(seed) => StdRng::seed_from_u64(seed),
533            None => {
534                use scirs2_core::random::thread_rng;
535                StdRng::from_rng(&mut thread_rng())
536            }
537        };
538        values.shuffle(&mut rng);
539    }
540
541    /// Generate random index
542    fn random_index(&self, max: usize) -> usize {
543        use scirs2_core::random::rngs::StdRng;
544        use scirs2_core::random::SeedableRng;
545        let mut rng = match self.config.random_state {
546            Some(seed) => StdRng::seed_from_u64(seed),
547            None => {
548                use scirs2_core::random::thread_rng;
549                StdRng::from_rng(&mut thread_rng())
550            }
551        };
552        rng.random_range(0..max)
553    }
554}
555
556#[allow(non_snake_case)]
557#[cfg(test)]
558mod tests {
559    use super::*;
560
561    #[test]
562    fn test_adversarial_validation_same_distribution() {
563        let config = AdversarialValidationConfig::default();
564        let validator = AdversarialValidator::new(config);
565
566        // Create identical distributions
567        let train_data = Array2::zeros((100, 5));
568        let test_data = Array2::zeros((50, 5));
569
570        let result = validator
571            .validate(&train_data, &test_data)
572            .expect("operation should succeed");
573
574        // AUC should be close to 0.5 for identical distributions
575        assert!(
576            result.discriminator_auc < 0.6,
577            "AUC should be close to 0.5 for identical distributions"
578        );
579        assert!(
580            !result.is_significantly_different,
581            "Identical distributions should not be significantly different"
582        );
583    }
584
585    #[test]
586    fn test_adversarial_validation_different_distributions() {
587        let config = AdversarialValidationConfig {
588            significance_threshold: 0.6,
589            ..Default::default()
590        };
591        let validator = AdversarialValidator::new(config);
592
593        // Create different distributions
594        let mut train_data = Array2::zeros((100, 5));
595        let mut test_data = Array2::ones((50, 5));
596
597        // Make them clearly different
598        for i in 0..train_data.nrows() {
599            for j in 0..train_data.ncols() {
600                train_data[[i, j]] = 0.0;
601            }
602        }
603
604        for i in 0..test_data.nrows() {
605            for j in 0..test_data.ncols() {
606                test_data[[i, j]] = 1.0;
607            }
608        }
609
610        let result = validator
611            .validate(&train_data, &test_data)
612            .expect("operation should succeed");
613
614        // AUC should be high for clearly different distributions
615        assert!(
616            result.discriminator_auc > 0.7,
617            "AUC should be high for different distributions"
618        );
619    }
620}