1#![allow(dead_code)]
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::path::PathBuf;
12use std::time::{Duration, SystemTime};
13use torsh_core::error::{Result, TorshError};
14
15pub struct ModelDebugger {
17 active_sessions: HashMap<String, DebugSession>,
19 hooks: DebugHooksRegistry,
21 tensor_inspector: TensorInspector,
23 gradient_debugger: GradientDebugger,
25 activation_analyzer: ActivationAnalyzer,
27 config: DebugConfig,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct DebugConfig {
34 pub enable_tensor_inspection: bool,
36 pub enable_gradient_debugging: bool,
38 pub enable_activation_analysis: bool,
40 pub enable_nan_detection: bool,
42 pub enable_gradient_explosion_detection: bool,
44 pub max_tensor_inspection_size: usize,
46 pub debug_dir: PathBuf,
48 pub enable_interactive: bool,
50 pub tensor_display_precision: usize,
52 pub enable_layer_debugging: bool,
54}
55
56#[derive(Debug)]
58pub struct DebugSession {
59 pub session_id: String,
61 pub model_id: String,
63 pub start_time: SystemTime,
65 pub active_hooks: Vec<DebugHook>,
67 pub debug_info: DebugInfo,
69 pub interactive_state: Option<InteractiveDebugState>,
71 pub statistics: DebugStatistics,
73}
74
75#[derive(Debug, Clone)]
77pub struct DebugHook {
78 pub hook_id: String,
80 pub hook_type: HookType,
82 pub pattern: String,
84 pub condition: Option<TriggerCondition>,
86 pub actions: Vec<DebugAction>,
88 pub active: bool,
90}
91
92#[derive(Debug, Clone, Serialize, Deserialize)]
94pub enum HookType {
95 ForwardHook,
97 BackwardHook,
99 ParameterUpdateHook,
101 GradientHook,
103 OperationHook,
105}
106
107#[derive(Debug, Clone, Serialize, Deserialize)]
109pub enum TriggerCondition {
110 Always,
112 OnNaN,
114 OnInf,
116 OnGradientExplosion { threshold: f32 },
118 OnValueRange { min: f32, max: f32 },
120 OnShapeMismatch,
122 OnIteration(usize),
124 Custom(String),
126}
127
128#[derive(Debug, Clone, Serialize, Deserialize)]
130pub enum DebugAction {
131 LogTensor,
133 SaveTensor { path: PathBuf },
135 AnalyzeTensor,
137 BreakInteractive,
139 LogWarning { message: String },
141 StopExecution,
143 CaptureStackTrace,
145 VisualizeTensor,
147}
148
149pub struct DebugHooksRegistry {
151 hooks: HashMap<String, DebugHook>,
153 execution_order: Vec<String>,
155}
156
157pub struct TensorInspector {
159 config: TensorInspectionConfig,
161 tensor_cache: HashMap<String, TensorStatistics>,
163}
164
165#[derive(Debug, Clone, Serialize, Deserialize)]
167pub struct TensorInspectionConfig {
168 pub max_display_values: usize,
170 pub enable_histograms: bool,
172 pub histogram_bins: usize,
174 pub enable_distribution_analysis: bool,
176 pub sample_size: usize,
178}
179
180pub struct GradientDebugger {
182 gradient_stats: HashMap<String, GradientStatistics>,
184 explosion_detector: GradientExplosionDetector,
186 vanishing_detector: GradientVanishingDetector,
188}
189
190pub struct ActivationAnalyzer {
192 activation_stats: HashMap<String, ActivationStatistics>,
194 dead_neuron_detector: DeadNeuronDetector,
196 distribution_analyzer: ActivationDistributionAnalyzer,
198}
199
200#[derive(Debug, Clone, Serialize, Deserialize)]
202pub struct DebugInfo {
203 pub tensor_snapshots: HashMap<String, Vec<TensorSnapshot>>,
205 pub gradient_info: HashMap<String, Vec<GradientInfo>>,
207 pub activation_patterns: HashMap<String, ActivationPattern>,
209 pub anomalies: Vec<Anomaly>,
211 pub performance_issues: Vec<PerformanceIssue>,
213 pub health_metrics: ModelHealthMetrics,
215}
216
217#[derive(Debug)]
219pub struct InteractiveDebugState {
220 pub current_breakpoint: Option<Breakpoint>,
222 pub execution_stack: Vec<StackFrame>,
224 pub available_commands: Vec<DebugCommand>,
226 pub variable_inspector: VariableInspector,
228 pub step_mode: StepMode,
230}
231
232#[derive(Debug, Clone, Serialize, Deserialize)]
234pub struct DebugStatistics {
235 pub hooks_triggered: usize,
237 pub anomalies_detected: usize,
239 pub tensors_inspected: usize,
241 pub total_debug_time: Duration,
243 pub debug_overhead: f32,
245}
246
247#[derive(Debug, Clone, Serialize, Deserialize)]
249pub struct TensorSnapshot {
250 pub timestamp: SystemTime,
252 pub shape: Vec<usize>,
254 pub dtype: String,
256 pub device: String,
258 pub statistics: TensorStatistics,
260 pub sample_values: Vec<f32>,
262 pub full_values: Option<Vec<f32>>,
264 pub metadata: TensorMetadata,
266}
267
268#[derive(Debug, Clone, Serialize, Deserialize)]
270pub struct TensorStatistics {
271 pub mean: f32,
273 pub std: f32,
275 pub min: f32,
277 pub max: f32,
279 pub nan_count: usize,
281 pub inf_count: usize,
283 pub zero_count: usize,
285 pub sparsity: f32,
287 pub distribution: ValueDistribution,
289 pub gradient_norm: Option<f32>,
291}
292
293#[derive(Debug, Clone, Serialize, Deserialize)]
295pub struct TensorMetadata {
296 pub name: String,
298 pub producer_layer: Option<String>,
300 pub requires_grad: bool,
302 pub memory_usage: u64,
304 pub created_at: SystemTime,
306}
307
308#[derive(Debug, Clone, Serialize, Deserialize)]
310pub struct ValueDistribution {
311 pub histogram: Vec<usize>,
313 pub bin_edges: Vec<f32>,
315 pub percentiles: HashMap<u8, f32>,
317 pub kurtosis: f32,
319 pub skewness: f32,
321}
322
323#[derive(Debug, Clone, Serialize, Deserialize)]
325pub struct GradientInfo {
326 pub parameter_name: String,
328 pub gradient_stats: GradientStatistics,
330 pub gradient_flow: GradientFlow,
332 pub update_magnitude: f32,
334 pub clipped: bool,
336}
337
338#[derive(Debug, Clone, Serialize, Deserialize)]
340pub struct GradientStatistics {
341 pub norm: f32,
343 pub mean: f32,
345 pub std: f32,
347 pub max: f32,
349 pub min: f32,
351 pub zero_count: usize,
353 pub sparsity: f32,
355}
356
357#[derive(Debug, Clone, Serialize, Deserialize)]
359pub struct GradientFlow {
360 pub direction: FlowDirection,
362 pub magnitude: f32,
364 pub bottlenecks: Vec<String>,
366 pub efficiency: f32,
368}
369
370#[derive(Debug, Clone, Serialize, Deserialize)]
372pub enum FlowDirection {
373 Forward,
374 Backward,
375 Bidirectional,
376}
377
378#[derive(Debug, Clone, Serialize, Deserialize)]
380pub struct ActivationPattern {
381 pub layer_name: String,
383 pub activation_stats: ActivationStatistics,
385 pub dead_neurons: Vec<usize>,
387 pub saturated_neurons: Vec<usize>,
389 pub distribution: ActivationDistribution,
391}
392
393#[derive(Debug, Clone, Serialize, Deserialize)]
395pub struct ActivationStatistics {
396 pub mean_activation: f32,
398 pub activation_variance: f32,
400 pub activation_range: (f32, f32),
402 pub active_percentage: f32,
404 pub entropy: f32,
406}
407
408#[derive(Debug, Clone, Serialize, Deserialize)]
410pub struct ActivationDistribution {
411 pub distribution_type: DistributionType,
413 pub parameters: HashMap<String, f32>,
415 pub goodness_of_fit: f32,
417}
418
419#[derive(Debug, Clone, Serialize, Deserialize)]
421pub enum DistributionType {
422 Normal,
423 Uniform,
424 Exponential,
425 LogNormal,
426 Beta,
427 Gamma,
428 Unknown,
429}
430
431#[derive(Debug, Clone, Serialize, Deserialize)]
433pub struct Anomaly {
434 pub anomaly_type: AnomalyType,
436 pub severity: Severity,
438 pub location: String,
440 pub description: String,
442 pub timestamp: SystemTime,
444 pub suggested_fixes: Vec<String>,
446 pub context: AnomalyContext,
448}
449
450#[derive(Debug, Clone, Serialize, Deserialize)]
452pub enum AnomalyType {
453 NaNValues,
454 InfiniteValues,
455 GradientExplosion,
456 GradientVanishing,
457 DeadNeurons,
458 ActivationSaturation,
459 MemoryLeak,
460 PerformanceDegradation,
461 ShapeMismatch,
462 NumericInstability,
463}
464
465#[derive(Debug, Clone, Serialize, Deserialize)]
467pub enum Severity {
468 Critical,
469 High,
470 Medium,
471 Low,
472 Info,
473}
474
475#[derive(Debug, Clone, Serialize, Deserialize)]
477pub struct AnomalyContext {
478 pub related_tensors: Vec<String>,
480 pub stack_trace: Option<Vec<String>>,
482 pub model_state: ModelState,
484 pub environment: EnvironmentInfo,
486}
487
488#[derive(Debug, Clone, Serialize, Deserialize)]
490pub struct PerformanceIssue {
491 pub issue_type: PerformanceIssueType,
493 pub affected_components: Vec<String>,
495 pub impact: f32,
497 pub optimizations: Vec<String>,
499}
500
501#[derive(Debug, Clone, Serialize, Deserialize)]
503pub enum PerformanceIssueType {
504 SlowLayer,
505 MemoryInefficiency,
506 ComputeBottleneck,
507 IOBottleneck,
508 SynchronizationOverhead,
509}
510
511#[derive(Debug, Clone, Serialize, Deserialize)]
513pub struct ModelHealthMetrics {
514 pub overall_health: f32,
516 pub gradient_health: f32,
518 pub activation_health: f32,
520 pub memory_health: f32,
522 pub performance_health: f32,
524 pub stability_indicators: StabilityIndicators,
526}
527
528#[derive(Debug, Clone, Serialize, Deserialize)]
530pub struct StabilityIndicators {
531 pub numerical_stability: f32,
533 pub training_stability: f32,
535 pub memory_stability: f32,
537 pub convergence_indicators: ConvergenceIndicators,
539}
540
541#[derive(Debug, Clone, Serialize, Deserialize)]
543pub struct ConvergenceIndicators {
544 pub loss_convergence: f32,
546 pub gradient_convergence: f32,
548 pub parameter_convergence: f32,
550 pub validation_convergence: f32,
552}
553
554#[derive(Debug, Clone, Serialize, Deserialize)]
556pub struct Breakpoint {
557 pub id: String,
559 pub location: BreakpointLocation,
561 pub condition: Option<String>,
563 pub hit_count: usize,
565 pub enabled: bool,
567}
568
569#[derive(Debug, Clone, Serialize, Deserialize)]
570pub enum BreakpointLocation {
571 Layer(String),
572 Operation(String),
573 Line(usize),
574 Function(String),
575}
576
577#[derive(Debug, Clone)]
578pub struct StackFrame {
579 pub function: String,
581 pub file: Option<PathBuf>,
583 pub line: Option<usize>,
585 pub locals: HashMap<String, String>,
587}
588
589#[derive(Debug, Clone, Serialize, Deserialize)]
590pub enum DebugCommand {
591 Continue,
592 Step,
593 StepInto,
594 StepOut,
595 Inspect(String),
596 Evaluate(String),
597 SetBreakpoint(BreakpointLocation),
598 RemoveBreakpoint(String),
599 ListVariables,
600 PrintTensor(String),
601 SaveTensor(String, PathBuf),
602 Exit,
603}
604
605#[derive(Debug)]
606pub struct VariableInspector {
607 variables: HashMap<String, VariableInfo>,
609 watch_list: Vec<String>,
611}
612
613#[derive(Debug, Clone)]
614pub struct VariableInfo {
615 pub name: String,
617 pub var_type: String,
619 pub value: String,
621 pub address: Option<String>,
623 pub size: usize,
625}
626
627#[derive(Debug, Clone, Serialize, Deserialize)]
628pub enum StepMode {
629 Normal,
630 StepInto,
631 StepOver,
632 StepOut,
633}
634
635#[derive(Debug, Clone, Serialize, Deserialize)]
636pub enum ModelState {
637 Training,
638 Evaluation,
639 Inference,
640 Paused,
641 Error,
642}
643
644#[derive(Debug, Clone, Serialize, Deserialize)]
645pub struct EnvironmentInfo {
646 pub device: String,
648 pub memory_usage: u64,
650 pub cpu_usage: f32,
652 pub temperature: Option<f32>,
654}
655
656pub struct GradientExplosionDetector {
658 threshold: f32,
660 history: Vec<f32>,
662 window_size: usize,
664}
665
666pub struct GradientVanishingDetector {
667 threshold: f32,
669 layer_gradients: HashMap<String, Vec<f32>>,
671}
672
673pub struct DeadNeuronDetector {
674 threshold: f32,
676 window: usize,
678 neuron_states: HashMap<String, Vec<bool>>,
680}
681
682pub struct ActivationDistributionAnalyzer {
683 distributions: HashMap<String, ActivationDistribution>,
685 config: DistributionAnalysisConfig,
687}
688
689#[derive(Debug, Clone, Serialize, Deserialize)]
690pub struct DistributionAnalysisConfig {
691 pub enable_fitting: bool,
693 pub confidence_level: f32,
695 pub sample_size: usize,
697}
698
699impl Default for DebugConfig {
700 fn default() -> Self {
701 Self {
702 enable_tensor_inspection: true,
703 enable_gradient_debugging: true,
704 enable_activation_analysis: true,
705 enable_nan_detection: true,
706 enable_gradient_explosion_detection: true,
707 max_tensor_inspection_size: 1000000,
708 debug_dir: PathBuf::from("./debug"),
709 enable_interactive: false,
710 tensor_display_precision: 4,
711 enable_layer_debugging: true,
712 }
713 }
714}
715
716impl ModelDebugger {
717 pub fn new(config: DebugConfig) -> Result<Self> {
719 std::fs::create_dir_all(&config.debug_dir)?;
720
721 Ok(Self {
722 active_sessions: HashMap::new(),
723 hooks: DebugHooksRegistry::new(),
724 tensor_inspector: TensorInspector::new(TensorInspectionConfig::default()),
725 gradient_debugger: GradientDebugger::new(),
726 activation_analyzer: ActivationAnalyzer::new(),
727 config,
728 })
729 }
730
731 pub fn start_debugging(&mut self, model_id: &str) -> Result<String> {
733 let session_id = format!(
734 "debug_{}_{}",
735 model_id,
736 SystemTime::now()
737 .duration_since(std::time::UNIX_EPOCH)
738 .expect("system time should be after UNIX epoch")
739 .as_secs()
740 );
741
742 let session = DebugSession {
743 session_id: session_id.clone(),
744 model_id: model_id.to_string(),
745 start_time: SystemTime::now(),
746 active_hooks: Vec::new(),
747 debug_info: DebugInfo::new(),
748 interactive_state: if self.config.enable_interactive {
749 Some(InteractiveDebugState::new())
750 } else {
751 None
752 },
753 statistics: DebugStatistics::default(),
754 };
755
756 self.active_sessions.insert(session_id.clone(), session);
757
758 println!("Started debugging session: {}", session_id);
759 Ok(session_id)
760 }
761
762 pub fn stop_debugging(&mut self, session_id: &str) -> Result<DebugReport> {
764 let session = self.active_sessions.remove(session_id).ok_or_else(|| {
765 TorshError::InvalidArgument(format!("Unknown session: {}", session_id))
766 })?;
767
768 let report = self.generate_debug_report(session)?;
769
770 self.save_debug_report(session_id, &report)?;
772
773 println!("Completed debugging session: {}", session_id);
774 Ok(report)
775 }
776
777 pub fn add_hook(&mut self, session_id: &str, hook: DebugHook) -> Result<()> {
779 if let Some(session) = self.active_sessions.get_mut(session_id) {
780 session.active_hooks.push(hook.clone());
781 }
782
783 self.hooks.register_hook(hook)?;
784 Ok(())
785 }
786
787 pub fn remove_hook(&mut self, session_id: &str, hook_id: &str) -> Result<()> {
789 if let Some(session) = self.active_sessions.get_mut(session_id) {
790 session.active_hooks.retain(|h| h.hook_id != hook_id);
791 }
792
793 self.hooks.unregister_hook(hook_id)?;
794 Ok(())
795 }
796
797 pub fn inspect_tensor(
799 &mut self,
800 session_id: &str,
801 tensor_name: &str,
802 tensor_data: &[f32],
803 shape: &[usize],
804 ) -> Result<TensorSnapshot> {
805 let snapshot = self
806 .tensor_inspector
807 .inspect(tensor_name, tensor_data, shape)?;
808
809 if let Some(session) = self.active_sessions.get_mut(session_id) {
810 session
811 .debug_info
812 .tensor_snapshots
813 .entry(tensor_name.to_string())
814 .or_insert_with(Vec::new)
815 .push(snapshot.clone());
816 session.statistics.tensors_inspected += 1;
817 }
818
819 Ok(snapshot)
820 }
821
822 pub fn analyze_gradients(
824 &mut self,
825 session_id: &str,
826 parameter_name: &str,
827 gradients: &[f32],
828 ) -> Result<GradientInfo> {
829 let gradient_info = self.gradient_debugger.analyze(parameter_name, gradients)?;
830
831 if let Some(session) = self.active_sessions.get_mut(session_id) {
832 session
833 .debug_info
834 .gradient_info
835 .entry(parameter_name.to_string())
836 .or_insert_with(Vec::new)
837 .push(gradient_info.clone());
838 }
839
840 Ok(gradient_info)
841 }
842
843 pub fn analyze_activations(
845 &mut self,
846 session_id: &str,
847 layer_name: &str,
848 activations: &[f32],
849 ) -> Result<ActivationPattern> {
850 let pattern = self.activation_analyzer.analyze(layer_name, activations)?;
851
852 if let Some(session) = self.active_sessions.get_mut(session_id) {
853 session
854 .debug_info
855 .activation_patterns
856 .insert(layer_name.to_string(), pattern.clone());
857 }
858
859 Ok(pattern)
860 }
861
862 pub fn detect_anomalies(&mut self, session_id: &str) -> Result<Vec<Anomaly>> {
864 let mut anomalies = Vec::new();
865
866 if let Some(session) = self.active_sessions.get(session_id) {
867 for (tensor_name, snapshots) in &session.debug_info.tensor_snapshots {
869 if let Some(latest) = snapshots.last() {
870 if latest.statistics.nan_count > 0 {
871 anomalies.push(Anomaly {
872 anomaly_type: AnomalyType::NaNValues,
873 severity: Severity::Critical,
874 location: tensor_name.clone(),
875 description: format!(
876 "Found {} NaN values in tensor {}",
877 latest.statistics.nan_count, tensor_name
878 ),
879 timestamp: SystemTime::now(),
880 suggested_fixes: vec![
881 "Check for division by zero".to_string(),
882 "Verify input data quality".to_string(),
883 "Add gradient clipping".to_string(),
884 ],
885 context: AnomalyContext {
886 related_tensors: vec![tensor_name.clone()],
887 stack_trace: None,
888 model_state: ModelState::Training,
889 environment: EnvironmentInfo {
890 device: "CPU".to_string(),
891 memory_usage: 1024 * 1024 * 100,
892 cpu_usage: 75.0,
893 temperature: None,
894 },
895 },
896 });
897 }
898
899 if latest.statistics.inf_count > 0 {
900 anomalies.push(Anomaly {
901 anomaly_type: AnomalyType::InfiniteValues,
902 severity: Severity::High,
903 location: tensor_name.clone(),
904 description: format!(
905 "Found {} infinite values in tensor {}",
906 latest.statistics.inf_count, tensor_name
907 ),
908 timestamp: SystemTime::now(),
909 suggested_fixes: vec![
910 "Check for overflow in computations".to_string(),
911 "Reduce learning rate".to_string(),
912 "Add numerical stability checks".to_string(),
913 ],
914 context: AnomalyContext {
915 related_tensors: vec![tensor_name.clone()],
916 stack_trace: None,
917 model_state: ModelState::Training,
918 environment: EnvironmentInfo {
919 device: "CPU".to_string(),
920 memory_usage: 1024 * 1024 * 100,
921 cpu_usage: 75.0,
922 temperature: None,
923 },
924 },
925 });
926 }
927 }
928 }
929
930 for (param_name, gradient_infos) in &session.debug_info.gradient_info {
932 if let Some(latest) = gradient_infos.last() {
933 if latest.gradient_stats.norm > 10.0 {
934 anomalies.push(Anomaly {
935 anomaly_type: AnomalyType::GradientExplosion,
936 severity: Severity::High,
937 location: param_name.clone(),
938 description: format!(
939 "Gradient explosion detected: norm = {:.4}",
940 latest.gradient_stats.norm
941 ),
942 timestamp: SystemTime::now(),
943 suggested_fixes: vec![
944 "Add gradient clipping".to_string(),
945 "Reduce learning rate".to_string(),
946 "Check for numerical instability".to_string(),
947 ],
948 context: AnomalyContext {
949 related_tensors: vec![param_name.clone()],
950 stack_trace: None,
951 model_state: ModelState::Training,
952 environment: EnvironmentInfo {
953 device: "CPU".to_string(),
954 memory_usage: 1024 * 1024 * 100,
955 cpu_usage: 75.0,
956 temperature: None,
957 },
958 },
959 });
960 }
961 }
962 }
963 }
964
965 if let Some(session) = self.active_sessions.get_mut(session_id) {
967 session.debug_info.anomalies.extend(anomalies.clone());
968 session.statistics.anomalies_detected += anomalies.len();
969 }
970
971 Ok(anomalies)
972 }
973
974 fn generate_debug_report(&self, session: DebugSession) -> Result<DebugReport> {
976 let duration = SystemTime::now()
977 .duration_since(session.start_time)
978 .unwrap_or_default();
979
980 Ok(DebugReport {
981 session_info: DebugSessionInfo {
982 session_id: session.session_id,
983 model_id: session.model_id,
984 start_time: session.start_time,
985 end_time: SystemTime::now(),
986 duration,
987 },
988 debug_info: session.debug_info,
989 summary: DebugSummary {
990 total_anomalies: session.statistics.anomalies_detected,
991 critical_issues: 0, performance_score: 0.85, stability_score: 0.78, recommendations: vec![
995 "Consider adding gradient clipping".to_string(),
996 "Monitor memory usage more closely".to_string(),
997 ],
998 },
999 statistics: session.statistics,
1000 })
1001 }
1002
1003 fn save_debug_report(&self, session_id: &str, report: &DebugReport) -> Result<()> {
1004 let file_path = self
1005 .config
1006 .debug_dir
1007 .join(format!("{}_debug_report.json", session_id));
1008 let content = serde_json::to_string_pretty(report)?;
1009 std::fs::write(file_path, content)?;
1010 Ok(())
1011 }
1012}
1013
1014#[derive(Debug, Clone, Serialize, Deserialize)]
1016pub struct DebugReport {
1017 pub session_info: DebugSessionInfo,
1019 pub debug_info: DebugInfo,
1021 pub statistics: DebugStatistics,
1023 pub summary: DebugSummary,
1025}
1026
1027#[derive(Debug, Clone, Serialize, Deserialize)]
1028pub struct DebugSessionInfo {
1029 pub session_id: String,
1030 pub model_id: String,
1031 pub start_time: SystemTime,
1032 pub end_time: SystemTime,
1033 pub duration: Duration,
1034}
1035
1036#[derive(Debug, Clone, Serialize, Deserialize)]
1037pub struct DebugSummary {
1038 pub total_anomalies: usize,
1039 pub critical_issues: usize,
1040 pub performance_score: f32,
1041 pub stability_score: f32,
1042 pub recommendations: Vec<String>,
1043}
1044
1045impl DebugInfo {
1047 fn new() -> Self {
1048 Self {
1049 tensor_snapshots: HashMap::new(),
1050 gradient_info: HashMap::new(),
1051 activation_patterns: HashMap::new(),
1052 anomalies: Vec::new(),
1053 performance_issues: Vec::new(),
1054 health_metrics: ModelHealthMetrics::default(),
1055 }
1056 }
1057}
1058
1059impl InteractiveDebugState {
1060 fn new() -> Self {
1061 Self {
1062 current_breakpoint: None,
1063 execution_stack: Vec::new(),
1064 available_commands: vec![
1065 DebugCommand::Continue,
1066 DebugCommand::Step,
1067 DebugCommand::StepInto,
1068 DebugCommand::StepOut,
1069 DebugCommand::ListVariables,
1070 ],
1071 variable_inspector: VariableInspector::new(),
1072 step_mode: StepMode::Normal,
1073 }
1074 }
1075}
1076
1077impl VariableInspector {
1078 fn new() -> Self {
1079 Self {
1080 variables: HashMap::new(),
1081 watch_list: Vec::new(),
1082 }
1083 }
1084}
1085
1086impl DebugHooksRegistry {
1087 fn new() -> Self {
1088 Self {
1089 hooks: HashMap::new(),
1090 execution_order: Vec::new(),
1091 }
1092 }
1093
1094 fn register_hook(&mut self, hook: DebugHook) -> Result<()> {
1095 self.hooks.insert(hook.hook_id.clone(), hook.clone());
1096 self.execution_order.push(hook.hook_id);
1097 Ok(())
1098 }
1099
1100 fn unregister_hook(&mut self, hook_id: &str) -> Result<()> {
1101 self.hooks.remove(hook_id);
1102 self.execution_order.retain(|id| id != hook_id);
1103 Ok(())
1104 }
1105}
1106
1107impl TensorInspector {
1108 fn new(config: TensorInspectionConfig) -> Self {
1109 Self {
1110 config,
1111 tensor_cache: HashMap::new(),
1112 }
1113 }
1114
1115 fn inspect(
1116 &mut self,
1117 tensor_name: &str,
1118 data: &[f32],
1119 shape: &[usize],
1120 ) -> Result<TensorSnapshot> {
1121 let statistics = self.calculate_statistics(data);
1122
1123 let (sample_values, full_values) = if data.len() > self.config.max_display_values {
1124 (self.sample_values(data), None)
1125 } else {
1126 (Vec::new(), Some(data.to_vec()))
1127 };
1128
1129 Ok(TensorSnapshot {
1130 timestamp: SystemTime::now(),
1131 shape: shape.to_vec(),
1132 dtype: "f32".to_string(),
1133 device: "CPU".to_string(),
1134 statistics,
1135 sample_values,
1136 full_values,
1137 metadata: TensorMetadata {
1138 name: tensor_name.to_string(),
1139 producer_layer: None,
1140 requires_grad: false,
1141 memory_usage: std::mem::size_of_val(data) as u64,
1142 created_at: SystemTime::now(),
1143 },
1144 })
1145 }
1146
1147 fn calculate_statistics(&self, data: &[f32]) -> TensorStatistics {
1148 let len = data.len();
1149 if len == 0 {
1150 return TensorStatistics::default();
1151 }
1152
1153 let sum: f32 = data.iter().sum();
1154 let mean = sum / len as f32;
1155
1156 let variance: f32 = data.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / len as f32;
1157 let std = variance.sqrt();
1158
1159 let min = data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
1160 let max = data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
1161
1162 let nan_count = data.iter().filter(|&&x| x.is_nan()).count();
1163 let inf_count = data.iter().filter(|&&x| x.is_infinite()).count();
1164 let zero_count = data.iter().filter(|&&x| x == 0.0).count();
1165
1166 let sparsity = zero_count as f32 / len as f32;
1167
1168 TensorStatistics {
1169 mean,
1170 std,
1171 min,
1172 max,
1173 nan_count,
1174 inf_count,
1175 zero_count,
1176 sparsity,
1177 distribution: ValueDistribution::default(),
1178 gradient_norm: None,
1179 }
1180 }
1181
1182 fn sample_values(&self, data: &[f32]) -> Vec<f32> {
1183 let sample_size = self.config.sample_size.min(data.len());
1184 let step = data.len() / sample_size;
1185 (0..sample_size).map(|i| data[i * step]).collect()
1186 }
1187}
1188
1189impl GradientDebugger {
1190 fn new() -> Self {
1191 Self {
1192 gradient_stats: HashMap::new(),
1193 explosion_detector: GradientExplosionDetector::new(10.0),
1194 vanishing_detector: GradientVanishingDetector::new(1e-6),
1195 }
1196 }
1197
1198 fn analyze(&mut self, parameter_name: &str, gradients: &[f32]) -> Result<GradientInfo> {
1199 let stats = self.calculate_gradient_statistics(gradients);
1200 let norm = stats.norm;
1201
1202 Ok(GradientInfo {
1203 parameter_name: parameter_name.to_string(),
1204 gradient_flow: GradientFlow {
1205 direction: FlowDirection::Backward,
1206 magnitude: norm,
1207 bottlenecks: vec![],
1208 efficiency: 0.85,
1209 },
1210 update_magnitude: norm,
1211 gradient_stats: stats,
1212 clipped: false,
1213 })
1214 }
1215
1216 fn calculate_gradient_statistics(&self, gradients: &[f32]) -> GradientStatistics {
1217 if gradients.is_empty() {
1218 return GradientStatistics::default();
1219 }
1220
1221 let norm = gradients.iter().map(|x| x * x).sum::<f32>().sqrt();
1222 let mean = gradients.iter().sum::<f32>() / gradients.len() as f32;
1223 let variance =
1224 gradients.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / gradients.len() as f32;
1225 let std = variance.sqrt();
1226 let max = gradients.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
1227 let min = gradients.iter().fold(f32::INFINITY, |a, &b| a.min(b));
1228 let zero_count = gradients.iter().filter(|&&x| x == 0.0).count();
1229 let sparsity = zero_count as f32 / gradients.len() as f32;
1230
1231 GradientStatistics {
1232 norm,
1233 mean,
1234 std,
1235 max,
1236 min,
1237 zero_count,
1238 sparsity,
1239 }
1240 }
1241}
1242
1243impl ActivationAnalyzer {
1244 fn new() -> Self {
1245 Self {
1246 activation_stats: HashMap::new(),
1247 dead_neuron_detector: DeadNeuronDetector::new(0.01),
1248 distribution_analyzer: ActivationDistributionAnalyzer::new(),
1249 }
1250 }
1251
1252 fn analyze(&mut self, layer_name: &str, activations: &[f32]) -> Result<ActivationPattern> {
1253 let stats = self.calculate_activation_statistics(activations);
1254 let dead_neurons = self.dead_neuron_detector.detect(layer_name, activations);
1255 let distribution = self.distribution_analyzer.analyze(activations);
1256
1257 Ok(ActivationPattern {
1258 layer_name: layer_name.to_string(),
1259 activation_stats: stats,
1260 dead_neurons,
1261 saturated_neurons: vec![], distribution,
1263 })
1264 }
1265
1266 fn calculate_activation_statistics(&self, activations: &[f32]) -> ActivationStatistics {
1267 if activations.is_empty() {
1268 return ActivationStatistics::default();
1269 }
1270
1271 let mean = activations.iter().sum::<f32>() / activations.len() as f32;
1272 let variance =
1273 activations.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / activations.len() as f32;
1274 let min = activations.iter().fold(f32::INFINITY, |a, &b| a.min(b));
1275 let max = activations.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
1276 let active_count = activations.iter().filter(|&&x| x > 0.01).count();
1277 let active_percentage = active_count as f32 / activations.len() as f32;
1278
1279 ActivationStatistics {
1280 mean_activation: mean,
1281 activation_variance: variance,
1282 activation_range: (min, max),
1283 active_percentage,
1284 entropy: 0.0, }
1286 }
1287}
1288
1289impl GradientExplosionDetector {
1291 fn new(threshold: f32) -> Self {
1292 Self {
1293 threshold,
1294 history: Vec::new(),
1295 window_size: 10,
1296 }
1297 }
1298}
1299
1300impl GradientVanishingDetector {
1301 fn new(threshold: f32) -> Self {
1302 Self {
1303 threshold,
1304 layer_gradients: HashMap::new(),
1305 }
1306 }
1307}
1308
1309impl DeadNeuronDetector {
1310 fn new(threshold: f32) -> Self {
1311 Self {
1312 threshold,
1313 window: 100,
1314 neuron_states: HashMap::new(),
1315 }
1316 }
1317
1318 fn detect(&mut self, _layer_name: &str, activations: &[f32]) -> Vec<usize> {
1319 activations
1320 .iter()
1321 .enumerate()
1322 .filter_map(|(i, &activation)| {
1323 if activation.abs() < self.threshold {
1324 Some(i)
1325 } else {
1326 None
1327 }
1328 })
1329 .collect()
1330 }
1331}
1332
1333impl ActivationDistributionAnalyzer {
1334 fn new() -> Self {
1335 Self {
1336 distributions: HashMap::new(),
1337 config: DistributionAnalysisConfig::default(),
1338 }
1339 }
1340
1341 fn analyze(&mut self, _activations: &[f32]) -> ActivationDistribution {
1342 ActivationDistribution {
1344 distribution_type: DistributionType::Normal,
1345 parameters: HashMap::new(),
1346 goodness_of_fit: 0.85,
1347 }
1348 }
1349}
1350
1351impl Default for TensorStatistics {
1353 fn default() -> Self {
1354 Self {
1355 mean: 0.0,
1356 std: 0.0,
1357 min: 0.0,
1358 max: 0.0,
1359 nan_count: 0,
1360 inf_count: 0,
1361 zero_count: 0,
1362 sparsity: 0.0,
1363 distribution: ValueDistribution::default(),
1364 gradient_norm: None,
1365 }
1366 }
1367}
1368
1369impl Default for ValueDistribution {
1370 fn default() -> Self {
1371 Self {
1372 histogram: vec![],
1373 bin_edges: vec![],
1374 percentiles: HashMap::new(),
1375 kurtosis: 0.0,
1376 skewness: 0.0,
1377 }
1378 }
1379}
1380
1381impl Default for GradientStatistics {
1382 fn default() -> Self {
1383 Self {
1384 norm: 0.0,
1385 mean: 0.0,
1386 std: 0.0,
1387 max: 0.0,
1388 min: 0.0,
1389 zero_count: 0,
1390 sparsity: 0.0,
1391 }
1392 }
1393}
1394
1395impl Default for ActivationStatistics {
1396 fn default() -> Self {
1397 Self {
1398 mean_activation: 0.0,
1399 activation_variance: 0.0,
1400 activation_range: (0.0, 0.0),
1401 active_percentage: 0.0,
1402 entropy: 0.0,
1403 }
1404 }
1405}
1406
1407impl Default for ModelHealthMetrics {
1408 fn default() -> Self {
1409 Self {
1410 overall_health: 1.0,
1411 gradient_health: 1.0,
1412 activation_health: 1.0,
1413 memory_health: 1.0,
1414 performance_health: 1.0,
1415 stability_indicators: StabilityIndicators::default(),
1416 }
1417 }
1418}
1419
1420impl Default for StabilityIndicators {
1421 fn default() -> Self {
1422 Self {
1423 numerical_stability: 1.0,
1424 training_stability: 1.0,
1425 memory_stability: 1.0,
1426 convergence_indicators: ConvergenceIndicators::default(),
1427 }
1428 }
1429}
1430
1431impl Default for ConvergenceIndicators {
1432 fn default() -> Self {
1433 Self {
1434 loss_convergence: 1.0,
1435 gradient_convergence: 1.0,
1436 parameter_convergence: 1.0,
1437 validation_convergence: 1.0,
1438 }
1439 }
1440}
1441
1442impl Default for DebugStatistics {
1443 fn default() -> Self {
1444 Self {
1445 hooks_triggered: 0,
1446 anomalies_detected: 0,
1447 tensors_inspected: 0,
1448 total_debug_time: Duration::from_secs(0),
1449 debug_overhead: 0.0,
1450 }
1451 }
1452}
1453
1454impl Default for TensorInspectionConfig {
1455 fn default() -> Self {
1456 Self {
1457 max_display_values: 100,
1458 enable_histograms: true,
1459 histogram_bins: 50,
1460 enable_distribution_analysis: true,
1461 sample_size: 1000,
1462 }
1463 }
1464}
1465
1466impl Default for DistributionAnalysisConfig {
1467 fn default() -> Self {
1468 Self {
1469 enable_fitting: true,
1470 confidence_level: 0.95,
1471 sample_size: 10000,
1472 }
1473 }
1474}
1475
1476#[cfg(test)]
1477mod tests {
1478 use super::*;
1479
1480 #[test]
1481 fn test_debugger_creation() {
1482 let config = DebugConfig::default();
1483 let debugger = ModelDebugger::new(config);
1484 assert!(debugger.is_ok());
1485 }
1486
1487 #[test]
1488 fn test_debugging_session() {
1489 let config = DebugConfig::default();
1490 let mut debugger = ModelDebugger::new(config).unwrap();
1491
1492 let session_id = debugger.start_debugging("test_model").unwrap();
1493 assert!(!session_id.is_empty());
1494
1495 let result = debugger.stop_debugging(&session_id);
1496 assert!(result.is_ok());
1497 }
1498
1499 #[test]
1500 fn test_tensor_inspection() {
1501 let config = DebugConfig::default();
1502 let mut debugger = ModelDebugger::new(config).unwrap();
1503
1504 let session_id = debugger.start_debugging("test_model").unwrap();
1505
1506 let test_data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1507 let shape = vec![5];
1508
1509 let snapshot = debugger
1510 .inspect_tensor(&session_id, "test_tensor", &test_data, &shape)
1511 .unwrap();
1512
1513 assert_eq!(snapshot.shape, shape);
1514 assert_eq!(snapshot.statistics.mean, 3.0);
1515 }
1516
1517 #[test]
1518 fn test_anomaly_detection() {
1519 let config = DebugConfig::default();
1520 let mut debugger = ModelDebugger::new(config).unwrap();
1521
1522 let session_id = debugger.start_debugging("test_model").unwrap();
1523
1524 let bad_data = vec![1.0, 2.0, f32::NAN, 4.0, 5.0];
1526 let shape = vec![5];
1527 debugger
1528 .inspect_tensor(&session_id, "bad_tensor", &bad_data, &shape)
1529 .unwrap();
1530
1531 let anomalies = debugger.detect_anomalies(&session_id).unwrap();
1532 assert!(!anomalies.is_empty());
1533
1534 let nan_anomaly = anomalies
1535 .iter()
1536 .find(|a| matches!(a.anomaly_type, AnomalyType::NaNValues));
1537 assert!(nan_anomaly.is_some());
1538 }
1539}