1use super::adversarial_analysis::*;
8use super::edge_case_discovery::*;
9use super::perturbation_testing::*;
10use super::reporting::SimulationReport;
11use super::types::*;
12use super::what_if_analysis::*;
13use anyhow::Result;
14use chrono::Utc;
15use std::collections::HashMap;
16
17#[derive(Debug)]
19pub struct SimulationAnalyzer {
20 config: SimulationConfig,
21 what_if_results: Vec<WhatIfAnalysisResult>,
22 perturbation_results: Vec<PerturbationTestResult>,
23 adversarial_results: Vec<AdversarialProbingResult>,
24 edge_case_results: Vec<EdgeCaseDiscoveryResult>,
25}
26
27impl SimulationAnalyzer {
28 pub fn new(config: SimulationConfig) -> Self {
30 Self {
31 config,
32 what_if_results: Vec::new(),
33 perturbation_results: Vec::new(),
34 adversarial_results: Vec::new(),
35 edge_case_results: Vec::new(),
36 }
37 }
38
39 pub async fn analyze_what_if(
41 &mut self,
42 base_input: &HashMap<String, f64>,
43 model_fn: Box<dyn Fn(&HashMap<String, f64>) -> f64 + Send + Sync>,
44 ) -> Result<WhatIfAnalysisResult> {
45 if !self.config.enable_what_if_analysis {
46 return Err(anyhow::anyhow!("What-if analysis is disabled"));
47 }
48
49 let base_prediction = model_fn(base_input);
50 let base_scenario = Scenario {
51 id: "base".to_string(),
52 description: "Original input scenario".to_string(),
53 features: base_input.clone(),
54 prediction: base_prediction,
55 confidence: 1.0, changed_features: vec![],
57 distance_from_base: 0.0,
58 plausibility: 1.0,
59 };
60
61 let scenarios = self.generate_what_if_scenarios(base_input, &model_fn).await?;
63
64 let impact_analysis = self.analyze_scenario_impacts(&base_scenario, &scenarios);
66
67 let sensitivity_analysis = self.analyze_feature_sensitivity_from_scenarios(&scenarios);
69
70 let counterfactual_insights =
72 self.generate_counterfactual_insights(&base_scenario, &scenarios);
73
74 let decision_boundary_exploration = self.explore_decision_boundary(&scenarios);
76
77 let result = WhatIfAnalysisResult {
78 timestamp: Utc::now(),
79 base_scenario,
80 scenarios,
81 impact_analysis,
82 sensitivity_analysis,
83 counterfactual_insights,
84 decision_boundary_exploration,
85 };
86
87 self.what_if_results.push(result.clone());
88 Ok(result)
89 }
90
91 pub async fn test_perturbations(
93 &mut self,
94 base_input: &HashMap<String, f64>,
95 model_fn: Box<dyn Fn(&HashMap<String, f64>) -> f64 + Send + Sync>,
96 ) -> Result<PerturbationTestResult> {
97 if !self.config.enable_perturbation_testing {
98 return Err(anyhow::anyhow!("Perturbation testing is disabled"));
99 }
100
101 let mut results_by_intensity = HashMap::new();
102
103 for &intensity in &self.config.perturbation_intensities {
105 let intensity_result =
106 self.test_perturbation_intensity(base_input, &model_fn, intensity).await?;
107 results_by_intensity.insert(intensity.to_string(), intensity_result);
108 }
109
110 let robustness_assessment = self.assess_robustness(&results_by_intensity);
112
113 let sensitivity_hotspots = self.identify_sensitivity_hotspots(&results_by_intensity);
115
116 let failure_modes = self.analyze_failure_modes(&results_by_intensity);
118
119 let result = PerturbationTestResult {
120 timestamp: Utc::now(),
121 base_input: base_input.clone(),
122 results_by_intensity,
123 robustness_assessment,
124 sensitivity_hotspots,
125 failure_modes,
126 };
127
128 self.perturbation_results.push(result.clone());
129 Ok(result)
130 }
131
132 pub async fn probe_adversarial(
134 &mut self,
135 base_input: &HashMap<String, f64>,
136 model_fn: Box<dyn Fn(&HashMap<String, f64>) -> f64 + Send + Sync>,
137 ) -> Result<AdversarialProbingResult> {
138 if !self.config.enable_adversarial_probing {
139 return Err(anyhow::anyhow!("Adversarial probing is disabled"));
140 }
141
142 let mut adversarial_examples = HashMap::new();
143
144 for method in &self.config.adversarial_methods {
146 let examples =
147 self.generate_adversarial_examples(base_input, &model_fn, method).await?;
148 adversarial_examples.insert(method.clone(), examples);
149 }
150
151 let attack_success_analysis = self.analyze_attack_success(&adversarial_examples);
153
154 let robustness_assessment = self.assess_adversarial_robustness(&adversarial_examples);
156
157 let defense_recommendations = self.generate_defense_recommendations(&adversarial_examples);
159
160 let result = AdversarialProbingResult {
161 timestamp: Utc::now(),
162 base_input: base_input.clone(),
163 adversarial_examples,
164 attack_success_analysis,
165 robustness_assessment,
166 defense_recommendations,
167 };
168
169 self.adversarial_results.push(result.clone());
170 Ok(result)
171 }
172
173 pub async fn discover_edge_cases(
175 &mut self,
176 input_space: &HashMap<String, (f64, f64)>,
177 model_fn: Box<dyn Fn(&HashMap<String, f64>) -> f64 + Send + Sync>,
178 ) -> Result<EdgeCaseDiscoveryResult> {
179 if !self.config.enable_edge_case_discovery {
180 return Err(anyhow::anyhow!("Edge case discovery is disabled"));
181 }
182
183 let edge_cases = self.search_edge_cases(input_space, &model_fn).await?;
185
186 let classification = self.classify_edge_cases(&edge_cases);
188
189 let coverage_analysis = self.analyze_edge_case_coverage(&edge_cases, input_space);
191
192 let risk_assessment = self.assess_edge_case_risks(&edge_cases);
194
195 let result = EdgeCaseDiscoveryResult {
196 timestamp: Utc::now(),
197 edge_cases,
198 classification,
199 coverage_analysis,
200 risk_assessment,
201 };
202
203 self.edge_case_results.push(result.clone());
204 Ok(result)
205 }
206
207 pub async fn generate_report(&self) -> Result<SimulationReport> {
209 Ok(SimulationReport {
210 timestamp: Utc::now(),
211 config: self.config.clone(),
212 what_if_analyses_count: self.what_if_results.len(),
213 perturbation_tests_count: self.perturbation_results.len(),
214 adversarial_probes_count: self.adversarial_results.len(),
215 edge_case_discoveries_count: self.edge_case_results.len(),
216 recent_what_if_results: self.what_if_results.iter().rev().take(3).cloned().collect(),
217 recent_perturbation_results: self
218 .perturbation_results
219 .iter()
220 .rev()
221 .take(3)
222 .cloned()
223 .collect(),
224 recent_adversarial_results: self
225 .adversarial_results
226 .iter()
227 .rev()
228 .take(3)
229 .cloned()
230 .collect(),
231 recent_edge_case_results: self
232 .edge_case_results
233 .iter()
234 .rev()
235 .take(3)
236 .cloned()
237 .collect(),
238 simulation_summary: self.generate_simulation_summary(),
239 })
240 }
241
242 async fn generate_what_if_scenarios(
245 &self,
246 base_input: &HashMap<String, f64>,
247 model_fn: &(dyn Fn(&HashMap<String, f64>) -> f64 + Send + Sync),
248 ) -> Result<Vec<Scenario>> {
249 let mut scenarios = Vec::new();
250 let _base_prediction = model_fn(base_input);
251
252 for i in 0..self.config.num_what_if_scenarios {
253 let mut scenario_input = base_input.clone();
254 let mut changed_features = Vec::new();
255
256 use scirs2_core::random::*; let mut rng = thread_rng();
259 let num_features_to_change = 1 + (rng.random_range(0..3)); let features: Vec<String> = base_input.keys().cloned().collect();
261
262 for _ in 0..num_features_to_change {
263 if let Some(feature_name) = features.get(rng.random_range(0..features.len())) {
264 let original_value = base_input[feature_name];
265 let change_factor = 0.8 + (rng.random::<f64>() * 0.4); let new_value = original_value * change_factor;
267
268 scenario_input.insert(feature_name.clone(), new_value);
269
270 changed_features.push(FeatureChange {
271 feature_name: feature_name.clone(),
272 original_value,
273 new_value,
274 change_magnitude: (new_value - original_value).abs(),
275 change_direction: if new_value > original_value {
276 ChangeDirection::Increase
277 } else {
278 ChangeDirection::Decrease
279 },
280 change_type: if (new_value - original_value).abs() / original_value.abs()
281 > 0.1
282 {
283 ChangeType::Significant
284 } else {
285 ChangeType::Incremental
286 },
287 });
288 }
289 }
290
291 let prediction = model_fn(&scenario_input);
292 let distance_from_base = self.calculate_distance(base_input, &scenario_input);
293
294 scenarios.push(Scenario {
295 id: format!("scenario_{}", i),
296 description: format!("What-if scenario {}", i),
297 features: scenario_input,
298 prediction,
299 confidence: 0.8, changed_features,
301 distance_from_base,
302 plausibility: 1.0 - (distance_from_base / 10.0).min(1.0), });
304 }
305
306 Ok(scenarios)
307 }
308
309 fn calculate_distance(
310 &self,
311 input1: &HashMap<String, f64>,
312 input2: &HashMap<String, f64>,
313 ) -> f64 {
314 input1
315 .iter()
316 .map(|(key, value)| {
317 let other_value = input2.get(key).unwrap_or(&0.0);
318 (value - other_value).powi(2)
319 })
320 .sum::<f64>()
321 .sqrt()
322 }
323
324 fn analyze_scenario_impacts(
325 &self,
326 base_scenario: &Scenario,
327 scenarios: &[Scenario],
328 ) -> ScenarioImpactAnalysis {
329 let prediction_changes: Vec<f64> = scenarios
330 .iter()
331 .map(|s| (s.prediction - base_scenario.prediction).abs())
332 .collect();
333
334 let avg_prediction_change =
335 prediction_changes.iter().sum::<f64>() / prediction_changes.len() as f64;
336 let max_prediction_change = prediction_changes.iter().cloned().fold(0.0, f64::max);
337
338 let high_impact_scenarios: Vec<String> = scenarios
339 .iter()
340 .filter(|s| {
341 (s.prediction - base_scenario.prediction).abs() > avg_prediction_change * 2.0
342 })
343 .map(|s| s.id.clone())
344 .collect();
345
346 let prediction_flip_scenarios: Vec<String> = scenarios
347 .iter()
348 .filter(|s| (s.prediction > 0.5) != (base_scenario.prediction > 0.5))
349 .map(|s| s.id.clone())
350 .collect();
351
352 let mut feature_impacts: HashMap<String, Vec<f64>> = HashMap::new();
354 for scenario in scenarios {
355 for change in &scenario.changed_features {
356 feature_impacts
357 .entry(change.feature_name.clone())
358 .or_default()
359 .push((scenario.prediction - base_scenario.prediction).abs());
360 }
361 }
362
363 let feature_importance_ranking: Vec<FeatureImportanceRank> = feature_impacts
364 .iter()
365 .enumerate()
366 .map(|(rank, (feature_name, impacts))| {
367 let avg_impact = impacts.iter().sum::<f64>() / impacts.len() as f64;
368 FeatureImportanceRank {
369 feature_name: feature_name.clone(),
370 importance_score: avg_impact,
371 rank: rank + 1,
372 avg_impact,
373 change_frequency: impacts.len(),
374 }
375 })
376 .collect();
377
378 let stability_analysis = PredictionStabilityAnalysis {
379 stability_score: 1.0
380 - (max_prediction_change / base_scenario.prediction.abs()).min(1.0),
381 prediction_variance: {
382 let predictions: Vec<f64> = scenarios.iter().map(|s| s.prediction).collect();
383 let mean = predictions.iter().sum::<f64>() / predictions.len() as f64;
384 predictions.iter().map(|p| (p - mean).powi(2)).sum::<f64>()
385 / predictions.len() as f64
386 },
387 prediction_flips: prediction_flip_scenarios.len(),
388 stability_by_magnitude: HashMap::new(), };
390
391 ScenarioImpactAnalysis {
392 high_impact_scenarios,
393 prediction_flip_scenarios,
394 avg_prediction_change,
395 max_prediction_change,
396 stability_analysis,
397 feature_importance_ranking,
398 }
399 }
400
401 fn analyze_feature_sensitivity_from_scenarios(
402 &self,
403 scenarios: &[Scenario],
404 ) -> FeatureSensitivityAnalysis {
405 let mut feature_sensitivities = HashMap::new();
406 let mut feature_change_counts = HashMap::new();
407
408 for scenario in scenarios {
409 for change in &scenario.changed_features {
410 let sensitivity = change.change_magnitude / scenario.distance_from_base;
411 *feature_sensitivities.entry(change.feature_name.clone()).or_insert(0.0) +=
412 sensitivity;
413 *feature_change_counts.entry(change.feature_name.clone()).or_insert(0) += 1;
414 }
415 }
416
417 for (feature, sensitivity) in feature_sensitivities.iter_mut() {
419 let count = feature_change_counts[feature] as f64;
420 if count > 0.0 {
421 *sensitivity /= count;
422 }
423 }
424
425 let mut sorted_features: Vec<_> = feature_sensitivities.iter().collect();
426 sorted_features.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap_or(std::cmp::Ordering::Equal));
427
428 let most_sensitive_features: Vec<String> =
429 sorted_features.iter().take(5).map(|(name, _)| (*name).clone()).collect();
430
431 let least_sensitive_features: Vec<String> =
432 sorted_features.iter().rev().take(5).map(|(name, _)| (*name).clone()).collect();
433
434 FeatureSensitivityAnalysis {
435 feature_sensitivities,
436 most_sensitive_features,
437 least_sensitive_features,
438 non_linear_features: vec![], interaction_sensitivities: vec![], }
441 }
442
443 fn generate_counterfactual_insights(
444 &self,
445 base_scenario: &Scenario,
446 scenarios: &[Scenario],
447 ) -> Vec<CounterfactualInsight> {
448 let mut insights = Vec::new();
449
450 for scenario in scenarios {
452 let prediction_change = (scenario.prediction - base_scenario.prediction).abs();
453 if prediction_change > 0.1 {
454 insights.push(CounterfactualInsight {
456 description: format!(
457 "Changing {} features can alter prediction by {:.3}",
458 scenario.changed_features.len(),
459 prediction_change
460 ),
461 required_changes: scenario.changed_features.clone(),
462 predicted_outcome: scenario.prediction,
463 confidence: scenario.confidence,
464 feasibility: if scenario.changed_features.len() <= 2 {
465 ImplementationFeasibility::Easy
466 } else {
467 ImplementationFeasibility::Moderate
468 },
469 });
470 }
471 }
472
473 insights
474 }
475
476 fn explore_decision_boundary(&self, scenarios: &[Scenario]) -> DecisionBoundaryExploration {
477 let boundary_points: Vec<BoundaryPoint> = scenarios.iter()
479 .filter(|s| (s.prediction - 0.5).abs() < 0.1) .take(10)
481 .map(|s| BoundaryPoint {
482 coordinates: s.features.clone(),
483 distance_to_boundary: (s.prediction - 0.5).abs(),
484 prediction: s.prediction,
485 gradient_direction: HashMap::new(), })
487 .collect();
488
489 DecisionBoundaryExploration {
490 boundary_points: boundary_points.clone(),
491 boundary_complexity: BoundaryComplexity {
492 complexity_score: 0.6,
493 curvature: 0.3,
494 inflection_points: 2,
495 complexity_class: ComplexityClass::Polynomial,
496 },
497 local_linearity: LocalLinearityAnalysis {
498 avg_linearity: 0.7,
499 linearity_by_region: HashMap::new(),
500 most_linear_regions: vec![],
501 most_nonlinear_regions: vec![],
502 },
503 crossing_analysis: BoundaryCrossingAnalysis {
504 crossing_count: boundary_points.len(),
505 avg_crossing_distance: boundary_points
506 .iter()
507 .map(|p| p.distance_to_boundary)
508 .sum::<f64>()
509 / boundary_points.len() as f64,
510 crossing_directions: vec![],
511 common_crossing_features: vec![],
512 },
513 }
514 }
515
516 async fn test_perturbation_intensity(
517 &self,
518 base_input: &HashMap<String, f64>,
519 model_fn: &(dyn Fn(&HashMap<String, f64>) -> f64 + Send + Sync),
520 intensity: f64,
521 ) -> Result<PerturbationIntensityResult> {
522 let base_prediction = model_fn(base_input);
523 let mut perturbation_details = Vec::new();
524 let mut successful_perturbations = 0;
525 let mut failed_perturbations = 0;
526 let mut prediction_changes = Vec::new();
527
528 for i in 0..self.config.num_perturbation_samples {
530 let perturbed_input = self.generate_perturbation(base_input, intensity);
531 let perturbed_prediction = model_fn(&perturbed_input);
532 let prediction_change = (perturbed_prediction - base_prediction).abs();
533
534 let is_successful = prediction_change < 0.1; if is_successful {
537 successful_perturbations += 1;
538 } else {
539 failed_perturbations += 1;
540 }
541
542 prediction_changes.push(prediction_change);
543
544 let perturbation_vector: HashMap<String, f64> = base_input
545 .iter()
546 .map(|(key, &base_val)| {
547 let perturbed_val = perturbed_input.get(key).unwrap_or(&base_val);
548 (key.clone(), perturbed_val - base_val)
549 })
550 .collect();
551
552 let perturbation_magnitude =
553 perturbation_vector.values().map(|&v| v.powi(2)).sum::<f64>().sqrt();
554
555 perturbation_details.push(PerturbationDetail {
556 id: format!("pert_{}_{}", intensity, i),
557 original_input: base_input.clone(),
558 perturbed_input,
559 original_prediction: base_prediction,
560 perturbed_prediction,
561 prediction_change,
562 perturbation_vector,
563 perturbation_magnitude,
564 is_successful,
565 });
566 }
567
568 let avg_prediction_change =
569 prediction_changes.iter().sum::<f64>() / prediction_changes.len() as f64;
570 let max_prediction_change = prediction_changes.iter().cloned().fold(0.0, f64::max);
571 let std_prediction_change = {
572 let variance = prediction_changes
573 .iter()
574 .map(|&x| (x - avg_prediction_change).powi(2))
575 .sum::<f64>()
576 / prediction_changes.len() as f64;
577 variance.sqrt()
578 };
579
580 Ok(PerturbationIntensityResult {
581 intensity,
582 num_perturbations: self.config.num_perturbation_samples,
583 successful_perturbations,
584 failed_perturbations,
585 avg_prediction_change,
586 max_prediction_change,
587 std_prediction_change,
588 perturbation_details,
589 })
590 }
591
592 fn generate_perturbation(
593 &self,
594 base_input: &HashMap<String, f64>,
595 intensity: f64,
596 ) -> HashMap<String, f64> {
597 let mut perturbed_input = base_input.clone();
598 use scirs2_core::random::*; let mut rng = thread_rng();
600
601 for (_key, value) in perturbed_input.iter_mut() {
602 let noise = (rng.random::<f64>() - 0.5) * 2.0 * intensity;
604 *value += noise;
605 }
606
607 perturbed_input
608 }
609
610 fn assess_robustness(
611 &self,
612 results: &HashMap<String, PerturbationIntensityResult>,
613 ) -> RobustnessAssessment {
614 let success_rates: Vec<f64> = results
616 .values()
617 .map(|r| r.successful_perturbations as f64 / r.num_perturbations as f64)
618 .collect();
619
620 let robustness_score = success_rates.iter().sum::<f64>() / success_rates.len() as f64;
621
622 let robustness_class = match robustness_score {
623 x if x > 0.9 => RobustnessClass::VeryRobust,
624 x if x > 0.7 => RobustnessClass::Robust,
625 x if x > 0.5 => RobustnessClass::SomewhatRobust,
626 x if x > 0.3 => RobustnessClass::Sensitive,
627 _ => RobustnessClass::Fragile,
628 };
629
630 let critical_threshold = results
632 .iter()
633 .find(|(_, result)| {
634 let success_rate =
635 result.successful_perturbations as f64 / result.num_perturbations as f64;
636 success_rate < 0.5
637 })
638 .map(|(intensity, _)| intensity.parse::<f64>().unwrap_or(1.0))
639 .unwrap_or(1.0);
640
641 RobustnessAssessment {
642 robustness_score,
643 robustness_class,
644 feature_robustness: HashMap::new(), critical_threshold,
646 improvement_recommendations: vec![
647 "Consider adding regularization".to_string(),
648 "Increase training data diversity".to_string(),
649 ],
650 }
651 }
652
653 fn identify_sensitivity_hotspots(
654 &self,
655 _results: &HashMap<String, PerturbationIntensityResult>,
656 ) -> Vec<SensitivityHotspot> {
657 vec![SensitivityHotspot {
659 location: HashMap::new(), sensitivity_score: 0.8,
661 sensitivity_radius: 0.1,
662 sensitive_features: vec!["feature1".to_string()],
663 hotspot_type: HotspotType::Local,
664 }]
665 }
666
667 fn analyze_failure_modes(
668 &self,
669 _results: &HashMap<String, PerturbationIntensityResult>,
670 ) -> FailureModesAnalysis {
671 let failure_modes = vec![FailureMode {
673 id: "noise_sensitivity".to_string(),
674 description: "Model sensitive to input noise".to_string(),
675 triggering_conditions: vec![TriggeringCondition {
676 feature: "any".to_string(),
677 condition_type: ConditionType::Exceeds,
678 threshold: 0.1,
679 description: "Noise level exceeds 10%".to_string(),
680 }],
681 severity: FailureSeverity::Moderate,
682 frequency: 0.3,
683 example_inputs: vec![],
684 }];
685
686 FailureModesAnalysis {
687 failure_modes,
688 failure_frequency: FailureFrequencyAnalysis {
689 overall_failure_rate: 0.3,
690 failure_rate_by_intensity: HashMap::new(),
691 failure_rate_by_feature: HashMap::new(),
692 time_to_failure: TimeToFailureAnalysis {
693 avg_time_to_failure: 5.0,
694 median_time_to_failure: 3.0,
695 distribution_parameters: HashMap::new(),
696 },
697 },
698 failure_severity: FailureSeverityAnalysis {
699 avg_severity: 2.5,
700 severity_distribution: HashMap::new(),
701 most_severe_modes: vec!["noise_sensitivity".to_string()],
702 cascading_failures: CascadingFailureAnalysis {
703 cascading_events: 0,
704 avg_cascade_length: 0.0,
705 cascade_triggers: vec![],
706 amplification_factors: HashMap::new(),
707 },
708 },
709 mitigation_strategies: vec![MitigationStrategy {
710 name: "Data Augmentation".to_string(),
711 description: "Add noise during training".to_string(),
712 target_failure_modes: vec!["noise_sensitivity".to_string()],
713 effectiveness: 0.8,
714 implementation_cost: ImplementationCost::Medium,
715 implementation_steps: vec![
716 "Add noise to training data".to_string(),
717 "Retrain model".to_string(),
718 ],
719 }],
720 }
721 }
722
723 async fn generate_adversarial_examples(
724 &self,
725 base_input: &HashMap<String, f64>,
726 model_fn: &(dyn Fn(&HashMap<String, f64>) -> f64 + Send + Sync),
727 method: &AdversarialMethod,
728 ) -> Result<Vec<AdversarialExample>> {
729 let mut examples = Vec::new();
730 let base_prediction = model_fn(base_input);
731
732 for i in 0..self.config.num_adversarial_examples {
733 let adversarial_input = match method {
734 AdversarialMethod::FGSM => self.generate_fgsm_example(base_input, model_fn),
735 AdversarialMethod::PGD => self.generate_pgd_example(base_input, model_fn),
736 AdversarialMethod::CW => self.generate_cw_example(base_input, model_fn),
737 AdversarialMethod::DeepFool => self.generate_deepfool_example(base_input, model_fn),
738 AdversarialMethod::UAP => self.generate_uap_example(base_input, model_fn),
739 AdversarialMethod::Boundary => self.generate_boundary_example(base_input, model_fn),
740 };
741
742 let adversarial_prediction = model_fn(&adversarial_input);
743
744 let perturbation: HashMap<String, f64> = base_input
745 .iter()
746 .map(|(key, &base_val)| {
747 let adv_val = adversarial_input.get(key).unwrap_or(&base_val);
748 (key.clone(), adv_val - base_val)
749 })
750 .collect();
751
752 let perturbation_norm = perturbation.values().map(|&v| v.powi(2)).sum::<f64>().sqrt();
753
754 let is_successful = (adversarial_prediction - base_prediction).abs() > 0.1;
755
756 examples.push(AdversarialExample {
757 id: format!("adv_{:?}_{}", method, i),
758 attack_method: method.clone(),
759 original_input: base_input.clone(),
760 adversarial_input,
761 original_prediction: base_prediction,
762 adversarial_prediction,
763 perturbation,
764 perturbation_norm,
765 is_successful,
766 confidence: 0.8, });
768 }
769
770 Ok(examples)
771 }
772
773 fn generate_fgsm_example(
775 &self,
776 base_input: &HashMap<String, f64>,
777 _model_fn: &(dyn Fn(&HashMap<String, f64>) -> f64 + Send + Sync),
778 ) -> HashMap<String, f64> {
779 let epsilon = 0.01;
780 let mut adversarial_input = base_input.clone();
781 use scirs2_core::random::*; let mut rng = thread_rng();
783
784 for (_key, value) in adversarial_input.iter_mut() {
786 let sign = if rng.random::<f64>() > 0.5 { 1.0 } else { -1.0 };
787 *value += epsilon * sign;
788 }
789
790 adversarial_input
791 }
792
793 fn generate_pgd_example(
794 &self,
795 base_input: &HashMap<String, f64>,
796 _model_fn: &(dyn Fn(&HashMap<String, f64>) -> f64 + Send + Sync),
797 ) -> HashMap<String, f64> {
798 let mut adversarial_input = base_input.clone();
800 let epsilon = 0.001;
801 let iterations = 10;
802 use scirs2_core::random::*; let mut rng = thread_rng();
804
805 for _ in 0..iterations {
806 for (_key, value) in adversarial_input.iter_mut() {
807 let sign = if rng.random::<f64>() > 0.5 { 1.0 } else { -1.0 };
808 *value += epsilon * sign;
809 }
810 }
811
812 adversarial_input
813 }
814
815 fn generate_cw_example(
816 &self,
817 base_input: &HashMap<String, f64>,
818 model_fn: &(dyn Fn(&HashMap<String, f64>) -> f64 + Send + Sync),
819 ) -> HashMap<String, f64> {
820 self.generate_fgsm_example(base_input, model_fn) }
823
824 fn generate_deepfool_example(
825 &self,
826 base_input: &HashMap<String, f64>,
827 model_fn: &(dyn Fn(&HashMap<String, f64>) -> f64 + Send + Sync),
828 ) -> HashMap<String, f64> {
829 self.generate_fgsm_example(base_input, model_fn) }
832
833 fn generate_uap_example(
834 &self,
835 base_input: &HashMap<String, f64>,
836 model_fn: &(dyn Fn(&HashMap<String, f64>) -> f64 + Send + Sync),
837 ) -> HashMap<String, f64> {
838 self.generate_fgsm_example(base_input, model_fn) }
841
842 fn generate_boundary_example(
843 &self,
844 base_input: &HashMap<String, f64>,
845 model_fn: &(dyn Fn(&HashMap<String, f64>) -> f64 + Send + Sync),
846 ) -> HashMap<String, f64> {
847 self.generate_fgsm_example(base_input, model_fn) }
850
851 fn analyze_attack_success(
852 &self,
853 adversarial_examples: &HashMap<AdversarialMethod, Vec<AdversarialExample>>,
854 ) -> AttackSuccessAnalysis {
855 let mut success_rate_by_method = HashMap::new();
856 let mut total_successful = 0;
857 let mut total_examples = 0;
858 let mut total_perturbation = 0.0;
859
860 for (method, examples) in adversarial_examples {
861 let successful = examples.iter().filter(|e| e.is_successful).count();
862 let success_rate = successful as f64 / examples.len() as f64;
863
864 success_rate_by_method.insert(method.clone(), success_rate);
865 total_successful += successful;
866 total_examples += examples.len();
867
868 total_perturbation += examples.iter().map(|e| e.perturbation_norm).sum::<f64>();
869 }
870
871 let overall_success_rate = total_successful as f64 / total_examples as f64;
872 let avg_perturbation_magnitude = total_perturbation / total_examples as f64;
873
874 let most_effective_methods: Vec<AdversarialMethod> = success_rate_by_method
875 .iter()
876 .filter(|(_, &rate)| rate > 0.5)
877 .map(|(method, _)| method.clone())
878 .collect();
879
880 AttackSuccessAnalysis {
881 success_rate_by_method,
882 overall_success_rate,
883 avg_perturbation_magnitude,
884 most_effective_methods,
885 attack_difficulty: AttackDifficultyAnalysis {
886 easy_targets: vec!["feature1".to_string()],
887 hard_targets: vec!["feature2".to_string()],
888 perturbation_by_feature: HashMap::new(),
889 complexity_assessment: ComplexityAssessment {
890 complexity_score: 0.6,
891 features_required: 2,
892 min_perturbation: 0.01,
893 sophistication_level: SophisticationLevel::Intermediate,
894 },
895 },
896 }
897 }
898
899 fn assess_adversarial_robustness(
900 &self,
901 adversarial_examples: &HashMap<AdversarialMethod, Vec<AdversarialExample>>,
902 ) -> AdversarialRobustnessAssessment {
903 let robustness_by_attack: HashMap<AdversarialMethod, f64> = adversarial_examples
905 .iter()
906 .map(|(method, examples)| {
907 let failed_attacks = examples.iter().filter(|e| !e.is_successful).count();
908 let robustness = failed_attacks as f64 / examples.len() as f64;
909 (method.clone(), robustness)
910 })
911 .collect();
912
913 let overall_robustness =
914 robustness_by_attack.values().sum::<f64>() / robustness_by_attack.len() as f64;
915
916 AdversarialRobustnessAssessment {
917 robustness_score: overall_robustness,
918 robustness_by_attack,
919 vulnerability_hotspots: vec![], certified_robustness: CertifiedRobustnessAnalysis {
921 certified_radius: 0.01,
922 certification_confidence: 0.8,
923 certification_method: "Simplified".to_string(),
924 robustness_guarantees: vec![],
925 },
926 }
927 }
928
929 fn generate_defense_recommendations(
930 &self,
931 _adversarial_examples: &HashMap<AdversarialMethod, Vec<AdversarialExample>>,
932 ) -> Vec<DefenseRecommendation> {
933 vec![
934 DefenseRecommendation {
935 name: "Adversarial Training".to_string(),
936 description: "Train with adversarial examples".to_string(),
937 target_vulnerabilities: vec!["FGSM".to_string(), "PGD".to_string()],
938 effectiveness: 0.8,
939 complexity: DefenseComplexity::Moderate,
940 performance_impact: PerformanceImpact::Medium,
941 },
942 DefenseRecommendation {
943 name: "Input Preprocessing".to_string(),
944 description: "Add noise reduction preprocessing".to_string(),
945 target_vulnerabilities: vec!["All".to_string()],
946 effectiveness: 0.6,
947 complexity: DefenseComplexity::Simple,
948 performance_impact: PerformanceImpact::Low,
949 },
950 ]
951 }
952
953 async fn search_edge_cases(
954 &self,
955 input_space: &HashMap<String, (f64, f64)>,
956 model_fn: &(dyn Fn(&HashMap<String, f64>) -> f64 + Send + Sync),
957 ) -> Result<Vec<EdgeCase>> {
958 let mut edge_cases = Vec::new();
959 use scirs2_core::random::*; let mut rng = thread_rng();
961
962 for i in 0..100 {
964 let mut test_input = HashMap::new();
966
967 for (feature, (min_val, max_val)) in input_space {
969 let value = if rng.random::<f64>() > 0.5 {
970 *min_val + (*max_val - *min_val) * 0.01 } else {
972 *max_val - (*max_val - *min_val) * 0.01 };
974 test_input.insert(feature.clone(), value);
975 }
976
977 let prediction = model_fn(&test_input);
978
979 if !(0.1..=0.9).contains(&prediction) || prediction.is_nan() {
981 edge_cases.push(EdgeCase {
982 id: format!("edge_{}", i),
983 description: format!("Edge case with extreme prediction: {:.3}", prediction),
984 trigger_input: test_input,
985 model_output: prediction,
986 expected_output: None,
987 edge_case_type: if prediction.is_nan() {
988 EdgeCaseType::ModelConfusion
989 } else if !(0.1..=0.9).contains(&prediction) {
990 EdgeCaseType::DistributionBoundary
991 } else {
992 EdgeCaseType::Outlier
993 },
994 severity: if prediction.is_nan() {
995 EdgeCaseSeverity::Critical
996 } else {
997 EdgeCaseSeverity::Medium
998 },
999 likelihood: 0.1, detection_method: "Boundary exploration".to_string(),
1001 });
1002 }
1003 }
1004
1005 Ok(edge_cases)
1006 }
1007
1008 fn classify_edge_cases(&self, edge_cases: &[EdgeCase]) -> EdgeCaseClassification {
1009 let mut by_type = HashMap::new();
1010 let mut by_severity = HashMap::new();
1011
1012 for edge_case in edge_cases {
1013 *by_type.entry(edge_case.edge_case_type.clone()).or_insert(0) += 1;
1014 *by_severity.entry(edge_case.severity.clone()).or_insert(0) += 1;
1015 }
1016
1017 EdgeCaseClassification {
1018 by_type,
1019 by_severity,
1020 common_patterns: vec![], systematic_issues: vec![], }
1023 }
1024
1025 fn analyze_edge_case_coverage(
1026 &self,
1027 _edge_cases: &[EdgeCase],
1028 _input_space: &HashMap<String, (f64, f64)>,
1029 ) -> CoverageAnalysis {
1030 CoverageAnalysis {
1032 feature_space_coverage: 0.3, boundary_coverage: 0.8, uncovered_regions: vec![], coverage_gaps: vec![], }
1037 }
1038
1039 fn assess_edge_case_risks(&self, edge_cases: &[EdgeCase]) -> EdgeCaseRiskAssessment {
1040 let overall_risk = edge_cases
1041 .iter()
1042 .map(|ec| match ec.severity {
1043 EdgeCaseSeverity::Critical => 1.0,
1044 EdgeCaseSeverity::High => 0.8,
1045 EdgeCaseSeverity::Medium => 0.5,
1046 EdgeCaseSeverity::Low => 0.2,
1047 })
1048 .sum::<f64>()
1049 / edge_cases.len() as f64;
1050
1051 let high_risk_cases: Vec<String> = edge_cases
1052 .iter()
1053 .filter(|ec| {
1054 matches!(
1055 ec.severity,
1056 EdgeCaseSeverity::High | EdgeCaseSeverity::Critical
1057 )
1058 })
1059 .map(|ec| ec.id.clone())
1060 .collect();
1061
1062 EdgeCaseRiskAssessment {
1063 overall_risk,
1064 risk_by_type: HashMap::new(), high_risk_cases,
1066 mitigation_priorities: vec![], }
1068 }
1069
1070 fn generate_simulation_summary(&self) -> HashMap<String, String> {
1071 let mut summary = HashMap::new();
1072
1073 summary.insert(
1074 "total_what_if_analyses".to_string(),
1075 self.what_if_results.len().to_string(),
1076 );
1077 summary.insert(
1078 "total_perturbation_tests".to_string(),
1079 self.perturbation_results.len().to_string(),
1080 );
1081 summary.insert(
1082 "total_adversarial_probes".to_string(),
1083 self.adversarial_results.len().to_string(),
1084 );
1085 summary.insert(
1086 "total_edge_case_discoveries".to_string(),
1087 self.edge_case_results.len().to_string(),
1088 );
1089
1090 if let Some(latest_perturbation) = self.perturbation_results.last() {
1091 summary.insert(
1092 "latest_robustness_score".to_string(),
1093 format!(
1094 "{:.2}",
1095 latest_perturbation.robustness_assessment.robustness_score
1096 ),
1097 );
1098 }
1099
1100 if let Some(latest_adversarial) = self.adversarial_results.last() {
1101 summary.insert(
1102 "latest_adversarial_robustness".to_string(),
1103 format!(
1104 "{:.2}",
1105 latest_adversarial.robustness_assessment.robustness_score
1106 ),
1107 );
1108 }
1109
1110 summary
1111 }
1112}
1113
1114#[cfg(test)]
1115mod tests {
1116 use super::*;
1117
1118 #[tokio::test]
1119 async fn test_simulation_analyzer_creation() {
1120 let config = SimulationConfig::default();
1121 let analyzer = SimulationAnalyzer::new(config);
1122 assert_eq!(analyzer.what_if_results.len(), 0);
1123 }
1124
1125 #[tokio::test]
1126 async fn test_what_if_analysis() {
1127 let config = SimulationConfig::default();
1128 let mut analyzer = SimulationAnalyzer::new(config);
1129
1130 let mut base_input = HashMap::new();
1131 base_input.insert("feature1".to_string(), 1.0);
1132 base_input.insert("feature2".to_string(), 2.0);
1133
1134 let model_fn =
1135 Box::new(|input: &HashMap<String, f64>| -> f64 { input.values().sum::<f64>() * 0.1 });
1136
1137 let result = analyzer.analyze_what_if(&base_input, model_fn).await;
1138 assert!(result.is_ok());
1139 }
1140
1141 #[tokio::test]
1142 async fn test_perturbation_testing() {
1143 let config = SimulationConfig::default();
1144 let mut analyzer = SimulationAnalyzer::new(config);
1145
1146 let mut base_input = HashMap::new();
1147 base_input.insert("feature1".to_string(), 1.0);
1148 base_input.insert("feature2".to_string(), 2.0);
1149
1150 let model_fn =
1151 Box::new(|input: &HashMap<String, f64>| -> f64 { input.values().sum::<f64>() * 0.1 });
1152
1153 let result = analyzer.test_perturbations(&base_input, model_fn).await;
1154 assert!(result.is_ok());
1155 }
1156}