1use scirs2_core::ndarray::Array2;
7use scirs2_core::RngExt;
8use scirs2_core::SliceRandomExt;
9use sklears_core::error::{Result, SklearsError};
10
11#[derive(Debug, Clone)]
13pub struct AdversarialValidationConfig {
14 pub cv_folds: usize,
16 pub test_size: f64,
18 pub significance_threshold: f64,
20 pub n_bootstrap: usize,
22 pub random_state: Option<u64>,
24 pub analyze_features: bool,
26 pub max_iterations: usize,
28}
29
30impl Default for AdversarialValidationConfig {
31 fn default() -> Self {
32 Self {
33 cv_folds: 5,
34 test_size: 0.2,
35 significance_threshold: 0.6, n_bootstrap: 1000,
37 random_state: None,
38 analyze_features: true,
39 max_iterations: 100,
40 }
41 }
42}
43
44#[derive(Debug, Clone)]
46pub struct AdversarialValidationResult {
47 pub discriminator_auc: f64,
49 pub auc_confidence_interval: (f64, f64),
51 pub p_value: f64,
53 pub is_significantly_different: bool,
55 pub feature_importance: Option<Vec<f64>>,
57 pub suspicious_features: Vec<usize>,
59 pub cv_scores: Vec<f64>,
61 pub statistics: AdversarialStatistics,
63}
64
65#[derive(Debug, Clone)]
67pub struct AdversarialStatistics {
68 pub n_train_samples: usize,
70 pub n_test_samples: usize,
72 pub n_features: usize,
74 pub mean_cv_auc: f64,
76 pub std_cv_auc: f64,
78 pub best_cv_auc: f64,
80 pub worst_cv_auc: f64,
82}
83
84#[derive(Debug, Clone)]
86pub struct AdversarialValidator {
87 config: AdversarialValidationConfig,
88}
89
90impl AdversarialValidator {
91 pub fn new(config: AdversarialValidationConfig) -> Self {
92 Self { config }
93 }
94
95 pub fn validate(
97 &self,
98 train_data: &Array2<f64>,
99 test_data: &Array2<f64>,
100 ) -> Result<AdversarialValidationResult> {
101 if train_data.ncols() != test_data.ncols() {
102 return Err(SklearsError::InvalidInput(
103 "Training and test data must have the same number of features".to_string(),
104 ));
105 }
106
107 let (combined_data, labels) = self.prepare_adversarial_data(train_data, test_data)?;
109
110 let cv_scores = self.cross_validate_discriminator(&combined_data, &labels)?;
112
113 let discriminator_auc = self.train_discriminator(&combined_data, &labels)?;
115
116 let auc_ci = self.bootstrap_confidence_interval(&combined_data, &labels)?;
118
119 let (feature_importance, suspicious_features) = if self.config.analyze_features {
121 self.analyze_feature_importance(&combined_data, &labels)?
122 } else {
123 (None, Vec::new())
124 };
125
126 let p_value = self.calculate_p_value(&cv_scores);
128 let is_significantly_different = discriminator_auc > self.config.significance_threshold;
129
130 let statistics = AdversarialStatistics {
132 n_train_samples: train_data.nrows(),
133 n_test_samples: test_data.nrows(),
134 n_features: train_data.ncols(),
135 mean_cv_auc: cv_scores.iter().sum::<f64>() / cv_scores.len() as f64,
136 std_cv_auc: self.calculate_std(&cv_scores),
137 best_cv_auc: cv_scores.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b)),
138 worst_cv_auc: cv_scores.iter().fold(f64::INFINITY, |a, &b| a.min(b)),
139 };
140
141 Ok(AdversarialValidationResult {
142 discriminator_auc,
143 auc_confidence_interval: auc_ci,
144 p_value,
145 is_significantly_different,
146 feature_importance,
147 suspicious_features,
148 cv_scores,
149 statistics,
150 })
151 }
152
153 fn prepare_adversarial_data(
155 &self,
156 train_data: &Array2<f64>,
157 test_data: &Array2<f64>,
158 ) -> Result<(Array2<f64>, Vec<usize>)> {
159 let n_train = train_data.nrows();
160 let n_test = test_data.nrows();
161 let n_features = train_data.ncols();
162
163 let mut combined_data = Array2::zeros((n_train + n_test, n_features));
165
166 for i in 0..n_train {
168 for j in 0..n_features {
169 combined_data[[i, j]] = train_data[[i, j]];
170 }
171 }
172
173 for i in 0..n_test {
175 for j in 0..n_features {
176 combined_data[[n_train + i, j]] = test_data[[i, j]];
177 }
178 }
179
180 let mut labels = Vec::with_capacity(n_train + n_test);
182 labels.extend(vec![0; n_train]);
183 labels.extend(vec![1; n_test]);
184
185 Ok((combined_data, labels))
186 }
187
188 fn cross_validate_discriminator(
190 &self,
191 data: &Array2<f64>,
192 labels: &[usize],
193 ) -> Result<Vec<f64>> {
194 let n_samples = data.nrows();
195 let fold_size = n_samples / self.config.cv_folds;
196 let mut cv_scores = Vec::new();
197
198 for fold in 0..self.config.cv_folds {
199 let test_start = fold * fold_size;
200 let test_end = if fold == self.config.cv_folds - 1 {
201 n_samples
202 } else {
203 (fold + 1) * fold_size
204 };
205
206 let mut train_indices = Vec::new();
208 let mut test_indices = Vec::new();
209
210 for i in 0..n_samples {
211 if i >= test_start && i < test_end {
212 test_indices.push(i);
213 } else {
214 train_indices.push(i);
215 }
216 }
217
218 let train_fold_data = self.extract_rows(data, &train_indices);
220 let test_fold_data = self.extract_rows(data, &test_indices);
221 let train_fold_labels: Vec<usize> = train_indices.iter().map(|&i| labels[i]).collect();
222 let test_fold_labels: Vec<usize> = test_indices.iter().map(|&i| labels[i]).collect();
223
224 let fold_auc = self.train_simple_discriminator(
226 &train_fold_data,
227 &train_fold_labels,
228 &test_fold_data,
229 &test_fold_labels,
230 )?;
231 cv_scores.push(fold_auc);
232 }
233
234 Ok(cv_scores)
235 }
236
237 fn train_discriminator(&self, data: &Array2<f64>, labels: &[usize]) -> Result<f64> {
239 let n_samples = data.nrows();
241 let test_size = (n_samples as f64 * self.config.test_size) as usize;
242
243 let mut indices: Vec<usize> = (0..n_samples).collect();
244 self.shuffle_indices(&mut indices);
245
246 let train_indices = &indices[test_size..];
247 let test_indices = &indices[..test_size];
248
249 let train_data = self.extract_rows(data, train_indices);
250 let test_data = self.extract_rows(data, test_indices);
251 let train_labels: Vec<usize> = train_indices.iter().map(|&i| labels[i]).collect();
252 let test_labels: Vec<usize> = test_indices.iter().map(|&i| labels[i]).collect();
253
254 self.train_simple_discriminator(&train_data, &train_labels, &test_data, &test_labels)
255 }
256
257 fn train_simple_discriminator(
259 &self,
260 train_data: &Array2<f64>,
261 train_labels: &[usize],
262 test_data: &Array2<f64>,
263 test_labels: &[usize],
264 ) -> Result<f64> {
265 let n_features = train_data.ncols();
266 let mut weights = vec![0.0; n_features + 1]; let learning_rate = 0.01;
268
269 let train_y: Vec<f64> = train_labels
271 .iter()
272 .map(|&label| if label == 1 { 1.0 } else { -1.0 })
273 .collect();
274
275 for _iteration in 0..self.config.max_iterations {
277 let mut gradients = vec![0.0; n_features + 1];
278
279 for (i, &y) in train_y.iter().enumerate() {
280 let mut prediction = weights[0]; for j in 0..n_features {
283 prediction += weights[j + 1] * train_data[[i, j]];
284 }
285
286 let prob = 1.0 / (1.0 + (-prediction).exp());
288 let error = prob - (y + 1.0) / 2.0; gradients[0] += error; for j in 0..n_features {
293 gradients[j + 1] += error * train_data[[i, j]];
294 }
295 }
296
297 for j in 0..weights.len() {
299 weights[j] -= learning_rate * gradients[j] / train_y.len() as f64;
300 }
301 }
302
303 self.calculate_auc(&weights, test_data, test_labels)
305 }
306
307 fn calculate_auc(
309 &self,
310 weights: &[f64],
311 test_data: &Array2<f64>,
312 test_labels: &[usize],
313 ) -> Result<f64> {
314 let n_features = test_data.ncols();
315 let mut predictions = Vec::new();
316
317 for i in 0..test_data.nrows() {
318 let mut prediction = weights[0]; for j in 0..n_features {
320 prediction += weights[j + 1] * test_data[[i, j]];
321 }
322 let prob = 1.0 / (1.0 + (-prediction).exp());
323 predictions.push(prob);
324 }
325
326 let mut positive_scores = Vec::new();
328 let mut negative_scores = Vec::new();
329
330 for (i, &score) in predictions.iter().enumerate() {
331 if test_labels[i] == 1 {
332 positive_scores.push(score);
333 } else {
334 negative_scores.push(score);
335 }
336 }
337
338 if positive_scores.is_empty() || negative_scores.is_empty() {
339 return Ok(0.5); }
341
342 let mut concordant = 0;
344 let mut total = 0;
345
346 for &pos_score in &positive_scores {
347 for &neg_score in &negative_scores {
348 total += 1;
349 if pos_score > neg_score {
350 concordant += 1;
351 }
352 }
353 }
354
355 Ok(concordant as f64 / total as f64)
356 }
357
358 fn bootstrap_confidence_interval(
360 &self,
361 data: &Array2<f64>,
362 labels: &[usize],
363 ) -> Result<(f64, f64)> {
364 let mut bootstrap_aucs = Vec::new();
365 let n_samples = data.nrows();
366
367 for _ in 0..self.config.n_bootstrap {
368 let mut boot_indices = Vec::new();
370 for _ in 0..n_samples {
371 boot_indices.push(self.random_index(n_samples));
372 }
373
374 let boot_data = self.extract_rows(data, &boot_indices);
375 let boot_labels: Vec<usize> = boot_indices.iter().map(|&i| labels[i]).collect();
376
377 if let Ok(auc) = self.train_discriminator(&boot_data, &boot_labels) {
379 bootstrap_aucs.push(auc);
380 }
381 }
382
383 if bootstrap_aucs.is_empty() {
384 return Ok((0.5, 0.5));
385 }
386
387 bootstrap_aucs.sort_by(|a, b| a.partial_cmp(b).expect("operation should succeed"));
388
389 let lower_idx = (self.config.n_bootstrap as f64 * 0.025) as usize;
390 let upper_idx = (self.config.n_bootstrap as f64 * 0.975) as usize;
391
392 let lower_bound = bootstrap_aucs[lower_idx.min(bootstrap_aucs.len() - 1)];
393 let upper_bound = bootstrap_aucs[upper_idx.min(bootstrap_aucs.len() - 1)];
394
395 Ok((lower_bound, upper_bound))
396 }
397
398 fn analyze_feature_importance(
400 &self,
401 data: &Array2<f64>,
402 labels: &[usize],
403 ) -> Result<(Option<Vec<f64>>, Vec<usize>)> {
404 let n_features = data.ncols();
405 let mut feature_importance = vec![0.0; n_features];
406
407 let baseline_auc = self.train_discriminator(data, labels)?;
409
410 for feature_idx in 0..n_features {
412 let mut permuted_data = data.clone();
413
414 let mut feature_values: Vec<f64> =
416 (0..data.nrows()).map(|i| data[[i, feature_idx]]).collect();
417 self.shuffle_f64(&mut feature_values);
418
419 for (i, &value) in feature_values.iter().enumerate() {
420 permuted_data[[i, feature_idx]] = value;
421 }
422
423 let permuted_auc = self.train_discriminator(&permuted_data, labels)?;
425
426 feature_importance[feature_idx] = baseline_auc - permuted_auc;
428 }
429
430 let mut suspicious_features = Vec::new();
432 let importance_threshold = 0.01; for (i, &importance) in feature_importance.iter().enumerate() {
435 if importance > importance_threshold {
436 suspicious_features.push(i);
437 }
438 }
439
440 Ok((Some(feature_importance), suspicious_features))
441 }
442
443 fn calculate_p_value(&self, cv_scores: &[f64]) -> f64 {
445 let mean_auc = cv_scores.iter().sum::<f64>() / cv_scores.len() as f64;
447 let std_auc = self.calculate_std(cv_scores);
448 let n = cv_scores.len() as f64;
449
450 if std_auc == 0.0 {
451 return if mean_auc > 0.5 { 0.0 } else { 1.0 };
452 }
453
454 let t_stat = (mean_auc - 0.5) * n.sqrt() / std_auc;
455
456 let p_value = 2.0 * (1.0 - self.normal_cdf(t_stat.abs()));
458 p_value.clamp(0.0, 1.0)
459 }
460
461 fn calculate_std(&self, values: &[f64]) -> f64 {
463 if values.len() < 2 {
464 return 0.0;
465 }
466
467 let mean = values.iter().sum::<f64>() / values.len() as f64;
468 let variance =
469 values.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / (values.len() - 1) as f64;
470
471 variance.sqrt()
472 }
473
474 fn normal_cdf(&self, x: f64) -> f64 {
476 0.5 * (1.0 + self.erf(x / 2.0_f64.sqrt()))
477 }
478
479 fn erf(&self, x: f64) -> f64 {
481 let a1 = 0.254829592;
483 let a2 = -0.284496736;
484 let a3 = 1.421413741;
485 let a4 = -1.453152027;
486 let a5 = 1.061405429;
487 let p = 0.3275911;
488
489 let sign = if x < 0.0 { -1.0 } else { 1.0 };
490 let x = x.abs();
491
492 let t = 1.0 / (1.0 + p * x);
493 let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp();
494
495 sign * y
496 }
497
498 fn extract_rows(&self, data: &Array2<f64>, indices: &[usize]) -> Array2<f64> {
500 let n_rows = indices.len();
501 let n_cols = data.ncols();
502 let mut result = Array2::zeros((n_rows, n_cols));
503
504 for (i, &idx) in indices.iter().enumerate() {
505 for j in 0..n_cols {
506 result[[i, j]] = data[[idx, j]];
507 }
508 }
509
510 result
511 }
512
513 fn shuffle_indices(&self, indices: &mut [usize]) {
515 use scirs2_core::random::rngs::StdRng;
516 use scirs2_core::random::SeedableRng;
517 let mut rng = match self.config.random_state {
518 Some(seed) => StdRng::seed_from_u64(seed),
519 None => {
520 use scirs2_core::random::thread_rng;
521 StdRng::from_rng(&mut thread_rng())
522 }
523 };
524 indices.shuffle(&mut rng);
525 }
526
527 fn shuffle_f64(&self, values: &mut [f64]) {
529 use scirs2_core::random::rngs::StdRng;
530 use scirs2_core::random::SeedableRng;
531 let mut rng = match self.config.random_state {
532 Some(seed) => StdRng::seed_from_u64(seed),
533 None => {
534 use scirs2_core::random::thread_rng;
535 StdRng::from_rng(&mut thread_rng())
536 }
537 };
538 values.shuffle(&mut rng);
539 }
540
541 fn random_index(&self, max: usize) -> usize {
543 use scirs2_core::random::rngs::StdRng;
544 use scirs2_core::random::SeedableRng;
545 let mut rng = match self.config.random_state {
546 Some(seed) => StdRng::seed_from_u64(seed),
547 None => {
548 use scirs2_core::random::thread_rng;
549 StdRng::from_rng(&mut thread_rng())
550 }
551 };
552 rng.random_range(0..max)
553 }
554}
555
556#[allow(non_snake_case)]
557#[cfg(test)]
558mod tests {
559 use super::*;
560
561 #[test]
562 fn test_adversarial_validation_same_distribution() {
563 let config = AdversarialValidationConfig::default();
564 let validator = AdversarialValidator::new(config);
565
566 let train_data = Array2::zeros((100, 5));
568 let test_data = Array2::zeros((50, 5));
569
570 let result = validator
571 .validate(&train_data, &test_data)
572 .expect("operation should succeed");
573
574 assert!(
576 result.discriminator_auc < 0.6,
577 "AUC should be close to 0.5 for identical distributions"
578 );
579 assert!(
580 !result.is_significantly_different,
581 "Identical distributions should not be significantly different"
582 );
583 }
584
585 #[test]
586 fn test_adversarial_validation_different_distributions() {
587 let config = AdversarialValidationConfig {
588 significance_threshold: 0.6,
589 ..Default::default()
590 };
591 let validator = AdversarialValidator::new(config);
592
593 let mut train_data = Array2::zeros((100, 5));
595 let mut test_data = Array2::ones((50, 5));
596
597 for i in 0..train_data.nrows() {
599 for j in 0..train_data.ncols() {
600 train_data[[i, j]] = 0.0;
601 }
602 }
603
604 for i in 0..test_data.nrows() {
605 for j in 0..test_data.ncols() {
606 test_data[[i, j]] = 1.0;
607 }
608 }
609
610 let result = validator
611 .validate(&train_data, &test_data)
612 .expect("operation should succeed");
613
614 assert!(
616 result.discriminator_auc > 0.7,
617 "AUC should be high for different distributions"
618 );
619 }
620}