sklears_model_selection/
adversarial_validation.rs

1//! Adversarial validation for robustness testing and data leakage detection
2//!
3//! This module provides methods to detect potential issues in training/test splits
4//! by training discriminators to distinguish between training and test data.
5
6use scirs2_core::ndarray::Array2;
7use scirs2_core::SliceRandomExt;
8use sklears_core::error::{Result, SklearsError};
9
10/// Configuration for adversarial validation
11#[derive(Debug, Clone)]
12pub struct AdversarialValidationConfig {
13    /// Number of cross-validation folds for discriminator training
14    pub cv_folds: usize,
15    /// Test size for discriminator evaluation
16    pub test_size: f64,
17    /// Threshold for considering distributions significantly different
18    pub significance_threshold: f64,
19    /// Number of bootstrap samples for confidence intervals
20    pub n_bootstrap: usize,
21    /// Random state for reproducible results
22    pub random_state: Option<u64>,
23    /// Whether to perform feature importance analysis
24    pub analyze_features: bool,
25    /// Maximum number of discriminator iterations
26    pub max_iterations: usize,
27}
28
29impl Default for AdversarialValidationConfig {
30    fn default() -> Self {
31        Self {
32            cv_folds: 5,
33            test_size: 0.2,
34            significance_threshold: 0.6, // AUC threshold for detecting issues
35            n_bootstrap: 1000,
36            random_state: None,
37            analyze_features: true,
38            max_iterations: 100,
39        }
40    }
41}
42
43/// Results from adversarial validation
44#[derive(Debug, Clone)]
45pub struct AdversarialValidationResult {
46    /// AUC score of the discriminator (0.5 = no difference, 1.0 = perfect discrimination)
47    pub discriminator_auc: f64,
48    /// Confidence interval for the AUC score
49    pub auc_confidence_interval: (f64, f64),
50    /// P-value for statistical significance test
51    pub p_value: f64,
52    /// Whether the distributions are significantly different
53    pub is_significantly_different: bool,
54    /// Feature importance scores (if enabled)
55    pub feature_importance: Option<Vec<f64>>,
56    /// Suspicious feature indices (features that help discriminate)
57    pub suspicious_features: Vec<usize>,
58    /// Cross-validation scores for the discriminator
59    pub cv_scores: Vec<f64>,
60    /// Detailed statistics about the validation
61    pub statistics: AdversarialStatistics,
62}
63
64/// Detailed statistics from adversarial validation
65#[derive(Debug, Clone)]
66pub struct AdversarialStatistics {
67    /// Number of training samples used
68    pub n_train_samples: usize,
69    /// Number of test samples used
70    pub n_test_samples: usize,
71    /// Number of features analyzed
72    pub n_features: usize,
73    /// Mean AUC across cross-validation folds
74    pub mean_cv_auc: f64,
75    /// Standard deviation of AUC across folds
76    pub std_cv_auc: f64,
77    /// Best single fold AUC
78    pub best_cv_auc: f64,
79    /// Worst single fold AUC
80    pub worst_cv_auc: f64,
81}
82
83/// Adversarial validator for detecting data leakage and distribution shifts
84#[derive(Debug, Clone)]
85pub struct AdversarialValidator {
86    config: AdversarialValidationConfig,
87}
88
89impl AdversarialValidator {
90    pub fn new(config: AdversarialValidationConfig) -> Self {
91        Self { config }
92    }
93
94    /// Perform adversarial validation on training and test sets
95    pub fn validate(
96        &self,
97        train_data: &Array2<f64>,
98        test_data: &Array2<f64>,
99    ) -> Result<AdversarialValidationResult> {
100        if train_data.ncols() != test_data.ncols() {
101            return Err(SklearsError::InvalidInput(
102                "Training and test data must have the same number of features".to_string(),
103            ));
104        }
105
106        // Create combined dataset with labels (0 = train, 1 = test)
107        let (combined_data, labels) = self.prepare_adversarial_data(train_data, test_data)?;
108
109        // Train discriminator using cross-validation
110        let cv_scores = self.cross_validate_discriminator(&combined_data, &labels)?;
111
112        // Calculate main discriminator performance
113        let discriminator_auc = self.train_discriminator(&combined_data, &labels)?;
114
115        // Bootstrap confidence intervals
116        let auc_ci = self.bootstrap_confidence_interval(&combined_data, &labels)?;
117
118        // Feature importance analysis
119        let (feature_importance, suspicious_features) = if self.config.analyze_features {
120            self.analyze_feature_importance(&combined_data, &labels)?
121        } else {
122            (None, Vec::new())
123        };
124
125        // Statistical significance test
126        let p_value = self.calculate_p_value(&cv_scores);
127        let is_significantly_different = discriminator_auc > self.config.significance_threshold;
128
129        // Calculate statistics
130        let statistics = AdversarialStatistics {
131            n_train_samples: train_data.nrows(),
132            n_test_samples: test_data.nrows(),
133            n_features: train_data.ncols(),
134            mean_cv_auc: cv_scores.iter().sum::<f64>() / cv_scores.len() as f64,
135            std_cv_auc: self.calculate_std(&cv_scores),
136            best_cv_auc: cv_scores.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b)),
137            worst_cv_auc: cv_scores.iter().fold(f64::INFINITY, |a, &b| a.min(b)),
138        };
139
140        Ok(AdversarialValidationResult {
141            discriminator_auc,
142            auc_confidence_interval: auc_ci,
143            p_value,
144            is_significantly_different,
145            feature_importance,
146            suspicious_features,
147            cv_scores,
148            statistics,
149        })
150    }
151
152    /// Prepare adversarial dataset by combining train and test data
153    fn prepare_adversarial_data(
154        &self,
155        train_data: &Array2<f64>,
156        test_data: &Array2<f64>,
157    ) -> Result<(Array2<f64>, Vec<usize>)> {
158        let n_train = train_data.nrows();
159        let n_test = test_data.nrows();
160        let n_features = train_data.ncols();
161
162        // Combine data
163        let mut combined_data = Array2::zeros((n_train + n_test, n_features));
164
165        // Copy training data
166        for i in 0..n_train {
167            for j in 0..n_features {
168                combined_data[[i, j]] = train_data[[i, j]];
169            }
170        }
171
172        // Copy test data
173        for i in 0..n_test {
174            for j in 0..n_features {
175                combined_data[[n_train + i, j]] = test_data[[i, j]];
176            }
177        }
178
179        // Create labels (0 = train, 1 = test)
180        let mut labels = Vec::with_capacity(n_train + n_test);
181        labels.extend(vec![0; n_train]);
182        labels.extend(vec![1; n_test]);
183
184        Ok((combined_data, labels))
185    }
186
187    /// Cross-validate discriminator performance
188    fn cross_validate_discriminator(
189        &self,
190        data: &Array2<f64>,
191        labels: &[usize],
192    ) -> Result<Vec<f64>> {
193        let n_samples = data.nrows();
194        let fold_size = n_samples / self.config.cv_folds;
195        let mut cv_scores = Vec::new();
196
197        for fold in 0..self.config.cv_folds {
198            let test_start = fold * fold_size;
199            let test_end = if fold == self.config.cv_folds - 1 {
200                n_samples
201            } else {
202                (fold + 1) * fold_size
203            };
204
205            // Create train/test splits for this fold
206            let mut train_indices = Vec::new();
207            let mut test_indices = Vec::new();
208
209            for i in 0..n_samples {
210                if i >= test_start && i < test_end {
211                    test_indices.push(i);
212                } else {
213                    train_indices.push(i);
214                }
215            }
216
217            // Extract fold data
218            let train_fold_data = self.extract_rows(data, &train_indices);
219            let test_fold_data = self.extract_rows(data, &test_indices);
220            let train_fold_labels: Vec<usize> = train_indices.iter().map(|&i| labels[i]).collect();
221            let test_fold_labels: Vec<usize> = test_indices.iter().map(|&i| labels[i]).collect();
222
223            // Train discriminator on this fold
224            let fold_auc = self.train_simple_discriminator(
225                &train_fold_data,
226                &train_fold_labels,
227                &test_fold_data,
228                &test_fold_labels,
229            )?;
230            cv_scores.push(fold_auc);
231        }
232
233        Ok(cv_scores)
234    }
235
236    /// Train a discriminator and return AUC score
237    fn train_discriminator(&self, data: &Array2<f64>, labels: &[usize]) -> Result<f64> {
238        // Split data for training discriminator
239        let n_samples = data.nrows();
240        let test_size = (n_samples as f64 * self.config.test_size) as usize;
241
242        let mut indices: Vec<usize> = (0..n_samples).collect();
243        self.shuffle_indices(&mut indices);
244
245        let train_indices = &indices[test_size..];
246        let test_indices = &indices[..test_size];
247
248        let train_data = self.extract_rows(data, train_indices);
249        let test_data = self.extract_rows(data, test_indices);
250        let train_labels: Vec<usize> = train_indices.iter().map(|&i| labels[i]).collect();
251        let test_labels: Vec<usize> = test_indices.iter().map(|&i| labels[i]).collect();
252
253        self.train_simple_discriminator(&train_data, &train_labels, &test_data, &test_labels)
254    }
255
256    /// Train a simple logistic regression discriminator
257    fn train_simple_discriminator(
258        &self,
259        train_data: &Array2<f64>,
260        train_labels: &[usize],
261        test_data: &Array2<f64>,
262        test_labels: &[usize],
263    ) -> Result<f64> {
264        let n_features = train_data.ncols();
265        let mut weights = vec![0.0; n_features + 1]; // +1 for bias
266        let learning_rate = 0.01;
267
268        // Convert labels to -1/1 for logistic regression
269        let train_y: Vec<f64> = train_labels
270            .iter()
271            .map(|&label| if label == 1 { 1.0 } else { -1.0 })
272            .collect();
273
274        // Gradient descent training
275        for _iteration in 0..self.config.max_iterations {
276            let mut gradients = vec![0.0; n_features + 1];
277
278            for (i, &y) in train_y.iter().enumerate() {
279                // Compute prediction
280                let mut prediction = weights[0]; // bias
281                for j in 0..n_features {
282                    prediction += weights[j + 1] * train_data[[i, j]];
283                }
284
285                // Sigmoid activation
286                let prob = 1.0 / (1.0 + (-prediction).exp());
287                let error = prob - (y + 1.0) / 2.0; // Convert back to 0/1
288
289                // Update gradients
290                gradients[0] += error; // bias gradient
291                for j in 0..n_features {
292                    gradients[j + 1] += error * train_data[[i, j]];
293                }
294            }
295
296            // Update weights
297            for j in 0..weights.len() {
298                weights[j] -= learning_rate * gradients[j] / train_y.len() as f64;
299            }
300        }
301
302        // Evaluate on test set and calculate AUC
303        self.calculate_auc(&weights, test_data, test_labels)
304    }
305
306    /// Calculate AUC score
307    fn calculate_auc(
308        &self,
309        weights: &[f64],
310        test_data: &Array2<f64>,
311        test_labels: &[usize],
312    ) -> Result<f64> {
313        let n_features = test_data.ncols();
314        let mut predictions = Vec::new();
315
316        for i in 0..test_data.nrows() {
317            let mut prediction = weights[0]; // bias
318            for j in 0..n_features {
319                prediction += weights[j + 1] * test_data[[i, j]];
320            }
321            let prob = 1.0 / (1.0 + (-prediction).exp());
322            predictions.push(prob);
323        }
324
325        // Calculate AUC using trapezoidal rule
326        let mut positive_scores = Vec::new();
327        let mut negative_scores = Vec::new();
328
329        for (i, &score) in predictions.iter().enumerate() {
330            if test_labels[i] == 1 {
331                positive_scores.push(score);
332            } else {
333                negative_scores.push(score);
334            }
335        }
336
337        if positive_scores.is_empty() || negative_scores.is_empty() {
338            return Ok(0.5); // No discrimination possible
339        }
340
341        // Count concordant pairs
342        let mut concordant = 0;
343        let mut total = 0;
344
345        for &pos_score in &positive_scores {
346            for &neg_score in &negative_scores {
347                total += 1;
348                if pos_score > neg_score {
349                    concordant += 1;
350                }
351            }
352        }
353
354        Ok(concordant as f64 / total as f64)
355    }
356
357    /// Bootstrap confidence intervals for AUC
358    fn bootstrap_confidence_interval(
359        &self,
360        data: &Array2<f64>,
361        labels: &[usize],
362    ) -> Result<(f64, f64)> {
363        let mut bootstrap_aucs = Vec::new();
364        let n_samples = data.nrows();
365
366        for _ in 0..self.config.n_bootstrap {
367            // Bootstrap sample
368            let mut boot_indices = Vec::new();
369            for _ in 0..n_samples {
370                boot_indices.push(self.random_index(n_samples));
371            }
372
373            let boot_data = self.extract_rows(data, &boot_indices);
374            let boot_labels: Vec<usize> = boot_indices.iter().map(|&i| labels[i]).collect();
375
376            // Train discriminator on bootstrap sample
377            if let Ok(auc) = self.train_discriminator(&boot_data, &boot_labels) {
378                bootstrap_aucs.push(auc);
379            }
380        }
381
382        if bootstrap_aucs.is_empty() {
383            return Ok((0.5, 0.5));
384        }
385
386        bootstrap_aucs.sort_by(|a, b| a.partial_cmp(b).unwrap());
387
388        let lower_idx = (self.config.n_bootstrap as f64 * 0.025) as usize;
389        let upper_idx = (self.config.n_bootstrap as f64 * 0.975) as usize;
390
391        let lower_bound = bootstrap_aucs[lower_idx.min(bootstrap_aucs.len() - 1)];
392        let upper_bound = bootstrap_aucs[upper_idx.min(bootstrap_aucs.len() - 1)];
393
394        Ok((lower_bound, upper_bound))
395    }
396
397    /// Analyze feature importance for discrimination
398    fn analyze_feature_importance(
399        &self,
400        data: &Array2<f64>,
401        labels: &[usize],
402    ) -> Result<(Option<Vec<f64>>, Vec<usize>)> {
403        let n_features = data.ncols();
404        let mut feature_importance = vec![0.0; n_features];
405
406        // Calculate baseline AUC
407        let baseline_auc = self.train_discriminator(data, labels)?;
408
409        // Permute each feature and measure AUC drop
410        for feature_idx in 0..n_features {
411            let mut permuted_data = data.clone();
412
413            // Permute this feature
414            let mut feature_values: Vec<f64> =
415                (0..data.nrows()).map(|i| data[[i, feature_idx]]).collect();
416            self.shuffle_f64(&mut feature_values);
417
418            for (i, &value) in feature_values.iter().enumerate() {
419                permuted_data[[i, feature_idx]] = value;
420            }
421
422            // Calculate AUC with permuted feature
423            let permuted_auc = self.train_discriminator(&permuted_data, labels)?;
424
425            // Feature importance is the drop in AUC
426            feature_importance[feature_idx] = baseline_auc - permuted_auc;
427        }
428
429        // Identify suspicious features (those that help discrimination)
430        let mut suspicious_features = Vec::new();
431        let importance_threshold = 0.01; // 1% AUC drop threshold
432
433        for (i, &importance) in feature_importance.iter().enumerate() {
434            if importance > importance_threshold {
435                suspicious_features.push(i);
436            }
437        }
438
439        Ok((Some(feature_importance), suspicious_features))
440    }
441
442    /// Calculate p-value for statistical significance
443    fn calculate_p_value(&self, cv_scores: &[f64]) -> f64 {
444        // One-sample t-test against null hypothesis (AUC = 0.5)
445        let mean_auc = cv_scores.iter().sum::<f64>() / cv_scores.len() as f64;
446        let std_auc = self.calculate_std(cv_scores);
447        let n = cv_scores.len() as f64;
448
449        if std_auc == 0.0 {
450            return if mean_auc > 0.5 { 0.0 } else { 1.0 };
451        }
452
453        let t_stat = (mean_auc - 0.5) * n.sqrt() / std_auc;
454
455        // Approximate p-value using normal distribution (for large samples)
456        let p_value = 2.0 * (1.0 - self.normal_cdf(t_stat.abs()));
457        p_value.clamp(0.0, 1.0)
458    }
459
460    /// Calculate standard deviation
461    fn calculate_std(&self, values: &[f64]) -> f64 {
462        if values.len() < 2 {
463            return 0.0;
464        }
465
466        let mean = values.iter().sum::<f64>() / values.len() as f64;
467        let variance =
468            values.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / (values.len() - 1) as f64;
469
470        variance.sqrt()
471    }
472
473    /// Approximate normal CDF
474    fn normal_cdf(&self, x: f64) -> f64 {
475        0.5 * (1.0 + self.erf(x / 2.0_f64.sqrt()))
476    }
477
478    /// Approximate error function
479    fn erf(&self, x: f64) -> f64 {
480        // Abramowitz and Stegun approximation
481        let a1 = 0.254829592;
482        let a2 = -0.284496736;
483        let a3 = 1.421413741;
484        let a4 = -1.453152027;
485        let a5 = 1.061405429;
486        let p = 0.3275911;
487
488        let sign = if x < 0.0 { -1.0 } else { 1.0 };
489        let x = x.abs();
490
491        let t = 1.0 / (1.0 + p * x);
492        let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp();
493
494        sign * y
495    }
496
497    /// Extract rows from array by indices
498    fn extract_rows(&self, data: &Array2<f64>, indices: &[usize]) -> Array2<f64> {
499        let n_rows = indices.len();
500        let n_cols = data.ncols();
501        let mut result = Array2::zeros((n_rows, n_cols));
502
503        for (i, &idx) in indices.iter().enumerate() {
504            for j in 0..n_cols {
505                result[[i, j]] = data[[idx, j]];
506            }
507        }
508
509        result
510    }
511
512    /// Shuffle indices randomly
513    fn shuffle_indices(&self, indices: &mut [usize]) {
514        use scirs2_core::random::rngs::StdRng;
515        use scirs2_core::random::SeedableRng;
516        let mut rng = match self.config.random_state {
517            Some(seed) => StdRng::seed_from_u64(seed),
518            None => {
519                use scirs2_core::random::thread_rng;
520                StdRng::from_rng(&mut thread_rng())
521            }
522        };
523        indices.shuffle(&mut rng);
524    }
525
526    /// Shuffle f64 values randomly
527    fn shuffle_f64(&self, values: &mut [f64]) {
528        use scirs2_core::random::rngs::StdRng;
529        use scirs2_core::random::SeedableRng;
530        let mut rng = match self.config.random_state {
531            Some(seed) => StdRng::seed_from_u64(seed),
532            None => {
533                use scirs2_core::random::thread_rng;
534                StdRng::from_rng(&mut thread_rng())
535            }
536        };
537        values.shuffle(&mut rng);
538    }
539
540    /// Generate random index
541    fn random_index(&self, max: usize) -> usize {
542        use scirs2_core::random::rngs::StdRng;
543        use scirs2_core::random::Rng;
544        use scirs2_core::random::SeedableRng;
545        let mut rng = match self.config.random_state {
546            Some(seed) => StdRng::seed_from_u64(seed),
547            None => {
548                use scirs2_core::random::thread_rng;
549                StdRng::from_rng(&mut thread_rng())
550            }
551        };
552        rng.gen_range(0..max)
553    }
554}
555
556#[allow(non_snake_case)]
557#[cfg(test)]
558mod tests {
559    use super::*;
560
561    #[test]
562    fn test_adversarial_validation_same_distribution() {
563        let config = AdversarialValidationConfig::default();
564        let validator = AdversarialValidator::new(config);
565
566        // Create identical distributions
567        let train_data = Array2::zeros((100, 5));
568        let test_data = Array2::zeros((50, 5));
569
570        let result = validator.validate(&train_data, &test_data).unwrap();
571
572        // AUC should be close to 0.5 for identical distributions
573        assert!(
574            result.discriminator_auc < 0.6,
575            "AUC should be close to 0.5 for identical distributions"
576        );
577        assert!(
578            !result.is_significantly_different,
579            "Identical distributions should not be significantly different"
580        );
581    }
582
583    #[test]
584    fn test_adversarial_validation_different_distributions() {
585        let config = AdversarialValidationConfig {
586            significance_threshold: 0.6,
587            ..Default::default()
588        };
589        let validator = AdversarialValidator::new(config);
590
591        // Create different distributions
592        let mut train_data = Array2::zeros((100, 5));
593        let mut test_data = Array2::ones((50, 5));
594
595        // Make them clearly different
596        for i in 0..train_data.nrows() {
597            for j in 0..train_data.ncols() {
598                train_data[[i, j]] = 0.0;
599            }
600        }
601
602        for i in 0..test_data.nrows() {
603            for j in 0..test_data.ncols() {
604                test_data[[i, j]] = 1.0;
605            }
606        }
607
608        let result = validator.validate(&train_data, &test_data).unwrap();
609
610        // AUC should be high for clearly different distributions
611        assert!(
612            result.discriminator_auc > 0.7,
613            "AUC should be high for different distributions"
614        );
615    }
616}