trustformers_debug/interface/
simple.rs1use crate::core::session::{DebugConfig, DebugReport, DebugSession};
7use anyhow::Result;
8use serde::{Deserialize, Serialize};
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum QuickDebugLevel {
13 Light,
15 Standard,
17 Deep,
19 Production,
21}
22
23pub async fn quick_debug<T>(_model: &T, level: QuickDebugLevel) -> Result<SimplifiedDebugResult> {
25 let config = smart_config_for_level(level);
26 let mut session = DebugSession::new(config);
27
28 session.start().await?;
30
31 match level {
33 QuickDebugLevel::Light => {
34 let health_summary = session.health_checker().quick_health_check().await?;
36 session.stop().await?;
37 Ok(SimplifiedDebugResult::Light(health_summary))
38 },
39 QuickDebugLevel::Standard => {
40 let health_summary = session.health_checker().quick_health_check().await?;
42 let gradient_analysis = session.gradient_debugger().quick_analysis().await?;
43 let gradient_summary = QuickGradientSummary::from_analysis(&gradient_analysis);
44 let architecture_summary = session.architecture_analyzer().quick_analysis().await?;
45 session.stop().await?;
46 Ok(SimplifiedDebugResult::Standard {
47 health: health_summary,
48 gradients: gradient_summary,
49 architecture: architecture_summary,
50 })
51 },
52 QuickDebugLevel::Deep => {
53 let report = session.stop().await?;
55 Ok(SimplifiedDebugResult::Deep(report))
56 },
57 QuickDebugLevel::Production => {
58 let anomaly_summary = session.anomaly_detector().quick_check().await?;
60 session.stop().await?;
61 Ok(SimplifiedDebugResult::Production(anomaly_summary))
62 },
63 }
64}
65
66pub async fn debug<T>(model: &T) -> Result<SimplifiedDebugResult> {
68 quick_debug(model, QuickDebugLevel::Standard).await
69}
70
71fn smart_config_for_level(level: QuickDebugLevel) -> DebugConfig {
73 match level {
74 QuickDebugLevel::Light => DebugConfig {
75 enable_tensor_inspection: false,
76 enable_gradient_debugging: false,
77 enable_model_diagnostics: false,
78 enable_visualization: false,
79 enable_memory_profiling: false,
80 enable_computation_graph_analysis: false,
81 max_tracked_tensors: 100,
82 max_gradient_history: 10,
83 sampling_rate: 0.1,
84 ..Default::default()
85 },
86 QuickDebugLevel::Standard => DebugConfig {
87 enable_tensor_inspection: true,
88 enable_gradient_debugging: true,
89 enable_model_diagnostics: true,
90 enable_visualization: false,
91 enable_memory_profiling: false,
92 enable_computation_graph_analysis: true,
93 max_tracked_tensors: 500,
94 max_gradient_history: 50,
95 sampling_rate: 0.5,
96 ..Default::default()
97 },
98 QuickDebugLevel::Deep => DebugConfig::default(),
99 QuickDebugLevel::Production => DebugConfig {
100 enable_tensor_inspection: false,
101 enable_gradient_debugging: false,
102 enable_model_diagnostics: false,
103 enable_visualization: false,
104 enable_memory_profiling: false,
105 enable_computation_graph_analysis: false,
106 max_tracked_tensors: 50,
107 max_gradient_history: 5,
108 sampling_rate: 0.01,
109 ..Default::default()
110 },
111 }
112}
113
114#[derive(Debug, Serialize, Deserialize)]
116pub enum SimplifiedDebugResult {
117 Light(QuickHealthSummary),
118 Standard {
119 health: QuickHealthSummary,
120 gradients: QuickGradientSummary,
121 architecture: QuickArchitectureSummary,
122 },
123 Deep(DebugReport),
124 Production(QuickAnomalySummary),
125}
126
127impl SimplifiedDebugResult {
128 pub fn summary(&self) -> String {
130 match self {
131 SimplifiedDebugResult::Light(health) => {
132 format!("Health Score: {:.2}/100 ({})", health.score, health.status)
133 },
134 SimplifiedDebugResult::Standard {
135 health,
136 gradients,
137 architecture,
138 } => {
139 format!(
140 "Health: {:.2}/100 | Gradients: {} | Architecture: {} parameters",
141 health.score, gradients.status, architecture.total_parameters
142 )
143 },
144 SimplifiedDebugResult::Deep(report) => {
145 let summary = report.summary();
146 format!(
147 "Issues: {} | Critical: {} | Session: {}",
148 summary.total_issues, summary.critical_issues, summary.session_id
149 )
150 },
151 SimplifiedDebugResult::Production(anomaly) => {
152 format!("Anomalies: {} detected", anomaly.anomaly_count)
153 },
154 }
155 }
156
157 pub fn has_critical_issues(&self) -> bool {
159 match self {
160 SimplifiedDebugResult::Light(health) => health.score < 30.0,
161 SimplifiedDebugResult::Standard { health, .. } => health.score < 30.0,
162 SimplifiedDebugResult::Deep(report) => report.summary().critical_issues > 0,
163 SimplifiedDebugResult::Production(anomaly) => anomaly.anomaly_count > 0,
164 }
165 }
166
167 pub fn recommendations(&self) -> Vec<String> {
169 match self {
170 SimplifiedDebugResult::Light(health) => health.recommendations.clone(),
171 SimplifiedDebugResult::Standard {
172 health, gradients, ..
173 } => {
174 let mut recs = health.recommendations.clone();
175 recs.extend(gradients.recommendations.clone());
176 recs
177 },
178 SimplifiedDebugResult::Deep(report) => report.summary().recommendations.clone(),
179 SimplifiedDebugResult::Production(anomaly) => anomaly.recommendations.clone(),
180 }
181 }
182}
183
184#[derive(Debug, Serialize, Deserialize)]
186pub struct QuickHealthSummary {
187 pub score: f64,
188 pub status: String,
189 pub recommendations: Vec<String>,
190}
191
192#[derive(Debug, Serialize, Deserialize)]
194pub struct QuickGradientSummary {
195 pub status: String,
196 pub vanishing_risk: f64,
197 pub exploding_risk: f64,
198 pub recommendations: Vec<String>,
199}
200
201impl QuickGradientSummary {
202 pub fn from_analysis(
204 analysis: &crate::gradient_debugger::debugger::GradientQuickAnalysis,
205 ) -> Self {
206 use crate::gradient_debugger::types::LayerHealth;
207
208 let status = match analysis.overall_health {
209 LayerHealth::Healthy => "Healthy".to_string(),
210 LayerHealth::Warning => "Warning".to_string(),
211 LayerHealth::Critical => "Critical".to_string(),
212 _ => "Unknown".to_string(),
213 };
214
215 let vanishing_risk = analysis
217 .problematic_layers
218 .iter()
219 .filter(|layer| layer.contains("Vanishing"))
220 .count() as f64
221 / analysis.active_layers.max(1) as f64;
222
223 let exploding_risk = analysis
224 .problematic_layers
225 .iter()
226 .filter(|layer| layer.contains("Exploding"))
227 .count() as f64
228 / analysis.active_layers.max(1) as f64;
229
230 let mut recommendations = Vec::new();
231 if vanishing_risk > 0.1 {
232 recommendations
233 .push("Consider using residual connections or skip connections".to_string());
234 }
235 if exploding_risk > 0.1 {
236 recommendations
237 .push("Consider gradient clipping or learning rate reduction".to_string());
238 }
239 if analysis.recent_alerts_count > 0 {
240 recommendations.push(format!(
241 "Address {} recent gradient alerts",
242 analysis.recent_alerts_count
243 ));
244 }
245 if recommendations.is_empty() {
246 recommendations.push("Gradients look stable".to_string());
247 }
248
249 Self {
250 status,
251 vanishing_risk,
252 exploding_risk,
253 recommendations,
254 }
255 }
256}
257
258#[derive(Debug, Serialize, Deserialize)]
260pub struct QuickArchitectureSummary {
261 pub total_parameters: u64,
262 pub model_size_mb: f64,
263 pub efficiency_score: f64,
264 pub recommendations: Vec<String>,
265}
266
267#[derive(Debug, Serialize, Deserialize)]
269pub struct QuickAnomalySummary {
270 pub anomaly_count: usize,
271 pub severity_level: String,
272 pub recommendations: Vec<String>,
273}