trustformers_debug/simulation_tools/
what_if_analysis.rs1use super::types::*;
8use chrono::{DateTime, Utc};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct WhatIfAnalysisResult {
15 pub timestamp: DateTime<Utc>,
17 pub base_scenario: Scenario,
19 pub scenarios: Vec<Scenario>,
21 pub impact_analysis: ScenarioImpactAnalysis,
23 pub sensitivity_analysis: FeatureSensitivityAnalysis,
25 pub counterfactual_insights: Vec<CounterfactualInsight>,
27 pub decision_boundary_exploration: DecisionBoundaryExploration,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct Scenario {
34 pub id: String,
36 pub description: String,
38 pub features: HashMap<String, f64>,
40 pub prediction: f64,
42 pub confidence: f64,
44 pub changed_features: Vec<FeatureChange>,
46 pub distance_from_base: f64,
48 pub plausibility: f64,
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct FeatureChange {
55 pub feature_name: String,
57 pub original_value: f64,
59 pub new_value: f64,
61 pub change_magnitude: f64,
63 pub change_direction: ChangeDirection,
65 pub change_type: ChangeType,
67}
68
69#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct ScenarioImpactAnalysis {
72 pub high_impact_scenarios: Vec<String>,
74 pub prediction_flip_scenarios: Vec<String>,
76 pub avg_prediction_change: f64,
78 pub max_prediction_change: f64,
80 pub stability_analysis: PredictionStabilityAnalysis,
82 pub feature_importance_ranking: Vec<FeatureImportanceRank>,
84}
85
86#[derive(Debug, Clone, Serialize, Deserialize)]
88pub struct PredictionStabilityAnalysis {
89 pub stability_score: f64,
91 pub prediction_variance: f64,
93 pub prediction_flips: usize,
95 pub stability_by_magnitude: HashMap<String, f64>,
97}
98
99#[derive(Debug, Clone, Serialize, Deserialize)]
101pub struct FeatureImportanceRank {
102 pub feature_name: String,
104 pub importance_score: f64,
106 pub rank: usize,
108 pub avg_impact: f64,
110 pub change_frequency: usize,
112}
113
114#[derive(Debug, Clone, Serialize, Deserialize)]
116pub struct FeatureSensitivityAnalysis {
117 pub feature_sensitivities: HashMap<String, f64>,
119 pub most_sensitive_features: Vec<String>,
121 pub least_sensitive_features: Vec<String>,
123 pub non_linear_features: Vec<String>,
125 pub interaction_sensitivities: Vec<FeatureInteractionSensitivity>,
127}
128
129#[derive(Debug, Clone, Serialize, Deserialize)]
131pub struct FeatureInteractionSensitivity {
132 pub feature1: String,
134 pub feature2: String,
136 pub sensitivity_score: f64,
138 pub interaction_type: InteractionType,
140}
141
142#[derive(Debug, Clone, Serialize, Deserialize)]
144pub struct CounterfactualInsight {
145 pub description: String,
147 pub required_changes: Vec<FeatureChange>,
149 pub predicted_outcome: f64,
151 pub confidence: f64,
153 pub feasibility: ImplementationFeasibility,
155}
156
157#[derive(Debug, Clone, Serialize, Deserialize)]
159pub struct DecisionBoundaryExploration {
160 pub boundary_points: Vec<BoundaryPoint>,
162 pub boundary_complexity: BoundaryComplexity,
164 pub local_linearity: LocalLinearityAnalysis,
166 pub crossing_analysis: BoundaryCrossingAnalysis,
168}
169
170#[derive(Debug, Clone, Serialize, Deserialize)]
172pub struct BoundaryPoint {
173 pub coordinates: HashMap<String, f64>,
175 pub distance_to_boundary: f64,
177 pub prediction: f64,
179 pub gradient_direction: HashMap<String, f64>,
181}
182
183#[derive(Debug, Clone, Serialize, Deserialize)]
185pub struct BoundaryComplexity {
186 pub complexity_score: f64,
188 pub curvature: f64,
190 pub inflection_points: usize,
192 pub complexity_class: ComplexityClass,
194}
195
196#[derive(Debug, Clone, Serialize, Deserialize)]
198pub struct LocalLinearityAnalysis {
199 pub avg_linearity: f64,
201 pub linearity_by_region: HashMap<String, f64>,
203 pub most_linear_regions: Vec<String>,
205 pub most_nonlinear_regions: Vec<String>,
207}
208
209#[derive(Debug, Clone, Serialize, Deserialize)]
211pub struct BoundaryCrossingAnalysis {
212 pub crossing_count: usize,
214 pub avg_crossing_distance: f64,
216 pub crossing_directions: Vec<CrossingDirection>,
218 pub common_crossing_features: Vec<String>,
220}
221
222#[derive(Debug, Clone, Serialize, Deserialize)]
224pub struct CrossingDirection {
225 pub direction: HashMap<String, f64>,
227 pub magnitude: f64,
229 pub frequency: usize,
231}