1use scirs2_core::ndarray::Array2;
7use scirs2_core::SliceRandomExt;
8use sklears_core::error::{Result, SklearsError};
9
10#[derive(Debug, Clone)]
12pub struct AdversarialValidationConfig {
13 pub cv_folds: usize,
15 pub test_size: f64,
17 pub significance_threshold: f64,
19 pub n_bootstrap: usize,
21 pub random_state: Option<u64>,
23 pub analyze_features: bool,
25 pub max_iterations: usize,
27}
28
29impl Default for AdversarialValidationConfig {
30 fn default() -> Self {
31 Self {
32 cv_folds: 5,
33 test_size: 0.2,
34 significance_threshold: 0.6, n_bootstrap: 1000,
36 random_state: None,
37 analyze_features: true,
38 max_iterations: 100,
39 }
40 }
41}
42
43#[derive(Debug, Clone)]
45pub struct AdversarialValidationResult {
46 pub discriminator_auc: f64,
48 pub auc_confidence_interval: (f64, f64),
50 pub p_value: f64,
52 pub is_significantly_different: bool,
54 pub feature_importance: Option<Vec<f64>>,
56 pub suspicious_features: Vec<usize>,
58 pub cv_scores: Vec<f64>,
60 pub statistics: AdversarialStatistics,
62}
63
64#[derive(Debug, Clone)]
66pub struct AdversarialStatistics {
67 pub n_train_samples: usize,
69 pub n_test_samples: usize,
71 pub n_features: usize,
73 pub mean_cv_auc: f64,
75 pub std_cv_auc: f64,
77 pub best_cv_auc: f64,
79 pub worst_cv_auc: f64,
81}
82
83#[derive(Debug, Clone)]
85pub struct AdversarialValidator {
86 config: AdversarialValidationConfig,
87}
88
89impl AdversarialValidator {
90 pub fn new(config: AdversarialValidationConfig) -> Self {
91 Self { config }
92 }
93
94 pub fn validate(
96 &self,
97 train_data: &Array2<f64>,
98 test_data: &Array2<f64>,
99 ) -> Result<AdversarialValidationResult> {
100 if train_data.ncols() != test_data.ncols() {
101 return Err(SklearsError::InvalidInput(
102 "Training and test data must have the same number of features".to_string(),
103 ));
104 }
105
106 let (combined_data, labels) = self.prepare_adversarial_data(train_data, test_data)?;
108
109 let cv_scores = self.cross_validate_discriminator(&combined_data, &labels)?;
111
112 let discriminator_auc = self.train_discriminator(&combined_data, &labels)?;
114
115 let auc_ci = self.bootstrap_confidence_interval(&combined_data, &labels)?;
117
118 let (feature_importance, suspicious_features) = if self.config.analyze_features {
120 self.analyze_feature_importance(&combined_data, &labels)?
121 } else {
122 (None, Vec::new())
123 };
124
125 let p_value = self.calculate_p_value(&cv_scores);
127 let is_significantly_different = discriminator_auc > self.config.significance_threshold;
128
129 let statistics = AdversarialStatistics {
131 n_train_samples: train_data.nrows(),
132 n_test_samples: test_data.nrows(),
133 n_features: train_data.ncols(),
134 mean_cv_auc: cv_scores.iter().sum::<f64>() / cv_scores.len() as f64,
135 std_cv_auc: self.calculate_std(&cv_scores),
136 best_cv_auc: cv_scores.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b)),
137 worst_cv_auc: cv_scores.iter().fold(f64::INFINITY, |a, &b| a.min(b)),
138 };
139
140 Ok(AdversarialValidationResult {
141 discriminator_auc,
142 auc_confidence_interval: auc_ci,
143 p_value,
144 is_significantly_different,
145 feature_importance,
146 suspicious_features,
147 cv_scores,
148 statistics,
149 })
150 }
151
152 fn prepare_adversarial_data(
154 &self,
155 train_data: &Array2<f64>,
156 test_data: &Array2<f64>,
157 ) -> Result<(Array2<f64>, Vec<usize>)> {
158 let n_train = train_data.nrows();
159 let n_test = test_data.nrows();
160 let n_features = train_data.ncols();
161
162 let mut combined_data = Array2::zeros((n_train + n_test, n_features));
164
165 for i in 0..n_train {
167 for j in 0..n_features {
168 combined_data[[i, j]] = train_data[[i, j]];
169 }
170 }
171
172 for i in 0..n_test {
174 for j in 0..n_features {
175 combined_data[[n_train + i, j]] = test_data[[i, j]];
176 }
177 }
178
179 let mut labels = Vec::with_capacity(n_train + n_test);
181 labels.extend(vec![0; n_train]);
182 labels.extend(vec![1; n_test]);
183
184 Ok((combined_data, labels))
185 }
186
187 fn cross_validate_discriminator(
189 &self,
190 data: &Array2<f64>,
191 labels: &[usize],
192 ) -> Result<Vec<f64>> {
193 let n_samples = data.nrows();
194 let fold_size = n_samples / self.config.cv_folds;
195 let mut cv_scores = Vec::new();
196
197 for fold in 0..self.config.cv_folds {
198 let test_start = fold * fold_size;
199 let test_end = if fold == self.config.cv_folds - 1 {
200 n_samples
201 } else {
202 (fold + 1) * fold_size
203 };
204
205 let mut train_indices = Vec::new();
207 let mut test_indices = Vec::new();
208
209 for i in 0..n_samples {
210 if i >= test_start && i < test_end {
211 test_indices.push(i);
212 } else {
213 train_indices.push(i);
214 }
215 }
216
217 let train_fold_data = self.extract_rows(data, &train_indices);
219 let test_fold_data = self.extract_rows(data, &test_indices);
220 let train_fold_labels: Vec<usize> = train_indices.iter().map(|&i| labels[i]).collect();
221 let test_fold_labels: Vec<usize> = test_indices.iter().map(|&i| labels[i]).collect();
222
223 let fold_auc = self.train_simple_discriminator(
225 &train_fold_data,
226 &train_fold_labels,
227 &test_fold_data,
228 &test_fold_labels,
229 )?;
230 cv_scores.push(fold_auc);
231 }
232
233 Ok(cv_scores)
234 }
235
236 fn train_discriminator(&self, data: &Array2<f64>, labels: &[usize]) -> Result<f64> {
238 let n_samples = data.nrows();
240 let test_size = (n_samples as f64 * self.config.test_size) as usize;
241
242 let mut indices: Vec<usize> = (0..n_samples).collect();
243 self.shuffle_indices(&mut indices);
244
245 let train_indices = &indices[test_size..];
246 let test_indices = &indices[..test_size];
247
248 let train_data = self.extract_rows(data, train_indices);
249 let test_data = self.extract_rows(data, test_indices);
250 let train_labels: Vec<usize> = train_indices.iter().map(|&i| labels[i]).collect();
251 let test_labels: Vec<usize> = test_indices.iter().map(|&i| labels[i]).collect();
252
253 self.train_simple_discriminator(&train_data, &train_labels, &test_data, &test_labels)
254 }
255
256 fn train_simple_discriminator(
258 &self,
259 train_data: &Array2<f64>,
260 train_labels: &[usize],
261 test_data: &Array2<f64>,
262 test_labels: &[usize],
263 ) -> Result<f64> {
264 let n_features = train_data.ncols();
265 let mut weights = vec![0.0; n_features + 1]; let learning_rate = 0.01;
267
268 let train_y: Vec<f64> = train_labels
270 .iter()
271 .map(|&label| if label == 1 { 1.0 } else { -1.0 })
272 .collect();
273
274 for _iteration in 0..self.config.max_iterations {
276 let mut gradients = vec![0.0; n_features + 1];
277
278 for (i, &y) in train_y.iter().enumerate() {
279 let mut prediction = weights[0]; for j in 0..n_features {
282 prediction += weights[j + 1] * train_data[[i, j]];
283 }
284
285 let prob = 1.0 / (1.0 + (-prediction).exp());
287 let error = prob - (y + 1.0) / 2.0; gradients[0] += error; for j in 0..n_features {
292 gradients[j + 1] += error * train_data[[i, j]];
293 }
294 }
295
296 for j in 0..weights.len() {
298 weights[j] -= learning_rate * gradients[j] / train_y.len() as f64;
299 }
300 }
301
302 self.calculate_auc(&weights, test_data, test_labels)
304 }
305
306 fn calculate_auc(
308 &self,
309 weights: &[f64],
310 test_data: &Array2<f64>,
311 test_labels: &[usize],
312 ) -> Result<f64> {
313 let n_features = test_data.ncols();
314 let mut predictions = Vec::new();
315
316 for i in 0..test_data.nrows() {
317 let mut prediction = weights[0]; for j in 0..n_features {
319 prediction += weights[j + 1] * test_data[[i, j]];
320 }
321 let prob = 1.0 / (1.0 + (-prediction).exp());
322 predictions.push(prob);
323 }
324
325 let mut positive_scores = Vec::new();
327 let mut negative_scores = Vec::new();
328
329 for (i, &score) in predictions.iter().enumerate() {
330 if test_labels[i] == 1 {
331 positive_scores.push(score);
332 } else {
333 negative_scores.push(score);
334 }
335 }
336
337 if positive_scores.is_empty() || negative_scores.is_empty() {
338 return Ok(0.5); }
340
341 let mut concordant = 0;
343 let mut total = 0;
344
345 for &pos_score in &positive_scores {
346 for &neg_score in &negative_scores {
347 total += 1;
348 if pos_score > neg_score {
349 concordant += 1;
350 }
351 }
352 }
353
354 Ok(concordant as f64 / total as f64)
355 }
356
357 fn bootstrap_confidence_interval(
359 &self,
360 data: &Array2<f64>,
361 labels: &[usize],
362 ) -> Result<(f64, f64)> {
363 let mut bootstrap_aucs = Vec::new();
364 let n_samples = data.nrows();
365
366 for _ in 0..self.config.n_bootstrap {
367 let mut boot_indices = Vec::new();
369 for _ in 0..n_samples {
370 boot_indices.push(self.random_index(n_samples));
371 }
372
373 let boot_data = self.extract_rows(data, &boot_indices);
374 let boot_labels: Vec<usize> = boot_indices.iter().map(|&i| labels[i]).collect();
375
376 if let Ok(auc) = self.train_discriminator(&boot_data, &boot_labels) {
378 bootstrap_aucs.push(auc);
379 }
380 }
381
382 if bootstrap_aucs.is_empty() {
383 return Ok((0.5, 0.5));
384 }
385
386 bootstrap_aucs.sort_by(|a, b| a.partial_cmp(b).unwrap());
387
388 let lower_idx = (self.config.n_bootstrap as f64 * 0.025) as usize;
389 let upper_idx = (self.config.n_bootstrap as f64 * 0.975) as usize;
390
391 let lower_bound = bootstrap_aucs[lower_idx.min(bootstrap_aucs.len() - 1)];
392 let upper_bound = bootstrap_aucs[upper_idx.min(bootstrap_aucs.len() - 1)];
393
394 Ok((lower_bound, upper_bound))
395 }
396
397 fn analyze_feature_importance(
399 &self,
400 data: &Array2<f64>,
401 labels: &[usize],
402 ) -> Result<(Option<Vec<f64>>, Vec<usize>)> {
403 let n_features = data.ncols();
404 let mut feature_importance = vec![0.0; n_features];
405
406 let baseline_auc = self.train_discriminator(data, labels)?;
408
409 for feature_idx in 0..n_features {
411 let mut permuted_data = data.clone();
412
413 let mut feature_values: Vec<f64> =
415 (0..data.nrows()).map(|i| data[[i, feature_idx]]).collect();
416 self.shuffle_f64(&mut feature_values);
417
418 for (i, &value) in feature_values.iter().enumerate() {
419 permuted_data[[i, feature_idx]] = value;
420 }
421
422 let permuted_auc = self.train_discriminator(&permuted_data, labels)?;
424
425 feature_importance[feature_idx] = baseline_auc - permuted_auc;
427 }
428
429 let mut suspicious_features = Vec::new();
431 let importance_threshold = 0.01; for (i, &importance) in feature_importance.iter().enumerate() {
434 if importance > importance_threshold {
435 suspicious_features.push(i);
436 }
437 }
438
439 Ok((Some(feature_importance), suspicious_features))
440 }
441
442 fn calculate_p_value(&self, cv_scores: &[f64]) -> f64 {
444 let mean_auc = cv_scores.iter().sum::<f64>() / cv_scores.len() as f64;
446 let std_auc = self.calculate_std(cv_scores);
447 let n = cv_scores.len() as f64;
448
449 if std_auc == 0.0 {
450 return if mean_auc > 0.5 { 0.0 } else { 1.0 };
451 }
452
453 let t_stat = (mean_auc - 0.5) * n.sqrt() / std_auc;
454
455 let p_value = 2.0 * (1.0 - self.normal_cdf(t_stat.abs()));
457 p_value.clamp(0.0, 1.0)
458 }
459
460 fn calculate_std(&self, values: &[f64]) -> f64 {
462 if values.len() < 2 {
463 return 0.0;
464 }
465
466 let mean = values.iter().sum::<f64>() / values.len() as f64;
467 let variance =
468 values.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / (values.len() - 1) as f64;
469
470 variance.sqrt()
471 }
472
473 fn normal_cdf(&self, x: f64) -> f64 {
475 0.5 * (1.0 + self.erf(x / 2.0_f64.sqrt()))
476 }
477
478 fn erf(&self, x: f64) -> f64 {
480 let a1 = 0.254829592;
482 let a2 = -0.284496736;
483 let a3 = 1.421413741;
484 let a4 = -1.453152027;
485 let a5 = 1.061405429;
486 let p = 0.3275911;
487
488 let sign = if x < 0.0 { -1.0 } else { 1.0 };
489 let x = x.abs();
490
491 let t = 1.0 / (1.0 + p * x);
492 let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp();
493
494 sign * y
495 }
496
497 fn extract_rows(&self, data: &Array2<f64>, indices: &[usize]) -> Array2<f64> {
499 let n_rows = indices.len();
500 let n_cols = data.ncols();
501 let mut result = Array2::zeros((n_rows, n_cols));
502
503 for (i, &idx) in indices.iter().enumerate() {
504 for j in 0..n_cols {
505 result[[i, j]] = data[[idx, j]];
506 }
507 }
508
509 result
510 }
511
512 fn shuffle_indices(&self, indices: &mut [usize]) {
514 use scirs2_core::random::rngs::StdRng;
515 use scirs2_core::random::SeedableRng;
516 let mut rng = match self.config.random_state {
517 Some(seed) => StdRng::seed_from_u64(seed),
518 None => {
519 use scirs2_core::random::thread_rng;
520 StdRng::from_rng(&mut thread_rng())
521 }
522 };
523 indices.shuffle(&mut rng);
524 }
525
526 fn shuffle_f64(&self, values: &mut [f64]) {
528 use scirs2_core::random::rngs::StdRng;
529 use scirs2_core::random::SeedableRng;
530 let mut rng = match self.config.random_state {
531 Some(seed) => StdRng::seed_from_u64(seed),
532 None => {
533 use scirs2_core::random::thread_rng;
534 StdRng::from_rng(&mut thread_rng())
535 }
536 };
537 values.shuffle(&mut rng);
538 }
539
540 fn random_index(&self, max: usize) -> usize {
542 use scirs2_core::random::rngs::StdRng;
543 use scirs2_core::random::Rng;
544 use scirs2_core::random::SeedableRng;
545 let mut rng = match self.config.random_state {
546 Some(seed) => StdRng::seed_from_u64(seed),
547 None => {
548 use scirs2_core::random::thread_rng;
549 StdRng::from_rng(&mut thread_rng())
550 }
551 };
552 rng.gen_range(0..max)
553 }
554}
555
556#[allow(non_snake_case)]
557#[cfg(test)]
558mod tests {
559 use super::*;
560
561 #[test]
562 fn test_adversarial_validation_same_distribution() {
563 let config = AdversarialValidationConfig::default();
564 let validator = AdversarialValidator::new(config);
565
566 let train_data = Array2::zeros((100, 5));
568 let test_data = Array2::zeros((50, 5));
569
570 let result = validator.validate(&train_data, &test_data).unwrap();
571
572 assert!(
574 result.discriminator_auc < 0.6,
575 "AUC should be close to 0.5 for identical distributions"
576 );
577 assert!(
578 !result.is_significantly_different,
579 "Identical distributions should not be significantly different"
580 );
581 }
582
583 #[test]
584 fn test_adversarial_validation_different_distributions() {
585 let config = AdversarialValidationConfig {
586 significance_threshold: 0.6,
587 ..Default::default()
588 };
589 let validator = AdversarialValidator::new(config);
590
591 let mut train_data = Array2::zeros((100, 5));
593 let mut test_data = Array2::ones((50, 5));
594
595 for i in 0..train_data.nrows() {
597 for j in 0..train_data.ncols() {
598 train_data[[i, j]] = 0.0;
599 }
600 }
601
602 for i in 0..test_data.nrows() {
603 for j in 0..test_data.ncols() {
604 test_data[[i, j]] = 1.0;
605 }
606 }
607
608 let result = validator.validate(&train_data, &test_data).unwrap();
609
610 assert!(
612 result.discriminator_auc > 0.7,
613 "AUC should be high for different distributions"
614 );
615 }
616}