1use super::types::*;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct GradientFlowVisualization {
13 pub layer_flows: HashMap<String, GradientLayerFlow>,
14 pub temporal_flows: Vec<TemporalGradientFlow>,
15 pub flow_network: GradientFlowNetwork,
16 pub critical_paths: Vec<CriticalGradientPath>,
17 pub vanishing_regions: Vec<VanishingRegion>,
18 pub exploding_regions: Vec<ExplodingRegion>,
19 pub dead_zones: Vec<GradientDeadZone>,
20 pub visualization_config: GradientVisualizationConfig,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct GradientLayerFlow {
26 pub layer_name: String,
27 pub gradient_magnitudes: Vec<f64>,
28 pub gradient_directions: Vec<GradientDirection>,
29 pub flow_consistency: f64,
30 pub bottleneck_score: f64,
31 pub information_flow_rate: f64,
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct GradientDirection {
37 pub step: usize,
38 pub direction_vector: Vec<f64>,
39 pub magnitude: f64,
40 pub consistency_score: f64,
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct TemporalGradientFlow {
46 pub step: usize,
47 pub layer_name: String,
48 pub gradient_magnitude: f64,
49 pub flow_direction: FlowDirection,
50 pub stability_score: f64,
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
55pub enum FlowDirection {
56 Forward,
57 Backward,
58 Oscillating,
59 Stagnant,
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct GradientFlowNetwork {
65 pub nodes: Vec<FlowNode>,
66 pub edges: Vec<FlowEdge>,
67 pub network_metrics: NetworkMetrics,
68}
69
70#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct FlowNode {
73 pub layer_name: String,
74 pub node_type: NodeType,
75 pub gradient_strength: f64,
76 pub connectivity: usize,
77 pub influence_score: f64,
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize)]
82pub enum NodeType {
83 Source,
84 Sink,
85 Bottleneck,
86 Amplifier,
87 Normal,
88}
89
90#[derive(Debug, Clone, Serialize, Deserialize)]
92pub struct FlowEdge {
93 pub from_layer: String,
94 pub to_layer: String,
95 pub flow_strength: f64,
96 pub flow_consistency: f64,
97 pub edge_type: EdgeType,
98}
99
100#[derive(Debug, Clone, Serialize, Deserialize)]
102pub enum EdgeType {
103 Strong,
104 Weak,
105 Intermittent,
106 Blocked,
107}
108
109#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct NetworkMetrics {
112 pub overall_flow_efficiency: f64,
113 pub network_connectivity: f64,
114 pub bottleneck_density: f64,
115 pub flow_stability: f64,
116 pub information_propagation_speed: f64,
117}
118
119#[derive(Debug, Clone, Serialize, Deserialize)]
121pub struct CriticalGradientPath {
122 pub path_id: String,
123 pub layers: Vec<String>,
124 pub path_length: usize,
125 pub total_flow_strength: f64,
126 pub bottleneck_layers: Vec<String>,
127 pub criticality_score: f64,
128 pub optimization_potential: f64,
129}
130
131#[derive(Debug, Clone, Serialize, Deserialize)]
133pub struct VanishingRegion {
134 pub region_id: String,
135 pub affected_layers: Vec<String>,
136 pub severity_level: VanishingSeverity,
137 pub extent: RegionExtent,
138 pub mitigation_suggestions: Vec<String>,
139}
140
141#[derive(Debug, Clone, Serialize, Deserialize)]
142pub enum VanishingSeverity {
143 Mild,
144 Moderate,
145 Severe,
146 Critical,
147}
148
149#[derive(Debug, Clone, Serialize, Deserialize)]
151pub struct ExplodingRegion {
152 pub region_id: String,
153 pub affected_layers: Vec<String>,
154 pub severity_level: ExplodingSeverity,
155 pub extent: RegionExtent,
156 pub mitigation_suggestions: Vec<String>,
157}
158
159#[derive(Debug, Clone, Serialize, Deserialize)]
160pub enum ExplodingSeverity {
161 Mild,
162 Moderate,
163 Severe,
164 Critical,
165}
166
167#[derive(Debug, Clone, Serialize, Deserialize)]
169pub struct RegionExtent {
170 pub start_layer: String,
171 pub end_layer: String,
172 pub affected_parameters: usize,
173 pub duration_steps: usize,
174}
175
176#[derive(Debug, Clone, Serialize, Deserialize)]
178pub struct GradientDeadZone {
179 pub zone_id: String,
180 pub affected_layers: Vec<String>,
181 pub dead_duration: usize,
182 pub recovery_potential: RecoveryPotential,
183 pub intervention_required: bool,
184}
185
186#[derive(Debug, Clone, Serialize, Deserialize)]
187pub enum RecoveryPotential {
188 High,
189 Medium,
190 Low,
191 None,
192}
193
194#[derive(Debug, Clone, Serialize, Deserialize)]
196pub struct GradientVisualizationConfig {
197 pub show_temporal_flows: bool,
198 pub show_critical_paths: bool,
199 pub show_problem_regions: bool,
200 pub color_scheme: ColorScheme,
201 pub temporal_window: usize,
202 pub flow_threshold: f64,
203}
204
205impl Default for GradientVisualizationConfig {
206 fn default() -> Self {
207 Self {
208 show_temporal_flows: true,
209 show_critical_paths: true,
210 show_problem_regions: true,
211 color_scheme: ColorScheme::Default,
212 temporal_window: 50,
213 flow_threshold: 0.01,
214 }
215 }
216}
217
218#[derive(Debug, Clone, Serialize, Deserialize)]
219pub enum ColorScheme {
220 Default,
221 HighContrast,
222 ColorBlind,
223 Monochrome,
224}
225
226#[derive(Debug)]
228pub struct GradientFlowVisualizer {
229 config: GradientVisualizationConfig,
230}
231
232impl Default for GradientFlowVisualizer {
233 fn default() -> Self {
234 Self {
235 config: GradientVisualizationConfig::default(),
236 }
237 }
238}
239
240impl GradientFlowVisualizer {
241 pub fn new(config: GradientVisualizationConfig) -> Self {
242 Self { config }
243 }
244
245 pub fn generate_visualization(
246 &self,
247 gradient_histories: &HashMap<String, GradientHistory>,
248 current_step: usize,
249 ) -> GradientFlowVisualization {
250 let layer_flows = self.generate_layer_flows(gradient_histories);
251 let temporal_flows = self.generate_temporal_flows(gradient_histories, current_step);
252 let flow_network = self.build_gradient_flow_network(&layer_flows);
253 let critical_paths = self.identify_critical_gradient_paths(&flow_network);
254 let vanishing_regions = self.identify_vanishing_regions(gradient_histories);
255 let exploding_regions = self.identify_exploding_regions(gradient_histories);
256 let dead_zones = self.identify_gradient_dead_zones(gradient_histories);
257
258 GradientFlowVisualization {
259 layer_flows,
260 temporal_flows,
261 flow_network,
262 critical_paths,
263 vanishing_regions,
264 exploding_regions,
265 dead_zones,
266 visualization_config: self.config.clone(),
267 }
268 }
269
270 fn generate_layer_flows(
271 &self,
272 gradient_histories: &HashMap<String, GradientHistory>,
273 ) -> HashMap<String, GradientLayerFlow> {
274 let mut layer_flows = HashMap::new();
275
276 for (layer_name, history) in gradient_histories {
277 let gradient_magnitudes: Vec<f64> = history.gradient_norms.iter().cloned().collect();
278 let gradient_directions = self.compute_gradient_directions(history);
279 let flow_consistency = self.compute_flow_consistency(history);
280 let bottleneck_score = self.compute_bottleneck_score(history);
281 let information_flow_rate = self.compute_information_flow_rate(history);
282
283 let flow_data = GradientLayerFlow {
284 layer_name: layer_name.clone(),
285 gradient_magnitudes,
286 gradient_directions,
287 flow_consistency,
288 bottleneck_score,
289 information_flow_rate,
290 };
291
292 layer_flows.insert(layer_name.clone(), flow_data);
293 }
294
295 layer_flows
296 }
297
298 fn compute_gradient_directions(&self, history: &GradientHistory) -> Vec<GradientDirection> {
299 let mut directions = Vec::new();
300
301 for (i, (&norm, &step)) in
302 history.gradient_norms.iter().zip(history.step_numbers.iter()).enumerate()
303 {
304 let direction_vector = vec![norm]; let magnitude = norm;
307 let consistency_score = if i > 0 {
308 let prev_norm = history.gradient_norms[i - 1];
309 1.0 - ((norm - prev_norm).abs() / (norm + prev_norm + 1e-8))
310 } else {
311 1.0
312 };
313
314 directions.push(GradientDirection {
315 step,
316 direction_vector,
317 magnitude,
318 consistency_score,
319 });
320 }
321
322 directions
323 }
324
325 fn compute_flow_consistency(&self, history: &GradientHistory) -> f64 {
326 if history.gradient_norms.len() < 2 {
327 return 1.0;
328 }
329
330 let variations: Vec<f64> = history
331 .gradient_norms
332 .iter()
333 .collect::<Vec<&f64>>()
334 .windows(2)
335 .map(|pair| (*pair[1] - *pair[0]).abs() / (*pair[0] + 1e-8))
336 .collect();
337
338 let avg_variation = variations.iter().sum::<f64>() / variations.len() as f64;
339 (1.0_f64 / (1.0 + avg_variation)).min(1.0)
340 }
341
342 fn compute_bottleneck_score(&self, history: &GradientHistory) -> f64 {
343 if history.gradient_norms.is_empty() {
344 return 0.0;
345 }
346
347 let mean = history.gradient_norms.iter().sum::<f64>() / history.gradient_norms.len() as f64;
348 let min_val = history.gradient_norms.iter().cloned().fold(f64::INFINITY, f64::min);
349
350 if mean == 0.0 {
351 return 1.0;
352 }
353
354 1.0 - (min_val / mean).min(1.0)
355 }
356
357 fn compute_information_flow_rate(&self, history: &GradientHistory) -> f64 {
358 if history.gradient_norms.len() < 2 {
359 return 0.0;
360 }
361
362 let total_change: f64 = history
363 .gradient_norms
364 .iter()
365 .collect::<Vec<&f64>>()
366 .windows(2)
367 .map(|pair| (*pair[1] - *pair[0]).abs())
368 .sum();
369
370 let time_span = history.gradient_norms.len() as f64;
371 total_change / time_span
372 }
373
374 fn generate_temporal_flows(
375 &self,
376 gradient_histories: &HashMap<String, GradientHistory>,
377 current_step: usize,
378 ) -> Vec<TemporalGradientFlow> {
379 let mut temporal_flows = Vec::new();
380
381 for (layer_name, history) in gradient_histories {
382 if let Some(latest_norm) = history.gradient_norms.back() {
383 let flow_direction = self.get_latest_flow_direction(history);
384 let stability_score = self.compute_stability_score(history);
385
386 temporal_flows.push(TemporalGradientFlow {
387 step: current_step,
388 layer_name: layer_name.clone(),
389 gradient_magnitude: *latest_norm,
390 flow_direction,
391 stability_score,
392 });
393 }
394 }
395
396 temporal_flows
397 }
398
399 fn get_latest_flow_direction(&self, history: &GradientHistory) -> FlowDirection {
400 if history.gradient_norms.len() < 3 {
401 return FlowDirection::Forward;
402 }
403
404 let recent: Vec<f64> = history.gradient_norms.iter().rev().take(3).cloned().collect();
405 let trend = recent[0] - recent[2]; if trend.abs() < 1e-6 {
408 FlowDirection::Stagnant
409 } else if trend > 0.0 {
410 FlowDirection::Forward
411 } else {
412 let changes: Vec<f64> = recent.windows(2).map(|pair| pair[0] - pair[1]).collect();
414 let sign_changes = changes.windows(2).filter(|pair| pair[0] * pair[1] < 0.0).count();
415
416 if sign_changes > 0 {
417 FlowDirection::Oscillating
418 } else {
419 FlowDirection::Backward
420 }
421 }
422 }
423
424 fn compute_stability_score(&self, history: &GradientHistory) -> f64 {
425 if history.gradient_norms.len() < 3 {
426 return 1.0;
427 }
428
429 let recent: Vec<f64> = history.gradient_norms.iter().rev().take(5).cloned().collect();
430 let mean = recent.iter().sum::<f64>() / recent.len() as f64;
431 let variance =
432 recent.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / recent.len() as f64;
433
434 1.0 / (1.0 + variance)
435 }
436
437 fn build_gradient_flow_network(
438 &self,
439 layer_flows: &HashMap<String, GradientLayerFlow>,
440 ) -> GradientFlowNetwork {
441 let mut nodes = Vec::new();
442 let mut edges = Vec::new();
443
444 for (layer_name, flow) in layer_flows {
446 let node_type = self.classify_node_type(flow);
447 let gradient_strength = flow.gradient_magnitudes.iter().sum::<f64>()
448 / flow.gradient_magnitudes.len() as f64;
449 let connectivity = layer_flows.len(); let influence_score = gradient_strength * flow.flow_consistency;
451
452 nodes.push(FlowNode {
453 layer_name: layer_name.clone(),
454 node_type,
455 gradient_strength,
456 connectivity,
457 influence_score,
458 });
459 }
460
461 let layer_names: Vec<String> = layer_flows.keys().cloned().collect();
463 for i in 0..layer_names.len().saturating_sub(1) {
464 let from_layer = &layer_names[i];
465 let to_layer = &layer_names[i + 1];
466
467 if let (Some(from_flow), Some(to_flow)) =
468 (layer_flows.get(from_layer), layer_flows.get(to_layer))
469 {
470 let flow_strength =
471 (from_flow.information_flow_rate + to_flow.information_flow_rate) / 2.0;
472 let flow_consistency =
473 (from_flow.flow_consistency + to_flow.flow_consistency) / 2.0;
474 let edge_type = self.classify_edge_type(flow_strength, flow_consistency);
475
476 edges.push(FlowEdge {
477 from_layer: from_layer.clone(),
478 to_layer: to_layer.clone(),
479 flow_strength,
480 flow_consistency,
481 edge_type,
482 });
483 }
484 }
485
486 let network_metrics = self.compute_network_metrics(&nodes, &edges);
487
488 GradientFlowNetwork {
489 nodes,
490 edges,
491 network_metrics,
492 }
493 }
494
495 fn classify_node_type(&self, flow: &GradientLayerFlow) -> NodeType {
496 if flow.bottleneck_score > 0.8 {
497 NodeType::Bottleneck
498 } else if flow.information_flow_rate > 1.0 {
499 NodeType::Amplifier
500 } else if flow.gradient_magnitudes.iter().sum::<f64>() < 0.01 {
501 NodeType::Sink
502 } else if flow.gradient_magnitudes.iter().any(|&x| x > 10.0) {
503 NodeType::Source
504 } else {
505 NodeType::Normal
506 }
507 }
508
509 fn classify_edge_type(&self, flow_strength: f64, flow_consistency: f64) -> EdgeType {
510 if flow_strength > 1.0 && flow_consistency > 0.8 {
511 EdgeType::Strong
512 } else if flow_strength < 0.1 || flow_consistency < 0.3 {
513 EdgeType::Weak
514 } else if flow_consistency < 0.6 {
515 EdgeType::Intermittent
516 } else {
517 EdgeType::Blocked
518 }
519 }
520
521 fn compute_network_metrics(&self, nodes: &[FlowNode], edges: &[FlowEdge]) -> NetworkMetrics {
522 let overall_flow_efficiency =
523 edges.iter().map(|e| e.flow_strength).sum::<f64>() / edges.len().max(1) as f64;
524 let network_connectivity = edges.len() as f64
525 / (nodes.len().max(1) * (nodes.len().saturating_sub(1)).max(1)) as f64;
526 let bottleneck_density =
527 nodes.iter().filter(|n| matches!(n.node_type, NodeType::Bottleneck)).count() as f64
528 / nodes.len() as f64;
529 let flow_stability =
530 edges.iter().map(|e| e.flow_consistency).sum::<f64>() / edges.len().max(1) as f64;
531 let information_propagation_speed = overall_flow_efficiency * network_connectivity;
532
533 NetworkMetrics {
534 overall_flow_efficiency,
535 network_connectivity,
536 bottleneck_density,
537 flow_stability,
538 information_propagation_speed,
539 }
540 }
541
542 fn identify_critical_gradient_paths(
543 &self,
544 network: &GradientFlowNetwork,
545 ) -> Vec<CriticalGradientPath> {
546 let mut paths = Vec::new();
547
548 if network.nodes.len() < 2 {
550 return paths;
551 }
552
553 let path_layers: Vec<String> = network.nodes.iter().map(|n| n.layer_name.clone()).collect();
554 let total_flow_strength = network.edges.iter().map(|e| e.flow_strength).sum();
555 let bottleneck_layers: Vec<String> = network
556 .nodes
557 .iter()
558 .filter(|n| matches!(n.node_type, NodeType::Bottleneck))
559 .map(|n| n.layer_name.clone())
560 .collect();
561
562 paths.push(CriticalGradientPath {
563 path_id: "main_path".to_string(),
564 path_length: path_layers.len(),
565 layers: path_layers,
566 total_flow_strength,
567 bottleneck_layers,
568 criticality_score: 0.8, optimization_potential: 0.6,
570 });
571
572 paths
573 }
574
575 fn identify_vanishing_regions(
576 &self,
577 gradient_histories: &HashMap<String, GradientHistory>,
578 ) -> Vec<VanishingRegion> {
579 let mut regions = Vec::new();
580 let mut region_id = 0;
581
582 for (layer_name, history) in gradient_histories {
583 let avg_gradient =
584 history.gradient_norms.iter().sum::<f64>() / history.gradient_norms.len() as f64;
585 if avg_gradient < 1e-5 {
586 region_id += 1;
587 regions.push(VanishingRegion {
588 region_id: format!("vanishing_{}", region_id),
589 affected_layers: vec![layer_name.clone()],
590 severity_level: if avg_gradient < 1e-7 {
591 VanishingSeverity::Critical
592 } else {
593 VanishingSeverity::Moderate
594 },
595 extent: RegionExtent {
596 start_layer: layer_name.clone(),
597 end_layer: layer_name.clone(),
598 affected_parameters: 1000, duration_steps: history.gradient_norms.len(),
600 },
601 mitigation_suggestions: vec![
602 "Consider better weight initialization".to_string(),
603 "Add skip connections".to_string(),
604 "Use gradient clipping".to_string(),
605 ],
606 });
607 }
608 }
609
610 regions
611 }
612
613 fn identify_exploding_regions(
614 &self,
615 gradient_histories: &HashMap<String, GradientHistory>,
616 ) -> Vec<ExplodingRegion> {
617 let mut regions = Vec::new();
618 let mut region_id = 0;
619
620 for (layer_name, history) in gradient_histories {
621 let max_gradient = history.gradient_norms.iter().cloned().fold(0.0, f64::max);
622 if max_gradient > 100.0 {
623 region_id += 1;
624 regions.push(ExplodingRegion {
625 region_id: format!("exploding_{}", region_id),
626 affected_layers: vec![layer_name.clone()],
627 severity_level: if max_gradient > 1000.0 {
628 ExplodingSeverity::Critical
629 } else {
630 ExplodingSeverity::Moderate
631 },
632 extent: RegionExtent {
633 start_layer: layer_name.clone(),
634 end_layer: layer_name.clone(),
635 affected_parameters: 1000, duration_steps: history.gradient_norms.len(),
637 },
638 mitigation_suggestions: vec![
639 "Apply gradient clipping".to_string(),
640 "Reduce learning rate".to_string(),
641 "Check weight initialization".to_string(),
642 ],
643 });
644 }
645 }
646
647 regions
648 }
649
650 fn identify_gradient_dead_zones(
651 &self,
652 gradient_histories: &HashMap<String, GradientHistory>,
653 ) -> Vec<GradientDeadZone> {
654 let mut dead_zones = Vec::new();
655 let mut zone_id = 0;
656
657 for (layer_name, history) in gradient_histories {
658 let zero_gradients = history.gradient_norms.iter().filter(|&&x| x < 1e-8).count();
659 let dead_ratio = zero_gradients as f64 / history.gradient_norms.len() as f64;
660
661 if dead_ratio > 0.5 {
662 zone_id += 1;
663 dead_zones.push(GradientDeadZone {
664 zone_id: format!("dead_zone_{}", zone_id),
665 affected_layers: vec![layer_name.clone()],
666 dead_duration: zero_gradients,
667 recovery_potential: if dead_ratio > 0.9 {
668 RecoveryPotential::Low
669 } else {
670 RecoveryPotential::Medium
671 },
672 intervention_required: dead_ratio > 0.8,
673 });
674 }
675 }
676
677 dead_zones
678 }
679
680 pub fn create_visualization(
682 &self,
683 gradient_histories: &HashMap<String, GradientHistory>,
684 ) -> GradientFlowVisualization {
685 self.generate_visualization(gradient_histories, 0)
687 }
688}