trustformers_debug/gradient_debugger/
visualization.rs

1//! Gradient Flow Visualization Data Generation
2//!
3//! This module provides comprehensive visualization data generation for gradient flow
4//! analysis, including network topology, temporal flows, and critical path identification.
5
6use super::types::*;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10/// Complete gradient flow visualization data
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct GradientFlowVisualization {
13    pub layer_flows: HashMap<String, GradientLayerFlow>,
14    pub temporal_flows: Vec<TemporalGradientFlow>,
15    pub flow_network: GradientFlowNetwork,
16    pub critical_paths: Vec<CriticalGradientPath>,
17    pub vanishing_regions: Vec<VanishingRegion>,
18    pub exploding_regions: Vec<ExplodingRegion>,
19    pub dead_zones: Vec<GradientDeadZone>,
20    pub visualization_config: GradientVisualizationConfig,
21}
22
23/// Gradient flow data for a specific layer
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct GradientLayerFlow {
26    pub layer_name: String,
27    pub gradient_magnitudes: Vec<f64>,
28    pub gradient_directions: Vec<GradientDirection>,
29    pub flow_consistency: f64,
30    pub bottleneck_score: f64,
31    pub information_flow_rate: f64,
32}
33
34/// Direction and characteristics of gradient flow
35#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct GradientDirection {
37    pub step: usize,
38    pub direction_vector: Vec<f64>,
39    pub magnitude: f64,
40    pub consistency_score: f64,
41}
42
43/// Temporal gradient flow at a specific timestep
44#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct TemporalGradientFlow {
46    pub step: usize,
47    pub layer_name: String,
48    pub gradient_magnitude: f64,
49    pub flow_direction: FlowDirection,
50    pub stability_score: f64,
51}
52
53/// Flow direction classification
54#[derive(Debug, Clone, Serialize, Deserialize)]
55pub enum FlowDirection {
56    Forward,
57    Backward,
58    Oscillating,
59    Stagnant,
60}
61
62/// Network representation of gradient flows
63#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct GradientFlowNetwork {
65    pub nodes: Vec<FlowNode>,
66    pub edges: Vec<FlowEdge>,
67    pub network_metrics: NetworkMetrics,
68}
69
70/// Node in the gradient flow network
71#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct FlowNode {
73    pub layer_name: String,
74    pub node_type: NodeType,
75    pub gradient_strength: f64,
76    pub connectivity: usize,
77    pub influence_score: f64,
78}
79
80/// Type of node in the flow network
81#[derive(Debug, Clone, Serialize, Deserialize)]
82pub enum NodeType {
83    Source,
84    Sink,
85    Bottleneck,
86    Amplifier,
87    Normal,
88}
89
90/// Edge in the gradient flow network
91#[derive(Debug, Clone, Serialize, Deserialize)]
92pub struct FlowEdge {
93    pub from_layer: String,
94    pub to_layer: String,
95    pub flow_strength: f64,
96    pub flow_consistency: f64,
97    pub edge_type: EdgeType,
98}
99
100/// Type of edge in the flow network
101#[derive(Debug, Clone, Serialize, Deserialize)]
102pub enum EdgeType {
103    Strong,
104    Weak,
105    Intermittent,
106    Blocked,
107}
108
109/// Network-level gradient flow metrics
110#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct NetworkMetrics {
112    pub overall_flow_efficiency: f64,
113    pub network_connectivity: f64,
114    pub bottleneck_density: f64,
115    pub flow_stability: f64,
116    pub information_propagation_speed: f64,
117}
118
119/// Critical path in gradient flow
120#[derive(Debug, Clone, Serialize, Deserialize)]
121pub struct CriticalGradientPath {
122    pub path_id: String,
123    pub layers: Vec<String>,
124    pub path_length: usize,
125    pub total_flow_strength: f64,
126    pub bottleneck_layers: Vec<String>,
127    pub criticality_score: f64,
128    pub optimization_potential: f64,
129}
130
131/// Region where gradients are vanishing
132#[derive(Debug, Clone, Serialize, Deserialize)]
133pub struct VanishingRegion {
134    pub region_id: String,
135    pub affected_layers: Vec<String>,
136    pub severity_level: VanishingSeverity,
137    pub extent: RegionExtent,
138    pub mitigation_suggestions: Vec<String>,
139}
140
141#[derive(Debug, Clone, Serialize, Deserialize)]
142pub enum VanishingSeverity {
143    Mild,
144    Moderate,
145    Severe,
146    Critical,
147}
148
149/// Region where gradients are exploding
150#[derive(Debug, Clone, Serialize, Deserialize)]
151pub struct ExplodingRegion {
152    pub region_id: String,
153    pub affected_layers: Vec<String>,
154    pub severity_level: ExplodingSeverity,
155    pub extent: RegionExtent,
156    pub mitigation_suggestions: Vec<String>,
157}
158
159#[derive(Debug, Clone, Serialize, Deserialize)]
160pub enum ExplodingSeverity {
161    Mild,
162    Moderate,
163    Severe,
164    Critical,
165}
166
167/// Spatial extent of a gradient region
168#[derive(Debug, Clone, Serialize, Deserialize)]
169pub struct RegionExtent {
170    pub start_layer: String,
171    pub end_layer: String,
172    pub affected_parameters: usize,
173    pub duration_steps: usize,
174}
175
176/// Zone where gradient flow is effectively dead
177#[derive(Debug, Clone, Serialize, Deserialize)]
178pub struct GradientDeadZone {
179    pub zone_id: String,
180    pub affected_layers: Vec<String>,
181    pub dead_duration: usize,
182    pub recovery_potential: RecoveryPotential,
183    pub intervention_required: bool,
184}
185
186#[derive(Debug, Clone, Serialize, Deserialize)]
187pub enum RecoveryPotential {
188    High,
189    Medium,
190    Low,
191    None,
192}
193
194/// Configuration for gradient visualization
195#[derive(Debug, Clone, Serialize, Deserialize)]
196pub struct GradientVisualizationConfig {
197    pub show_temporal_flows: bool,
198    pub show_critical_paths: bool,
199    pub show_problem_regions: bool,
200    pub color_scheme: ColorScheme,
201    pub temporal_window: usize,
202    pub flow_threshold: f64,
203}
204
205impl Default for GradientVisualizationConfig {
206    fn default() -> Self {
207        Self {
208            show_temporal_flows: true,
209            show_critical_paths: true,
210            show_problem_regions: true,
211            color_scheme: ColorScheme::Default,
212            temporal_window: 50,
213            flow_threshold: 0.01,
214        }
215    }
216}
217
218#[derive(Debug, Clone, Serialize, Deserialize)]
219pub enum ColorScheme {
220    Default,
221    HighContrast,
222    ColorBlind,
223    Monochrome,
224}
225
226/// Gradient flow visualization generator
227#[derive(Debug)]
228pub struct GradientFlowVisualizer {
229    config: GradientVisualizationConfig,
230}
231
232impl Default for GradientFlowVisualizer {
233    fn default() -> Self {
234        Self {
235            config: GradientVisualizationConfig::default(),
236        }
237    }
238}
239
240impl GradientFlowVisualizer {
241    pub fn new(config: GradientVisualizationConfig) -> Self {
242        Self { config }
243    }
244
245    pub fn generate_visualization(
246        &self,
247        gradient_histories: &HashMap<String, GradientHistory>,
248        current_step: usize,
249    ) -> GradientFlowVisualization {
250        let layer_flows = self.generate_layer_flows(gradient_histories);
251        let temporal_flows = self.generate_temporal_flows(gradient_histories, current_step);
252        let flow_network = self.build_gradient_flow_network(&layer_flows);
253        let critical_paths = self.identify_critical_gradient_paths(&flow_network);
254        let vanishing_regions = self.identify_vanishing_regions(gradient_histories);
255        let exploding_regions = self.identify_exploding_regions(gradient_histories);
256        let dead_zones = self.identify_gradient_dead_zones(gradient_histories);
257
258        GradientFlowVisualization {
259            layer_flows,
260            temporal_flows,
261            flow_network,
262            critical_paths,
263            vanishing_regions,
264            exploding_regions,
265            dead_zones,
266            visualization_config: self.config.clone(),
267        }
268    }
269
270    fn generate_layer_flows(
271        &self,
272        gradient_histories: &HashMap<String, GradientHistory>,
273    ) -> HashMap<String, GradientLayerFlow> {
274        let mut layer_flows = HashMap::new();
275
276        for (layer_name, history) in gradient_histories {
277            let gradient_magnitudes: Vec<f64> = history.gradient_norms.iter().cloned().collect();
278            let gradient_directions = self.compute_gradient_directions(history);
279            let flow_consistency = self.compute_flow_consistency(history);
280            let bottleneck_score = self.compute_bottleneck_score(history);
281            let information_flow_rate = self.compute_information_flow_rate(history);
282
283            let flow_data = GradientLayerFlow {
284                layer_name: layer_name.clone(),
285                gradient_magnitudes,
286                gradient_directions,
287                flow_consistency,
288                bottleneck_score,
289                information_flow_rate,
290            };
291
292            layer_flows.insert(layer_name.clone(), flow_data);
293        }
294
295        layer_flows
296    }
297
298    fn compute_gradient_directions(&self, history: &GradientHistory) -> Vec<GradientDirection> {
299        let mut directions = Vec::new();
300
301        for (i, (&norm, &step)) in
302            history.gradient_norms.iter().zip(history.step_numbers.iter()).enumerate()
303        {
304            // Simplified direction computation - in practice, this would use actual gradient vectors
305            let direction_vector = vec![norm]; // Placeholder
306            let magnitude = norm;
307            let consistency_score = if i > 0 {
308                let prev_norm = history.gradient_norms[i - 1];
309                1.0 - ((norm - prev_norm).abs() / (norm + prev_norm + 1e-8))
310            } else {
311                1.0
312            };
313
314            directions.push(GradientDirection {
315                step,
316                direction_vector,
317                magnitude,
318                consistency_score,
319            });
320        }
321
322        directions
323    }
324
325    fn compute_flow_consistency(&self, history: &GradientHistory) -> f64 {
326        if history.gradient_norms.len() < 2 {
327            return 1.0;
328        }
329
330        let variations: Vec<f64> = history
331            .gradient_norms
332            .iter()
333            .collect::<Vec<&f64>>()
334            .windows(2)
335            .map(|pair| (*pair[1] - *pair[0]).abs() / (*pair[0] + 1e-8))
336            .collect();
337
338        let avg_variation = variations.iter().sum::<f64>() / variations.len() as f64;
339        (1.0_f64 / (1.0 + avg_variation)).min(1.0)
340    }
341
342    fn compute_bottleneck_score(&self, history: &GradientHistory) -> f64 {
343        if history.gradient_norms.is_empty() {
344            return 0.0;
345        }
346
347        let mean = history.gradient_norms.iter().sum::<f64>() / history.gradient_norms.len() as f64;
348        let min_val = history.gradient_norms.iter().cloned().fold(f64::INFINITY, f64::min);
349
350        if mean == 0.0 {
351            return 1.0;
352        }
353
354        1.0 - (min_val / mean).min(1.0)
355    }
356
357    fn compute_information_flow_rate(&self, history: &GradientHistory) -> f64 {
358        if history.gradient_norms.len() < 2 {
359            return 0.0;
360        }
361
362        let total_change: f64 = history
363            .gradient_norms
364            .iter()
365            .collect::<Vec<&f64>>()
366            .windows(2)
367            .map(|pair| (*pair[1] - *pair[0]).abs())
368            .sum();
369
370        let time_span = history.gradient_norms.len() as f64;
371        total_change / time_span
372    }
373
374    fn generate_temporal_flows(
375        &self,
376        gradient_histories: &HashMap<String, GradientHistory>,
377        current_step: usize,
378    ) -> Vec<TemporalGradientFlow> {
379        let mut temporal_flows = Vec::new();
380
381        for (layer_name, history) in gradient_histories {
382            if let Some(latest_norm) = history.gradient_norms.back() {
383                let flow_direction = self.get_latest_flow_direction(history);
384                let stability_score = self.compute_stability_score(history);
385
386                temporal_flows.push(TemporalGradientFlow {
387                    step: current_step,
388                    layer_name: layer_name.clone(),
389                    gradient_magnitude: *latest_norm,
390                    flow_direction,
391                    stability_score,
392                });
393            }
394        }
395
396        temporal_flows
397    }
398
399    fn get_latest_flow_direction(&self, history: &GradientHistory) -> FlowDirection {
400        if history.gradient_norms.len() < 3 {
401            return FlowDirection::Forward;
402        }
403
404        let recent: Vec<f64> = history.gradient_norms.iter().rev().take(3).cloned().collect();
405        let trend = recent[0] - recent[2]; // Latest - oldest in recent window
406
407        if trend.abs() < 1e-6 {
408            FlowDirection::Stagnant
409        } else if trend > 0.0 {
410            FlowDirection::Forward
411        } else {
412            // Check for oscillation
413            let changes: Vec<f64> = recent.windows(2).map(|pair| pair[0] - pair[1]).collect();
414            let sign_changes = changes.windows(2).filter(|pair| pair[0] * pair[1] < 0.0).count();
415
416            if sign_changes > 0 {
417                FlowDirection::Oscillating
418            } else {
419                FlowDirection::Backward
420            }
421        }
422    }
423
424    fn compute_stability_score(&self, history: &GradientHistory) -> f64 {
425        if history.gradient_norms.len() < 3 {
426            return 1.0;
427        }
428
429        let recent: Vec<f64> = history.gradient_norms.iter().rev().take(5).cloned().collect();
430        let mean = recent.iter().sum::<f64>() / recent.len() as f64;
431        let variance =
432            recent.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / recent.len() as f64;
433
434        1.0 / (1.0 + variance)
435    }
436
437    fn build_gradient_flow_network(
438        &self,
439        layer_flows: &HashMap<String, GradientLayerFlow>,
440    ) -> GradientFlowNetwork {
441        let mut nodes = Vec::new();
442        let mut edges = Vec::new();
443
444        // Create nodes
445        for (layer_name, flow) in layer_flows {
446            let node_type = self.classify_node_type(flow);
447            let gradient_strength = flow.gradient_magnitudes.iter().sum::<f64>()
448                / flow.gradient_magnitudes.len() as f64;
449            let connectivity = layer_flows.len(); // Simplified
450            let influence_score = gradient_strength * flow.flow_consistency;
451
452            nodes.push(FlowNode {
453                layer_name: layer_name.clone(),
454                node_type,
455                gradient_strength,
456                connectivity,
457                influence_score,
458            });
459        }
460
461        // Create edges (simplified - would need actual layer connectivity information)
462        let layer_names: Vec<String> = layer_flows.keys().cloned().collect();
463        for i in 0..layer_names.len().saturating_sub(1) {
464            let from_layer = &layer_names[i];
465            let to_layer = &layer_names[i + 1];
466
467            if let (Some(from_flow), Some(to_flow)) =
468                (layer_flows.get(from_layer), layer_flows.get(to_layer))
469            {
470                let flow_strength =
471                    (from_flow.information_flow_rate + to_flow.information_flow_rate) / 2.0;
472                let flow_consistency =
473                    (from_flow.flow_consistency + to_flow.flow_consistency) / 2.0;
474                let edge_type = self.classify_edge_type(flow_strength, flow_consistency);
475
476                edges.push(FlowEdge {
477                    from_layer: from_layer.clone(),
478                    to_layer: to_layer.clone(),
479                    flow_strength,
480                    flow_consistency,
481                    edge_type,
482                });
483            }
484        }
485
486        let network_metrics = self.compute_network_metrics(&nodes, &edges);
487
488        GradientFlowNetwork {
489            nodes,
490            edges,
491            network_metrics,
492        }
493    }
494
495    fn classify_node_type(&self, flow: &GradientLayerFlow) -> NodeType {
496        if flow.bottleneck_score > 0.8 {
497            NodeType::Bottleneck
498        } else if flow.information_flow_rate > 1.0 {
499            NodeType::Amplifier
500        } else if flow.gradient_magnitudes.iter().sum::<f64>() < 0.01 {
501            NodeType::Sink
502        } else if flow.gradient_magnitudes.iter().any(|&x| x > 10.0) {
503            NodeType::Source
504        } else {
505            NodeType::Normal
506        }
507    }
508
509    fn classify_edge_type(&self, flow_strength: f64, flow_consistency: f64) -> EdgeType {
510        if flow_strength > 1.0 && flow_consistency > 0.8 {
511            EdgeType::Strong
512        } else if flow_strength < 0.1 || flow_consistency < 0.3 {
513            EdgeType::Weak
514        } else if flow_consistency < 0.6 {
515            EdgeType::Intermittent
516        } else {
517            EdgeType::Blocked
518        }
519    }
520
521    fn compute_network_metrics(&self, nodes: &[FlowNode], edges: &[FlowEdge]) -> NetworkMetrics {
522        let overall_flow_efficiency =
523            edges.iter().map(|e| e.flow_strength).sum::<f64>() / edges.len().max(1) as f64;
524        let network_connectivity = edges.len() as f64
525            / (nodes.len().max(1) * (nodes.len().saturating_sub(1)).max(1)) as f64;
526        let bottleneck_density =
527            nodes.iter().filter(|n| matches!(n.node_type, NodeType::Bottleneck)).count() as f64
528                / nodes.len() as f64;
529        let flow_stability =
530            edges.iter().map(|e| e.flow_consistency).sum::<f64>() / edges.len().max(1) as f64;
531        let information_propagation_speed = overall_flow_efficiency * network_connectivity;
532
533        NetworkMetrics {
534            overall_flow_efficiency,
535            network_connectivity,
536            bottleneck_density,
537            flow_stability,
538            information_propagation_speed,
539        }
540    }
541
542    fn identify_critical_gradient_paths(
543        &self,
544        network: &GradientFlowNetwork,
545    ) -> Vec<CriticalGradientPath> {
546        let mut paths = Vec::new();
547
548        // Simplified path identification - would use graph algorithms in practice
549        if network.nodes.len() < 2 {
550            return paths;
551        }
552
553        let path_layers: Vec<String> = network.nodes.iter().map(|n| n.layer_name.clone()).collect();
554        let total_flow_strength = network.edges.iter().map(|e| e.flow_strength).sum();
555        let bottleneck_layers: Vec<String> = network
556            .nodes
557            .iter()
558            .filter(|n| matches!(n.node_type, NodeType::Bottleneck))
559            .map(|n| n.layer_name.clone())
560            .collect();
561
562        paths.push(CriticalGradientPath {
563            path_id: "main_path".to_string(),
564            path_length: path_layers.len(),
565            layers: path_layers,
566            total_flow_strength,
567            bottleneck_layers,
568            criticality_score: 0.8, // Simplified
569            optimization_potential: 0.6,
570        });
571
572        paths
573    }
574
575    fn identify_vanishing_regions(
576        &self,
577        gradient_histories: &HashMap<String, GradientHistory>,
578    ) -> Vec<VanishingRegion> {
579        let mut regions = Vec::new();
580        let mut region_id = 0;
581
582        for (layer_name, history) in gradient_histories {
583            let avg_gradient =
584                history.gradient_norms.iter().sum::<f64>() / history.gradient_norms.len() as f64;
585            if avg_gradient < 1e-5 {
586                region_id += 1;
587                regions.push(VanishingRegion {
588                    region_id: format!("vanishing_{}", region_id),
589                    affected_layers: vec![layer_name.clone()],
590                    severity_level: if avg_gradient < 1e-7 {
591                        VanishingSeverity::Critical
592                    } else {
593                        VanishingSeverity::Moderate
594                    },
595                    extent: RegionExtent {
596                        start_layer: layer_name.clone(),
597                        end_layer: layer_name.clone(),
598                        affected_parameters: 1000, // Placeholder
599                        duration_steps: history.gradient_norms.len(),
600                    },
601                    mitigation_suggestions: vec![
602                        "Consider better weight initialization".to_string(),
603                        "Add skip connections".to_string(),
604                        "Use gradient clipping".to_string(),
605                    ],
606                });
607            }
608        }
609
610        regions
611    }
612
613    fn identify_exploding_regions(
614        &self,
615        gradient_histories: &HashMap<String, GradientHistory>,
616    ) -> Vec<ExplodingRegion> {
617        let mut regions = Vec::new();
618        let mut region_id = 0;
619
620        for (layer_name, history) in gradient_histories {
621            let max_gradient = history.gradient_norms.iter().cloned().fold(0.0, f64::max);
622            if max_gradient > 100.0 {
623                region_id += 1;
624                regions.push(ExplodingRegion {
625                    region_id: format!("exploding_{}", region_id),
626                    affected_layers: vec![layer_name.clone()],
627                    severity_level: if max_gradient > 1000.0 {
628                        ExplodingSeverity::Critical
629                    } else {
630                        ExplodingSeverity::Moderate
631                    },
632                    extent: RegionExtent {
633                        start_layer: layer_name.clone(),
634                        end_layer: layer_name.clone(),
635                        affected_parameters: 1000, // Placeholder
636                        duration_steps: history.gradient_norms.len(),
637                    },
638                    mitigation_suggestions: vec![
639                        "Apply gradient clipping".to_string(),
640                        "Reduce learning rate".to_string(),
641                        "Check weight initialization".to_string(),
642                    ],
643                });
644            }
645        }
646
647        regions
648    }
649
650    fn identify_gradient_dead_zones(
651        &self,
652        gradient_histories: &HashMap<String, GradientHistory>,
653    ) -> Vec<GradientDeadZone> {
654        let mut dead_zones = Vec::new();
655        let mut zone_id = 0;
656
657        for (layer_name, history) in gradient_histories {
658            let zero_gradients = history.gradient_norms.iter().filter(|&&x| x < 1e-8).count();
659            let dead_ratio = zero_gradients as f64 / history.gradient_norms.len() as f64;
660
661            if dead_ratio > 0.5 {
662                zone_id += 1;
663                dead_zones.push(GradientDeadZone {
664                    zone_id: format!("dead_zone_{}", zone_id),
665                    affected_layers: vec![layer_name.clone()],
666                    dead_duration: zero_gradients,
667                    recovery_potential: if dead_ratio > 0.9 {
668                        RecoveryPotential::Low
669                    } else {
670                        RecoveryPotential::Medium
671                    },
672                    intervention_required: dead_ratio > 0.8,
673                });
674            }
675        }
676
677        dead_zones
678    }
679
680    /// Create comprehensive visualization data for gradient flows
681    pub fn create_visualization(
682        &self,
683        gradient_histories: &HashMap<String, GradientHistory>,
684    ) -> GradientFlowVisualization {
685        // Use existing methods to generate the visualization
686        self.generate_visualization(gradient_histories, 0)
687    }
688}