Skip to main content

trustformers_debug/gradient_debugger/
visualization.rs

1//! Gradient Flow Visualization Data Generation
2//!
3//! This module provides comprehensive visualization data generation for gradient flow
4//! analysis, including network topology, temporal flows, and critical path identification.
5
6use super::types::*;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10/// Complete gradient flow visualization data
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct GradientFlowVisualization {
13    pub layer_flows: HashMap<String, GradientLayerFlow>,
14    pub temporal_flows: Vec<TemporalGradientFlow>,
15    pub flow_network: GradientFlowNetwork,
16    pub critical_paths: Vec<CriticalGradientPath>,
17    pub vanishing_regions: Vec<VanishingRegion>,
18    pub exploding_regions: Vec<ExplodingRegion>,
19    pub dead_zones: Vec<GradientDeadZone>,
20    pub visualization_config: GradientVisualizationConfig,
21}
22
23/// Gradient flow data for a specific layer
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct GradientLayerFlow {
26    pub layer_name: String,
27    pub gradient_magnitudes: Vec<f64>,
28    pub gradient_directions: Vec<GradientDirection>,
29    pub flow_consistency: f64,
30    pub bottleneck_score: f64,
31    pub information_flow_rate: f64,
32}
33
34/// Direction and characteristics of gradient flow
35#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct GradientDirection {
37    pub step: usize,
38    pub direction_vector: Vec<f64>,
39    pub magnitude: f64,
40    pub consistency_score: f64,
41}
42
43/// Temporal gradient flow at a specific timestep
44#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct TemporalGradientFlow {
46    pub step: usize,
47    pub layer_name: String,
48    pub gradient_magnitude: f64,
49    pub flow_direction: FlowDirection,
50    pub stability_score: f64,
51}
52
53/// Flow direction classification
54#[derive(Debug, Clone, Serialize, Deserialize)]
55pub enum FlowDirection {
56    Forward,
57    Backward,
58    Oscillating,
59    Stagnant,
60}
61
62/// Network representation of gradient flows
63#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct GradientFlowNetwork {
65    pub nodes: Vec<FlowNode>,
66    pub edges: Vec<FlowEdge>,
67    pub network_metrics: NetworkMetrics,
68}
69
70/// Node in the gradient flow network
71#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct FlowNode {
73    pub layer_name: String,
74    pub node_type: NodeType,
75    pub gradient_strength: f64,
76    pub connectivity: usize,
77    pub influence_score: f64,
78}
79
80/// Type of node in the flow network
81#[derive(Debug, Clone, Serialize, Deserialize)]
82pub enum NodeType {
83    Source,
84    Sink,
85    Bottleneck,
86    Amplifier,
87    Normal,
88}
89
90/// Edge in the gradient flow network
91#[derive(Debug, Clone, Serialize, Deserialize)]
92pub struct FlowEdge {
93    pub from_layer: String,
94    pub to_layer: String,
95    pub flow_strength: f64,
96    pub flow_consistency: f64,
97    pub edge_type: EdgeType,
98}
99
100/// Type of edge in the flow network
101#[derive(Debug, Clone, Serialize, Deserialize)]
102pub enum EdgeType {
103    Strong,
104    Weak,
105    Intermittent,
106    Blocked,
107}
108
109/// Network-level gradient flow metrics
110#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct NetworkMetrics {
112    pub overall_flow_efficiency: f64,
113    pub network_connectivity: f64,
114    pub bottleneck_density: f64,
115    pub flow_stability: f64,
116    pub information_propagation_speed: f64,
117}
118
119/// Critical path in gradient flow
120#[derive(Debug, Clone, Serialize, Deserialize)]
121pub struct CriticalGradientPath {
122    pub path_id: String,
123    pub layers: Vec<String>,
124    pub path_length: usize,
125    pub total_flow_strength: f64,
126    pub bottleneck_layers: Vec<String>,
127    pub criticality_score: f64,
128    pub optimization_potential: f64,
129}
130
131/// Region where gradients are vanishing
132#[derive(Debug, Clone, Serialize, Deserialize)]
133pub struct VanishingRegion {
134    pub region_id: String,
135    pub affected_layers: Vec<String>,
136    pub severity_level: VanishingSeverity,
137    pub extent: RegionExtent,
138    pub mitigation_suggestions: Vec<String>,
139}
140
141#[derive(Debug, Clone, Serialize, Deserialize)]
142pub enum VanishingSeverity {
143    Mild,
144    Moderate,
145    Severe,
146    Critical,
147}
148
149/// Region where gradients are exploding
150#[derive(Debug, Clone, Serialize, Deserialize)]
151pub struct ExplodingRegion {
152    pub region_id: String,
153    pub affected_layers: Vec<String>,
154    pub severity_level: ExplodingSeverity,
155    pub extent: RegionExtent,
156    pub mitigation_suggestions: Vec<String>,
157}
158
159#[derive(Debug, Clone, Serialize, Deserialize)]
160pub enum ExplodingSeverity {
161    Mild,
162    Moderate,
163    Severe,
164    Critical,
165}
166
167/// Spatial extent of a gradient region
168#[derive(Debug, Clone, Serialize, Deserialize)]
169pub struct RegionExtent {
170    pub start_layer: String,
171    pub end_layer: String,
172    pub affected_parameters: usize,
173    pub duration_steps: usize,
174}
175
176/// Zone where gradient flow is effectively dead
177#[derive(Debug, Clone, Serialize, Deserialize)]
178pub struct GradientDeadZone {
179    pub zone_id: String,
180    pub affected_layers: Vec<String>,
181    pub dead_duration: usize,
182    pub recovery_potential: RecoveryPotential,
183    pub intervention_required: bool,
184}
185
186#[derive(Debug, Clone, Serialize, Deserialize)]
187pub enum RecoveryPotential {
188    High,
189    Medium,
190    Low,
191    None,
192}
193
194/// Configuration for gradient visualization
195#[derive(Debug, Clone, Serialize, Deserialize)]
196pub struct GradientVisualizationConfig {
197    pub show_temporal_flows: bool,
198    pub show_critical_paths: bool,
199    pub show_problem_regions: bool,
200    pub color_scheme: ColorScheme,
201    pub temporal_window: usize,
202    pub flow_threshold: f64,
203}
204
205impl Default for GradientVisualizationConfig {
206    fn default() -> Self {
207        Self {
208            show_temporal_flows: true,
209            show_critical_paths: true,
210            show_problem_regions: true,
211            color_scheme: ColorScheme::Default,
212            temporal_window: 50,
213            flow_threshold: 0.01,
214        }
215    }
216}
217
218#[derive(Debug, Clone, Serialize, Deserialize)]
219pub enum ColorScheme {
220    Default,
221    HighContrast,
222    ColorBlind,
223    Monochrome,
224}
225
226/// Gradient flow visualization generator
227#[derive(Debug, Default)]
228pub struct GradientFlowVisualizer {
229    config: GradientVisualizationConfig,
230}
231
232impl GradientFlowVisualizer {
233    pub fn new(config: GradientVisualizationConfig) -> Self {
234        Self { config }
235    }
236
237    pub fn generate_visualization(
238        &self,
239        gradient_histories: &HashMap<String, GradientHistory>,
240        current_step: usize,
241    ) -> GradientFlowVisualization {
242        let layer_flows = self.generate_layer_flows(gradient_histories);
243        let temporal_flows = self.generate_temporal_flows(gradient_histories, current_step);
244        let flow_network = self.build_gradient_flow_network(&layer_flows);
245        let critical_paths = self.identify_critical_gradient_paths(&flow_network);
246        let vanishing_regions = self.identify_vanishing_regions(gradient_histories);
247        let exploding_regions = self.identify_exploding_regions(gradient_histories);
248        let dead_zones = self.identify_gradient_dead_zones(gradient_histories);
249
250        GradientFlowVisualization {
251            layer_flows,
252            temporal_flows,
253            flow_network,
254            critical_paths,
255            vanishing_regions,
256            exploding_regions,
257            dead_zones,
258            visualization_config: self.config.clone(),
259        }
260    }
261
262    fn generate_layer_flows(
263        &self,
264        gradient_histories: &HashMap<String, GradientHistory>,
265    ) -> HashMap<String, GradientLayerFlow> {
266        let mut layer_flows = HashMap::new();
267
268        for (layer_name, history) in gradient_histories {
269            let gradient_magnitudes: Vec<f64> = history.gradient_norms.iter().cloned().collect();
270            let gradient_directions = self.compute_gradient_directions(history);
271            let flow_consistency = self.compute_flow_consistency(history);
272            let bottleneck_score = self.compute_bottleneck_score(history);
273            let information_flow_rate = self.compute_information_flow_rate(history);
274
275            let flow_data = GradientLayerFlow {
276                layer_name: layer_name.clone(),
277                gradient_magnitudes,
278                gradient_directions,
279                flow_consistency,
280                bottleneck_score,
281                information_flow_rate,
282            };
283
284            layer_flows.insert(layer_name.clone(), flow_data);
285        }
286
287        layer_flows
288    }
289
290    fn compute_gradient_directions(&self, history: &GradientHistory) -> Vec<GradientDirection> {
291        let mut directions = Vec::new();
292
293        for (i, (&norm, &step)) in
294            history.gradient_norms.iter().zip(history.step_numbers.iter()).enumerate()
295        {
296            // Simplified direction computation - in practice, this would use actual gradient vectors
297            let direction_vector = vec![norm]; // Placeholder
298            let magnitude = norm;
299            let consistency_score = if i > 0 {
300                let prev_norm = history.gradient_norms[i - 1];
301                1.0 - ((norm - prev_norm).abs() / (norm + prev_norm + 1e-8))
302            } else {
303                1.0
304            };
305
306            directions.push(GradientDirection {
307                step,
308                direction_vector,
309                magnitude,
310                consistency_score,
311            });
312        }
313
314        directions
315    }
316
317    fn compute_flow_consistency(&self, history: &GradientHistory) -> f64 {
318        if history.gradient_norms.len() < 2 {
319            return 1.0;
320        }
321
322        let variations: Vec<f64> = history
323            .gradient_norms
324            .iter()
325            .collect::<Vec<&f64>>()
326            .windows(2)
327            .map(|pair| (*pair[1] - *pair[0]).abs() / (*pair[0] + 1e-8))
328            .collect();
329
330        let avg_variation = variations.iter().sum::<f64>() / variations.len() as f64;
331        (1.0_f64 / (1.0 + avg_variation)).min(1.0)
332    }
333
334    fn compute_bottleneck_score(&self, history: &GradientHistory) -> f64 {
335        if history.gradient_norms.is_empty() {
336            return 0.0;
337        }
338
339        let mean = history.gradient_norms.iter().sum::<f64>() / history.gradient_norms.len() as f64;
340        let min_val = history.gradient_norms.iter().cloned().fold(f64::INFINITY, f64::min);
341
342        if mean == 0.0 {
343            return 1.0;
344        }
345
346        1.0 - (min_val / mean).min(1.0)
347    }
348
349    fn compute_information_flow_rate(&self, history: &GradientHistory) -> f64 {
350        if history.gradient_norms.len() < 2 {
351            return 0.0;
352        }
353
354        let total_change: f64 = history
355            .gradient_norms
356            .iter()
357            .collect::<Vec<&f64>>()
358            .windows(2)
359            .map(|pair| (*pair[1] - *pair[0]).abs())
360            .sum();
361
362        let time_span = history.gradient_norms.len() as f64;
363        total_change / time_span
364    }
365
366    fn generate_temporal_flows(
367        &self,
368        gradient_histories: &HashMap<String, GradientHistory>,
369        current_step: usize,
370    ) -> Vec<TemporalGradientFlow> {
371        let mut temporal_flows = Vec::new();
372
373        for (layer_name, history) in gradient_histories {
374            if let Some(latest_norm) = history.gradient_norms.back() {
375                let flow_direction = self.get_latest_flow_direction(history);
376                let stability_score = self.compute_stability_score(history);
377
378                temporal_flows.push(TemporalGradientFlow {
379                    step: current_step,
380                    layer_name: layer_name.clone(),
381                    gradient_magnitude: *latest_norm,
382                    flow_direction,
383                    stability_score,
384                });
385            }
386        }
387
388        temporal_flows
389    }
390
391    fn get_latest_flow_direction(&self, history: &GradientHistory) -> FlowDirection {
392        if history.gradient_norms.len() < 3 {
393            return FlowDirection::Forward;
394        }
395
396        let recent: Vec<f64> = history.gradient_norms.iter().rev().take(3).cloned().collect();
397        let trend = recent[0] - recent[2]; // Latest - oldest in recent window
398
399        if trend.abs() < 1e-6 {
400            FlowDirection::Stagnant
401        } else if trend > 0.0 {
402            FlowDirection::Forward
403        } else {
404            // Check for oscillation
405            let changes: Vec<f64> = recent.windows(2).map(|pair| pair[0] - pair[1]).collect();
406            let sign_changes = changes.windows(2).filter(|pair| pair[0] * pair[1] < 0.0).count();
407
408            if sign_changes > 0 {
409                FlowDirection::Oscillating
410            } else {
411                FlowDirection::Backward
412            }
413        }
414    }
415
416    fn compute_stability_score(&self, history: &GradientHistory) -> f64 {
417        if history.gradient_norms.len() < 3 {
418            return 1.0;
419        }
420
421        let recent: Vec<f64> = history.gradient_norms.iter().rev().take(5).cloned().collect();
422        let mean = recent.iter().sum::<f64>() / recent.len() as f64;
423        let variance =
424            recent.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / recent.len() as f64;
425
426        1.0 / (1.0 + variance)
427    }
428
429    fn build_gradient_flow_network(
430        &self,
431        layer_flows: &HashMap<String, GradientLayerFlow>,
432    ) -> GradientFlowNetwork {
433        let mut nodes = Vec::new();
434        let mut edges = Vec::new();
435
436        // Create nodes
437        for (layer_name, flow) in layer_flows {
438            let node_type = self.classify_node_type(flow);
439            let gradient_strength = flow.gradient_magnitudes.iter().sum::<f64>()
440                / flow.gradient_magnitudes.len() as f64;
441            let connectivity = layer_flows.len(); // Simplified
442            let influence_score = gradient_strength * flow.flow_consistency;
443
444            nodes.push(FlowNode {
445                layer_name: layer_name.clone(),
446                node_type,
447                gradient_strength,
448                connectivity,
449                influence_score,
450            });
451        }
452
453        // Create edges (simplified - would need actual layer connectivity information)
454        let layer_names: Vec<String> = layer_flows.keys().cloned().collect();
455        for i in 0..layer_names.len().saturating_sub(1) {
456            let from_layer = &layer_names[i];
457            let to_layer = &layer_names[i + 1];
458
459            if let (Some(from_flow), Some(to_flow)) =
460                (layer_flows.get(from_layer), layer_flows.get(to_layer))
461            {
462                let flow_strength =
463                    (from_flow.information_flow_rate + to_flow.information_flow_rate) / 2.0;
464                let flow_consistency =
465                    (from_flow.flow_consistency + to_flow.flow_consistency) / 2.0;
466                let edge_type = self.classify_edge_type(flow_strength, flow_consistency);
467
468                edges.push(FlowEdge {
469                    from_layer: from_layer.clone(),
470                    to_layer: to_layer.clone(),
471                    flow_strength,
472                    flow_consistency,
473                    edge_type,
474                });
475            }
476        }
477
478        let network_metrics = self.compute_network_metrics(&nodes, &edges);
479
480        GradientFlowNetwork {
481            nodes,
482            edges,
483            network_metrics,
484        }
485    }
486
487    fn classify_node_type(&self, flow: &GradientLayerFlow) -> NodeType {
488        if flow.bottleneck_score > 0.8 {
489            NodeType::Bottleneck
490        } else if flow.information_flow_rate > 1.0 {
491            NodeType::Amplifier
492        } else if flow.gradient_magnitudes.iter().sum::<f64>() < 0.01 {
493            NodeType::Sink
494        } else if flow.gradient_magnitudes.iter().any(|&x| x > 10.0) {
495            NodeType::Source
496        } else {
497            NodeType::Normal
498        }
499    }
500
501    fn classify_edge_type(&self, flow_strength: f64, flow_consistency: f64) -> EdgeType {
502        if flow_strength > 1.0 && flow_consistency > 0.8 {
503            EdgeType::Strong
504        } else if flow_strength < 0.1 || flow_consistency < 0.3 {
505            EdgeType::Weak
506        } else if flow_consistency < 0.6 {
507            EdgeType::Intermittent
508        } else {
509            EdgeType::Blocked
510        }
511    }
512
513    fn compute_network_metrics(&self, nodes: &[FlowNode], edges: &[FlowEdge]) -> NetworkMetrics {
514        let overall_flow_efficiency =
515            edges.iter().map(|e| e.flow_strength).sum::<f64>() / edges.len().max(1) as f64;
516        let network_connectivity = edges.len() as f64
517            / (nodes.len().max(1) * (nodes.len().saturating_sub(1)).max(1)) as f64;
518        let bottleneck_density =
519            nodes.iter().filter(|n| matches!(n.node_type, NodeType::Bottleneck)).count() as f64
520                / nodes.len() as f64;
521        let flow_stability =
522            edges.iter().map(|e| e.flow_consistency).sum::<f64>() / edges.len().max(1) as f64;
523        let information_propagation_speed = overall_flow_efficiency * network_connectivity;
524
525        NetworkMetrics {
526            overall_flow_efficiency,
527            network_connectivity,
528            bottleneck_density,
529            flow_stability,
530            information_propagation_speed,
531        }
532    }
533
534    fn identify_critical_gradient_paths(
535        &self,
536        network: &GradientFlowNetwork,
537    ) -> Vec<CriticalGradientPath> {
538        let mut paths = Vec::new();
539
540        // Simplified path identification - would use graph algorithms in practice
541        if network.nodes.len() < 2 {
542            return paths;
543        }
544
545        let path_layers: Vec<String> = network.nodes.iter().map(|n| n.layer_name.clone()).collect();
546        let total_flow_strength = network.edges.iter().map(|e| e.flow_strength).sum();
547        let bottleneck_layers: Vec<String> = network
548            .nodes
549            .iter()
550            .filter(|n| matches!(n.node_type, NodeType::Bottleneck))
551            .map(|n| n.layer_name.clone())
552            .collect();
553
554        paths.push(CriticalGradientPath {
555            path_id: "main_path".to_string(),
556            path_length: path_layers.len(),
557            layers: path_layers,
558            total_flow_strength,
559            bottleneck_layers,
560            criticality_score: 0.8, // Simplified
561            optimization_potential: 0.6,
562        });
563
564        paths
565    }
566
567    fn identify_vanishing_regions(
568        &self,
569        gradient_histories: &HashMap<String, GradientHistory>,
570    ) -> Vec<VanishingRegion> {
571        let mut regions = Vec::new();
572        let mut region_id = 0;
573
574        for (layer_name, history) in gradient_histories {
575            let avg_gradient =
576                history.gradient_norms.iter().sum::<f64>() / history.gradient_norms.len() as f64;
577            if avg_gradient < 1e-5 {
578                region_id += 1;
579                regions.push(VanishingRegion {
580                    region_id: format!("vanishing_{}", region_id),
581                    affected_layers: vec![layer_name.clone()],
582                    severity_level: if avg_gradient < 1e-7 {
583                        VanishingSeverity::Critical
584                    } else {
585                        VanishingSeverity::Moderate
586                    },
587                    extent: RegionExtent {
588                        start_layer: layer_name.clone(),
589                        end_layer: layer_name.clone(),
590                        affected_parameters: 1000, // Placeholder
591                        duration_steps: history.gradient_norms.len(),
592                    },
593                    mitigation_suggestions: vec![
594                        "Consider better weight initialization".to_string(),
595                        "Add skip connections".to_string(),
596                        "Use gradient clipping".to_string(),
597                    ],
598                });
599            }
600        }
601
602        regions
603    }
604
605    fn identify_exploding_regions(
606        &self,
607        gradient_histories: &HashMap<String, GradientHistory>,
608    ) -> Vec<ExplodingRegion> {
609        let mut regions = Vec::new();
610        let mut region_id = 0;
611
612        for (layer_name, history) in gradient_histories {
613            let max_gradient = history.gradient_norms.iter().cloned().fold(0.0, f64::max);
614            if max_gradient > 100.0 {
615                region_id += 1;
616                regions.push(ExplodingRegion {
617                    region_id: format!("exploding_{}", region_id),
618                    affected_layers: vec![layer_name.clone()],
619                    severity_level: if max_gradient > 1000.0 {
620                        ExplodingSeverity::Critical
621                    } else {
622                        ExplodingSeverity::Moderate
623                    },
624                    extent: RegionExtent {
625                        start_layer: layer_name.clone(),
626                        end_layer: layer_name.clone(),
627                        affected_parameters: 1000, // Placeholder
628                        duration_steps: history.gradient_norms.len(),
629                    },
630                    mitigation_suggestions: vec![
631                        "Apply gradient clipping".to_string(),
632                        "Reduce learning rate".to_string(),
633                        "Check weight initialization".to_string(),
634                    ],
635                });
636            }
637        }
638
639        regions
640    }
641
642    fn identify_gradient_dead_zones(
643        &self,
644        gradient_histories: &HashMap<String, GradientHistory>,
645    ) -> Vec<GradientDeadZone> {
646        let mut dead_zones = Vec::new();
647        let mut zone_id = 0;
648
649        for (layer_name, history) in gradient_histories {
650            let zero_gradients = history.gradient_norms.iter().filter(|&&x| x < 1e-8).count();
651            let dead_ratio = zero_gradients as f64 / history.gradient_norms.len() as f64;
652
653            if dead_ratio > 0.5 {
654                zone_id += 1;
655                dead_zones.push(GradientDeadZone {
656                    zone_id: format!("dead_zone_{}", zone_id),
657                    affected_layers: vec![layer_name.clone()],
658                    dead_duration: zero_gradients,
659                    recovery_potential: if dead_ratio > 0.9 {
660                        RecoveryPotential::Low
661                    } else {
662                        RecoveryPotential::Medium
663                    },
664                    intervention_required: dead_ratio > 0.8,
665                });
666            }
667        }
668
669        dead_zones
670    }
671
672    /// Create comprehensive visualization data for gradient flows
673    pub fn create_visualization(
674        &self,
675        gradient_histories: &HashMap<String, GradientHistory>,
676    ) -> GradientFlowVisualization {
677        // Use existing methods to generate the visualization
678        self.generate_visualization(gradient_histories, 0)
679    }
680}