1use super::types::*;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct GradientFlowVisualization {
13 pub layer_flows: HashMap<String, GradientLayerFlow>,
14 pub temporal_flows: Vec<TemporalGradientFlow>,
15 pub flow_network: GradientFlowNetwork,
16 pub critical_paths: Vec<CriticalGradientPath>,
17 pub vanishing_regions: Vec<VanishingRegion>,
18 pub exploding_regions: Vec<ExplodingRegion>,
19 pub dead_zones: Vec<GradientDeadZone>,
20 pub visualization_config: GradientVisualizationConfig,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct GradientLayerFlow {
26 pub layer_name: String,
27 pub gradient_magnitudes: Vec<f64>,
28 pub gradient_directions: Vec<GradientDirection>,
29 pub flow_consistency: f64,
30 pub bottleneck_score: f64,
31 pub information_flow_rate: f64,
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct GradientDirection {
37 pub step: usize,
38 pub direction_vector: Vec<f64>,
39 pub magnitude: f64,
40 pub consistency_score: f64,
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct TemporalGradientFlow {
46 pub step: usize,
47 pub layer_name: String,
48 pub gradient_magnitude: f64,
49 pub flow_direction: FlowDirection,
50 pub stability_score: f64,
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
55pub enum FlowDirection {
56 Forward,
57 Backward,
58 Oscillating,
59 Stagnant,
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct GradientFlowNetwork {
65 pub nodes: Vec<FlowNode>,
66 pub edges: Vec<FlowEdge>,
67 pub network_metrics: NetworkMetrics,
68}
69
70#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct FlowNode {
73 pub layer_name: String,
74 pub node_type: NodeType,
75 pub gradient_strength: f64,
76 pub connectivity: usize,
77 pub influence_score: f64,
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize)]
82pub enum NodeType {
83 Source,
84 Sink,
85 Bottleneck,
86 Amplifier,
87 Normal,
88}
89
90#[derive(Debug, Clone, Serialize, Deserialize)]
92pub struct FlowEdge {
93 pub from_layer: String,
94 pub to_layer: String,
95 pub flow_strength: f64,
96 pub flow_consistency: f64,
97 pub edge_type: EdgeType,
98}
99
100#[derive(Debug, Clone, Serialize, Deserialize)]
102pub enum EdgeType {
103 Strong,
104 Weak,
105 Intermittent,
106 Blocked,
107}
108
109#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct NetworkMetrics {
112 pub overall_flow_efficiency: f64,
113 pub network_connectivity: f64,
114 pub bottleneck_density: f64,
115 pub flow_stability: f64,
116 pub information_propagation_speed: f64,
117}
118
119#[derive(Debug, Clone, Serialize, Deserialize)]
121pub struct CriticalGradientPath {
122 pub path_id: String,
123 pub layers: Vec<String>,
124 pub path_length: usize,
125 pub total_flow_strength: f64,
126 pub bottleneck_layers: Vec<String>,
127 pub criticality_score: f64,
128 pub optimization_potential: f64,
129}
130
131#[derive(Debug, Clone, Serialize, Deserialize)]
133pub struct VanishingRegion {
134 pub region_id: String,
135 pub affected_layers: Vec<String>,
136 pub severity_level: VanishingSeverity,
137 pub extent: RegionExtent,
138 pub mitigation_suggestions: Vec<String>,
139}
140
141#[derive(Debug, Clone, Serialize, Deserialize)]
142pub enum VanishingSeverity {
143 Mild,
144 Moderate,
145 Severe,
146 Critical,
147}
148
149#[derive(Debug, Clone, Serialize, Deserialize)]
151pub struct ExplodingRegion {
152 pub region_id: String,
153 pub affected_layers: Vec<String>,
154 pub severity_level: ExplodingSeverity,
155 pub extent: RegionExtent,
156 pub mitigation_suggestions: Vec<String>,
157}
158
159#[derive(Debug, Clone, Serialize, Deserialize)]
160pub enum ExplodingSeverity {
161 Mild,
162 Moderate,
163 Severe,
164 Critical,
165}
166
167#[derive(Debug, Clone, Serialize, Deserialize)]
169pub struct RegionExtent {
170 pub start_layer: String,
171 pub end_layer: String,
172 pub affected_parameters: usize,
173 pub duration_steps: usize,
174}
175
176#[derive(Debug, Clone, Serialize, Deserialize)]
178pub struct GradientDeadZone {
179 pub zone_id: String,
180 pub affected_layers: Vec<String>,
181 pub dead_duration: usize,
182 pub recovery_potential: RecoveryPotential,
183 pub intervention_required: bool,
184}
185
186#[derive(Debug, Clone, Serialize, Deserialize)]
187pub enum RecoveryPotential {
188 High,
189 Medium,
190 Low,
191 None,
192}
193
194#[derive(Debug, Clone, Serialize, Deserialize)]
196pub struct GradientVisualizationConfig {
197 pub show_temporal_flows: bool,
198 pub show_critical_paths: bool,
199 pub show_problem_regions: bool,
200 pub color_scheme: ColorScheme,
201 pub temporal_window: usize,
202 pub flow_threshold: f64,
203}
204
205impl Default for GradientVisualizationConfig {
206 fn default() -> Self {
207 Self {
208 show_temporal_flows: true,
209 show_critical_paths: true,
210 show_problem_regions: true,
211 color_scheme: ColorScheme::Default,
212 temporal_window: 50,
213 flow_threshold: 0.01,
214 }
215 }
216}
217
218#[derive(Debug, Clone, Serialize, Deserialize)]
219pub enum ColorScheme {
220 Default,
221 HighContrast,
222 ColorBlind,
223 Monochrome,
224}
225
226#[derive(Debug, Default)]
228pub struct GradientFlowVisualizer {
229 config: GradientVisualizationConfig,
230}
231
232impl GradientFlowVisualizer {
233 pub fn new(config: GradientVisualizationConfig) -> Self {
234 Self { config }
235 }
236
237 pub fn generate_visualization(
238 &self,
239 gradient_histories: &HashMap<String, GradientHistory>,
240 current_step: usize,
241 ) -> GradientFlowVisualization {
242 let layer_flows = self.generate_layer_flows(gradient_histories);
243 let temporal_flows = self.generate_temporal_flows(gradient_histories, current_step);
244 let flow_network = self.build_gradient_flow_network(&layer_flows);
245 let critical_paths = self.identify_critical_gradient_paths(&flow_network);
246 let vanishing_regions = self.identify_vanishing_regions(gradient_histories);
247 let exploding_regions = self.identify_exploding_regions(gradient_histories);
248 let dead_zones = self.identify_gradient_dead_zones(gradient_histories);
249
250 GradientFlowVisualization {
251 layer_flows,
252 temporal_flows,
253 flow_network,
254 critical_paths,
255 vanishing_regions,
256 exploding_regions,
257 dead_zones,
258 visualization_config: self.config.clone(),
259 }
260 }
261
262 fn generate_layer_flows(
263 &self,
264 gradient_histories: &HashMap<String, GradientHistory>,
265 ) -> HashMap<String, GradientLayerFlow> {
266 let mut layer_flows = HashMap::new();
267
268 for (layer_name, history) in gradient_histories {
269 let gradient_magnitudes: Vec<f64> = history.gradient_norms.iter().cloned().collect();
270 let gradient_directions = self.compute_gradient_directions(history);
271 let flow_consistency = self.compute_flow_consistency(history);
272 let bottleneck_score = self.compute_bottleneck_score(history);
273 let information_flow_rate = self.compute_information_flow_rate(history);
274
275 let flow_data = GradientLayerFlow {
276 layer_name: layer_name.clone(),
277 gradient_magnitudes,
278 gradient_directions,
279 flow_consistency,
280 bottleneck_score,
281 information_flow_rate,
282 };
283
284 layer_flows.insert(layer_name.clone(), flow_data);
285 }
286
287 layer_flows
288 }
289
290 fn compute_gradient_directions(&self, history: &GradientHistory) -> Vec<GradientDirection> {
291 let mut directions = Vec::new();
292
293 for (i, (&norm, &step)) in
294 history.gradient_norms.iter().zip(history.step_numbers.iter()).enumerate()
295 {
296 let direction_vector = vec![norm]; let magnitude = norm;
299 let consistency_score = if i > 0 {
300 let prev_norm = history.gradient_norms[i - 1];
301 1.0 - ((norm - prev_norm).abs() / (norm + prev_norm + 1e-8))
302 } else {
303 1.0
304 };
305
306 directions.push(GradientDirection {
307 step,
308 direction_vector,
309 magnitude,
310 consistency_score,
311 });
312 }
313
314 directions
315 }
316
317 fn compute_flow_consistency(&self, history: &GradientHistory) -> f64 {
318 if history.gradient_norms.len() < 2 {
319 return 1.0;
320 }
321
322 let variations: Vec<f64> = history
323 .gradient_norms
324 .iter()
325 .collect::<Vec<&f64>>()
326 .windows(2)
327 .map(|pair| (*pair[1] - *pair[0]).abs() / (*pair[0] + 1e-8))
328 .collect();
329
330 let avg_variation = variations.iter().sum::<f64>() / variations.len() as f64;
331 (1.0_f64 / (1.0 + avg_variation)).min(1.0)
332 }
333
334 fn compute_bottleneck_score(&self, history: &GradientHistory) -> f64 {
335 if history.gradient_norms.is_empty() {
336 return 0.0;
337 }
338
339 let mean = history.gradient_norms.iter().sum::<f64>() / history.gradient_norms.len() as f64;
340 let min_val = history.gradient_norms.iter().cloned().fold(f64::INFINITY, f64::min);
341
342 if mean == 0.0 {
343 return 1.0;
344 }
345
346 1.0 - (min_val / mean).min(1.0)
347 }
348
349 fn compute_information_flow_rate(&self, history: &GradientHistory) -> f64 {
350 if history.gradient_norms.len() < 2 {
351 return 0.0;
352 }
353
354 let total_change: f64 = history
355 .gradient_norms
356 .iter()
357 .collect::<Vec<&f64>>()
358 .windows(2)
359 .map(|pair| (*pair[1] - *pair[0]).abs())
360 .sum();
361
362 let time_span = history.gradient_norms.len() as f64;
363 total_change / time_span
364 }
365
366 fn generate_temporal_flows(
367 &self,
368 gradient_histories: &HashMap<String, GradientHistory>,
369 current_step: usize,
370 ) -> Vec<TemporalGradientFlow> {
371 let mut temporal_flows = Vec::new();
372
373 for (layer_name, history) in gradient_histories {
374 if let Some(latest_norm) = history.gradient_norms.back() {
375 let flow_direction = self.get_latest_flow_direction(history);
376 let stability_score = self.compute_stability_score(history);
377
378 temporal_flows.push(TemporalGradientFlow {
379 step: current_step,
380 layer_name: layer_name.clone(),
381 gradient_magnitude: *latest_norm,
382 flow_direction,
383 stability_score,
384 });
385 }
386 }
387
388 temporal_flows
389 }
390
391 fn get_latest_flow_direction(&self, history: &GradientHistory) -> FlowDirection {
392 if history.gradient_norms.len() < 3 {
393 return FlowDirection::Forward;
394 }
395
396 let recent: Vec<f64> = history.gradient_norms.iter().rev().take(3).cloned().collect();
397 let trend = recent[0] - recent[2]; if trend.abs() < 1e-6 {
400 FlowDirection::Stagnant
401 } else if trend > 0.0 {
402 FlowDirection::Forward
403 } else {
404 let changes: Vec<f64> = recent.windows(2).map(|pair| pair[0] - pair[1]).collect();
406 let sign_changes = changes.windows(2).filter(|pair| pair[0] * pair[1] < 0.0).count();
407
408 if sign_changes > 0 {
409 FlowDirection::Oscillating
410 } else {
411 FlowDirection::Backward
412 }
413 }
414 }
415
416 fn compute_stability_score(&self, history: &GradientHistory) -> f64 {
417 if history.gradient_norms.len() < 3 {
418 return 1.0;
419 }
420
421 let recent: Vec<f64> = history.gradient_norms.iter().rev().take(5).cloned().collect();
422 let mean = recent.iter().sum::<f64>() / recent.len() as f64;
423 let variance =
424 recent.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / recent.len() as f64;
425
426 1.0 / (1.0 + variance)
427 }
428
429 fn build_gradient_flow_network(
430 &self,
431 layer_flows: &HashMap<String, GradientLayerFlow>,
432 ) -> GradientFlowNetwork {
433 let mut nodes = Vec::new();
434 let mut edges = Vec::new();
435
436 for (layer_name, flow) in layer_flows {
438 let node_type = self.classify_node_type(flow);
439 let gradient_strength = flow.gradient_magnitudes.iter().sum::<f64>()
440 / flow.gradient_magnitudes.len() as f64;
441 let connectivity = layer_flows.len(); let influence_score = gradient_strength * flow.flow_consistency;
443
444 nodes.push(FlowNode {
445 layer_name: layer_name.clone(),
446 node_type,
447 gradient_strength,
448 connectivity,
449 influence_score,
450 });
451 }
452
453 let layer_names: Vec<String> = layer_flows.keys().cloned().collect();
455 for i in 0..layer_names.len().saturating_sub(1) {
456 let from_layer = &layer_names[i];
457 let to_layer = &layer_names[i + 1];
458
459 if let (Some(from_flow), Some(to_flow)) =
460 (layer_flows.get(from_layer), layer_flows.get(to_layer))
461 {
462 let flow_strength =
463 (from_flow.information_flow_rate + to_flow.information_flow_rate) / 2.0;
464 let flow_consistency =
465 (from_flow.flow_consistency + to_flow.flow_consistency) / 2.0;
466 let edge_type = self.classify_edge_type(flow_strength, flow_consistency);
467
468 edges.push(FlowEdge {
469 from_layer: from_layer.clone(),
470 to_layer: to_layer.clone(),
471 flow_strength,
472 flow_consistency,
473 edge_type,
474 });
475 }
476 }
477
478 let network_metrics = self.compute_network_metrics(&nodes, &edges);
479
480 GradientFlowNetwork {
481 nodes,
482 edges,
483 network_metrics,
484 }
485 }
486
487 fn classify_node_type(&self, flow: &GradientLayerFlow) -> NodeType {
488 if flow.bottleneck_score > 0.8 {
489 NodeType::Bottleneck
490 } else if flow.information_flow_rate > 1.0 {
491 NodeType::Amplifier
492 } else if flow.gradient_magnitudes.iter().sum::<f64>() < 0.01 {
493 NodeType::Sink
494 } else if flow.gradient_magnitudes.iter().any(|&x| x > 10.0) {
495 NodeType::Source
496 } else {
497 NodeType::Normal
498 }
499 }
500
501 fn classify_edge_type(&self, flow_strength: f64, flow_consistency: f64) -> EdgeType {
502 if flow_strength > 1.0 && flow_consistency > 0.8 {
503 EdgeType::Strong
504 } else if flow_strength < 0.1 || flow_consistency < 0.3 {
505 EdgeType::Weak
506 } else if flow_consistency < 0.6 {
507 EdgeType::Intermittent
508 } else {
509 EdgeType::Blocked
510 }
511 }
512
513 fn compute_network_metrics(&self, nodes: &[FlowNode], edges: &[FlowEdge]) -> NetworkMetrics {
514 let overall_flow_efficiency =
515 edges.iter().map(|e| e.flow_strength).sum::<f64>() / edges.len().max(1) as f64;
516 let network_connectivity = edges.len() as f64
517 / (nodes.len().max(1) * (nodes.len().saturating_sub(1)).max(1)) as f64;
518 let bottleneck_density =
519 nodes.iter().filter(|n| matches!(n.node_type, NodeType::Bottleneck)).count() as f64
520 / nodes.len() as f64;
521 let flow_stability =
522 edges.iter().map(|e| e.flow_consistency).sum::<f64>() / edges.len().max(1) as f64;
523 let information_propagation_speed = overall_flow_efficiency * network_connectivity;
524
525 NetworkMetrics {
526 overall_flow_efficiency,
527 network_connectivity,
528 bottleneck_density,
529 flow_stability,
530 information_propagation_speed,
531 }
532 }
533
534 fn identify_critical_gradient_paths(
535 &self,
536 network: &GradientFlowNetwork,
537 ) -> Vec<CriticalGradientPath> {
538 let mut paths = Vec::new();
539
540 if network.nodes.len() < 2 {
542 return paths;
543 }
544
545 let path_layers: Vec<String> = network.nodes.iter().map(|n| n.layer_name.clone()).collect();
546 let total_flow_strength = network.edges.iter().map(|e| e.flow_strength).sum();
547 let bottleneck_layers: Vec<String> = network
548 .nodes
549 .iter()
550 .filter(|n| matches!(n.node_type, NodeType::Bottleneck))
551 .map(|n| n.layer_name.clone())
552 .collect();
553
554 paths.push(CriticalGradientPath {
555 path_id: "main_path".to_string(),
556 path_length: path_layers.len(),
557 layers: path_layers,
558 total_flow_strength,
559 bottleneck_layers,
560 criticality_score: 0.8, optimization_potential: 0.6,
562 });
563
564 paths
565 }
566
567 fn identify_vanishing_regions(
568 &self,
569 gradient_histories: &HashMap<String, GradientHistory>,
570 ) -> Vec<VanishingRegion> {
571 let mut regions = Vec::new();
572 let mut region_id = 0;
573
574 for (layer_name, history) in gradient_histories {
575 let avg_gradient =
576 history.gradient_norms.iter().sum::<f64>() / history.gradient_norms.len() as f64;
577 if avg_gradient < 1e-5 {
578 region_id += 1;
579 regions.push(VanishingRegion {
580 region_id: format!("vanishing_{}", region_id),
581 affected_layers: vec![layer_name.clone()],
582 severity_level: if avg_gradient < 1e-7 {
583 VanishingSeverity::Critical
584 } else {
585 VanishingSeverity::Moderate
586 },
587 extent: RegionExtent {
588 start_layer: layer_name.clone(),
589 end_layer: layer_name.clone(),
590 affected_parameters: 1000, duration_steps: history.gradient_norms.len(),
592 },
593 mitigation_suggestions: vec![
594 "Consider better weight initialization".to_string(),
595 "Add skip connections".to_string(),
596 "Use gradient clipping".to_string(),
597 ],
598 });
599 }
600 }
601
602 regions
603 }
604
605 fn identify_exploding_regions(
606 &self,
607 gradient_histories: &HashMap<String, GradientHistory>,
608 ) -> Vec<ExplodingRegion> {
609 let mut regions = Vec::new();
610 let mut region_id = 0;
611
612 for (layer_name, history) in gradient_histories {
613 let max_gradient = history.gradient_norms.iter().cloned().fold(0.0, f64::max);
614 if max_gradient > 100.0 {
615 region_id += 1;
616 regions.push(ExplodingRegion {
617 region_id: format!("exploding_{}", region_id),
618 affected_layers: vec![layer_name.clone()],
619 severity_level: if max_gradient > 1000.0 {
620 ExplodingSeverity::Critical
621 } else {
622 ExplodingSeverity::Moderate
623 },
624 extent: RegionExtent {
625 start_layer: layer_name.clone(),
626 end_layer: layer_name.clone(),
627 affected_parameters: 1000, duration_steps: history.gradient_norms.len(),
629 },
630 mitigation_suggestions: vec![
631 "Apply gradient clipping".to_string(),
632 "Reduce learning rate".to_string(),
633 "Check weight initialization".to_string(),
634 ],
635 });
636 }
637 }
638
639 regions
640 }
641
642 fn identify_gradient_dead_zones(
643 &self,
644 gradient_histories: &HashMap<String, GradientHistory>,
645 ) -> Vec<GradientDeadZone> {
646 let mut dead_zones = Vec::new();
647 let mut zone_id = 0;
648
649 for (layer_name, history) in gradient_histories {
650 let zero_gradients = history.gradient_norms.iter().filter(|&&x| x < 1e-8).count();
651 let dead_ratio = zero_gradients as f64 / history.gradient_norms.len() as f64;
652
653 if dead_ratio > 0.5 {
654 zone_id += 1;
655 dead_zones.push(GradientDeadZone {
656 zone_id: format!("dead_zone_{}", zone_id),
657 affected_layers: vec![layer_name.clone()],
658 dead_duration: zero_gradients,
659 recovery_potential: if dead_ratio > 0.9 {
660 RecoveryPotential::Low
661 } else {
662 RecoveryPotential::Medium
663 },
664 intervention_required: dead_ratio > 0.8,
665 });
666 }
667 }
668
669 dead_zones
670 }
671
672 pub fn create_visualization(
674 &self,
675 gradient_histories: &HashMap<String, GradientHistory>,
676 ) -> GradientFlowVisualization {
677 self.generate_visualization(gradient_histories, 0)
679 }
680}