Skip to main content

torsh_hub/
debugging.rs

1//! Model debugging tools for ToRSh Hub
2//!
3//! This module provides comprehensive debugging capabilities for models,
4//! including tensor inspection, gradient debugging, activation analysis,
5//! and interactive debugging utilities.
6
7// Framework infrastructure - components designed for future use
8#![allow(dead_code)]
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::path::PathBuf;
12use std::time::{Duration, SystemTime};
13use torsh_core::error::{Result, TorshError};
14
15/// Comprehensive model debugger
16pub struct ModelDebugger {
17    /// Active debugging sessions
18    active_sessions: HashMap<String, DebugSession>,
19    /// Debug hooks registry
20    hooks: DebugHooksRegistry,
21    /// Tensor inspector
22    tensor_inspector: TensorInspector,
23    /// Gradient debugger
24    gradient_debugger: GradientDebugger,
25    /// Activation analyzer
26    activation_analyzer: ActivationAnalyzer,
27    /// Debug configuration
28    config: DebugConfig,
29}
30
31/// Debug configuration
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct DebugConfig {
34    /// Enable tensor value inspection
35    pub enable_tensor_inspection: bool,
36    /// Enable gradient debugging
37    pub enable_gradient_debugging: bool,
38    /// Enable activation analysis
39    pub enable_activation_analysis: bool,
40    /// Enable NaN/Inf detection
41    pub enable_nan_detection: bool,
42    /// Enable gradient explosion detection
43    pub enable_gradient_explosion_detection: bool,
44    /// Maximum tensor size to inspect (in elements)
45    pub max_tensor_inspection_size: usize,
46    /// Debug output directory
47    pub debug_dir: PathBuf,
48    /// Enable interactive debugging
49    pub enable_interactive: bool,
50    /// Tensor value precision for display
51    pub tensor_display_precision: usize,
52    /// Enable layer-wise debugging
53    pub enable_layer_debugging: bool,
54}
55
56/// Active debugging session
57#[derive(Debug)]
58pub struct DebugSession {
59    /// Session identifier
60    pub session_id: String,
61    /// Model being debugged
62    pub model_id: String,
63    /// Session start time
64    pub start_time: SystemTime,
65    /// Registered debug hooks
66    pub active_hooks: Vec<DebugHook>,
67    /// Collected debug information
68    pub debug_info: DebugInfo,
69    /// Interactive debugger state
70    pub interactive_state: Option<InteractiveDebugState>,
71    /// Debug statistics
72    pub statistics: DebugStatistics,
73}
74
75/// Debug hook for monitoring model execution
76#[derive(Debug, Clone)]
77pub struct DebugHook {
78    /// Hook identifier
79    pub hook_id: String,
80    /// Hook type
81    pub hook_type: HookType,
82    /// Layer/operation pattern to match
83    pub pattern: String,
84    /// Condition for triggering
85    pub condition: Option<TriggerCondition>,
86    /// Actions to perform when triggered
87    pub actions: Vec<DebugAction>,
88    /// Whether hook is active
89    pub active: bool,
90}
91
92/// Types of debug hooks
93#[derive(Debug, Clone, Serialize, Deserialize)]
94pub enum HookType {
95    /// Hook on forward pass
96    ForwardHook,
97    /// Hook on backward pass
98    BackwardHook,
99    /// Hook on parameter update
100    ParameterUpdateHook,
101    /// Hook on gradient computation
102    GradientHook,
103    /// Hook on tensor operation
104    OperationHook,
105}
106
107/// Conditions for triggering debug hooks
108#[derive(Debug, Clone, Serialize, Deserialize)]
109pub enum TriggerCondition {
110    /// Always trigger
111    Always,
112    /// Trigger on NaN values
113    OnNaN,
114    /// Trigger on infinite values
115    OnInf,
116    /// Trigger on gradient explosion
117    OnGradientExplosion { threshold: f32 },
118    /// Trigger on value range
119    OnValueRange { min: f32, max: f32 },
120    /// Trigger on tensor shape mismatch
121    OnShapeMismatch,
122    /// Trigger on specific iteration
123    OnIteration(usize),
124    /// Custom condition
125    Custom(String),
126}
127
128/// Actions to perform when debug hook is triggered
129#[derive(Debug, Clone, Serialize, Deserialize)]
130pub enum DebugAction {
131    /// Log tensor values
132    LogTensor,
133    /// Save tensor to file
134    SaveTensor { path: PathBuf },
135    /// Analyze tensor statistics
136    AnalyzeTensor,
137    /// Break into interactive debugger
138    BreakInteractive,
139    /// Log warning message
140    LogWarning { message: String },
141    /// Stop execution
142    StopExecution,
143    /// Capture stack trace
144    CaptureStackTrace,
145    /// Visualize tensor
146    VisualizeTensor,
147}
148
149/// Registry for managing debug hooks
150pub struct DebugHooksRegistry {
151    /// Registered hooks
152    hooks: HashMap<String, DebugHook>,
153    /// Hook execution order
154    execution_order: Vec<String>,
155}
156
157/// Tensor inspection utilities
158pub struct TensorInspector {
159    /// Configuration for inspection
160    config: TensorInspectionConfig,
161    /// Cached tensor statistics
162    tensor_cache: HashMap<String, TensorStatistics>,
163}
164
165/// Tensor inspection configuration
166#[derive(Debug, Clone, Serialize, Deserialize)]
167pub struct TensorInspectionConfig {
168    /// Maximum number of values to display
169    pub max_display_values: usize,
170    /// Enable histogram generation
171    pub enable_histograms: bool,
172    /// Number of histogram bins
173    pub histogram_bins: usize,
174    /// Enable distribution analysis
175    pub enable_distribution_analysis: bool,
176    /// Sample size for large tensors
177    pub sample_size: usize,
178}
179
180/// Gradient debugging utilities
181pub struct GradientDebugger {
182    /// Gradient statistics
183    gradient_stats: HashMap<String, GradientStatistics>,
184    /// Gradient explosion detection
185    explosion_detector: GradientExplosionDetector,
186    /// Gradient vanishing detection
187    vanishing_detector: GradientVanishingDetector,
188}
189
190/// Activation analysis utilities
191pub struct ActivationAnalyzer {
192    /// Activation statistics
193    activation_stats: HashMap<String, ActivationStatistics>,
194    /// Dead neuron detector
195    dead_neuron_detector: DeadNeuronDetector,
196    /// Activation distribution analyzer
197    distribution_analyzer: ActivationDistributionAnalyzer,
198}
199
200/// Collected debug information
201#[derive(Debug, Clone, Serialize, Deserialize)]
202pub struct DebugInfo {
203    /// Tensor snapshots
204    pub tensor_snapshots: HashMap<String, Vec<TensorSnapshot>>,
205    /// Gradient information
206    pub gradient_info: HashMap<String, Vec<GradientInfo>>,
207    /// Activation patterns
208    pub activation_patterns: HashMap<String, ActivationPattern>,
209    /// Detected anomalies
210    pub anomalies: Vec<Anomaly>,
211    /// Performance issues
212    pub performance_issues: Vec<PerformanceIssue>,
213    /// Model health metrics
214    pub health_metrics: ModelHealthMetrics,
215}
216
217/// Interactive debugging state
218#[derive(Debug)]
219pub struct InteractiveDebugState {
220    /// Current breakpoint
221    pub current_breakpoint: Option<Breakpoint>,
222    /// Execution stack
223    pub execution_stack: Vec<StackFrame>,
224    /// Available commands
225    pub available_commands: Vec<DebugCommand>,
226    /// Variable inspector
227    pub variable_inspector: VariableInspector,
228    /// Step mode
229    pub step_mode: StepMode,
230}
231
232/// Debug statistics
233#[derive(Debug, Clone, Serialize, Deserialize)]
234pub struct DebugStatistics {
235    /// Number of hooks triggered
236    pub hooks_triggered: usize,
237    /// Number of anomalies detected
238    pub anomalies_detected: usize,
239    /// Number of tensors inspected
240    pub tensors_inspected: usize,
241    /// Total debug time
242    pub total_debug_time: Duration,
243    /// Debug overhead ratio
244    pub debug_overhead: f32,
245}
246
247/// Tensor snapshot for debugging
248#[derive(Debug, Clone, Serialize, Deserialize)]
249pub struct TensorSnapshot {
250    /// Snapshot timestamp
251    pub timestamp: SystemTime,
252    /// Tensor shape
253    pub shape: Vec<usize>,
254    /// Data type
255    pub dtype: String,
256    /// Device location
257    pub device: String,
258    /// Tensor statistics
259    pub statistics: TensorStatistics,
260    /// Sample values (for large tensors)
261    pub sample_values: Vec<f32>,
262    /// Full values (for small tensors)
263    pub full_values: Option<Vec<f32>>,
264    /// Tensor metadata
265    pub metadata: TensorMetadata,
266}
267
268/// Tensor statistics
269#[derive(Debug, Clone, Serialize, Deserialize)]
270pub struct TensorStatistics {
271    /// Mean value
272    pub mean: f32,
273    /// Standard deviation
274    pub std: f32,
275    /// Minimum value
276    pub min: f32,
277    /// Maximum value
278    pub max: f32,
279    /// Number of NaN values
280    pub nan_count: usize,
281    /// Number of infinite values
282    pub inf_count: usize,
283    /// Number of zero values
284    pub zero_count: usize,
285    /// Sparsity ratio
286    pub sparsity: f32,
287    /// Value distribution
288    pub distribution: ValueDistribution,
289    /// Gradient norm (if available)
290    pub gradient_norm: Option<f32>,
291}
292
293/// Tensor metadata
294#[derive(Debug, Clone, Serialize, Deserialize)]
295pub struct TensorMetadata {
296    /// Tensor name/identifier
297    pub name: String,
298    /// Layer that produced this tensor
299    pub producer_layer: Option<String>,
300    /// Requires gradient flag
301    pub requires_grad: bool,
302    /// Memory usage in bytes
303    pub memory_usage: u64,
304    /// Creation timestamp
305    pub created_at: SystemTime,
306}
307
308/// Value distribution analysis
309#[derive(Debug, Clone, Serialize, Deserialize)]
310pub struct ValueDistribution {
311    /// Histogram bins
312    pub histogram: Vec<usize>,
313    /// Bin edges
314    pub bin_edges: Vec<f32>,
315    /// Percentiles
316    pub percentiles: HashMap<u8, f32>,
317    /// Kurtosis
318    pub kurtosis: f32,
319    /// Skewness
320    pub skewness: f32,
321}
322
323/// Gradient information for debugging
324#[derive(Debug, Clone, Serialize, Deserialize)]
325pub struct GradientInfo {
326    /// Parameter name
327    pub parameter_name: String,
328    /// Gradient statistics
329    pub gradient_stats: GradientStatistics,
330    /// Gradient flow analysis
331    pub gradient_flow: GradientFlow,
332    /// Update magnitude
333    pub update_magnitude: f32,
334    /// Gradient clipping applied
335    pub clipped: bool,
336}
337
338/// Gradient statistics
339#[derive(Debug, Clone, Serialize, Deserialize)]
340pub struct GradientStatistics {
341    /// Gradient norm
342    pub norm: f32,
343    /// Mean gradient
344    pub mean: f32,
345    /// Gradient standard deviation
346    pub std: f32,
347    /// Maximum gradient value
348    pub max: f32,
349    /// Minimum gradient value
350    pub min: f32,
351    /// Number of zero gradients
352    pub zero_count: usize,
353    /// Gradient sparsity
354    pub sparsity: f32,
355}
356
357/// Gradient flow analysis
358#[derive(Debug, Clone, Serialize, Deserialize)]
359pub struct GradientFlow {
360    /// Flow direction (forward/backward)
361    pub direction: FlowDirection,
362    /// Flow magnitude
363    pub magnitude: f32,
364    /// Bottleneck layers
365    pub bottlenecks: Vec<String>,
366    /// Flow efficiency
367    pub efficiency: f32,
368}
369
370/// Flow direction
371#[derive(Debug, Clone, Serialize, Deserialize)]
372pub enum FlowDirection {
373    Forward,
374    Backward,
375    Bidirectional,
376}
377
378/// Activation pattern analysis
379#[derive(Debug, Clone, Serialize, Deserialize)]
380pub struct ActivationPattern {
381    /// Layer name
382    pub layer_name: String,
383    /// Activation statistics
384    pub activation_stats: ActivationStatistics,
385    /// Dead neurons
386    pub dead_neurons: Vec<usize>,
387    /// Saturated neurons
388    pub saturated_neurons: Vec<usize>,
389    /// Activation distribution
390    pub distribution: ActivationDistribution,
391}
392
393/// Activation statistics
394#[derive(Debug, Clone, Serialize, Deserialize)]
395pub struct ActivationStatistics {
396    /// Mean activation
397    pub mean_activation: f32,
398    /// Activation variance
399    pub activation_variance: f32,
400    /// Activation range
401    pub activation_range: (f32, f32),
402    /// Percentage of active neurons
403    pub active_percentage: f32,
404    /// Activation entropy
405    pub entropy: f32,
406}
407
408/// Activation distribution
409#[derive(Debug, Clone, Serialize, Deserialize)]
410pub struct ActivationDistribution {
411    /// Distribution type
412    pub distribution_type: DistributionType,
413    /// Distribution parameters
414    pub parameters: HashMap<String, f32>,
415    /// Goodness of fit
416    pub goodness_of_fit: f32,
417}
418
419/// Distribution types
420#[derive(Debug, Clone, Serialize, Deserialize)]
421pub enum DistributionType {
422    Normal,
423    Uniform,
424    Exponential,
425    LogNormal,
426    Beta,
427    Gamma,
428    Unknown,
429}
430
431/// Detected anomaly
432#[derive(Debug, Clone, Serialize, Deserialize)]
433pub struct Anomaly {
434    /// Anomaly type
435    pub anomaly_type: AnomalyType,
436    /// Severity level
437    pub severity: Severity,
438    /// Location (layer/operation)
439    pub location: String,
440    /// Description
441    pub description: String,
442    /// Detection timestamp
443    pub timestamp: SystemTime,
444    /// Suggested fixes
445    pub suggested_fixes: Vec<String>,
446    /// Context information
447    pub context: AnomalyContext,
448}
449
450/// Types of anomalies
451#[derive(Debug, Clone, Serialize, Deserialize)]
452pub enum AnomalyType {
453    NaNValues,
454    InfiniteValues,
455    GradientExplosion,
456    GradientVanishing,
457    DeadNeurons,
458    ActivationSaturation,
459    MemoryLeak,
460    PerformanceDegradation,
461    ShapeMismatch,
462    NumericInstability,
463}
464
465/// Severity levels
466#[derive(Debug, Clone, Serialize, Deserialize)]
467pub enum Severity {
468    Critical,
469    High,
470    Medium,
471    Low,
472    Info,
473}
474
475/// Anomaly context
476#[derive(Debug, Clone, Serialize, Deserialize)]
477pub struct AnomalyContext {
478    /// Related tensors
479    pub related_tensors: Vec<String>,
480    /// Stack trace
481    pub stack_trace: Option<Vec<String>>,
482    /// Model state
483    pub model_state: ModelState,
484    /// Environmental factors
485    pub environment: EnvironmentInfo,
486}
487
488/// Performance issue
489#[derive(Debug, Clone, Serialize, Deserialize)]
490pub struct PerformanceIssue {
491    /// Issue type
492    pub issue_type: PerformanceIssueType,
493    /// Affected components
494    pub affected_components: Vec<String>,
495    /// Performance impact
496    pub impact: f32,
497    /// Recommended optimizations
498    pub optimizations: Vec<String>,
499}
500
501/// Performance issue types
502#[derive(Debug, Clone, Serialize, Deserialize)]
503pub enum PerformanceIssueType {
504    SlowLayer,
505    MemoryInefficiency,
506    ComputeBottleneck,
507    IOBottleneck,
508    SynchronizationOverhead,
509}
510
511/// Model health metrics
512#[derive(Debug, Clone, Serialize, Deserialize)]
513pub struct ModelHealthMetrics {
514    /// Overall health score (0-1)
515    pub overall_health: f32,
516    /// Gradient health
517    pub gradient_health: f32,
518    /// Activation health
519    pub activation_health: f32,
520    /// Memory health
521    pub memory_health: f32,
522    /// Performance health
523    pub performance_health: f32,
524    /// Stability indicators
525    pub stability_indicators: StabilityIndicators,
526}
527
528/// Stability indicators
529#[derive(Debug, Clone, Serialize, Deserialize)]
530pub struct StabilityIndicators {
531    /// Numerical stability
532    pub numerical_stability: f32,
533    /// Training stability
534    pub training_stability: f32,
535    /// Memory stability
536    pub memory_stability: f32,
537    /// Convergence indicators
538    pub convergence_indicators: ConvergenceIndicators,
539}
540
541/// Convergence indicators
542#[derive(Debug, Clone, Serialize, Deserialize)]
543pub struct ConvergenceIndicators {
544    /// Loss convergence
545    pub loss_convergence: f32,
546    /// Gradient convergence
547    pub gradient_convergence: f32,
548    /// Parameter convergence
549    pub parameter_convergence: f32,
550    /// Validation convergence
551    pub validation_convergence: f32,
552}
553
554// Interactive debugging structures
555#[derive(Debug, Clone, Serialize, Deserialize)]
556pub struct Breakpoint {
557    /// Breakpoint ID
558    pub id: String,
559    /// Location
560    pub location: BreakpointLocation,
561    /// Condition
562    pub condition: Option<String>,
563    /// Hit count
564    pub hit_count: usize,
565    /// Enabled flag
566    pub enabled: bool,
567}
568
569#[derive(Debug, Clone, Serialize, Deserialize)]
570pub enum BreakpointLocation {
571    Layer(String),
572    Operation(String),
573    Line(usize),
574    Function(String),
575}
576
577#[derive(Debug, Clone)]
578pub struct StackFrame {
579    /// Function name
580    pub function: String,
581    /// File path
582    pub file: Option<PathBuf>,
583    /// Line number
584    pub line: Option<usize>,
585    /// Local variables
586    pub locals: HashMap<String, String>,
587}
588
589#[derive(Debug, Clone, Serialize, Deserialize)]
590pub enum DebugCommand {
591    Continue,
592    Step,
593    StepInto,
594    StepOut,
595    Inspect(String),
596    Evaluate(String),
597    SetBreakpoint(BreakpointLocation),
598    RemoveBreakpoint(String),
599    ListVariables,
600    PrintTensor(String),
601    SaveTensor(String, PathBuf),
602    Exit,
603}
604
605#[derive(Debug)]
606pub struct VariableInspector {
607    /// Current variables
608    variables: HashMap<String, VariableInfo>,
609    /// Watch list
610    watch_list: Vec<String>,
611}
612
613#[derive(Debug, Clone)]
614pub struct VariableInfo {
615    /// Variable name
616    pub name: String,
617    /// Variable type
618    pub var_type: String,
619    /// Variable value (string representation)
620    pub value: String,
621    /// Memory address
622    pub address: Option<String>,
623    /// Size in bytes
624    pub size: usize,
625}
626
627#[derive(Debug, Clone, Serialize, Deserialize)]
628pub enum StepMode {
629    Normal,
630    StepInto,
631    StepOver,
632    StepOut,
633}
634
635#[derive(Debug, Clone, Serialize, Deserialize)]
636pub enum ModelState {
637    Training,
638    Evaluation,
639    Inference,
640    Paused,
641    Error,
642}
643
644#[derive(Debug, Clone, Serialize, Deserialize)]
645pub struct EnvironmentInfo {
646    /// Device information
647    pub device: String,
648    /// Memory usage
649    pub memory_usage: u64,
650    /// CPU usage
651    pub cpu_usage: f32,
652    /// Temperature
653    pub temperature: Option<f32>,
654}
655
656// Detection utilities
657pub struct GradientExplosionDetector {
658    /// Explosion threshold
659    threshold: f32,
660    /// History window
661    history: Vec<f32>,
662    /// Window size
663    window_size: usize,
664}
665
666pub struct GradientVanishingDetector {
667    /// Vanishing threshold
668    threshold: f32,
669    /// Layer gradients
670    layer_gradients: HashMap<String, Vec<f32>>,
671}
672
673pub struct DeadNeuronDetector {
674    /// Activation threshold
675    threshold: f32,
676    /// Monitoring window
677    window: usize,
678    /// Neuron states
679    neuron_states: HashMap<String, Vec<bool>>,
680}
681
682pub struct ActivationDistributionAnalyzer {
683    /// Distribution cache
684    distributions: HashMap<String, ActivationDistribution>,
685    /// Analysis configuration
686    config: DistributionAnalysisConfig,
687}
688
689#[derive(Debug, Clone, Serialize, Deserialize)]
690pub struct DistributionAnalysisConfig {
691    /// Enable distribution fitting
692    pub enable_fitting: bool,
693    /// Confidence level
694    pub confidence_level: f32,
695    /// Sample size
696    pub sample_size: usize,
697}
698
699impl Default for DebugConfig {
700    fn default() -> Self {
701        Self {
702            enable_tensor_inspection: true,
703            enable_gradient_debugging: true,
704            enable_activation_analysis: true,
705            enable_nan_detection: true,
706            enable_gradient_explosion_detection: true,
707            max_tensor_inspection_size: 1000000,
708            debug_dir: PathBuf::from("./debug"),
709            enable_interactive: false,
710            tensor_display_precision: 4,
711            enable_layer_debugging: true,
712        }
713    }
714}
715
716impl ModelDebugger {
717    /// Create a new model debugger
718    pub fn new(config: DebugConfig) -> Result<Self> {
719        std::fs::create_dir_all(&config.debug_dir)?;
720
721        Ok(Self {
722            active_sessions: HashMap::new(),
723            hooks: DebugHooksRegistry::new(),
724            tensor_inspector: TensorInspector::new(TensorInspectionConfig::default()),
725            gradient_debugger: GradientDebugger::new(),
726            activation_analyzer: ActivationAnalyzer::new(),
727            config,
728        })
729    }
730
731    /// Start a debugging session
732    pub fn start_debugging(&mut self, model_id: &str) -> Result<String> {
733        let session_id = format!(
734            "debug_{}_{}",
735            model_id,
736            SystemTime::now()
737                .duration_since(std::time::UNIX_EPOCH)
738                .expect("system time should be after UNIX epoch")
739                .as_secs()
740        );
741
742        let session = DebugSession {
743            session_id: session_id.clone(),
744            model_id: model_id.to_string(),
745            start_time: SystemTime::now(),
746            active_hooks: Vec::new(),
747            debug_info: DebugInfo::new(),
748            interactive_state: if self.config.enable_interactive {
749                Some(InteractiveDebugState::new())
750            } else {
751                None
752            },
753            statistics: DebugStatistics::default(),
754        };
755
756        self.active_sessions.insert(session_id.clone(), session);
757
758        println!("Started debugging session: {}", session_id);
759        Ok(session_id)
760    }
761
762    /// Stop debugging and generate report
763    pub fn stop_debugging(&mut self, session_id: &str) -> Result<DebugReport> {
764        let session = self.active_sessions.remove(session_id).ok_or_else(|| {
765            TorshError::InvalidArgument(format!("Unknown session: {}", session_id))
766        })?;
767
768        let report = self.generate_debug_report(session)?;
769
770        // Save debug report
771        self.save_debug_report(session_id, &report)?;
772
773        println!("Completed debugging session: {}", session_id);
774        Ok(report)
775    }
776
777    /// Add a debug hook
778    pub fn add_hook(&mut self, session_id: &str, hook: DebugHook) -> Result<()> {
779        if let Some(session) = self.active_sessions.get_mut(session_id) {
780            session.active_hooks.push(hook.clone());
781        }
782
783        self.hooks.register_hook(hook)?;
784        Ok(())
785    }
786
787    /// Remove a debug hook
788    pub fn remove_hook(&mut self, session_id: &str, hook_id: &str) -> Result<()> {
789        if let Some(session) = self.active_sessions.get_mut(session_id) {
790            session.active_hooks.retain(|h| h.hook_id != hook_id);
791        }
792
793        self.hooks.unregister_hook(hook_id)?;
794        Ok(())
795    }
796
797    /// Inspect a tensor
798    pub fn inspect_tensor(
799        &mut self,
800        session_id: &str,
801        tensor_name: &str,
802        tensor_data: &[f32],
803        shape: &[usize],
804    ) -> Result<TensorSnapshot> {
805        let snapshot = self
806            .tensor_inspector
807            .inspect(tensor_name, tensor_data, shape)?;
808
809        if let Some(session) = self.active_sessions.get_mut(session_id) {
810            session
811                .debug_info
812                .tensor_snapshots
813                .entry(tensor_name.to_string())
814                .or_insert_with(Vec::new)
815                .push(snapshot.clone());
816            session.statistics.tensors_inspected += 1;
817        }
818
819        Ok(snapshot)
820    }
821
822    /// Analyze gradients
823    pub fn analyze_gradients(
824        &mut self,
825        session_id: &str,
826        parameter_name: &str,
827        gradients: &[f32],
828    ) -> Result<GradientInfo> {
829        let gradient_info = self.gradient_debugger.analyze(parameter_name, gradients)?;
830
831        if let Some(session) = self.active_sessions.get_mut(session_id) {
832            session
833                .debug_info
834                .gradient_info
835                .entry(parameter_name.to_string())
836                .or_insert_with(Vec::new)
837                .push(gradient_info.clone());
838        }
839
840        Ok(gradient_info)
841    }
842
843    /// Analyze activations
844    pub fn analyze_activations(
845        &mut self,
846        session_id: &str,
847        layer_name: &str,
848        activations: &[f32],
849    ) -> Result<ActivationPattern> {
850        let pattern = self.activation_analyzer.analyze(layer_name, activations)?;
851
852        if let Some(session) = self.active_sessions.get_mut(session_id) {
853            session
854                .debug_info
855                .activation_patterns
856                .insert(layer_name.to_string(), pattern.clone());
857        }
858
859        Ok(pattern)
860    }
861
862    /// Detect anomalies
863    pub fn detect_anomalies(&mut self, session_id: &str) -> Result<Vec<Anomaly>> {
864        let mut anomalies = Vec::new();
865
866        if let Some(session) = self.active_sessions.get(session_id) {
867            // Check for NaN/Inf values
868            for (tensor_name, snapshots) in &session.debug_info.tensor_snapshots {
869                if let Some(latest) = snapshots.last() {
870                    if latest.statistics.nan_count > 0 {
871                        anomalies.push(Anomaly {
872                            anomaly_type: AnomalyType::NaNValues,
873                            severity: Severity::Critical,
874                            location: tensor_name.clone(),
875                            description: format!(
876                                "Found {} NaN values in tensor {}",
877                                latest.statistics.nan_count, tensor_name
878                            ),
879                            timestamp: SystemTime::now(),
880                            suggested_fixes: vec![
881                                "Check for division by zero".to_string(),
882                                "Verify input data quality".to_string(),
883                                "Add gradient clipping".to_string(),
884                            ],
885                            context: AnomalyContext {
886                                related_tensors: vec![tensor_name.clone()],
887                                stack_trace: None,
888                                model_state: ModelState::Training,
889                                environment: EnvironmentInfo {
890                                    device: "CPU".to_string(),
891                                    memory_usage: 1024 * 1024 * 100,
892                                    cpu_usage: 75.0,
893                                    temperature: None,
894                                },
895                            },
896                        });
897                    }
898
899                    if latest.statistics.inf_count > 0 {
900                        anomalies.push(Anomaly {
901                            anomaly_type: AnomalyType::InfiniteValues,
902                            severity: Severity::High,
903                            location: tensor_name.clone(),
904                            description: format!(
905                                "Found {} infinite values in tensor {}",
906                                latest.statistics.inf_count, tensor_name
907                            ),
908                            timestamp: SystemTime::now(),
909                            suggested_fixes: vec![
910                                "Check for overflow in computations".to_string(),
911                                "Reduce learning rate".to_string(),
912                                "Add numerical stability checks".to_string(),
913                            ],
914                            context: AnomalyContext {
915                                related_tensors: vec![tensor_name.clone()],
916                                stack_trace: None,
917                                model_state: ModelState::Training,
918                                environment: EnvironmentInfo {
919                                    device: "CPU".to_string(),
920                                    memory_usage: 1024 * 1024 * 100,
921                                    cpu_usage: 75.0,
922                                    temperature: None,
923                                },
924                            },
925                        });
926                    }
927                }
928            }
929
930            // Check for gradient explosion
931            for (param_name, gradient_infos) in &session.debug_info.gradient_info {
932                if let Some(latest) = gradient_infos.last() {
933                    if latest.gradient_stats.norm > 10.0 {
934                        anomalies.push(Anomaly {
935                            anomaly_type: AnomalyType::GradientExplosion,
936                            severity: Severity::High,
937                            location: param_name.clone(),
938                            description: format!(
939                                "Gradient explosion detected: norm = {:.4}",
940                                latest.gradient_stats.norm
941                            ),
942                            timestamp: SystemTime::now(),
943                            suggested_fixes: vec![
944                                "Add gradient clipping".to_string(),
945                                "Reduce learning rate".to_string(),
946                                "Check for numerical instability".to_string(),
947                            ],
948                            context: AnomalyContext {
949                                related_tensors: vec![param_name.clone()],
950                                stack_trace: None,
951                                model_state: ModelState::Training,
952                                environment: EnvironmentInfo {
953                                    device: "CPU".to_string(),
954                                    memory_usage: 1024 * 1024 * 100,
955                                    cpu_usage: 75.0,
956                                    temperature: None,
957                                },
958                            },
959                        });
960                    }
961                }
962            }
963        }
964
965        // Update session with detected anomalies
966        if let Some(session) = self.active_sessions.get_mut(session_id) {
967            session.debug_info.anomalies.extend(anomalies.clone());
968            session.statistics.anomalies_detected += anomalies.len();
969        }
970
971        Ok(anomalies)
972    }
973
974    /// Generate debug report
975    fn generate_debug_report(&self, session: DebugSession) -> Result<DebugReport> {
976        let duration = SystemTime::now()
977            .duration_since(session.start_time)
978            .unwrap_or_default();
979
980        Ok(DebugReport {
981            session_info: DebugSessionInfo {
982                session_id: session.session_id,
983                model_id: session.model_id,
984                start_time: session.start_time,
985                end_time: SystemTime::now(),
986                duration,
987            },
988            debug_info: session.debug_info,
989            summary: DebugSummary {
990                total_anomalies: session.statistics.anomalies_detected,
991                critical_issues: 0,      // Would be calculated
992                performance_score: 0.85, // Would be calculated
993                stability_score: 0.78,   // Would be calculated
994                recommendations: vec![
995                    "Consider adding gradient clipping".to_string(),
996                    "Monitor memory usage more closely".to_string(),
997                ],
998            },
999            statistics: session.statistics,
1000        })
1001    }
1002
1003    fn save_debug_report(&self, session_id: &str, report: &DebugReport) -> Result<()> {
1004        let file_path = self
1005            .config
1006            .debug_dir
1007            .join(format!("{}_debug_report.json", session_id));
1008        let content = serde_json::to_string_pretty(report)?;
1009        std::fs::write(file_path, content)?;
1010        Ok(())
1011    }
1012}
1013
1014/// Debug report structure
1015#[derive(Debug, Clone, Serialize, Deserialize)]
1016pub struct DebugReport {
1017    /// Session information
1018    pub session_info: DebugSessionInfo,
1019    /// Debug information collected
1020    pub debug_info: DebugInfo,
1021    /// Debug statistics
1022    pub statistics: DebugStatistics,
1023    /// Summary and recommendations
1024    pub summary: DebugSummary,
1025}
1026
1027#[derive(Debug, Clone, Serialize, Deserialize)]
1028pub struct DebugSessionInfo {
1029    pub session_id: String,
1030    pub model_id: String,
1031    pub start_time: SystemTime,
1032    pub end_time: SystemTime,
1033    pub duration: Duration,
1034}
1035
1036#[derive(Debug, Clone, Serialize, Deserialize)]
1037pub struct DebugSummary {
1038    pub total_anomalies: usize,
1039    pub critical_issues: usize,
1040    pub performance_score: f32,
1041    pub stability_score: f32,
1042    pub recommendations: Vec<String>,
1043}
1044
1045// Implementation for component structs
1046impl DebugInfo {
1047    fn new() -> Self {
1048        Self {
1049            tensor_snapshots: HashMap::new(),
1050            gradient_info: HashMap::new(),
1051            activation_patterns: HashMap::new(),
1052            anomalies: Vec::new(),
1053            performance_issues: Vec::new(),
1054            health_metrics: ModelHealthMetrics::default(),
1055        }
1056    }
1057}
1058
1059impl InteractiveDebugState {
1060    fn new() -> Self {
1061        Self {
1062            current_breakpoint: None,
1063            execution_stack: Vec::new(),
1064            available_commands: vec![
1065                DebugCommand::Continue,
1066                DebugCommand::Step,
1067                DebugCommand::StepInto,
1068                DebugCommand::StepOut,
1069                DebugCommand::ListVariables,
1070            ],
1071            variable_inspector: VariableInspector::new(),
1072            step_mode: StepMode::Normal,
1073        }
1074    }
1075}
1076
1077impl VariableInspector {
1078    fn new() -> Self {
1079        Self {
1080            variables: HashMap::new(),
1081            watch_list: Vec::new(),
1082        }
1083    }
1084}
1085
1086impl DebugHooksRegistry {
1087    fn new() -> Self {
1088        Self {
1089            hooks: HashMap::new(),
1090            execution_order: Vec::new(),
1091        }
1092    }
1093
1094    fn register_hook(&mut self, hook: DebugHook) -> Result<()> {
1095        self.hooks.insert(hook.hook_id.clone(), hook.clone());
1096        self.execution_order.push(hook.hook_id);
1097        Ok(())
1098    }
1099
1100    fn unregister_hook(&mut self, hook_id: &str) -> Result<()> {
1101        self.hooks.remove(hook_id);
1102        self.execution_order.retain(|id| id != hook_id);
1103        Ok(())
1104    }
1105}
1106
1107impl TensorInspector {
1108    fn new(config: TensorInspectionConfig) -> Self {
1109        Self {
1110            config,
1111            tensor_cache: HashMap::new(),
1112        }
1113    }
1114
1115    fn inspect(
1116        &mut self,
1117        tensor_name: &str,
1118        data: &[f32],
1119        shape: &[usize],
1120    ) -> Result<TensorSnapshot> {
1121        let statistics = self.calculate_statistics(data);
1122
1123        let (sample_values, full_values) = if data.len() > self.config.max_display_values {
1124            (self.sample_values(data), None)
1125        } else {
1126            (Vec::new(), Some(data.to_vec()))
1127        };
1128
1129        Ok(TensorSnapshot {
1130            timestamp: SystemTime::now(),
1131            shape: shape.to_vec(),
1132            dtype: "f32".to_string(),
1133            device: "CPU".to_string(),
1134            statistics,
1135            sample_values,
1136            full_values,
1137            metadata: TensorMetadata {
1138                name: tensor_name.to_string(),
1139                producer_layer: None,
1140                requires_grad: false,
1141                memory_usage: std::mem::size_of_val(data) as u64,
1142                created_at: SystemTime::now(),
1143            },
1144        })
1145    }
1146
1147    fn calculate_statistics(&self, data: &[f32]) -> TensorStatistics {
1148        let len = data.len();
1149        if len == 0 {
1150            return TensorStatistics::default();
1151        }
1152
1153        let sum: f32 = data.iter().sum();
1154        let mean = sum / len as f32;
1155
1156        let variance: f32 = data.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / len as f32;
1157        let std = variance.sqrt();
1158
1159        let min = data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
1160        let max = data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
1161
1162        let nan_count = data.iter().filter(|&&x| x.is_nan()).count();
1163        let inf_count = data.iter().filter(|&&x| x.is_infinite()).count();
1164        let zero_count = data.iter().filter(|&&x| x == 0.0).count();
1165
1166        let sparsity = zero_count as f32 / len as f32;
1167
1168        TensorStatistics {
1169            mean,
1170            std,
1171            min,
1172            max,
1173            nan_count,
1174            inf_count,
1175            zero_count,
1176            sparsity,
1177            distribution: ValueDistribution::default(),
1178            gradient_norm: None,
1179        }
1180    }
1181
1182    fn sample_values(&self, data: &[f32]) -> Vec<f32> {
1183        let sample_size = self.config.sample_size.min(data.len());
1184        let step = data.len() / sample_size;
1185        (0..sample_size).map(|i| data[i * step]).collect()
1186    }
1187}
1188
1189impl GradientDebugger {
1190    fn new() -> Self {
1191        Self {
1192            gradient_stats: HashMap::new(),
1193            explosion_detector: GradientExplosionDetector::new(10.0),
1194            vanishing_detector: GradientVanishingDetector::new(1e-6),
1195        }
1196    }
1197
1198    fn analyze(&mut self, parameter_name: &str, gradients: &[f32]) -> Result<GradientInfo> {
1199        let stats = self.calculate_gradient_statistics(gradients);
1200        let norm = stats.norm;
1201
1202        Ok(GradientInfo {
1203            parameter_name: parameter_name.to_string(),
1204            gradient_flow: GradientFlow {
1205                direction: FlowDirection::Backward,
1206                magnitude: norm,
1207                bottlenecks: vec![],
1208                efficiency: 0.85,
1209            },
1210            update_magnitude: norm,
1211            gradient_stats: stats,
1212            clipped: false,
1213        })
1214    }
1215
1216    fn calculate_gradient_statistics(&self, gradients: &[f32]) -> GradientStatistics {
1217        if gradients.is_empty() {
1218            return GradientStatistics::default();
1219        }
1220
1221        let norm = gradients.iter().map(|x| x * x).sum::<f32>().sqrt();
1222        let mean = gradients.iter().sum::<f32>() / gradients.len() as f32;
1223        let variance =
1224            gradients.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / gradients.len() as f32;
1225        let std = variance.sqrt();
1226        let max = gradients.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
1227        let min = gradients.iter().fold(f32::INFINITY, |a, &b| a.min(b));
1228        let zero_count = gradients.iter().filter(|&&x| x == 0.0).count();
1229        let sparsity = zero_count as f32 / gradients.len() as f32;
1230
1231        GradientStatistics {
1232            norm,
1233            mean,
1234            std,
1235            max,
1236            min,
1237            zero_count,
1238            sparsity,
1239        }
1240    }
1241}
1242
1243impl ActivationAnalyzer {
1244    fn new() -> Self {
1245        Self {
1246            activation_stats: HashMap::new(),
1247            dead_neuron_detector: DeadNeuronDetector::new(0.01),
1248            distribution_analyzer: ActivationDistributionAnalyzer::new(),
1249        }
1250    }
1251
1252    fn analyze(&mut self, layer_name: &str, activations: &[f32]) -> Result<ActivationPattern> {
1253        let stats = self.calculate_activation_statistics(activations);
1254        let dead_neurons = self.dead_neuron_detector.detect(layer_name, activations);
1255        let distribution = self.distribution_analyzer.analyze(activations);
1256
1257        Ok(ActivationPattern {
1258            layer_name: layer_name.to_string(),
1259            activation_stats: stats,
1260            dead_neurons,
1261            saturated_neurons: vec![], // Would implement saturation detection
1262            distribution,
1263        })
1264    }
1265
1266    fn calculate_activation_statistics(&self, activations: &[f32]) -> ActivationStatistics {
1267        if activations.is_empty() {
1268            return ActivationStatistics::default();
1269        }
1270
1271        let mean = activations.iter().sum::<f32>() / activations.len() as f32;
1272        let variance =
1273            activations.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / activations.len() as f32;
1274        let min = activations.iter().fold(f32::INFINITY, |a, &b| a.min(b));
1275        let max = activations.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
1276        let active_count = activations.iter().filter(|&&x| x > 0.01).count();
1277        let active_percentage = active_count as f32 / activations.len() as f32;
1278
1279        ActivationStatistics {
1280            mean_activation: mean,
1281            activation_variance: variance,
1282            activation_range: (min, max),
1283            active_percentage,
1284            entropy: 0.0, // Would calculate entropy
1285        }
1286    }
1287}
1288
1289// Component implementations
1290impl GradientExplosionDetector {
1291    fn new(threshold: f32) -> Self {
1292        Self {
1293            threshold,
1294            history: Vec::new(),
1295            window_size: 10,
1296        }
1297    }
1298}
1299
1300impl GradientVanishingDetector {
1301    fn new(threshold: f32) -> Self {
1302        Self {
1303            threshold,
1304            layer_gradients: HashMap::new(),
1305        }
1306    }
1307}
1308
1309impl DeadNeuronDetector {
1310    fn new(threshold: f32) -> Self {
1311        Self {
1312            threshold,
1313            window: 100,
1314            neuron_states: HashMap::new(),
1315        }
1316    }
1317
1318    fn detect(&mut self, _layer_name: &str, activations: &[f32]) -> Vec<usize> {
1319        activations
1320            .iter()
1321            .enumerate()
1322            .filter_map(|(i, &activation)| {
1323                if activation.abs() < self.threshold {
1324                    Some(i)
1325                } else {
1326                    None
1327                }
1328            })
1329            .collect()
1330    }
1331}
1332
1333impl ActivationDistributionAnalyzer {
1334    fn new() -> Self {
1335        Self {
1336            distributions: HashMap::new(),
1337            config: DistributionAnalysisConfig::default(),
1338        }
1339    }
1340
1341    fn analyze(&mut self, _activations: &[f32]) -> ActivationDistribution {
1342        // Simple implementation - would use proper statistical analysis
1343        ActivationDistribution {
1344            distribution_type: DistributionType::Normal,
1345            parameters: HashMap::new(),
1346            goodness_of_fit: 0.85,
1347        }
1348    }
1349}
1350
1351// Default implementations
1352impl Default for TensorStatistics {
1353    fn default() -> Self {
1354        Self {
1355            mean: 0.0,
1356            std: 0.0,
1357            min: 0.0,
1358            max: 0.0,
1359            nan_count: 0,
1360            inf_count: 0,
1361            zero_count: 0,
1362            sparsity: 0.0,
1363            distribution: ValueDistribution::default(),
1364            gradient_norm: None,
1365        }
1366    }
1367}
1368
1369impl Default for ValueDistribution {
1370    fn default() -> Self {
1371        Self {
1372            histogram: vec![],
1373            bin_edges: vec![],
1374            percentiles: HashMap::new(),
1375            kurtosis: 0.0,
1376            skewness: 0.0,
1377        }
1378    }
1379}
1380
1381impl Default for GradientStatistics {
1382    fn default() -> Self {
1383        Self {
1384            norm: 0.0,
1385            mean: 0.0,
1386            std: 0.0,
1387            max: 0.0,
1388            min: 0.0,
1389            zero_count: 0,
1390            sparsity: 0.0,
1391        }
1392    }
1393}
1394
1395impl Default for ActivationStatistics {
1396    fn default() -> Self {
1397        Self {
1398            mean_activation: 0.0,
1399            activation_variance: 0.0,
1400            activation_range: (0.0, 0.0),
1401            active_percentage: 0.0,
1402            entropy: 0.0,
1403        }
1404    }
1405}
1406
1407impl Default for ModelHealthMetrics {
1408    fn default() -> Self {
1409        Self {
1410            overall_health: 1.0,
1411            gradient_health: 1.0,
1412            activation_health: 1.0,
1413            memory_health: 1.0,
1414            performance_health: 1.0,
1415            stability_indicators: StabilityIndicators::default(),
1416        }
1417    }
1418}
1419
1420impl Default for StabilityIndicators {
1421    fn default() -> Self {
1422        Self {
1423            numerical_stability: 1.0,
1424            training_stability: 1.0,
1425            memory_stability: 1.0,
1426            convergence_indicators: ConvergenceIndicators::default(),
1427        }
1428    }
1429}
1430
1431impl Default for ConvergenceIndicators {
1432    fn default() -> Self {
1433        Self {
1434            loss_convergence: 1.0,
1435            gradient_convergence: 1.0,
1436            parameter_convergence: 1.0,
1437            validation_convergence: 1.0,
1438        }
1439    }
1440}
1441
1442impl Default for DebugStatistics {
1443    fn default() -> Self {
1444        Self {
1445            hooks_triggered: 0,
1446            anomalies_detected: 0,
1447            tensors_inspected: 0,
1448            total_debug_time: Duration::from_secs(0),
1449            debug_overhead: 0.0,
1450        }
1451    }
1452}
1453
1454impl Default for TensorInspectionConfig {
1455    fn default() -> Self {
1456        Self {
1457            max_display_values: 100,
1458            enable_histograms: true,
1459            histogram_bins: 50,
1460            enable_distribution_analysis: true,
1461            sample_size: 1000,
1462        }
1463    }
1464}
1465
1466impl Default for DistributionAnalysisConfig {
1467    fn default() -> Self {
1468        Self {
1469            enable_fitting: true,
1470            confidence_level: 0.95,
1471            sample_size: 10000,
1472        }
1473    }
1474}
1475
1476#[cfg(test)]
1477mod tests {
1478    use super::*;
1479
1480    #[test]
1481    fn test_debugger_creation() {
1482        let config = DebugConfig::default();
1483        let debugger = ModelDebugger::new(config);
1484        assert!(debugger.is_ok());
1485    }
1486
1487    #[test]
1488    fn test_debugging_session() {
1489        let config = DebugConfig::default();
1490        let mut debugger = ModelDebugger::new(config).unwrap();
1491
1492        let session_id = debugger.start_debugging("test_model").unwrap();
1493        assert!(!session_id.is_empty());
1494
1495        let result = debugger.stop_debugging(&session_id);
1496        assert!(result.is_ok());
1497    }
1498
1499    #[test]
1500    fn test_tensor_inspection() {
1501        let config = DebugConfig::default();
1502        let mut debugger = ModelDebugger::new(config).unwrap();
1503
1504        let session_id = debugger.start_debugging("test_model").unwrap();
1505
1506        let test_data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1507        let shape = vec![5];
1508
1509        let snapshot = debugger
1510            .inspect_tensor(&session_id, "test_tensor", &test_data, &shape)
1511            .unwrap();
1512
1513        assert_eq!(snapshot.shape, shape);
1514        assert_eq!(snapshot.statistics.mean, 3.0);
1515    }
1516
1517    #[test]
1518    fn test_anomaly_detection() {
1519        let config = DebugConfig::default();
1520        let mut debugger = ModelDebugger::new(config).unwrap();
1521
1522        let session_id = debugger.start_debugging("test_model").unwrap();
1523
1524        // Add tensor with NaN values
1525        let bad_data = vec![1.0, 2.0, f32::NAN, 4.0, 5.0];
1526        let shape = vec![5];
1527        debugger
1528            .inspect_tensor(&session_id, "bad_tensor", &bad_data, &shape)
1529            .unwrap();
1530
1531        let anomalies = debugger.detect_anomalies(&session_id).unwrap();
1532        assert!(!anomalies.is_empty());
1533
1534        let nan_anomaly = anomalies
1535            .iter()
1536            .find(|a| matches!(a.anomaly_type, AnomalyType::NaNValues));
1537        assert!(nan_anomaly.is_some());
1538    }
1539}