1use scirs2_core::ndarray::{Array1, Array2, Axis};
9use scirs2_core::random::rngs::StdRng;
10use scirs2_core::random::Rng;
11use scirs2_core::random::SeedableRng;
12use scirs2_core::SliceRandomExt;
13use sklears_core::types::Float;
14use std::collections::HashMap;
15
16#[derive(Debug, Clone)]
18pub enum WorstCaseScenario {
19 AdversarialExamples {
21 epsilon: Float,
22
23 attack_method: AdversarialAttackMethod,
24
25 targeted: bool,
26 },
27 DistributionShift {
29 shift_type: DistributionShiftType,
30
31 severity: Float,
32 },
33 ExtremeOutliers {
35 outlier_fraction: Float,
36 outlier_magnitude: Float,
37 },
38 ClassImbalance {
40 minority_fraction: Float,
41 imbalance_ratio: Float,
42 },
43 FeatureCorruption {
45 corruption_rate: Float,
46 corruption_type: CorruptionType,
47 },
48 TemporalDrift {
50 drift_rate: Float,
51 drift_pattern: DriftPattern,
52 },
53 LabelNoise {
55 noise_rate: Float,
56 noise_pattern: NoisePattern,
57 },
58 MissingData {
60 missing_rate: Float,
61 missing_pattern: MissingPattern,
62 },
63}
64
65#[derive(Debug, Clone)]
67pub enum AdversarialAttackMethod {
68 FGSM,
70 PGD { iterations: usize },
72 BIM { iterations: usize },
74 CW { confidence: Float },
76 BoundaryAttack { iterations: usize },
78 RandomNoise,
80}
81
82#[derive(Debug, Clone)]
84pub enum DistributionShiftType {
85 CovariateShift,
87 PriorShift,
89 ConceptDrift,
91 DomainShift,
93}
94
95#[derive(Debug, Clone)]
97pub enum CorruptionType {
98 GaussianNoise { std: Float },
100 SaltPepperNoise { ratio: Float },
102 MultiplicativeNoise { factor: Float },
104 FeatureMasking,
106 Quantization { levels: usize },
108}
109
110#[derive(Debug, Clone)]
112pub enum DriftPattern {
113 Linear,
115 Sudden,
117 Exponential,
119 Seasonal { period: usize },
121 RandomWalk,
123}
124
125#[derive(Debug, Clone)]
127pub enum NoisePattern {
128 Uniform,
130 ClassConditional { class_weights: Vec<Float> },
132 SystematicBias { target_class: usize },
134}
135
136#[derive(Debug, Clone)]
138pub enum MissingPattern {
139 MCAR,
141 MAR,
143 MNAR,
145 BlockMissing { block_size: usize },
147}
148
149#[derive(Debug, Clone)]
151pub struct WorstCaseValidationConfig {
152 pub scenarios: Vec<WorstCaseScenario>,
153 pub n_worst_case_samples: usize,
154 pub evaluation_metric: String,
155 pub confidence_level: Float,
156 pub random_state: Option<u64>,
157 pub severity_levels: Vec<Float>,
158}
159
160#[derive(Debug, Clone)]
162pub struct WorstCaseValidationResult {
163 pub scenario_results: HashMap<String, ScenarioResult>,
164 pub overall_worst_case_score: Float,
165 pub robustness_score: Float,
166 pub failure_rate: Float,
167 pub performance_degradation: Float,
168 pub confidence_intervals: HashMap<String, (Float, Float)>,
169}
170
171#[derive(Debug, Clone)]
173pub struct ScenarioResult {
174 pub scenario_name: String,
175 pub worst_case_score: Float,
176 pub baseline_score: Float,
177 pub performance_drop: Float,
178 pub failure_examples: Vec<usize>,
179 pub robustness_metrics: RobustnessMetrics,
180}
181
182#[derive(Debug, Clone)]
184pub struct RobustnessMetrics {
185 pub stability_score: Float,
186 pub consistency_score: Float,
187 pub resilience_score: Float,
188 pub recovery_score: Float,
189 pub breakdown_point: Float,
190}
191
192#[derive(Debug, Clone)]
194pub struct WorstCaseScenarioGenerator {
195 config: WorstCaseValidationConfig,
196 rng: StdRng,
197}
198
199#[derive(Debug)]
201pub struct WorstCaseValidator {
202 generator: WorstCaseScenarioGenerator,
203}
204
205impl Default for WorstCaseValidationConfig {
206 fn default() -> Self {
207 Self {
208 scenarios: vec![
209 WorstCaseScenario::AdversarialExamples {
210 epsilon: 0.1,
211 attack_method: AdversarialAttackMethod::FGSM,
212 targeted: false,
213 },
214 WorstCaseScenario::DistributionShift {
215 shift_type: DistributionShiftType::CovariateShift,
216 severity: 1.0,
217 },
218 WorstCaseScenario::ExtremeOutliers {
219 outlier_fraction: 0.1,
220 outlier_magnitude: 3.0,
221 },
222 ],
223 n_worst_case_samples: 1000,
224 evaluation_metric: "accuracy".to_string(),
225 confidence_level: 0.95,
226 random_state: None,
227 severity_levels: vec![0.5, 1.0, 1.5, 2.0],
228 }
229 }
230}
231
232impl WorstCaseScenarioGenerator {
233 pub fn new(config: WorstCaseValidationConfig) -> Self {
235 let rng = match config.random_state {
236 Some(seed) => StdRng::seed_from_u64(seed),
237 None => {
238 use scirs2_core::random::thread_rng;
239 StdRng::from_rng(&mut thread_rng())
240 }
241 };
242
243 Self { config, rng }
244 }
245
246 pub fn generate_scenarios(
248 &mut self,
249 x: &Array2<Float>,
250 y: &Array1<Float>,
251 ) -> Result<Vec<(Array2<Float>, Array1<Float>, String)>, Box<dyn std::error::Error>> {
252 let mut scenarios = Vec::new();
253
254 let scenarios_clone = self.config.scenarios.clone();
255 let severity_levels_clone = self.config.severity_levels.clone();
256
257 for scenario in &scenarios_clone {
258 for &severity in &severity_levels_clone {
259 let (worst_x, worst_y, name) =
260 self.generate_single_scenario(x, y, scenario, severity)?;
261 scenarios.push((worst_x, worst_y, name));
262 }
263 }
264
265 Ok(scenarios)
266 }
267
268 fn generate_single_scenario(
270 &mut self,
271 x: &Array2<Float>,
272 y: &Array1<Float>,
273 scenario: &WorstCaseScenario,
274 severity: Float,
275 ) -> Result<(Array2<Float>, Array1<Float>, String), Box<dyn std::error::Error>> {
276 match scenario {
277 WorstCaseScenario::AdversarialExamples {
278 epsilon,
279 attack_method,
280 targeted,
281 } => {
282 let (adv_x, adv_y) = self.generate_adversarial_examples(
283 x,
284 y,
285 *epsilon * severity,
286 attack_method,
287 *targeted,
288 )?;
289 let name = format!(
290 "Adversarial_{:?}_eps_{:.3}",
291 attack_method,
292 epsilon * severity
293 );
294 Ok((adv_x, adv_y, name))
295 }
296 WorstCaseScenario::DistributionShift {
297 shift_type,
298 severity: base_severity,
299 } => {
300 let (shift_x, shift_y) =
301 self.generate_distribution_shift(x, y, shift_type, base_severity * severity)?;
302 let name = format!(
303 "DistShift_{:?}_sev_{:.2}",
304 shift_type,
305 base_severity * severity
306 );
307 Ok((shift_x, shift_y, name))
308 }
309 WorstCaseScenario::ExtremeOutliers {
310 outlier_fraction,
311 outlier_magnitude,
312 } => {
313 let (outlier_x, outlier_y) = self.generate_extreme_outliers(
314 x,
315 y,
316 *outlier_fraction,
317 outlier_magnitude * severity,
318 )?;
319 let name = format!(
320 "Outliers_frac_{:.2}_mag_{:.2}",
321 outlier_fraction,
322 outlier_magnitude * severity
323 );
324 Ok((outlier_x, outlier_y, name))
325 }
326 WorstCaseScenario::ClassImbalance {
327 minority_fraction,
328 imbalance_ratio,
329 } => {
330 let (imbal_x, imbal_y) = self.generate_class_imbalance(
331 x,
332 y,
333 *minority_fraction,
334 imbalance_ratio * severity,
335 )?;
336 let name = format!(
337 "ClassImbalance_frac_{:.2}_ratio_{:.2}",
338 minority_fraction,
339 imbalance_ratio * severity
340 );
341 Ok((imbal_x, imbal_y, name))
342 }
343 WorstCaseScenario::FeatureCorruption {
344 corruption_rate,
345 corruption_type,
346 } => {
347 let (corr_x, corr_y) = self.generate_feature_corruption(
348 x,
349 y,
350 corruption_rate * severity,
351 corruption_type,
352 )?;
353 let name = format!(
354 "Corruption_{:?}_rate_{:.2}",
355 corruption_type,
356 corruption_rate * severity
357 );
358 Ok((corr_x, corr_y, name))
359 }
360 WorstCaseScenario::TemporalDrift {
361 drift_rate,
362 drift_pattern,
363 } => {
364 let (drift_x, drift_y) =
365 self.generate_temporal_drift(x, y, drift_rate * severity, drift_pattern)?;
366 let name = format!(
367 "TemporalDrift_{:?}_rate_{:.2}",
368 drift_pattern,
369 drift_rate * severity
370 );
371 Ok((drift_x, drift_y, name))
372 }
373 WorstCaseScenario::LabelNoise {
374 noise_rate,
375 noise_pattern,
376 } => {
377 let (noise_x, noise_y) =
378 self.generate_label_noise(x, y, noise_rate * severity, noise_pattern)?;
379 let name = format!(
380 "LabelNoise_{:?}_rate_{:.2}",
381 noise_pattern,
382 noise_rate * severity
383 );
384 Ok((noise_x, noise_y, name))
385 }
386 WorstCaseScenario::MissingData {
387 missing_rate,
388 missing_pattern,
389 } => {
390 let (missing_x, missing_y) =
391 self.generate_missing_data(x, y, missing_rate * severity, missing_pattern)?;
392 let name = format!(
393 "MissingData_{:?}_rate_{:.2}",
394 missing_pattern,
395 missing_rate * severity
396 );
397 Ok((missing_x, missing_y, name))
398 }
399 }
400 }
401
402 fn generate_adversarial_examples(
404 &mut self,
405 x: &Array2<Float>,
406 y: &Array1<Float>,
407 epsilon: Float,
408 attack_method: &AdversarialAttackMethod,
409 _targeted: bool,
410 ) -> Result<(Array2<Float>, Array1<Float>), Box<dyn std::error::Error>> {
411 let mut adv_x = x.clone();
412
413 match attack_method {
414 AdversarialAttackMethod::FGSM => {
415 for mut row in adv_x.axis_iter_mut(Axis(0)) {
417 for val in row.iter_mut() {
418 let perturbation = if self.rng.gen_bool(0.5) {
419 epsilon
420 } else {
421 -epsilon
422 };
423 *val += perturbation;
424 }
425 }
426 }
427 AdversarialAttackMethod::PGD { iterations } => {
428 for _ in 0..*iterations {
430 for mut row in adv_x.axis_iter_mut(Axis(0)) {
431 for val in row.iter_mut() {
432 let step_size = epsilon / (*iterations as Float);
433 let perturbation = if self.rng.gen_bool(0.5) {
434 step_size
435 } else {
436 -step_size
437 };
438 *val += perturbation;
439 *val = val.max(-epsilon).min(epsilon);
441 }
442 }
443 }
444 }
445 AdversarialAttackMethod::BIM { iterations } => {
446 let alpha = epsilon / (*iterations as Float);
448 for _ in 0..*iterations {
449 for mut row in adv_x.axis_iter_mut(Axis(0)) {
450 for val in row.iter_mut() {
451 let perturbation = if self.rng.gen_bool(0.5) {
452 alpha
453 } else {
454 -alpha
455 };
456 *val += perturbation;
457 }
458 }
459 }
460 }
461 AdversarialAttackMethod::RandomNoise => {
462 for mut row in adv_x.axis_iter_mut(Axis(0)) {
464 for val in row.iter_mut() {
465 let noise = self.rng.gen_range(-epsilon..epsilon + 1.0);
466 *val += noise;
467 }
468 }
469 }
470 AdversarialAttackMethod::CW { .. } | AdversarialAttackMethod::BoundaryAttack { .. } => {
471 for mut row in adv_x.axis_iter_mut(Axis(0)) {
473 for val in row.iter_mut() {
474 let perturbation = self.rng.gen_range(-epsilon..epsilon + 1.0);
475 *val += perturbation;
476 }
477 }
478 }
479 }
480
481 Ok((adv_x, y.clone()))
482 }
483
484 fn generate_distribution_shift(
486 &mut self,
487 x: &Array2<Float>,
488 y: &Array1<Float>,
489 shift_type: &DistributionShiftType,
490 severity: Float,
491 ) -> Result<(Array2<Float>, Array1<Float>), Box<dyn std::error::Error>> {
492 let mut shift_x = x.clone();
493 let mut shift_y = y.clone();
494
495 match shift_type {
496 DistributionShiftType::CovariateShift => {
497 for mut row in shift_x.axis_iter_mut(Axis(0)) {
499 for (i, val) in row.iter_mut().enumerate() {
500 let shift = severity * (i as Float * 0.1).sin();
501 *val += shift;
502 }
503 }
504 }
505 DistributionShiftType::PriorShift => {
506 let mut unique_classes: Vec<Float> = y.iter().cloned().collect();
508 unique_classes.sort_by(|a, b| a.partial_cmp(b).unwrap());
509 unique_classes.dedup();
510 if unique_classes.len() > 1 {
511 let target_class = unique_classes[0];
512 let removal_prob = severity * 0.5;
513
514 let mut keep_indices = Vec::new();
515 for (i, &class) in y.iter().enumerate() {
516 if class != target_class || self.rng.random::<Float>() > removal_prob {
517 keep_indices.push(i);
518 }
519 }
520
521 let mut new_x_data = Vec::new();
523 for &i in keep_indices.iter() {
524 new_x_data.extend(x.row(i).iter().cloned());
525 }
526 let new_x =
527 Array2::from_shape_vec((keep_indices.len(), x.ncols()), new_x_data)?;
528 let new_y = Array1::from_vec(keep_indices.iter().map(|&i| y[i]).collect());
529
530 return Ok((new_x, new_y));
531 }
532 }
533 DistributionShiftType::ConceptDrift => {
534 for label in shift_y.iter_mut() {
536 if self.rng.random::<Float>() < severity * 0.2 {
537 *label = 1.0 - *label;
539 }
540 }
541 }
542 DistributionShiftType::DomainShift => {
543 for mut row in shift_x.axis_iter_mut(Axis(0)) {
545 for val in row.iter_mut() {
546 *val = val.tanh() * severity;
548 }
549 }
550 }
551 }
552
553 Ok((shift_x, shift_y))
554 }
555
556 fn generate_extreme_outliers(
558 &mut self,
559 x: &Array2<Float>,
560 y: &Array1<Float>,
561 outlier_fraction: Float,
562 outlier_magnitude: Float,
563 ) -> Result<(Array2<Float>, Array1<Float>), Box<dyn std::error::Error>> {
564 let mut outlier_x = x.clone();
565 let n_outliers = (x.nrows() as Float * outlier_fraction) as usize;
566
567 let mut outlier_indices: Vec<usize> = (0..x.nrows()).collect();
568 outlier_indices.shuffle(&mut self.rng);
569 outlier_indices.truncate(n_outliers);
570
571 for &idx in &outlier_indices {
572 for val in outlier_x.row_mut(idx) {
573 let outlier_value = self
574 .rng
575 .gen_range(-outlier_magnitude..outlier_magnitude + 1.0);
576 *val += outlier_value;
577 }
578 }
579
580 Ok((outlier_x, y.clone()))
581 }
582
583 fn generate_class_imbalance(
585 &mut self,
586 x: &Array2<Float>,
587 y: &Array1<Float>,
588 minority_fraction: Float,
589 _imbalance_ratio: Float,
590 ) -> Result<(Array2<Float>, Array1<Float>), Box<dyn std::error::Error>> {
591 let mut unique_classes: Vec<Float> = y.iter().cloned().collect();
592 unique_classes.sort_by(|a, b| a.partial_cmp(b).unwrap());
593 unique_classes.dedup();
594 if unique_classes.len() < 2 {
595 return Ok((x.clone(), y.clone()));
596 }
597
598 let minority_class = unique_classes[0];
599 let target_minority_count = (x.nrows() as Float * minority_fraction) as usize;
600
601 let mut keep_indices = Vec::new();
602 let mut minority_count = 0;
603
604 for (i, &class) in y.iter().enumerate() {
605 if class == minority_class {
606 if minority_count < target_minority_count {
607 keep_indices.push(i);
608 minority_count += 1;
609 }
610 } else {
611 keep_indices.push(i);
612 }
613 }
614
615 let mut new_x_data = Vec::new();
616 for &i in keep_indices.iter() {
617 new_x_data.extend(x.row(i).iter().cloned());
618 }
619 let new_x = Array2::from_shape_vec((keep_indices.len(), x.ncols()), new_x_data)?;
620 let new_y = Array1::from_vec(keep_indices.iter().map(|&i| y[i]).collect());
621
622 Ok((new_x, new_y))
623 }
624
625 fn generate_feature_corruption(
627 &mut self,
628 x: &Array2<Float>,
629 y: &Array1<Float>,
630 corruption_rate: Float,
631 corruption_type: &CorruptionType,
632 ) -> Result<(Array2<Float>, Array1<Float>), Box<dyn std::error::Error>> {
633 let mut corrupted_x = x.clone();
634
635 match corruption_type {
636 CorruptionType::GaussianNoise { std } => {
637 for val in corrupted_x.iter_mut() {
638 if self.rng.random::<Float>() < corruption_rate {
639 let noise = self.rng.random::<Float>() * std;
640 *val += noise;
641 }
642 }
643 }
644 CorruptionType::SaltPepperNoise { ratio } => {
645 for val in corrupted_x.iter_mut() {
646 if self.rng.random::<Float>() < corruption_rate {
647 *val = if self.rng.random::<Float>() < *ratio {
648 1.0
649 } else {
650 0.0
651 };
652 }
653 }
654 }
655 CorruptionType::MultiplicativeNoise { factor } => {
656 for val in corrupted_x.iter_mut() {
657 if self.rng.random::<Float>() < corruption_rate {
658 let noise = 1.0 + (self.rng.random::<Float>() - 0.5) * factor;
659 *val *= noise;
660 }
661 }
662 }
663 CorruptionType::FeatureMasking => {
664 for val in corrupted_x.iter_mut() {
665 if self.rng.random::<Float>() < corruption_rate {
666 *val = 0.0;
667 }
668 }
669 }
670 CorruptionType::Quantization { levels } => {
671 let step_size = 2.0 / (*levels as Float);
672 for val in corrupted_x.iter_mut() {
673 if self.rng.random::<Float>() < corruption_rate {
674 *val = ((*val / step_size).round() * step_size).clamp(-1.0, 1.0);
675 }
676 }
677 }
678 }
679
680 Ok((corrupted_x, y.clone()))
681 }
682
683 fn generate_temporal_drift(
685 &mut self,
686 x: &Array2<Float>,
687 y: &Array1<Float>,
688 drift_rate: Float,
689 drift_pattern: &DriftPattern,
690 ) -> Result<(Array2<Float>, Array1<Float>), Box<dyn std::error::Error>> {
691 let mut drift_x = x.clone();
692 let n_samples = x.nrows();
693
694 for (t, row) in drift_x.axis_iter_mut(Axis(0)).enumerate() {
695 let time_factor = t as Float / n_samples as Float;
696
697 let drift_magnitude = match drift_pattern {
698 DriftPattern::Linear => drift_rate * time_factor,
699 DriftPattern::Sudden => {
700 if time_factor > 0.5 {
701 drift_rate
702 } else {
703 0.0
704 }
705 }
706 DriftPattern::Exponential => drift_rate * time_factor.exp(),
707 DriftPattern::Seasonal { period } => {
708 drift_rate
709 * (2.0 * std::f64::consts::PI * t as Float / *period as Float).sin()
710 as Float
711 }
712 DriftPattern::RandomWalk => drift_rate * self.rng.random::<Float>(),
713 };
714
715 for val in row {
716 *val += drift_magnitude;
717 }
718 }
719
720 Ok((drift_x, y.clone()))
721 }
722
723 fn generate_label_noise(
725 &mut self,
726 x: &Array2<Float>,
727 y: &Array1<Float>,
728 noise_rate: Float,
729 noise_pattern: &NoisePattern,
730 ) -> Result<(Array2<Float>, Array1<Float>), Box<dyn std::error::Error>> {
731 let mut noisy_y = y.clone();
732 let mut unique_classes: Vec<Float> = y.iter().cloned().collect();
733 unique_classes.sort_by(|a, b| a.partial_cmp(b).unwrap());
734 unique_classes.dedup();
735
736 if unique_classes.len() < 2 {
737 return Ok((x.clone(), noisy_y));
738 }
739
740 match noise_pattern {
741 NoisePattern::Uniform => {
742 for label in noisy_y.iter_mut() {
743 if self.rng.random::<Float>() < noise_rate {
744 let other_classes: Vec<Float> = unique_classes
746 .iter()
747 .filter(|&&c| c != *label)
748 .cloned()
749 .collect();
750 if !other_classes.is_empty() {
751 *label = other_classes[self.rng.gen_range(0..other_classes.len())];
752 }
753 }
754 }
755 }
756 NoisePattern::ClassConditional { class_weights } => {
757 for label in noisy_y.iter_mut() {
758 let class_idx = unique_classes
759 .iter()
760 .position(|&c| c == *label)
761 .unwrap_or(0);
762 let class_noise_rate = if class_idx < class_weights.len() {
763 noise_rate * class_weights[class_idx]
764 } else {
765 noise_rate
766 };
767
768 if self.rng.random::<Float>() < class_noise_rate {
769 let other_classes: Vec<Float> = unique_classes
770 .iter()
771 .filter(|&&c| c != *label)
772 .cloned()
773 .collect();
774 if !other_classes.is_empty() {
775 *label = other_classes[self.rng.gen_range(0..other_classes.len())];
776 }
777 }
778 }
779 }
780 NoisePattern::SystematicBias { target_class } => {
781 let target_class_value = if *target_class < unique_classes.len() {
782 unique_classes[*target_class]
783 } else {
784 unique_classes[0]
785 };
786
787 for label in noisy_y.iter_mut() {
788 if self.rng.random::<Float>() < noise_rate {
789 *label = target_class_value;
790 }
791 }
792 }
793 }
794
795 Ok((x.clone(), noisy_y))
796 }
797
798 fn generate_missing_data(
800 &mut self,
801 x: &Array2<Float>,
802 y: &Array1<Float>,
803 missing_rate: Float,
804 missing_pattern: &MissingPattern,
805 ) -> Result<(Array2<Float>, Array1<Float>), Box<dyn std::error::Error>> {
806 let mut missing_x = x.clone();
807
808 match missing_pattern {
809 MissingPattern::MCAR => {
810 for val in missing_x.iter_mut() {
812 if self.rng.random::<Float>() < missing_rate {
813 *val = Float::NAN;
814 }
815 }
816 }
817 MissingPattern::MAR => {
818 for row in missing_x.axis_iter_mut(Axis(0)) {
820 let row_mean =
821 row.iter().filter(|v| v.is_finite()).sum::<Float>() / row.len() as Float;
822 let missing_prob = if row_mean > 0.0 {
823 missing_rate * 1.5
824 } else {
825 missing_rate * 0.5
826 };
827
828 for val in row {
829 if self.rng.random::<Float>() < missing_prob {
830 *val = Float::NAN;
831 }
832 }
833 }
834 }
835 MissingPattern::MNAR => {
836 for val in missing_x.iter_mut() {
838 let missing_prob = if *val > 0.5 {
839 missing_rate * 2.0
840 } else {
841 missing_rate * 0.5
842 };
843 if self.rng.random::<Float>() < missing_prob {
844 *val = Float::NAN;
845 }
846 }
847 }
848 MissingPattern::BlockMissing { block_size } => {
849 let n_cols = missing_x.ncols();
851 let n_blocks = (missing_rate * n_cols as Float) as usize / block_size;
852
853 for _ in 0..n_blocks {
854 let start_col = self.rng.gen_range(0..n_cols.saturating_sub(*block_size));
855 let end_col = (start_col + block_size).min(n_cols);
856
857 for mut row in missing_x.axis_iter_mut(Axis(0)) {
858 for j in start_col..end_col {
859 row[j] = Float::NAN;
860 }
861 }
862 }
863 }
864 }
865
866 Ok((missing_x, y.clone()))
867 }
868}
869
870impl WorstCaseValidator {
871 pub fn new(config: WorstCaseValidationConfig) -> Self {
873 let generator = WorstCaseScenarioGenerator::new(config);
874 Self { generator }
875 }
876
877 pub fn validate<F>(
879 &mut self,
880 x: &Array2<Float>,
881 y: &Array1<Float>,
882 model_fn: F,
883 ) -> Result<WorstCaseValidationResult, Box<dyn std::error::Error>>
884 where
885 F: Fn(&Array2<Float>, &Array1<Float>) -> Result<Float, Box<dyn std::error::Error>>,
886 {
887 let baseline_score = model_fn(x, y)?;
889
890 let scenarios = self.generator.generate_scenarios(x, y)?;
892
893 let mut scenario_results = HashMap::new();
894 let mut all_scores = Vec::new();
895 let mut failure_count = 0;
896
897 for (scenario_x, scenario_y, scenario_name) in scenarios {
898 let scenario_score = model_fn(&scenario_x, &scenario_y).unwrap_or(0.0);
899 all_scores.push(scenario_score);
900
901 let performance_drop = (baseline_score - scenario_score) / baseline_score;
902
903 if performance_drop > 0.5 {
905 failure_count += 1;
906 }
907
908 let robustness_metrics =
909 self.calculate_robustness_metrics(baseline_score, scenario_score, &scenario_x, x);
910
911 let result = ScenarioResult {
912 scenario_name: scenario_name.clone(),
913 worst_case_score: scenario_score,
914 baseline_score,
915 performance_drop,
916 failure_examples: vec![], robustness_metrics,
918 };
919
920 scenario_results.insert(scenario_name, result);
921 }
922
923 let overall_worst_case_score = all_scores.iter().fold(Float::INFINITY, |a, &b| a.min(b));
924
925 let performance_degradation = (baseline_score - overall_worst_case_score) / baseline_score;
926 let failure_rate = failure_count as Float / all_scores.len() as Float;
927 let robustness_score = 1.0 - performance_degradation;
928
929 let mut confidence_intervals = HashMap::new();
931 for (scenario_name, result) in &scenario_results {
932 let ci_lower = result.worst_case_score * 0.9;
933 let ci_upper = result.worst_case_score * 1.1;
934 confidence_intervals.insert(scenario_name.clone(), (ci_lower, ci_upper));
935 }
936
937 Ok(WorstCaseValidationResult {
938 scenario_results,
939 overall_worst_case_score,
940 robustness_score,
941 failure_rate,
942 performance_degradation,
943 confidence_intervals,
944 })
945 }
946
947 fn calculate_robustness_metrics(
949 &self,
950 baseline_score: Float,
951 scenario_score: Float,
952 scenario_x: &Array2<Float>,
953 original_x: &Array2<Float>,
954 ) -> RobustnessMetrics {
955 let stability_score = (scenario_score / baseline_score).min(1.0);
956
957 let data_similarity = self.calculate_data_similarity(scenario_x, original_x);
959 let consistency_score = stability_score * data_similarity;
960
961 let resilience_score = if scenario_score > baseline_score * 0.7 {
962 1.0
963 } else {
964 0.0
965 };
966 let recovery_score = stability_score; let breakdown_point = 1.0 - stability_score;
968
969 RobustnessMetrics {
970 stability_score,
971 consistency_score,
972 resilience_score,
973 recovery_score,
974 breakdown_point,
975 }
976 }
977
978 fn calculate_data_similarity(&self, x1: &Array2<Float>, x2: &Array2<Float>) -> Float {
980 if x1.dim() != x2.dim() {
981 return 0.0;
982 }
983
984 let mut similarity_sum = 0.0;
985 let mut count = 0;
986
987 for (row1, row2) in x1.axis_iter(Axis(0)).zip(x2.axis_iter(Axis(0))) {
988 let mut row_similarity = 0.0;
989 let mut valid_features = 0;
990
991 for (&val1, &val2) in row1.iter().zip(row2.iter()) {
992 if val1.is_finite() && val2.is_finite() {
993 row_similarity += 1.0 - (val1 - val2).abs();
994 valid_features += 1;
995 }
996 }
997
998 if valid_features > 0 {
999 similarity_sum += row_similarity / valid_features as Float;
1000 count += 1;
1001 }
1002 }
1003
1004 if count > 0 {
1005 similarity_sum / count as Float
1006 } else {
1007 0.0
1008 }
1009 }
1010}
1011
1012pub fn worst_case_validate<F>(
1014 x: &Array2<Float>,
1015 y: &Array1<Float>,
1016 model_fn: F,
1017 config: Option<WorstCaseValidationConfig>,
1018) -> Result<WorstCaseValidationResult, Box<dyn std::error::Error>>
1019where
1020 F: Fn(&Array2<Float>, &Array1<Float>) -> Result<Float, Box<dyn std::error::Error>>,
1021{
1022 let config = config.unwrap_or_default();
1023 let mut validator = WorstCaseValidator::new(config);
1024 validator.validate(x, y, model_fn)
1025}
1026
1027#[allow(non_snake_case)]
1028#[cfg(test)]
1029mod tests {
1030 use super::*;
1031
1032 #[test]
1033 fn test_worst_case_scenario_generator() {
1034 let config = WorstCaseValidationConfig::default();
1035 let mut generator = WorstCaseScenarioGenerator::new(config);
1036
1037 let x = Array2::from_shape_vec((10, 3), (0..30).map(|i| i as Float).collect()).unwrap();
1038 let y = Array1::from_vec(vec![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0]);
1039
1040 let scenarios = generator.generate_scenarios(&x, &y).unwrap();
1041 assert!(!scenarios.is_empty());
1042 }
1043
1044 #[test]
1045 fn test_adversarial_example_generation() {
1046 let config = WorstCaseValidationConfig::default();
1047 let mut generator = WorstCaseScenarioGenerator::new(config);
1048
1049 let x = Array2::from_shape_vec((5, 3), (0..15).map(|i| i as Float).collect()).unwrap();
1050 let y = Array1::from_vec(vec![0.0, 1.0, 0.0, 1.0, 0.0]);
1051
1052 let (adv_x, adv_y) = generator
1053 .generate_adversarial_examples(&x, &y, 0.1, &AdversarialAttackMethod::FGSM, false)
1054 .unwrap();
1055
1056 assert_eq!(adv_x.dim(), x.dim());
1057 assert_eq!(adv_y.len(), y.len());
1058 }
1059
1060 #[test]
1061 fn test_worst_case_validation() {
1062 let config = WorstCaseValidationConfig {
1063 scenarios: vec![WorstCaseScenario::ExtremeOutliers {
1064 outlier_fraction: 0.1,
1065 outlier_magnitude: 2.0,
1066 }],
1067 n_worst_case_samples: 100,
1068 severity_levels: vec![1.0],
1069 ..Default::default()
1070 };
1071
1072 let x =
1073 Array2::from_shape_vec((10, 3), (0..30).map(|i| i as Float * 0.1).collect()).unwrap();
1074 let y = Array1::from_vec(vec![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0]);
1075
1076 let model_fn =
1077 |_x: &Array2<Float>, _y: &Array1<Float>| -> Result<Float, Box<dyn std::error::Error>> {
1078 Ok(0.8) };
1080
1081 let result = worst_case_validate(&x, &y, model_fn, Some(config)).unwrap();
1082
1083 assert!(result.robustness_score >= 0.0);
1084 assert!(result.robustness_score <= 1.0);
1085 assert!(!result.scenario_results.is_empty());
1086 }
1087
1088 #[test]
1089 fn test_label_noise_generation() {
1090 let config = WorstCaseValidationConfig::default();
1091 let mut generator = WorstCaseScenarioGenerator::new(config);
1092
1093 let x = Array2::from_shape_vec((10, 3), (0..30).map(|i| i as Float).collect()).unwrap();
1094 let y = Array1::from_vec(vec![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0]);
1095
1096 let (noisy_x, noisy_y) = generator
1097 .generate_label_noise(&x, &y, 0.2, &NoisePattern::Uniform)
1098 .unwrap();
1099
1100 assert_eq!(noisy_x.dim(), x.dim());
1101 assert_eq!(noisy_y.len(), y.len());
1102 }
1103}