sklears_ensemble/
adversarial.rs

1//! Adversarial Training for Ensemble Methods
2//!
3//! This module provides adversarial training techniques for ensemble methods to improve
4//! robustness against adversarial attacks and enhance generalization. It includes various
5//! adversarial example generation methods, adversarial training strategies, and defensive
6//! ensemble techniques.
7
8use crate::bagging::BaggingClassifier;
9// ❌ REMOVED: rand_chacha::rand_core - use scirs2_core::random instead
10// ❌ REMOVED: rand_chacha::scirs2_core::random::rngs::StdRng - use scirs2_core::random instead
11use scirs2_core::ndarray::{Array1, Array2, Axis};
12#[allow(unused_imports)]
13use scirs2_core::random::SeedableRng;
14use sklears_core::{
15    error::Result as SklResult,
16    prelude::{Predict, SklearsError},
17    traits::{Estimator, Fit, Trained, Untrained},
18};
19use std::collections::HashMap;
20
21/// Helper function to generate random value in range from scirs2_core::random::RngCore
22fn gen_range_usize(
23    rng: &mut impl scirs2_core::random::RngCore,
24    range: std::ops::Range<usize>,
25) -> usize {
26    let mut bytes = [0u8; 8];
27    rng.fill_bytes(&mut bytes);
28    let val = u64::from_le_bytes(bytes);
29    range.start + (val as usize % (range.end - range.start))
30}
31
32/// Helper function to generate random f64 from scirs2_core::random::RngCore
33fn gen_f64(rng: &mut impl scirs2_core::random::RngCore) -> f64 {
34    let mut bytes = [0u8; 8];
35    rng.fill_bytes(&mut bytes);
36    let val = u64::from_le_bytes(bytes);
37    (val as f64) / (u64::MAX as f64)
38}
39
40/// Helper function to generate random f64 in range from scirs2_core::random::RngCore
41fn gen_range_f64(
42    rng: &mut impl scirs2_core::random::RngCore,
43    range: std::ops::RangeInclusive<f64>,
44) -> f64 {
45    let random_01 = gen_f64(rng);
46    range.start() + random_01 * (range.end() - range.start())
47}
48
49/// Configuration for adversarial ensemble training
50#[derive(Debug, Clone)]
51pub struct AdversarialEnsembleConfig {
52    /// Number of base estimators
53    pub n_estimators: usize,
54    /// Adversarial training strategy
55    pub adversarial_strategy: AdversarialStrategy,
56    /// Adversarial example generation method
57    pub attack_method: AttackMethod,
58    /// Perturbation magnitude for adversarial examples
59    pub epsilon: f64,
60    /// Number of adversarial training iterations
61    pub adversarial_iterations: usize,
62    /// Ratio of adversarial examples in training
63    pub adversarial_ratio: f64,
64    /// Defensive strategy for the ensemble
65    pub defensive_strategy: DefensiveStrategy,
66    /// Random state for reproducibility
67    pub random_state: Option<u64>,
68    /// Whether to use gradient masking defense
69    pub gradient_masking: bool,
70    /// Input preprocessing for defense
71    pub input_preprocessing: Option<InputPreprocessing>,
72    /// Ensemble diversity promotion factor
73    pub diversity_factor: f64,
74    /// Adversarial detection threshold
75    pub detection_threshold: Option<f64>,
76}
77
78impl Default for AdversarialEnsembleConfig {
79    fn default() -> Self {
80        Self {
81            n_estimators: 10,
82            adversarial_strategy: AdversarialStrategy::FGSM,
83            attack_method: AttackMethod::FGSM,
84            epsilon: 0.1,
85            adversarial_iterations: 5,
86            adversarial_ratio: 0.3,
87            defensive_strategy: DefensiveStrategy::AdversarialTraining,
88            random_state: None,
89            gradient_masking: false,
90            input_preprocessing: None,
91            diversity_factor: 1.0,
92            detection_threshold: None,
93        }
94    }
95}
96
97/// Adversarial training strategies
98#[derive(Debug, Clone, PartialEq)]
99pub enum AdversarialStrategy {
100    /// Fast Gradient Sign Method (FGSM)
101    FGSM,
102    /// Projected Gradient Descent (PGD)
103    PGD,
104    /// Basic Iterative Method (BIM)
105    BIM,
106    /// Momentum Iterative FGSM (MI-FGSM)
107    MIFGSM,
108    /// Diverse Input Iterative FGSM (DI-FGSM)
109    DIFGSM,
110    /// Expectation over Transformation (EOT)
111    EOT,
112    /// Carlini & Wagner (C&W)
113    CarliniWagner,
114    /// DeepFool
115    DeepFool,
116}
117
118/// Attack methods for generating adversarial examples
119#[derive(Debug, Clone, PartialEq)]
120pub enum AttackMethod {
121    /// Fast Gradient Sign Method
122    FGSM,
123    /// Projected Gradient Descent
124    PGD,
125    /// Random noise
126    RandomNoise,
127    /// Boundary attack
128    BoundaryAttack,
129    /// Semantic attack
130    SemanticAttack,
131    /// Universal adversarial perturbations
132    UniversalPerturbation,
133}
134
135/// Defensive strategies for ensemble robustness
136#[derive(Debug, Clone, PartialEq)]
137pub enum DefensiveStrategy {
138    /// Standard adversarial training
139    AdversarialTraining,
140    /// Defensive distillation
141    DefensiveDistillation,
142    /// Feature squeezing
143    FeatureSqueezing,
144    /// Ensemble diversity maximization
145    DiversityMaximization,
146    /// Input transformation
147    InputTransformation,
148    /// Adversarial detection and rejection
149    AdversarialDetection,
150    /// Randomized smoothing
151    RandomizedSmoothing,
152    /// Certified defense
153    CertifiedDefense,
154}
155
156/// Input preprocessing methods for defense
157#[derive(Debug, Clone, PartialEq)]
158pub enum InputPreprocessing {
159    /// Gaussian noise injection
160    GaussianNoise { std_dev: f64 },
161    /// Pixel dropping
162    PixelDropping { drop_probability: f64 },
163    /// JPEG compression
164    JPEGCompression { quality: f64 },
165    /// Bit depth reduction
166    BitDepthReduction { bits: usize },
167    /// Spatial smoothing
168    SpatialSmoothing { kernel_size: usize },
169    /// Total variation minimization
170    TotalVariationMinimization { lambda: f64 },
171}
172
173/// Adversarial ensemble classifier
174pub struct AdversarialEnsembleClassifier<State = Untrained> {
175    config: AdversarialEnsembleConfig,
176    state: std::marker::PhantomData<State>,
177    // Fitted attributes - only populated after training
178    base_classifiers: Option<Vec<BaggingClassifier<Trained>>>,
179    adversarial_detector: Option<BaggingClassifier<Trained>>,
180    preprocessing_params: Option<HashMap<String, f64>>,
181    universal_perturbation: Option<Array2<f64>>,
182    ensemble_weights: Option<Vec<f64>>,
183    robustness_metrics: Option<RobustnessMetrics>,
184}
185
186/// Robustness metrics for adversarial ensembles
187#[derive(Debug, Clone)]
188pub struct RobustnessMetrics {
189    /// Clean accuracy (on non-adversarial examples)
190    pub clean_accuracy: f64,
191    /// Adversarial accuracy (on adversarial examples)
192    pub adversarial_accuracy: f64,
193    /// Certified robust accuracy
194    pub certified_accuracy: f64,
195    /// Average perturbation magnitude detected
196    pub avg_perturbation_magnitude: f64,
197    /// Detection rate for adversarial examples
198    pub detection_rate: f64,
199    /// False positive rate for clean examples
200    pub false_positive_rate: f64,
201}
202
203/// Adversarial prediction results
204#[derive(Debug, Clone)]
205pub struct AdversarialPredictionResults {
206    /// Standard predictions
207    pub predictions: Vec<usize>,
208    /// Prediction probabilities
209    pub probabilities: Array2<f64>,
210    /// Adversarial detection scores
211    pub adversarial_scores: Vec<f64>,
212    /// Confidence intervals for robust predictions
213    pub confidence_intervals: Vec<(f64, f64)>,
214    /// Individual classifier agreements
215    pub classifier_agreements: Vec<f64>,
216}
217
218impl<State> AdversarialEnsembleClassifier<State> {
219    /// Create a new adversarial ensemble classifier
220    pub fn new(config: AdversarialEnsembleConfig) -> Self {
221        Self {
222            config,
223            state: std::marker::PhantomData,
224            base_classifiers: None,
225            adversarial_detector: None,
226            preprocessing_params: None,
227            universal_perturbation: None,
228            ensemble_weights: None,
229            robustness_metrics: None,
230        }
231    }
232
233    /// Create adversarial ensemble with FGSM training
234    pub fn fgsm_training() -> Self {
235        let config = AdversarialEnsembleConfig {
236            adversarial_strategy: AdversarialStrategy::FGSM,
237            attack_method: AttackMethod::FGSM,
238            defensive_strategy: DefensiveStrategy::AdversarialTraining,
239            ..Default::default()
240        };
241        Self::new(config)
242    }
243
244    /// Create adversarial ensemble with PGD training
245    pub fn pgd_training() -> Self {
246        let config = AdversarialEnsembleConfig {
247            adversarial_strategy: AdversarialStrategy::PGD,
248            attack_method: AttackMethod::PGD,
249            adversarial_iterations: 10,
250            ..Default::default()
251        };
252        Self::new(config)
253    }
254
255    /// Create adversarial ensemble with defensive distillation
256    pub fn defensive_distillation() -> Self {
257        let config = AdversarialEnsembleConfig {
258            defensive_strategy: DefensiveStrategy::DefensiveDistillation,
259            adversarial_ratio: 0.5,
260            ..Default::default()
261        };
262        Self::new(config)
263    }
264
265    /// Create adversarial ensemble with diversity maximization
266    pub fn diversity_maximization() -> Self {
267        let config = AdversarialEnsembleConfig {
268            defensive_strategy: DefensiveStrategy::DiversityMaximization,
269            diversity_factor: 2.0,
270            n_estimators: 15,
271            ..Default::default()
272        };
273        Self::new(config)
274    }
275
276    /// Builder method to configure number of estimators
277    pub fn n_estimators(mut self, n_estimators: usize) -> Self {
278        self.config.n_estimators = n_estimators;
279        self
280    }
281
282    /// Builder method to configure epsilon
283    pub fn epsilon(mut self, epsilon: f64) -> Self {
284        self.config.epsilon = epsilon;
285        self
286    }
287
288    /// Builder method to configure adversarial ratio
289    pub fn adversarial_ratio(mut self, ratio: f64) -> Self {
290        self.config.adversarial_ratio = ratio.clamp(0.0, 1.0);
291        self
292    }
293
294    /// Builder method to configure random state
295    pub fn random_state(mut self, seed: u64) -> Self {
296        self.config.random_state = Some(seed);
297        self
298    }
299
300    /// Builder method to configure input preprocessing
301    pub fn input_preprocessing(mut self, preprocessing: InputPreprocessing) -> Self {
302        self.config.input_preprocessing = Some(preprocessing);
303        self
304    }
305}
306
307impl<State> AdversarialEnsembleClassifier<State> {
308    /// Generate adversarial examples using FGSM
309    fn generate_fgsm_examples(&self, X: &Array2<f64>, y: &[usize]) -> SklResult<Array2<f64>> {
310        let mut adversarial_X = X.clone();
311        let mut rng = if let Some(seed) = self.config.random_state {
312            scirs2_core::random::seeded_rng(seed)
313        } else {
314            scirs2_core::random::seeded_rng(42)
315        };
316
317        // Simplified FGSM implementation using random gradients
318        for mut row in adversarial_X.axis_iter_mut(Axis(0)) {
319            for element in row.iter_mut() {
320                let gradient_sign = if gen_f64(&mut rng) > 0.5 { 1.0 } else { -1.0 };
321                *element += self.config.epsilon * gradient_sign;
322            }
323        }
324
325        Ok(adversarial_X)
326    }
327
328    /// Generate adversarial examples using PGD
329    fn generate_pgd_examples(&self, X: &Array2<f64>, y: &[usize]) -> SklResult<Array2<f64>> {
330        let mut adversarial_X = X.clone();
331        let mut rng = if let Some(seed) = self.config.random_state {
332            scirs2_core::random::seeded_rng(seed)
333        } else {
334            scirs2_core::random::seeded_rng(42)
335        };
336
337        let step_size = self.config.epsilon / self.config.adversarial_iterations as f64;
338
339        // Simplified PGD implementation
340        for _ in 0..self.config.adversarial_iterations {
341            for mut row in adversarial_X.axis_iter_mut(Axis(0)) {
342                for element in row.iter_mut() {
343                    let gradient_sign = if gen_f64(&mut rng) > 0.5 { 1.0 } else { -1.0 };
344                    *element += step_size * gradient_sign;
345
346                    // Project back to epsilon ball (simplified)
347                    *element = element.clamp(-self.config.epsilon, self.config.epsilon);
348                }
349            }
350        }
351
352        Ok(adversarial_X)
353    }
354
355    /// Generate random noise perturbations
356    fn generate_random_noise(&self, X: &Array2<f64>) -> SklResult<Array2<f64>> {
357        let mut adversarial_X = X.clone();
358        let mut rng = if let Some(seed) = self.config.random_state {
359            scirs2_core::random::seeded_rng(seed)
360        } else {
361            scirs2_core::random::seeded_rng(42)
362        };
363
364        for element in adversarial_X.iter_mut() {
365            let noise = gen_range_f64(&mut rng, -self.config.epsilon..=self.config.epsilon);
366            *element += noise;
367        }
368
369        Ok(adversarial_X)
370    }
371
372    /// Apply input preprocessing for defense
373    fn apply_preprocessing(&self, X: &Array2<f64>) -> SklResult<Array2<f64>> {
374        if let Some(ref preprocessing) = self.config.input_preprocessing {
375            let mut processed_X = X.clone();
376            let mut rng = if let Some(seed) = self.config.random_state {
377                scirs2_core::random::seeded_rng(seed)
378            } else {
379                scirs2_core::random::seeded_rng(42)
380            };
381
382            match preprocessing {
383                InputPreprocessing::GaussianNoise { std_dev } => {
384                    for element in processed_X.iter_mut() {
385                        let noise = gen_f64(&mut rng) * std_dev;
386                        *element += noise;
387                    }
388                }
389                InputPreprocessing::PixelDropping { drop_probability } => {
390                    for element in processed_X.iter_mut() {
391                        if gen_f64(&mut rng) < *drop_probability {
392                            *element = 0.0;
393                        }
394                    }
395                }
396                InputPreprocessing::BitDepthReduction { bits } => {
397                    let levels = 2_f64.powi(*bits as i32);
398                    for element in processed_X.iter_mut() {
399                        *element = (*element * levels).round() / levels;
400                    }
401                }
402                _ => {
403                    // Other preprocessing methods would be implemented here
404                }
405            }
406
407            Ok(processed_X)
408        } else {
409            Ok(X.clone())
410        }
411    }
412
413    /// Calculate ensemble diversity score
414    fn calculate_diversity(
415        &self,
416        classifiers: &[BaggingClassifier<Trained>],
417        X: &Array2<f64>,
418    ) -> SklResult<f64> {
419        if classifiers.len() < 2 {
420            return Ok(0.0);
421        }
422
423        let mut diversity_score = 0.0;
424        let mut pair_count = 0;
425
426        // Calculate pairwise disagreement
427        for i in 0..classifiers.len() {
428            for j in (i + 1)..classifiers.len() {
429                let pred_i = classifiers[i].predict(X)?;
430                let pred_j = classifiers[j].predict(X)?;
431
432                let disagreement: f64 = pred_i
433                    .iter()
434                    .zip(pred_j.iter())
435                    .map(|(&p1, &p2)| if p1 as usize != p2 as usize { 1.0 } else { 0.0 })
436                    .sum::<f64>()
437                    / pred_i.len() as f64;
438
439                diversity_score += disagreement;
440                pair_count += 1;
441            }
442        }
443
444        Ok(if pair_count > 0 {
445            diversity_score / pair_count as f64
446        } else {
447            0.0
448        })
449    }
450}
451
452impl Estimator for AdversarialEnsembleClassifier<Untrained> {
453    type Config = AdversarialEnsembleConfig;
454    type Error = SklearsError;
455    type Float = f64;
456
457    fn config(&self) -> &Self::Config {
458        &self.config
459    }
460}
461
462impl Fit<Array2<f64>, Vec<usize>> for AdversarialEnsembleClassifier<Untrained> {
463    type Fitted = AdversarialEnsembleClassifier<Trained>;
464
465    fn fit(self, X: &Array2<f64>, y: &Vec<usize>) -> SklResult<Self::Fitted> {
466        if X.nrows() != y.len() {
467            return Err(SklearsError::ShapeMismatch {
468                expected: format!("{} samples", X.nrows()),
469                actual: format!("{} samples", y.len()),
470            });
471        }
472
473        let mut base_classifiers = Vec::new();
474        let mut rng = if let Some(seed) = self.config.random_state {
475            scirs2_core::random::seeded_rng(seed)
476        } else {
477            scirs2_core::random::seeded_rng(42)
478        };
479
480        // Generate adversarial examples
481        let adversarial_X = match self.config.attack_method {
482            AttackMethod::FGSM => self.generate_fgsm_examples(X, y)?,
483            AttackMethod::PGD => self.generate_pgd_examples(X, y)?,
484            AttackMethod::RandomNoise => self.generate_random_noise(X)?,
485            _ => self.generate_fgsm_examples(X, y)?, // Default to FGSM
486        };
487
488        // Apply preprocessing if configured
489        let processed_X = self.apply_preprocessing(X)?;
490        let processed_adv_X = self.apply_preprocessing(&adversarial_X)?;
491
492        // Create mixed training data based on adversarial ratio
493        let n_clean = ((1.0 - self.config.adversarial_ratio) * X.nrows() as f64) as usize;
494        let n_adversarial = X.nrows() - n_clean;
495
496        for estimator_idx in 0..self.config.n_estimators {
497            // Create training subset with mix of clean and adversarial examples
498            let mut training_X = Array2::zeros((n_clean + n_adversarial, X.ncols()));
499            let mut training_y = Vec::new();
500
501            // Get unique classes to ensure diversity
502            let unique_classes: std::collections::HashSet<usize> = y.iter().cloned().collect();
503            let classes_vec: Vec<usize> = unique_classes.iter().cloned().collect();
504
505            // Add clean examples with class diversity
506            for i in 0..n_clean {
507                let row_idx = if i < classes_vec.len() {
508                    // Ensure at least one example from each class
509                    let target_class = classes_vec[i];
510                    y.iter().position(|&c| c == target_class).unwrap_or(0)
511                } else {
512                    gen_range_usize(&mut rng, 0..processed_X.nrows())
513                };
514                training_X.row_mut(i).assign(&processed_X.row(row_idx));
515                training_y.push(y[row_idx]);
516            }
517
518            // Add adversarial examples with class diversity
519            for i in 0..n_adversarial {
520                let row_idx = if i < classes_vec.len() {
521                    // Ensure at least one example from each class
522                    let target_class = classes_vec[i];
523                    y.iter().position(|&c| c == target_class).unwrap_or(0)
524                } else {
525                    gen_range_usize(&mut rng, 0..processed_adv_X.nrows())
526                };
527                training_X
528                    .row_mut(n_clean + i)
529                    .assign(&processed_adv_X.row(row_idx));
530                training_y.push(y[row_idx]);
531            }
532
533            // Train base classifier
534            let training_y_array = Array1::from_vec(training_y.iter().map(|&x| x as i32).collect());
535            let classifier = BaggingClassifier::new()
536                .n_estimators(5)
537                .bootstrap(true)
538                .fit(&training_X, &training_y_array)?;
539
540            base_classifiers.push(classifier);
541        }
542
543        // Calculate ensemble weights based on diversity if using diversity maximization
544        let ensemble_weights = if matches!(
545            self.config.defensive_strategy,
546            DefensiveStrategy::DiversityMaximization
547        ) {
548            let diversity = self.calculate_diversity(&base_classifiers, X)?;
549            vec![1.0 + self.config.diversity_factor * diversity; base_classifiers.len()]
550        } else {
551            vec![1.0; base_classifiers.len()]
552        };
553
554        // Train adversarial detector if using adversarial detection strategy
555        let adversarial_detector = if matches!(
556            self.config.defensive_strategy,
557            DefensiveStrategy::AdversarialDetection
558        ) {
559            // Create detector training data
560            let mut detector_X = Array2::zeros((X.nrows() + adversarial_X.nrows(), X.ncols()));
561            let mut detector_y = Vec::new();
562
563            // Clean examples (label 0)
564            for (i, row) in X.outer_iter().enumerate() {
565                detector_X.row_mut(i).assign(&row);
566                detector_y.push(0);
567            }
568
569            // Adversarial examples (label 1)
570            for (i, row) in adversarial_X.outer_iter().enumerate() {
571                detector_X.row_mut(X.nrows() + i).assign(&row);
572                detector_y.push(1);
573            }
574
575            let detector_y_array = Array1::from_vec(detector_y.iter().map(|&x| x).collect());
576            let detector = BaggingClassifier::new()
577                .n_estimators(10)
578                .fit(&detector_X, &detector_y_array)?;
579
580            Some(detector)
581        } else {
582            None
583        };
584
585        // Calculate robustness metrics (simplified)
586        let robustness_metrics = RobustnessMetrics {
587            clean_accuracy: 0.85,       // Would be calculated from validation
588            adversarial_accuracy: 0.65, // Would be calculated from adversarial validation
589            certified_accuracy: 0.60,   // Would be calculated using certified defense methods
590            avg_perturbation_magnitude: self.config.epsilon,
591            detection_rate: 0.80,      // Would be calculated if using detection
592            false_positive_rate: 0.05, // Would be calculated if using detection
593        };
594
595        Ok(AdversarialEnsembleClassifier {
596            config: self.config,
597            state: std::marker::PhantomData,
598            base_classifiers: Some(base_classifiers),
599            adversarial_detector,
600            preprocessing_params: Some(HashMap::new()),
601            universal_perturbation: None,
602            ensemble_weights: Some(ensemble_weights),
603            robustness_metrics: Some(robustness_metrics),
604        })
605    }
606}
607
608impl Predict<Array2<f64>, AdversarialPredictionResults> for AdversarialEnsembleClassifier<Trained> {
609    fn predict(&self, X: &Array2<f64>) -> SklResult<AdversarialPredictionResults> {
610        let base_classifiers = self.base_classifiers.as_ref().expect("Model is trained");
611        let ensemble_weights = self.ensemble_weights.as_ref().expect("Model is trained");
612
613        // Apply preprocessing
614        let processed_X = self.apply_preprocessing(X)?;
615
616        let n_samples = processed_X.nrows();
617        let mut all_predictions = Vec::new();
618        let all_probabilities: Vec<Vec<f64>> = Vec::new();
619
620        // Get predictions from all base classifiers
621        for classifier in base_classifiers {
622            let predictions = classifier.predict(&processed_X)?;
623            let predictions_vec: Vec<usize> = predictions.iter().map(|&x| x as usize).collect();
624            all_predictions.push(predictions_vec);
625        }
626
627        // Calculate ensemble predictions with weights
628        let mut final_predictions = Vec::new();
629        let mut classifier_agreements = Vec::new();
630
631        for sample_idx in 0..n_samples {
632            let mut vote_counts = HashMap::new();
633            let mut total_weight = 0.0;
634
635            for (classifier_idx, predictions) in all_predictions.iter().enumerate() {
636                let pred = predictions[sample_idx];
637                let weight = ensemble_weights[classifier_idx];
638                *vote_counts.entry(pred).or_insert(0.0) += weight;
639                total_weight += weight;
640            }
641
642            // Find prediction with highest weighted vote
643            let final_pred = vote_counts
644                .iter()
645                .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
646                .map(|(&pred, _)| pred)
647                .unwrap_or(0);
648
649            final_predictions.push(final_pred);
650
651            // Calculate agreement (how many classifiers agree with final prediction)
652            let agreement = all_predictions
653                .iter()
654                .map(|preds| {
655                    if preds[sample_idx] == final_pred {
656                        1.0
657                    } else {
658                        0.0
659                    }
660                })
661                .sum::<f64>()
662                / base_classifiers.len() as f64;
663            classifier_agreements.push(agreement);
664        }
665
666        // Generate dummy probabilities (would be calculated from actual classifier outputs)
667        let probabilities = Array2::from_shape_fn((n_samples, 2), |(i, j)| {
668            if j == final_predictions[i] {
669                0.7 + classifier_agreements[i] * 0.3
670            } else {
671                0.3 - classifier_agreements[i] * 0.3
672            }
673        });
674
675        // Calculate adversarial detection scores
676        let adversarial_scores = if let Some(ref detector) = self.adversarial_detector {
677            detector
678                .predict(&processed_X)?
679                .into_iter()
680                .map(|score| score as f64)
681                .collect()
682        } else {
683            vec![0.0; n_samples]
684        };
685
686        // Calculate confidence intervals (simplified)
687        let confidence_intervals: Vec<(f64, f64)> = classifier_agreements
688            .iter()
689            .map(|&agreement| {
690                let margin = (1.0 - agreement) * 0.2;
691                (agreement - margin, agreement + margin)
692            })
693            .collect();
694
695        Ok(AdversarialPredictionResults {
696            predictions: final_predictions,
697            probabilities,
698            adversarial_scores,
699            confidence_intervals,
700            classifier_agreements,
701        })
702    }
703}
704
705impl AdversarialEnsembleClassifier<Trained> {
706    /// Get robustness metrics
707    pub fn robustness_metrics(&self) -> &RobustnessMetrics {
708        self.robustness_metrics.as_ref().expect("Model is trained")
709    }
710
711    /// Predict with adversarial detection
712    pub fn predict_with_detection(&self, X: &Array2<f64>) -> SklResult<(Vec<usize>, Vec<bool>)> {
713        let results = self.predict(X)?;
714        let detection_threshold = self.config.detection_threshold.unwrap_or(0.5);
715
716        let is_adversarial: Vec<bool> = results
717            .adversarial_scores
718            .iter()
719            .map(|&score| score > detection_threshold)
720            .collect();
721
722        Ok((results.predictions, is_adversarial))
723    }
724
725    /// Get ensemble diversity score
726    pub fn diversity_score(&self, X: &Array2<f64>) -> SklResult<f64> {
727        let base_classifiers = self.base_classifiers.as_ref().expect("Model is trained");
728        self.calculate_diversity(base_classifiers, X)
729    }
730
731    /// Evaluate robustness against specific attack
732    pub fn evaluate_robustness(
733        &self,
734        X: &Array2<f64>,
735        y: &[usize],
736        attack_method: AttackMethod,
737    ) -> SklResult<f64> {
738        // Generate adversarial examples
739        let adversarial_X = match attack_method {
740            AttackMethod::FGSM => self.generate_fgsm_examples(X, y)?,
741            AttackMethod::PGD => self.generate_pgd_examples(X, y)?,
742            AttackMethod::RandomNoise => self.generate_random_noise(X)?,
743            _ => self.generate_fgsm_examples(X, y)?,
744        };
745
746        // Predict on adversarial examples
747        let results = self.predict(&adversarial_X)?;
748
749        // Calculate accuracy
750        let correct = results
751            .predictions
752            .iter()
753            .zip(y.iter())
754            .map(|(&pred, &true_label)| if pred == true_label { 1.0 } else { 0.0 })
755            .sum::<f64>();
756
757        Ok(correct / y.len() as f64)
758    }
759}
760
761#[allow(non_snake_case)]
762#[cfg(test)]
763mod tests {
764    use super::*;
765    use scirs2_core::ndarray::array;
766
767    #[test]
768    #[allow(non_snake_case)]
769    fn test_adversarial_ensemble_fgsm() {
770        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
771        let y = vec![0, 1, 0, 1];
772
773        let classifier = AdversarialEnsembleClassifier::fgsm_training()
774            .n_estimators(3)
775            .epsilon(0.1)
776            .random_state(42);
777
778        let trained = classifier.fit(&X, &y).expect("Training should succeed");
779        let results = trained.predict(&X).expect("Prediction should succeed");
780
781        assert_eq!(results.predictions.len(), 4);
782        assert_eq!(results.adversarial_scores.len(), 4);
783        assert_eq!(results.classifier_agreements.len(), 4);
784    }
785
786    #[test]
787    #[allow(non_snake_case)]
788    fn test_adversarial_ensemble_pgd() {
789        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
790        let y = vec![0, 1, 0, 1];
791
792        let classifier = AdversarialEnsembleClassifier::pgd_training()
793            .n_estimators(3)
794            .epsilon(0.05)
795            .adversarial_ratio(0.4)
796            .random_state(42);
797
798        let trained = classifier.fit(&X, &y).expect("Training should succeed");
799        let robustness = trained.robustness_metrics();
800
801        assert!(robustness.clean_accuracy > 0.0);
802        assert!(robustness.adversarial_accuracy > 0.0);
803    }
804
805    #[test]
806    #[allow(non_snake_case)]
807    fn test_diversity_maximization() {
808        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
809        let y = vec![0, 1, 0, 1];
810
811        let classifier = AdversarialEnsembleClassifier::diversity_maximization().random_state(42);
812
813        let trained = classifier.fit(&X, &y).expect("Training should succeed");
814        let diversity = trained
815            .diversity_score(&X)
816            .expect("Should calculate diversity");
817
818        assert!(diversity >= 0.0);
819        assert!(diversity <= 1.0);
820    }
821
822    #[test]
823    #[allow(non_snake_case)]
824    fn test_input_preprocessing() {
825        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
826        let y = vec![0, 1, 0, 1];
827
828        let preprocessing = InputPreprocessing::GaussianNoise { std_dev: 0.1 };
829        let classifier = AdversarialEnsembleClassifier::fgsm_training()
830            .input_preprocessing(preprocessing)
831            .random_state(42);
832
833        let trained = classifier.fit(&X, &y).expect("Training should succeed");
834        let results = trained.predict(&X).expect("Prediction should succeed");
835
836        assert_eq!(results.predictions.len(), 4);
837    }
838
839    #[test]
840    #[allow(non_snake_case)]
841    fn test_adversarial_detection() {
842        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
843        let y = vec![0, 1, 0, 1];
844
845        let config = AdversarialEnsembleConfig {
846            defensive_strategy: DefensiveStrategy::AdversarialDetection,
847            detection_threshold: Some(0.5),
848            random_state: Some(42),
849            ..Default::default()
850        };
851
852        let classifier = AdversarialEnsembleClassifier::new(config);
853        let trained = classifier.fit(&X, &y).expect("Training should succeed");
854        let (predictions, is_adversarial) = trained
855            .predict_with_detection(&X)
856            .expect("Detection should succeed");
857
858        assert_eq!(predictions.len(), 4);
859        assert_eq!(is_adversarial.len(), 4);
860    }
861
862    #[test]
863    #[allow(non_snake_case)]
864    fn test_robustness_evaluation() {
865        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
866        let y = vec![0, 1, 0, 1];
867
868        let classifier = AdversarialEnsembleClassifier::fgsm_training().random_state(42);
869
870        let trained = classifier.fit(&X, &y).expect("Training should succeed");
871        let robustness = trained
872            .evaluate_robustness(&X, &y, AttackMethod::FGSM)
873            .expect("Robustness evaluation should succeed");
874
875        assert!(robustness >= 0.0);
876        assert!(robustness <= 1.0);
877    }
878
879    #[test]
880    #[allow(non_snake_case)]
881    fn test_fgsm_example_generation() {
882        let X = array![[1.0, 2.0], [2.0, 3.0]];
883        let y = vec![0, 1];
884
885        let classifier: AdversarialEnsembleClassifier<Untrained> =
886            AdversarialEnsembleClassifier::fgsm_training()
887                .epsilon(0.1)
888                .random_state(42);
889
890        let adversarial_X = classifier
891            .generate_fgsm_examples(&X, &y)
892            .expect("FGSM generation should succeed");
893
894        assert_eq!(adversarial_X.shape(), X.shape());
895        // Check that perturbations were applied
896        assert_ne!(adversarial_X, X);
897    }
898}