sklears_model_selection/
worst_case_validation.rs

1//! Worst-Case Validation Scenarios
2//!
3//! This module provides worst-case scenario validation for robust model evaluation.
4//! It generates challenging validation scenarios to test model robustness and reliability
5//! under adverse conditions, including adversarial examples, distribution shifts, and
6//! extreme data conditions.
7
8use scirs2_core::ndarray::{Array1, Array2, Axis};
9use scirs2_core::random::rngs::StdRng;
10use scirs2_core::random::Rng;
11use scirs2_core::random::SeedableRng;
12use scirs2_core::SliceRandomExt;
13use sklears_core::types::Float;
14use std::collections::HashMap;
15
16/// Worst-case scenario types
17#[derive(Debug, Clone)]
18pub enum WorstCaseScenario {
19    /// Adversarial examples with maximum perturbation
20    AdversarialExamples {
21        epsilon: Float,
22
23        attack_method: AdversarialAttackMethod,
24
25        targeted: bool,
26    },
27    /// Distribution shift scenarios
28    DistributionShift {
29        shift_type: DistributionShiftType,
30
31        severity: Float,
32    },
33    /// Extreme outliers and anomalies
34    ExtremeOutliers {
35        outlier_fraction: Float,
36        outlier_magnitude: Float,
37    },
38    /// Class imbalance scenarios
39    ClassImbalance {
40        minority_fraction: Float,
41        imbalance_ratio: Float,
42    },
43    /// Feature corruption scenarios
44    FeatureCorruption {
45        corruption_rate: Float,
46        corruption_type: CorruptionType,
47    },
48    /// Temporal drift for time series
49    TemporalDrift {
50        drift_rate: Float,
51        drift_pattern: DriftPattern,
52    },
53    /// Label noise scenarios
54    LabelNoise {
55        noise_rate: Float,
56        noise_pattern: NoisePattern,
57    },
58    /// Missing data scenarios
59    MissingData {
60        missing_rate: Float,
61        missing_pattern: MissingPattern,
62    },
63}
64
65/// Adversarial attack methods
66#[derive(Debug, Clone)]
67pub enum AdversarialAttackMethod {
68    /// Fast Gradient Sign Method
69    FGSM,
70    /// Projected Gradient Descent
71    PGD { iterations: usize },
72    /// Basic Iterative Method
73    BIM { iterations: usize },
74    /// Carlini & Wagner attack
75    CW { confidence: Float },
76    /// Boundary Attack
77    BoundaryAttack { iterations: usize },
78    /// Random noise attack
79    RandomNoise,
80}
81
82/// Distribution shift types
83#[derive(Debug, Clone)]
84pub enum DistributionShiftType {
85    /// Covariate shift (input distribution changes)
86    CovariateShift,
87    /// Prior probability shift (class distribution changes)
88    PriorShift,
89    /// Concept drift (relationship between input and output changes)
90    ConceptDrift,
91    /// Domain shift (complete domain change)
92    DomainShift,
93}
94
95/// Corruption types for features
96#[derive(Debug, Clone)]
97pub enum CorruptionType {
98    /// Gaussian noise
99    GaussianNoise { std: Float },
100    /// Salt and pepper noise
101    SaltPepperNoise { ratio: Float },
102    /// Multiplicative noise
103    MultiplicativeNoise { factor: Float },
104    /// Feature masking
105    FeatureMasking,
106    /// Value quantization
107    Quantization { levels: usize },
108}
109
110/// Drift patterns for temporal data
111#[derive(Debug, Clone)]
112pub enum DriftPattern {
113    /// Gradual linear drift
114    Linear,
115    /// Sudden step change
116    Sudden,
117    /// Exponential drift
118    Exponential,
119    /// Seasonal drift
120    Seasonal { period: usize },
121    /// Random walk drift
122    RandomWalk,
123}
124
125/// Label noise patterns
126#[derive(Debug, Clone)]
127pub enum NoisePattern {
128    /// Uniform random label flipping
129    Uniform,
130    /// Class-conditional noise (some classes more affected)
131    ClassConditional { class_weights: Vec<Float> },
132    /// Systematic bias towards specific classes
133    SystematicBias { target_class: usize },
134}
135
136/// Missing data patterns
137#[derive(Debug, Clone)]
138pub enum MissingPattern {
139    /// Missing completely at random
140    MCAR,
141    /// Missing at random (depends on observed variables)
142    MAR,
143    /// Missing not at random (depends on unobserved variables)
144    MNAR,
145    /// Block missing (consecutive features)
146    BlockMissing { block_size: usize },
147}
148
149/// Worst-case validation configuration
150#[derive(Debug, Clone)]
151pub struct WorstCaseValidationConfig {
152    pub scenarios: Vec<WorstCaseScenario>,
153    pub n_worst_case_samples: usize,
154    pub evaluation_metric: String,
155    pub confidence_level: Float,
156    pub random_state: Option<u64>,
157    pub severity_levels: Vec<Float>,
158}
159
160/// Worst-case validation result
161#[derive(Debug, Clone)]
162pub struct WorstCaseValidationResult {
163    pub scenario_results: HashMap<String, ScenarioResult>,
164    pub overall_worst_case_score: Float,
165    pub robustness_score: Float,
166    pub failure_rate: Float,
167    pub performance_degradation: Float,
168    pub confidence_intervals: HashMap<String, (Float, Float)>,
169}
170
171/// Individual scenario result
172#[derive(Debug, Clone)]
173pub struct ScenarioResult {
174    pub scenario_name: String,
175    pub worst_case_score: Float,
176    pub baseline_score: Float,
177    pub performance_drop: Float,
178    pub failure_examples: Vec<usize>,
179    pub robustness_metrics: RobustnessMetrics,
180}
181
182/// Robustness metrics for validation
183#[derive(Debug, Clone)]
184pub struct RobustnessMetrics {
185    pub stability_score: Float,
186    pub consistency_score: Float,
187    pub resilience_score: Float,
188    pub recovery_score: Float,
189    pub breakdown_point: Float,
190}
191
192/// Worst-case scenario generator
193#[derive(Debug, Clone)]
194pub struct WorstCaseScenarioGenerator {
195    config: WorstCaseValidationConfig,
196    rng: StdRng,
197}
198
199/// Worst-case validator
200#[derive(Debug)]
201pub struct WorstCaseValidator {
202    generator: WorstCaseScenarioGenerator,
203}
204
205impl Default for WorstCaseValidationConfig {
206    fn default() -> Self {
207        Self {
208            scenarios: vec![
209                WorstCaseScenario::AdversarialExamples {
210                    epsilon: 0.1,
211                    attack_method: AdversarialAttackMethod::FGSM,
212                    targeted: false,
213                },
214                WorstCaseScenario::DistributionShift {
215                    shift_type: DistributionShiftType::CovariateShift,
216                    severity: 1.0,
217                },
218                WorstCaseScenario::ExtremeOutliers {
219                    outlier_fraction: 0.1,
220                    outlier_magnitude: 3.0,
221                },
222            ],
223            n_worst_case_samples: 1000,
224            evaluation_metric: "accuracy".to_string(),
225            confidence_level: 0.95,
226            random_state: None,
227            severity_levels: vec![0.5, 1.0, 1.5, 2.0],
228        }
229    }
230}
231
232impl WorstCaseScenarioGenerator {
233    /// Create a new worst-case scenario generator
234    pub fn new(config: WorstCaseValidationConfig) -> Self {
235        let rng = match config.random_state {
236            Some(seed) => StdRng::seed_from_u64(seed),
237            None => {
238                use scirs2_core::random::thread_rng;
239                StdRng::from_rng(&mut thread_rng())
240            }
241        };
242
243        Self { config, rng }
244    }
245
246    /// Generate worst-case scenarios for given data
247    pub fn generate_scenarios(
248        &mut self,
249        x: &Array2<Float>,
250        y: &Array1<Float>,
251    ) -> Result<Vec<(Array2<Float>, Array1<Float>, String)>, Box<dyn std::error::Error>> {
252        let mut scenarios = Vec::new();
253
254        let scenarios_clone = self.config.scenarios.clone();
255        let severity_levels_clone = self.config.severity_levels.clone();
256
257        for scenario in &scenarios_clone {
258            for &severity in &severity_levels_clone {
259                let (worst_x, worst_y, name) =
260                    self.generate_single_scenario(x, y, scenario, severity)?;
261                scenarios.push((worst_x, worst_y, name));
262            }
263        }
264
265        Ok(scenarios)
266    }
267
268    /// Generate a single worst-case scenario
269    fn generate_single_scenario(
270        &mut self,
271        x: &Array2<Float>,
272        y: &Array1<Float>,
273        scenario: &WorstCaseScenario,
274        severity: Float,
275    ) -> Result<(Array2<Float>, Array1<Float>, String), Box<dyn std::error::Error>> {
276        match scenario {
277            WorstCaseScenario::AdversarialExamples {
278                epsilon,
279                attack_method,
280                targeted,
281            } => {
282                let (adv_x, adv_y) = self.generate_adversarial_examples(
283                    x,
284                    y,
285                    *epsilon * severity,
286                    attack_method,
287                    *targeted,
288                )?;
289                let name = format!(
290                    "Adversarial_{:?}_eps_{:.3}",
291                    attack_method,
292                    epsilon * severity
293                );
294                Ok((adv_x, adv_y, name))
295            }
296            WorstCaseScenario::DistributionShift {
297                shift_type,
298                severity: base_severity,
299            } => {
300                let (shift_x, shift_y) =
301                    self.generate_distribution_shift(x, y, shift_type, base_severity * severity)?;
302                let name = format!(
303                    "DistShift_{:?}_sev_{:.2}",
304                    shift_type,
305                    base_severity * severity
306                );
307                Ok((shift_x, shift_y, name))
308            }
309            WorstCaseScenario::ExtremeOutliers {
310                outlier_fraction,
311                outlier_magnitude,
312            } => {
313                let (outlier_x, outlier_y) = self.generate_extreme_outliers(
314                    x,
315                    y,
316                    *outlier_fraction,
317                    outlier_magnitude * severity,
318                )?;
319                let name = format!(
320                    "Outliers_frac_{:.2}_mag_{:.2}",
321                    outlier_fraction,
322                    outlier_magnitude * severity
323                );
324                Ok((outlier_x, outlier_y, name))
325            }
326            WorstCaseScenario::ClassImbalance {
327                minority_fraction,
328                imbalance_ratio,
329            } => {
330                let (imbal_x, imbal_y) = self.generate_class_imbalance(
331                    x,
332                    y,
333                    *minority_fraction,
334                    imbalance_ratio * severity,
335                )?;
336                let name = format!(
337                    "ClassImbalance_frac_{:.2}_ratio_{:.2}",
338                    minority_fraction,
339                    imbalance_ratio * severity
340                );
341                Ok((imbal_x, imbal_y, name))
342            }
343            WorstCaseScenario::FeatureCorruption {
344                corruption_rate,
345                corruption_type,
346            } => {
347                let (corr_x, corr_y) = self.generate_feature_corruption(
348                    x,
349                    y,
350                    corruption_rate * severity,
351                    corruption_type,
352                )?;
353                let name = format!(
354                    "Corruption_{:?}_rate_{:.2}",
355                    corruption_type,
356                    corruption_rate * severity
357                );
358                Ok((corr_x, corr_y, name))
359            }
360            WorstCaseScenario::TemporalDrift {
361                drift_rate,
362                drift_pattern,
363            } => {
364                let (drift_x, drift_y) =
365                    self.generate_temporal_drift(x, y, drift_rate * severity, drift_pattern)?;
366                let name = format!(
367                    "TemporalDrift_{:?}_rate_{:.2}",
368                    drift_pattern,
369                    drift_rate * severity
370                );
371                Ok((drift_x, drift_y, name))
372            }
373            WorstCaseScenario::LabelNoise {
374                noise_rate,
375                noise_pattern,
376            } => {
377                let (noise_x, noise_y) =
378                    self.generate_label_noise(x, y, noise_rate * severity, noise_pattern)?;
379                let name = format!(
380                    "LabelNoise_{:?}_rate_{:.2}",
381                    noise_pattern,
382                    noise_rate * severity
383                );
384                Ok((noise_x, noise_y, name))
385            }
386            WorstCaseScenario::MissingData {
387                missing_rate,
388                missing_pattern,
389            } => {
390                let (missing_x, missing_y) =
391                    self.generate_missing_data(x, y, missing_rate * severity, missing_pattern)?;
392                let name = format!(
393                    "MissingData_{:?}_rate_{:.2}",
394                    missing_pattern,
395                    missing_rate * severity
396                );
397                Ok((missing_x, missing_y, name))
398            }
399        }
400    }
401
402    /// Generate adversarial examples
403    fn generate_adversarial_examples(
404        &mut self,
405        x: &Array2<Float>,
406        y: &Array1<Float>,
407        epsilon: Float,
408        attack_method: &AdversarialAttackMethod,
409        _targeted: bool,
410    ) -> Result<(Array2<Float>, Array1<Float>), Box<dyn std::error::Error>> {
411        let mut adv_x = x.clone();
412
413        match attack_method {
414            AdversarialAttackMethod::FGSM => {
415                // Fast Gradient Sign Method
416                for mut row in adv_x.axis_iter_mut(Axis(0)) {
417                    for val in row.iter_mut() {
418                        let perturbation = if self.rng.gen_bool(0.5) {
419                            epsilon
420                        } else {
421                            -epsilon
422                        };
423                        *val += perturbation;
424                    }
425                }
426            }
427            AdversarialAttackMethod::PGD { iterations } => {
428                // Projected Gradient Descent
429                for _ in 0..*iterations {
430                    for mut row in adv_x.axis_iter_mut(Axis(0)) {
431                        for val in row.iter_mut() {
432                            let step_size = epsilon / (*iterations as Float);
433                            let perturbation = if self.rng.gen_bool(0.5) {
434                                step_size
435                            } else {
436                                -step_size
437                            };
438                            *val += perturbation;
439                            // Project back to epsilon ball
440                            *val = val.max(-epsilon).min(epsilon);
441                        }
442                    }
443                }
444            }
445            AdversarialAttackMethod::BIM { iterations } => {
446                // Basic Iterative Method
447                let alpha = epsilon / (*iterations as Float);
448                for _ in 0..*iterations {
449                    for mut row in adv_x.axis_iter_mut(Axis(0)) {
450                        for val in row.iter_mut() {
451                            let perturbation = if self.rng.gen_bool(0.5) {
452                                alpha
453                            } else {
454                                -alpha
455                            };
456                            *val += perturbation;
457                        }
458                    }
459                }
460            }
461            AdversarialAttackMethod::RandomNoise => {
462                // Random noise attack
463                for mut row in adv_x.axis_iter_mut(Axis(0)) {
464                    for val in row.iter_mut() {
465                        let noise = self.rng.gen_range(-epsilon..epsilon + 1.0);
466                        *val += noise;
467                    }
468                }
469            }
470            AdversarialAttackMethod::CW { .. } | AdversarialAttackMethod::BoundaryAttack { .. } => {
471                // Simplified implementation for C&W and Boundary Attack
472                for mut row in adv_x.axis_iter_mut(Axis(0)) {
473                    for val in row.iter_mut() {
474                        let perturbation = self.rng.gen_range(-epsilon..epsilon + 1.0);
475                        *val += perturbation;
476                    }
477                }
478            }
479        }
480
481        Ok((adv_x, y.clone()))
482    }
483
484    /// Generate distribution shift scenarios
485    fn generate_distribution_shift(
486        &mut self,
487        x: &Array2<Float>,
488        y: &Array1<Float>,
489        shift_type: &DistributionShiftType,
490        severity: Float,
491    ) -> Result<(Array2<Float>, Array1<Float>), Box<dyn std::error::Error>> {
492        let mut shift_x = x.clone();
493        let mut shift_y = y.clone();
494
495        match shift_type {
496            DistributionShiftType::CovariateShift => {
497                // Add systematic bias to features
498                for mut row in shift_x.axis_iter_mut(Axis(0)) {
499                    for (i, val) in row.iter_mut().enumerate() {
500                        let shift = severity * (i as Float * 0.1).sin();
501                        *val += shift;
502                    }
503                }
504            }
505            DistributionShiftType::PriorShift => {
506                // Change class distribution by removing samples from certain classes
507                let mut unique_classes: Vec<Float> = y.iter().cloned().collect();
508                unique_classes.sort_by(|a, b| a.partial_cmp(b).unwrap());
509                unique_classes.dedup();
510                if unique_classes.len() > 1 {
511                    let target_class = unique_classes[0];
512                    let removal_prob = severity * 0.5;
513
514                    let mut keep_indices = Vec::new();
515                    for (i, &class) in y.iter().enumerate() {
516                        if class != target_class || self.rng.random::<Float>() > removal_prob {
517                            keep_indices.push(i);
518                        }
519                    }
520
521                    // Create new arrays with selected indices
522                    let mut new_x_data = Vec::new();
523                    for &i in keep_indices.iter() {
524                        new_x_data.extend(x.row(i).iter().cloned());
525                    }
526                    let new_x =
527                        Array2::from_shape_vec((keep_indices.len(), x.ncols()), new_x_data)?;
528                    let new_y = Array1::from_vec(keep_indices.iter().map(|&i| y[i]).collect());
529
530                    return Ok((new_x, new_y));
531                }
532            }
533            DistributionShiftType::ConceptDrift => {
534                // Change the relationship between features and labels
535                for label in shift_y.iter_mut() {
536                    if self.rng.random::<Float>() < severity * 0.2 {
537                        // Flip some labels to simulate concept drift
538                        *label = 1.0 - *label;
539                    }
540                }
541            }
542            DistributionShiftType::DomainShift => {
543                // Apply domain transformation
544                for mut row in shift_x.axis_iter_mut(Axis(0)) {
545                    for val in row.iter_mut() {
546                        // Apply non-linear transformation
547                        *val = val.tanh() * severity;
548                    }
549                }
550            }
551        }
552
553        Ok((shift_x, shift_y))
554    }
555
556    /// Generate extreme outliers
557    fn generate_extreme_outliers(
558        &mut self,
559        x: &Array2<Float>,
560        y: &Array1<Float>,
561        outlier_fraction: Float,
562        outlier_magnitude: Float,
563    ) -> Result<(Array2<Float>, Array1<Float>), Box<dyn std::error::Error>> {
564        let mut outlier_x = x.clone();
565        let n_outliers = (x.nrows() as Float * outlier_fraction) as usize;
566
567        let mut outlier_indices: Vec<usize> = (0..x.nrows()).collect();
568        outlier_indices.shuffle(&mut self.rng);
569        outlier_indices.truncate(n_outliers);
570
571        for &idx in &outlier_indices {
572            for val in outlier_x.row_mut(idx) {
573                let outlier_value = self
574                    .rng
575                    .gen_range(-outlier_magnitude..outlier_magnitude + 1.0);
576                *val += outlier_value;
577            }
578        }
579
580        Ok((outlier_x, y.clone()))
581    }
582
583    /// Generate class imbalance scenarios
584    fn generate_class_imbalance(
585        &mut self,
586        x: &Array2<Float>,
587        y: &Array1<Float>,
588        minority_fraction: Float,
589        _imbalance_ratio: Float,
590    ) -> Result<(Array2<Float>, Array1<Float>), Box<dyn std::error::Error>> {
591        let mut unique_classes: Vec<Float> = y.iter().cloned().collect();
592        unique_classes.sort_by(|a, b| a.partial_cmp(b).unwrap());
593        unique_classes.dedup();
594        if unique_classes.len() < 2 {
595            return Ok((x.clone(), y.clone()));
596        }
597
598        let minority_class = unique_classes[0];
599        let target_minority_count = (x.nrows() as Float * minority_fraction) as usize;
600
601        let mut keep_indices = Vec::new();
602        let mut minority_count = 0;
603
604        for (i, &class) in y.iter().enumerate() {
605            if class == minority_class {
606                if minority_count < target_minority_count {
607                    keep_indices.push(i);
608                    minority_count += 1;
609                }
610            } else {
611                keep_indices.push(i);
612            }
613        }
614
615        let mut new_x_data = Vec::new();
616        for &i in keep_indices.iter() {
617            new_x_data.extend(x.row(i).iter().cloned());
618        }
619        let new_x = Array2::from_shape_vec((keep_indices.len(), x.ncols()), new_x_data)?;
620        let new_y = Array1::from_vec(keep_indices.iter().map(|&i| y[i]).collect());
621
622        Ok((new_x, new_y))
623    }
624
625    /// Generate feature corruption
626    fn generate_feature_corruption(
627        &mut self,
628        x: &Array2<Float>,
629        y: &Array1<Float>,
630        corruption_rate: Float,
631        corruption_type: &CorruptionType,
632    ) -> Result<(Array2<Float>, Array1<Float>), Box<dyn std::error::Error>> {
633        let mut corrupted_x = x.clone();
634
635        match corruption_type {
636            CorruptionType::GaussianNoise { std } => {
637                for val in corrupted_x.iter_mut() {
638                    if self.rng.random::<Float>() < corruption_rate {
639                        let noise = self.rng.random::<Float>() * std;
640                        *val += noise;
641                    }
642                }
643            }
644            CorruptionType::SaltPepperNoise { ratio } => {
645                for val in corrupted_x.iter_mut() {
646                    if self.rng.random::<Float>() < corruption_rate {
647                        *val = if self.rng.random::<Float>() < *ratio {
648                            1.0
649                        } else {
650                            0.0
651                        };
652                    }
653                }
654            }
655            CorruptionType::MultiplicativeNoise { factor } => {
656                for val in corrupted_x.iter_mut() {
657                    if self.rng.random::<Float>() < corruption_rate {
658                        let noise = 1.0 + (self.rng.random::<Float>() - 0.5) * factor;
659                        *val *= noise;
660                    }
661                }
662            }
663            CorruptionType::FeatureMasking => {
664                for val in corrupted_x.iter_mut() {
665                    if self.rng.random::<Float>() < corruption_rate {
666                        *val = 0.0;
667                    }
668                }
669            }
670            CorruptionType::Quantization { levels } => {
671                let step_size = 2.0 / (*levels as Float);
672                for val in corrupted_x.iter_mut() {
673                    if self.rng.random::<Float>() < corruption_rate {
674                        *val = ((*val / step_size).round() * step_size).clamp(-1.0, 1.0);
675                    }
676                }
677            }
678        }
679
680        Ok((corrupted_x, y.clone()))
681    }
682
683    /// Generate temporal drift
684    fn generate_temporal_drift(
685        &mut self,
686        x: &Array2<Float>,
687        y: &Array1<Float>,
688        drift_rate: Float,
689        drift_pattern: &DriftPattern,
690    ) -> Result<(Array2<Float>, Array1<Float>), Box<dyn std::error::Error>> {
691        let mut drift_x = x.clone();
692        let n_samples = x.nrows();
693
694        for (t, row) in drift_x.axis_iter_mut(Axis(0)).enumerate() {
695            let time_factor = t as Float / n_samples as Float;
696
697            let drift_magnitude = match drift_pattern {
698                DriftPattern::Linear => drift_rate * time_factor,
699                DriftPattern::Sudden => {
700                    if time_factor > 0.5 {
701                        drift_rate
702                    } else {
703                        0.0
704                    }
705                }
706                DriftPattern::Exponential => drift_rate * time_factor.exp(),
707                DriftPattern::Seasonal { period } => {
708                    drift_rate
709                        * (2.0 * std::f64::consts::PI * t as Float / *period as Float).sin()
710                            as Float
711                }
712                DriftPattern::RandomWalk => drift_rate * self.rng.random::<Float>(),
713            };
714
715            for val in row {
716                *val += drift_magnitude;
717            }
718        }
719
720        Ok((drift_x, y.clone()))
721    }
722
723    /// Generate label noise
724    fn generate_label_noise(
725        &mut self,
726        x: &Array2<Float>,
727        y: &Array1<Float>,
728        noise_rate: Float,
729        noise_pattern: &NoisePattern,
730    ) -> Result<(Array2<Float>, Array1<Float>), Box<dyn std::error::Error>> {
731        let mut noisy_y = y.clone();
732        let mut unique_classes: Vec<Float> = y.iter().cloned().collect();
733        unique_classes.sort_by(|a, b| a.partial_cmp(b).unwrap());
734        unique_classes.dedup();
735
736        if unique_classes.len() < 2 {
737            return Ok((x.clone(), noisy_y));
738        }
739
740        match noise_pattern {
741            NoisePattern::Uniform => {
742                for label in noisy_y.iter_mut() {
743                    if self.rng.random::<Float>() < noise_rate {
744                        // Flip to random other class
745                        let other_classes: Vec<Float> = unique_classes
746                            .iter()
747                            .filter(|&&c| c != *label)
748                            .cloned()
749                            .collect();
750                        if !other_classes.is_empty() {
751                            *label = other_classes[self.rng.gen_range(0..other_classes.len())];
752                        }
753                    }
754                }
755            }
756            NoisePattern::ClassConditional { class_weights } => {
757                for label in noisy_y.iter_mut() {
758                    let class_idx = unique_classes
759                        .iter()
760                        .position(|&c| c == *label)
761                        .unwrap_or(0);
762                    let class_noise_rate = if class_idx < class_weights.len() {
763                        noise_rate * class_weights[class_idx]
764                    } else {
765                        noise_rate
766                    };
767
768                    if self.rng.random::<Float>() < class_noise_rate {
769                        let other_classes: Vec<Float> = unique_classes
770                            .iter()
771                            .filter(|&&c| c != *label)
772                            .cloned()
773                            .collect();
774                        if !other_classes.is_empty() {
775                            *label = other_classes[self.rng.gen_range(0..other_classes.len())];
776                        }
777                    }
778                }
779            }
780            NoisePattern::SystematicBias { target_class } => {
781                let target_class_value = if *target_class < unique_classes.len() {
782                    unique_classes[*target_class]
783                } else {
784                    unique_classes[0]
785                };
786
787                for label in noisy_y.iter_mut() {
788                    if self.rng.random::<Float>() < noise_rate {
789                        *label = target_class_value;
790                    }
791                }
792            }
793        }
794
795        Ok((x.clone(), noisy_y))
796    }
797
798    /// Generate missing data scenarios
799    fn generate_missing_data(
800        &mut self,
801        x: &Array2<Float>,
802        y: &Array1<Float>,
803        missing_rate: Float,
804        missing_pattern: &MissingPattern,
805    ) -> Result<(Array2<Float>, Array1<Float>), Box<dyn std::error::Error>> {
806        let mut missing_x = x.clone();
807
808        match missing_pattern {
809            MissingPattern::MCAR => {
810                // Missing completely at random
811                for val in missing_x.iter_mut() {
812                    if self.rng.random::<Float>() < missing_rate {
813                        *val = Float::NAN;
814                    }
815                }
816            }
817            MissingPattern::MAR => {
818                // Missing at random (depends on other features)
819                for row in missing_x.axis_iter_mut(Axis(0)) {
820                    let row_mean =
821                        row.iter().filter(|v| v.is_finite()).sum::<Float>() / row.len() as Float;
822                    let missing_prob = if row_mean > 0.0 {
823                        missing_rate * 1.5
824                    } else {
825                        missing_rate * 0.5
826                    };
827
828                    for val in row {
829                        if self.rng.random::<Float>() < missing_prob {
830                            *val = Float::NAN;
831                        }
832                    }
833                }
834            }
835            MissingPattern::MNAR => {
836                // Missing not at random (depends on the value itself)
837                for val in missing_x.iter_mut() {
838                    let missing_prob = if *val > 0.5 {
839                        missing_rate * 2.0
840                    } else {
841                        missing_rate * 0.5
842                    };
843                    if self.rng.random::<Float>() < missing_prob {
844                        *val = Float::NAN;
845                    }
846                }
847            }
848            MissingPattern::BlockMissing { block_size } => {
849                // Block missing (consecutive features)
850                let n_cols = missing_x.ncols();
851                let n_blocks = (missing_rate * n_cols as Float) as usize / block_size;
852
853                for _ in 0..n_blocks {
854                    let start_col = self.rng.gen_range(0..n_cols.saturating_sub(*block_size));
855                    let end_col = (start_col + block_size).min(n_cols);
856
857                    for mut row in missing_x.axis_iter_mut(Axis(0)) {
858                        for j in start_col..end_col {
859                            row[j] = Float::NAN;
860                        }
861                    }
862                }
863            }
864        }
865
866        Ok((missing_x, y.clone()))
867    }
868}
869
870impl WorstCaseValidator {
871    /// Create a new worst-case validator
872    pub fn new(config: WorstCaseValidationConfig) -> Self {
873        let generator = WorstCaseScenarioGenerator::new(config);
874        Self { generator }
875    }
876
877    /// Validate model robustness under worst-case scenarios
878    pub fn validate<F>(
879        &mut self,
880        x: &Array2<Float>,
881        y: &Array1<Float>,
882        model_fn: F,
883    ) -> Result<WorstCaseValidationResult, Box<dyn std::error::Error>>
884    where
885        F: Fn(&Array2<Float>, &Array1<Float>) -> Result<Float, Box<dyn std::error::Error>>,
886    {
887        // Baseline performance
888        let baseline_score = model_fn(x, y)?;
889
890        // Generate worst-case scenarios
891        let scenarios = self.generator.generate_scenarios(x, y)?;
892
893        let mut scenario_results = HashMap::new();
894        let mut all_scores = Vec::new();
895        let mut failure_count = 0;
896
897        for (scenario_x, scenario_y, scenario_name) in scenarios {
898            let scenario_score = model_fn(&scenario_x, &scenario_y).unwrap_or(0.0);
899            all_scores.push(scenario_score);
900
901            let performance_drop = (baseline_score - scenario_score) / baseline_score;
902
903            // Check for failure (significant performance drop)
904            if performance_drop > 0.5 {
905                failure_count += 1;
906            }
907
908            let robustness_metrics =
909                self.calculate_robustness_metrics(baseline_score, scenario_score, &scenario_x, x);
910
911            let result = ScenarioResult {
912                scenario_name: scenario_name.clone(),
913                worst_case_score: scenario_score,
914                baseline_score,
915                performance_drop,
916                failure_examples: vec![], // Could be populated with specific failure indices
917                robustness_metrics,
918            };
919
920            scenario_results.insert(scenario_name, result);
921        }
922
923        let overall_worst_case_score = all_scores.iter().fold(Float::INFINITY, |a, &b| a.min(b));
924
925        let performance_degradation = (baseline_score - overall_worst_case_score) / baseline_score;
926        let failure_rate = failure_count as Float / all_scores.len() as Float;
927        let robustness_score = 1.0 - performance_degradation;
928
929        // Calculate confidence intervals (simplified)
930        let mut confidence_intervals = HashMap::new();
931        for (scenario_name, result) in &scenario_results {
932            let ci_lower = result.worst_case_score * 0.9;
933            let ci_upper = result.worst_case_score * 1.1;
934            confidence_intervals.insert(scenario_name.clone(), (ci_lower, ci_upper));
935        }
936
937        Ok(WorstCaseValidationResult {
938            scenario_results,
939            overall_worst_case_score,
940            robustness_score,
941            failure_rate,
942            performance_degradation,
943            confidence_intervals,
944        })
945    }
946
947    /// Calculate robustness metrics
948    fn calculate_robustness_metrics(
949        &self,
950        baseline_score: Float,
951        scenario_score: Float,
952        scenario_x: &Array2<Float>,
953        original_x: &Array2<Float>,
954    ) -> RobustnessMetrics {
955        let stability_score = (scenario_score / baseline_score).min(1.0);
956
957        // Calculate data similarity for consistency score
958        let data_similarity = self.calculate_data_similarity(scenario_x, original_x);
959        let consistency_score = stability_score * data_similarity;
960
961        let resilience_score = if scenario_score > baseline_score * 0.7 {
962            1.0
963        } else {
964            0.0
965        };
966        let recovery_score = stability_score; // Simplified
967        let breakdown_point = 1.0 - stability_score;
968
969        RobustnessMetrics {
970            stability_score,
971            consistency_score,
972            resilience_score,
973            recovery_score,
974            breakdown_point,
975        }
976    }
977
978    /// Calculate similarity between datasets
979    fn calculate_data_similarity(&self, x1: &Array2<Float>, x2: &Array2<Float>) -> Float {
980        if x1.dim() != x2.dim() {
981            return 0.0;
982        }
983
984        let mut similarity_sum = 0.0;
985        let mut count = 0;
986
987        for (row1, row2) in x1.axis_iter(Axis(0)).zip(x2.axis_iter(Axis(0))) {
988            let mut row_similarity = 0.0;
989            let mut valid_features = 0;
990
991            for (&val1, &val2) in row1.iter().zip(row2.iter()) {
992                if val1.is_finite() && val2.is_finite() {
993                    row_similarity += 1.0 - (val1 - val2).abs();
994                    valid_features += 1;
995                }
996            }
997
998            if valid_features > 0 {
999                similarity_sum += row_similarity / valid_features as Float;
1000                count += 1;
1001            }
1002        }
1003
1004        if count > 0 {
1005            similarity_sum / count as Float
1006        } else {
1007            0.0
1008        }
1009    }
1010}
1011
1012/// Convenience function for worst-case validation
1013pub fn worst_case_validate<F>(
1014    x: &Array2<Float>,
1015    y: &Array1<Float>,
1016    model_fn: F,
1017    config: Option<WorstCaseValidationConfig>,
1018) -> Result<WorstCaseValidationResult, Box<dyn std::error::Error>>
1019where
1020    F: Fn(&Array2<Float>, &Array1<Float>) -> Result<Float, Box<dyn std::error::Error>>,
1021{
1022    let config = config.unwrap_or_default();
1023    let mut validator = WorstCaseValidator::new(config);
1024    validator.validate(x, y, model_fn)
1025}
1026
1027#[allow(non_snake_case)]
1028#[cfg(test)]
1029mod tests {
1030    use super::*;
1031
1032    #[test]
1033    fn test_worst_case_scenario_generator() {
1034        let config = WorstCaseValidationConfig::default();
1035        let mut generator = WorstCaseScenarioGenerator::new(config);
1036
1037        let x = Array2::from_shape_vec((10, 3), (0..30).map(|i| i as Float).collect()).unwrap();
1038        let y = Array1::from_vec(vec![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0]);
1039
1040        let scenarios = generator.generate_scenarios(&x, &y).unwrap();
1041        assert!(!scenarios.is_empty());
1042    }
1043
1044    #[test]
1045    fn test_adversarial_example_generation() {
1046        let config = WorstCaseValidationConfig::default();
1047        let mut generator = WorstCaseScenarioGenerator::new(config);
1048
1049        let x = Array2::from_shape_vec((5, 3), (0..15).map(|i| i as Float).collect()).unwrap();
1050        let y = Array1::from_vec(vec![0.0, 1.0, 0.0, 1.0, 0.0]);
1051
1052        let (adv_x, adv_y) = generator
1053            .generate_adversarial_examples(&x, &y, 0.1, &AdversarialAttackMethod::FGSM, false)
1054            .unwrap();
1055
1056        assert_eq!(adv_x.dim(), x.dim());
1057        assert_eq!(adv_y.len(), y.len());
1058    }
1059
1060    #[test]
1061    fn test_worst_case_validation() {
1062        let config = WorstCaseValidationConfig {
1063            scenarios: vec![WorstCaseScenario::ExtremeOutliers {
1064                outlier_fraction: 0.1,
1065                outlier_magnitude: 2.0,
1066            }],
1067            n_worst_case_samples: 100,
1068            severity_levels: vec![1.0],
1069            ..Default::default()
1070        };
1071
1072        let x =
1073            Array2::from_shape_vec((10, 3), (0..30).map(|i| i as Float * 0.1).collect()).unwrap();
1074        let y = Array1::from_vec(vec![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0]);
1075
1076        let model_fn =
1077            |_x: &Array2<Float>, _y: &Array1<Float>| -> Result<Float, Box<dyn std::error::Error>> {
1078                Ok(0.8) // Mock accuracy
1079            };
1080
1081        let result = worst_case_validate(&x, &y, model_fn, Some(config)).unwrap();
1082
1083        assert!(result.robustness_score >= 0.0);
1084        assert!(result.robustness_score <= 1.0);
1085        assert!(!result.scenario_results.is_empty());
1086    }
1087
1088    #[test]
1089    fn test_label_noise_generation() {
1090        let config = WorstCaseValidationConfig::default();
1091        let mut generator = WorstCaseScenarioGenerator::new(config);
1092
1093        let x = Array2::from_shape_vec((10, 3), (0..30).map(|i| i as Float).collect()).unwrap();
1094        let y = Array1::from_vec(vec![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0]);
1095
1096        let (noisy_x, noisy_y) = generator
1097            .generate_label_noise(&x, &y, 0.2, &NoisePattern::Uniform)
1098            .unwrap();
1099
1100        assert_eq!(noisy_x.dim(), x.dim());
1101        assert_eq!(noisy_y.len(), y.len());
1102    }
1103}