Skip to main content

trustformers_debug/
neural_network_debugging.rs

1//! Advanced Neural Network Debugging Utilities
2//!
3//! This module provides specialized debugging utilities for modern neural network architectures,
4//! with particular focus on transformer models, attention mechanisms, and large-scale training scenarios.
5
6use anyhow::Result;
7use scirs2_core::ndarray::*; // SciRS2 Integration Policy - was: use ndarray::{Array, ArrayD, IxDyn};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10
11/// Advanced attention mechanism debugger for transformer architectures
12#[derive(Debug)]
13pub struct AttentionDebugger {
14    pub config: AttentionDebugConfig,
15    attention_maps: Vec<AttentionMap>,
16    head_analysis: HashMap<usize, AttentionHeadAnalysis>,
17}
18
19/// Configuration for attention debugging
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct AttentionDebugConfig {
22    pub enable_attention_visualization: bool,
23    pub enable_head_analysis: bool,
24    pub enable_pattern_detection: bool,
25    pub attention_threshold: f32,
26    pub max_heads_to_analyze: usize,
27}
28
29impl Default for AttentionDebugConfig {
30    fn default() -> Self {
31        Self {
32            enable_attention_visualization: true,
33            enable_head_analysis: true,
34            enable_pattern_detection: true,
35            attention_threshold: 0.01,
36            max_heads_to_analyze: 16,
37        }
38    }
39}
40
41/// Attention map for a specific layer and head
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct AttentionMap {
44    pub layer_index: usize,
45    pub head_index: usize,
46    pub sequence_length: usize,
47    pub attention_weights: Vec<Vec<f32>>,
48    pub attention_pattern: AttentionPattern,
49    pub attention_entropy: f32,
50    pub sparsity_ratio: f32,
51}
52
53/// Analysis of individual attention head behavior
54#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct AttentionHeadAnalysis {
56    pub head_id: usize,
57    pub layer_id: usize,
58    pub specialization_type: HeadSpecializationType,
59    pub attention_distribution: AttentionDistribution,
60    pub redundancy_score: f32,
61    pub importance_score: f32,
62    pub patterns_detected: Vec<AttentionPattern>,
63}
64
65/// Types of attention head specializations commonly found in transformers
66#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
67pub enum HeadSpecializationType {
68    LocalSyntax,  // Focuses on local syntactic patterns
69    LongRange,    // Captures long-range dependencies
70    Positional,   // Primarily position-based attention
71    ContentBased, // Content-driven attention patterns
72    Copying,      // Copy mechanisms
73    Delimiter,    // Focuses on delimiters and boundaries
74    Mixed,        // Mixed functionality
75    Redundant,    // Highly redundant with other heads
76}
77
78/// Attention pattern types
79#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
80pub enum AttentionPattern {
81    Diagonal,     // Attention along diagonal (local patterns)
82    Block,        // Block-structured attention
83    Sparse,       // Sparse attention patterns
84    Uniform,      // Uniform attention distribution
85    Concentrated, // Highly concentrated attention
86    Strided,      // Strided patterns
87    Random,       // Random/chaotic patterns
88}
89
90/// Distribution characteristics of attention weights
91#[derive(Debug, Clone, Serialize, Deserialize)]
92pub struct AttentionDistribution {
93    pub mean_attention: f32,
94    pub std_attention: f32,
95    pub max_attention: f32,
96    pub min_attention: f32,
97    pub entropy: f32,
98    pub effective_context_length: f32,
99}
100
101impl AttentionDebugger {
102    /// Create a new attention debugger
103    pub fn new(config: AttentionDebugConfig) -> Self {
104        Self {
105            config,
106            attention_maps: Vec::new(),
107            head_analysis: HashMap::new(),
108        }
109    }
110
111    /// Analyze attention weights for a transformer layer
112    pub fn analyze_attention_layer(
113        &mut self,
114        layer_index: usize,
115        attention_weights: &[ArrayD<f32>], // One array per head
116    ) -> Result<LayerAttentionAnalysis> {
117        let mut head_analyses = Vec::new();
118        let mut attention_maps = Vec::new();
119
120        for (head_index, weights) in attention_weights.iter().enumerate() {
121            if head_index >= self.config.max_heads_to_analyze {
122                break;
123            }
124
125            // Create attention map
126            let attention_map = self.create_attention_map(layer_index, head_index, weights)?;
127            attention_maps.push(attention_map.clone());
128            self.attention_maps.push(attention_map);
129
130            // Analyze head behavior
131            let head_analysis = self.analyze_attention_head(layer_index, head_index, weights)?;
132            head_analyses.push(head_analysis.clone());
133            self.head_analysis.insert(head_index, head_analysis);
134        }
135
136        let layer_diversity_score = self.compute_layer_diversity(&head_analyses);
137        let redundancy_analysis = self.analyze_head_redundancy(&head_analyses);
138
139        Ok(LayerAttentionAnalysis {
140            layer_index,
141            num_heads: attention_weights.len(),
142            head_analyses,
143            attention_maps,
144            layer_diversity_score,
145            redundancy_analysis,
146        })
147    }
148
149    /// Create attention map from weights
150    fn create_attention_map(
151        &self,
152        layer_index: usize,
153        head_index: usize,
154        weights: &ArrayD<f32>,
155    ) -> Result<AttentionMap> {
156        let shape = weights.shape();
157        if shape.len() != 2 {
158            return Err(anyhow::anyhow!(
159                "Expected 2D attention weights, got {}D",
160                shape.len()
161            ));
162        }
163
164        let seq_len = shape[0];
165        let attention_weights: Vec<Vec<f32>> =
166            (0..seq_len).map(|i| (0..shape[1]).map(|j| weights[[i, j]]).collect()).collect();
167
168        let pattern = self.detect_attention_pattern(&attention_weights);
169        let entropy = self.compute_attention_entropy(&attention_weights);
170        let sparsity = self.compute_sparsity_ratio(&attention_weights);
171
172        Ok(AttentionMap {
173            layer_index,
174            head_index,
175            sequence_length: seq_len,
176            attention_weights,
177            attention_pattern: pattern,
178            attention_entropy: entropy,
179            sparsity_ratio: sparsity,
180        })
181    }
182
183    /// Analyze individual attention head
184    fn analyze_attention_head(
185        &self,
186        layer_index: usize,
187        head_index: usize,
188        weights: &ArrayD<f32>,
189    ) -> Result<AttentionHeadAnalysis> {
190        let specialization = self.classify_head_specialization(weights)?;
191        let distribution = self.compute_attention_distribution(weights)?;
192        let patterns = vec![self.detect_attention_pattern_from_weights(weights)?];
193
194        Ok(AttentionHeadAnalysis {
195            head_id: head_index,
196            layer_id: layer_index,
197            specialization_type: specialization,
198            attention_distribution: distribution,
199            redundancy_score: 0.0, // Will be computed later
200            importance_score: self.compute_head_importance(weights)?,
201            patterns_detected: patterns,
202        })
203    }
204
205    /// Detect attention pattern from weights matrix
206    fn detect_attention_pattern(&self, weights: &[Vec<f32>]) -> AttentionPattern {
207        let seq_len = weights.len();
208        if seq_len == 0 {
209            return AttentionPattern::Random;
210        }
211
212        // Check for diagonal pattern
213        let diagonal_strength = self.measure_diagonal_strength(weights);
214        if diagonal_strength > 0.7 {
215            return AttentionPattern::Diagonal;
216        }
217
218        // Check for sparse pattern
219        let sparsity = self.compute_sparsity_ratio(weights);
220        if sparsity > 0.8 {
221            return AttentionPattern::Sparse;
222        }
223
224        // Check for uniform pattern
225        let uniformity = self.measure_uniformity(weights);
226        if uniformity > 0.8 {
227            return AttentionPattern::Uniform;
228        }
229
230        // Check for block pattern
231        if self.has_block_structure(weights) {
232            return AttentionPattern::Block;
233        }
234
235        AttentionPattern::Random
236    }
237
238    /// Measure diagonal strength in attention pattern
239    fn measure_diagonal_strength(&self, weights: &[Vec<f32>]) -> f32 {
240        let seq_len = weights.len();
241        if seq_len == 0 {
242            return 0.0;
243        }
244
245        let mut diagonal_sum = 0.0;
246        let mut total_sum = 0.0;
247        let window_size = 3; // Look at diagonal +/- window
248
249        for i in 0..seq_len {
250            for j in 0..weights[i].len() {
251                let weight = weights[i][j];
252                total_sum += weight;
253
254                if (i as i32 - j as i32).abs() <= window_size {
255                    diagonal_sum += weight;
256                }
257            }
258        }
259
260        if total_sum > 0.0 {
261            diagonal_sum / total_sum
262        } else {
263            0.0
264        }
265    }
266
267    /// Measure uniformity of attention distribution
268    fn measure_uniformity(&self, weights: &[Vec<f32>]) -> f32 {
269        let seq_len = weights.len();
270        if seq_len == 0 {
271            return 0.0;
272        }
273
274        let expected_weight = 1.0 / seq_len as f32;
275        let mut deviation_sum = 0.0;
276        let mut count = 0;
277
278        for row in weights {
279            for &weight in row {
280                deviation_sum += (weight - expected_weight).abs();
281                count += 1;
282            }
283        }
284
285        if count > 0 {
286            1.0 - (deviation_sum / count as f32)
287        } else {
288            0.0
289        }
290    }
291
292    /// Check if attention has block structure
293    fn has_block_structure(&self, weights: &[Vec<f32>]) -> bool {
294        // Simplified block detection - look for concentrated regions
295        let seq_len = weights.len();
296        if seq_len < 4 {
297            return false;
298        }
299
300        let block_size = seq_len / 4;
301        let mut block_concentrations = Vec::new();
302
303        for block_start in (0..seq_len).step_by(block_size) {
304            let block_end = (block_start + block_size).min(seq_len);
305            let mut block_sum = 0.0;
306            let mut block_count = 0;
307
308            for i in block_start..block_end {
309                for j in block_start..(block_end.min(weights[i].len())) {
310                    block_sum += weights[i][j];
311                    block_count += 1;
312                }
313            }
314
315            if block_count > 0 {
316                block_concentrations.push(block_sum / block_count as f32);
317            }
318        }
319
320        // Check if some blocks have significantly higher concentration
321        if block_concentrations.len() < 2 {
322            return false;
323        }
324
325        let max_concentration = block_concentrations.iter().cloned().fold(0.0f32, f32::max);
326        let avg_concentration =
327            block_concentrations.iter().sum::<f32>() / block_concentrations.len() as f32;
328
329        max_concentration > avg_concentration * 2.0
330    }
331
332    /// Classify attention head specialization
333    fn classify_head_specialization(
334        &self,
335        weights: &ArrayD<f32>,
336    ) -> Result<HeadSpecializationType> {
337        let shape = weights.shape();
338        if shape.len() != 2 {
339            return Ok(HeadSpecializationType::Mixed);
340        }
341
342        let seq_len = shape[0];
343
344        // Convert to 2D vec for analysis
345        let weights_2d: Vec<Vec<f32>> =
346            (0..seq_len).map(|i| (0..shape[1]).map(|j| weights[[i, j]]).collect()).collect();
347
348        // Analyze patterns to determine specialization
349        let diagonal_strength = self.measure_diagonal_strength(&weights_2d);
350        let long_range_strength = self.measure_long_range_attention(&weights_2d);
351        let positional_bias = self.measure_positional_bias(&weights_2d);
352
353        Ok(if diagonal_strength > 0.7 {
354            HeadSpecializationType::LocalSyntax
355        } else if long_range_strength > 0.6 {
356            HeadSpecializationType::LongRange
357        } else if positional_bias > 0.8 {
358            HeadSpecializationType::Positional
359        } else {
360            HeadSpecializationType::ContentBased
361        })
362    }
363
364    /// Measure long-range attention strength
365    fn measure_long_range_attention(&self, weights: &[Vec<f32>]) -> f32 {
366        let seq_len = weights.len();
367        if seq_len < 4 {
368            return 0.0;
369        }
370
371        let mut long_range_sum = 0.0;
372        let mut total_sum = 0.0;
373        let long_range_threshold = seq_len / 4; // Consider distances > 1/4 sequence length as long-range
374
375        for i in 0..seq_len {
376            for j in 0..weights[i].len() {
377                let weight = weights[i][j];
378                total_sum += weight;
379
380                if (i as i32 - j as i32).abs() > long_range_threshold as i32 {
381                    long_range_sum += weight;
382                }
383            }
384        }
385
386        if total_sum > 0.0 {
387            long_range_sum / total_sum
388        } else {
389            0.0
390        }
391    }
392
393    /// Measure positional bias in attention
394    fn measure_positional_bias(&self, weights: &[Vec<f32>]) -> f32 {
395        let seq_len = weights.len();
396        if seq_len == 0 {
397            return 0.0;
398        }
399
400        // Measure how much attention depends on absolute position vs content
401        let mut position_correlation = 0.0;
402        let mut count = 0;
403
404        for i in 0..seq_len {
405            for j in 0..weights[i].len().min(seq_len) {
406                let position_similarity = 1.0 - (i as f32 - j as f32).abs() / seq_len as f32;
407                position_correlation += weights[i][j] * position_similarity;
408                count += 1;
409            }
410        }
411
412        if count > 0 {
413            position_correlation / count as f32
414        } else {
415            0.0
416        }
417    }
418
419    /// Compute attention distribution statistics
420    fn compute_attention_distribution(
421        &self,
422        weights: &ArrayD<f32>,
423    ) -> Result<AttentionDistribution> {
424        let values: Vec<f32> = weights.iter().cloned().collect();
425
426        if values.is_empty() {
427            return Ok(AttentionDistribution {
428                mean_attention: 0.0,
429                std_attention: 0.0,
430                max_attention: 0.0,
431                min_attention: 0.0,
432                entropy: 0.0,
433                effective_context_length: 0.0,
434            });
435        }
436
437        let mean = values.iter().sum::<f32>() / values.len() as f32;
438        let variance = values.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / values.len() as f32;
439        let std_dev = variance.sqrt();
440        let max_val = values.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
441        let min_val = values.iter().cloned().fold(f32::INFINITY, f32::min);
442
443        // Compute entropy
444        let entropy = self.compute_entropy(&values);
445
446        // Compute effective context length (based on attention concentration)
447        let effective_length = self.compute_effective_context_length(&values);
448
449        Ok(AttentionDistribution {
450            mean_attention: mean,
451            std_attention: std_dev,
452            max_attention: max_val,
453            min_attention: min_val,
454            entropy,
455            effective_context_length: effective_length,
456        })
457    }
458
459    /// Compute entropy of attention distribution
460    fn compute_entropy(&self, values: &[f32]) -> f32 {
461        if values.is_empty() {
462            return 0.0;
463        }
464
465        let sum: f32 = values.iter().sum();
466        if sum <= 0.0 {
467            return 0.0;
468        }
469
470        let mut entropy = 0.0;
471        for &value in values {
472            if value > 0.0 {
473                let prob = value / sum;
474                entropy -= prob * prob.log2();
475            }
476        }
477
478        entropy
479    }
480
481    /// Compute effective context length based on attention concentration
482    fn compute_effective_context_length(&self, values: &[f32]) -> f32 {
483        if values.is_empty() {
484            return 0.0;
485        }
486
487        let sum: f32 = values.iter().sum();
488        if sum <= 0.0 {
489            return 0.0;
490        }
491
492        // Compute 90th percentile threshold
493        let mut sorted_values = values.to_vec();
494        sorted_values.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
495
496        let mut cumulative_sum = 0.0;
497        let target_sum = sum * 0.9; // 90% of attention mass
498
499        for (i, &value) in sorted_values.iter().enumerate() {
500            cumulative_sum += value;
501            if cumulative_sum >= target_sum {
502                return (i + 1) as f32;
503            }
504        }
505
506        values.len() as f32
507    }
508
509    /// Detect attention pattern from weights array
510    fn detect_attention_pattern_from_weights(
511        &self,
512        weights: &ArrayD<f32>,
513    ) -> Result<AttentionPattern> {
514        let shape = weights.shape();
515        if shape.len() != 2 {
516            return Ok(AttentionPattern::Random);
517        }
518
519        let weights_2d: Vec<Vec<f32>> = (0..shape[0])
520            .map(|i| (0..shape[1]).map(|j| weights[[i, j]]).collect())
521            .collect();
522
523        Ok(self.detect_attention_pattern(&weights_2d))
524    }
525
526    /// Compute head importance score
527    fn compute_head_importance(&self, weights: &ArrayD<f32>) -> Result<f32> {
528        let values: Vec<f32> = weights.iter().cloned().collect();
529
530        if values.is_empty() {
531            return Ok(0.0);
532        }
533
534        // Importance based on entropy (higher entropy = more important for diverse attention)
535        let entropy = self.compute_entropy(&values);
536        let max_entropy = (values.len() as f32).log2();
537
538        if max_entropy > 0.0 {
539            Ok(entropy / max_entropy)
540        } else {
541            Ok(0.0)
542        }
543    }
544
545    /// Compute attention entropy
546    fn compute_attention_entropy(&self, weights: &[Vec<f32>]) -> f32 {
547        let values: Vec<f32> = weights.iter().flatten().cloned().collect();
548        self.compute_entropy(&values)
549    }
550
551    /// Compute sparsity ratio
552    fn compute_sparsity_ratio(&self, weights: &[Vec<f32>]) -> f32 {
553        let total_count = weights.iter().map(|row| row.len()).sum::<usize>();
554        if total_count == 0 {
555            return 0.0;
556        }
557
558        let sparse_count = weights
559            .iter()
560            .flatten()
561            .filter(|&&w| w < self.config.attention_threshold)
562            .count();
563
564        sparse_count as f32 / total_count as f32
565    }
566
567    /// Compute layer diversity score
568    fn compute_layer_diversity(&self, head_analyses: &[AttentionHeadAnalysis]) -> f32 {
569        if head_analyses.len() < 2 {
570            return 0.0;
571        }
572
573        // Measure diversity based on different specialization types
574        let mut specialization_counts: HashMap<HeadSpecializationType, usize> = HashMap::new();
575        for analysis in head_analyses {
576            *specialization_counts.entry(analysis.specialization_type.clone()).or_insert(0) += 1;
577        }
578
579        let num_types = specialization_counts.len() as f32;
580        let max_types = 8.0; // Maximum possible specialization types
581
582        num_types / max_types
583    }
584
585    /// Analyze head redundancy
586    fn analyze_head_redundancy(
587        &self,
588        head_analyses: &[AttentionHeadAnalysis],
589    ) -> RedundancyAnalysis {
590        let mut redundant_heads = Vec::new();
591        let redundancy_groups = Vec::new();
592
593        // Group heads by similar behavior
594        for i in 0..head_analyses.len() {
595            for j in (i + 1)..head_analyses.len() {
596                let similarity = self.compute_head_similarity(&head_analyses[i], &head_analyses[j]);
597                if similarity > 0.8 {
598                    redundant_heads.push((i, j, similarity));
599                }
600            }
601        }
602
603        RedundancyAnalysis {
604            redundant_head_pairs: redundant_heads,
605            redundancy_groups,
606            overall_redundancy_score: self.compute_overall_redundancy(head_analyses),
607        }
608    }
609
610    /// Compute similarity between two attention heads
611    fn compute_head_similarity(
612        &self,
613        head1: &AttentionHeadAnalysis,
614        head2: &AttentionHeadAnalysis,
615    ) -> f32 {
616        // Similarity based on specialization type and attention distribution
617        let type_similarity =
618            if head1.specialization_type == head2.specialization_type { 1.0 } else { 0.0 };
619
620        let dist_similarity = {
621            let d1 = &head1.attention_distribution;
622            let d2 = &head2.attention_distribution;
623
624            let mean_diff = (d1.mean_attention - d2.mean_attention).abs();
625            let std_diff = (d1.std_attention - d2.std_attention).abs();
626            let entropy_diff = (d1.entropy - d2.entropy).abs();
627
628            1.0 - (mean_diff + std_diff + entropy_diff) / 3.0
629        };
630
631        (type_similarity + dist_similarity) / 2.0
632    }
633
634    /// Compute overall redundancy score for the layer
635    fn compute_overall_redundancy(&self, head_analyses: &[AttentionHeadAnalysis]) -> f32 {
636        if head_analyses.len() < 2 {
637            return 0.0;
638        }
639
640        let mut total_similarity = 0.0;
641        let mut pair_count = 0;
642
643        for i in 0..head_analyses.len() {
644            for j in (i + 1)..head_analyses.len() {
645                total_similarity +=
646                    self.compute_head_similarity(&head_analyses[i], &head_analyses[j]);
647                pair_count += 1;
648            }
649        }
650
651        if pair_count > 0 {
652            total_similarity / pair_count as f32
653        } else {
654            0.0
655        }
656    }
657}
658
659/// Analysis results for a transformer layer's attention mechanism
660#[derive(Debug, Clone, Serialize, Deserialize)]
661pub struct LayerAttentionAnalysis {
662    pub layer_index: usize,
663    pub num_heads: usize,
664    pub head_analyses: Vec<AttentionHeadAnalysis>,
665    pub attention_maps: Vec<AttentionMap>,
666    pub layer_diversity_score: f32,
667    pub redundancy_analysis: RedundancyAnalysis,
668}
669
670/// Analysis of attention head redundancy
671#[derive(Debug, Clone, Serialize, Deserialize)]
672pub struct RedundancyAnalysis {
673    pub redundant_head_pairs: Vec<(usize, usize, f32)>, // (head1, head2, similarity)
674    pub redundancy_groups: Vec<Vec<usize>>,
675    pub overall_redundancy_score: f32,
676}
677
678/// Transformer-specific debugging utilities
679#[derive(Debug)]
680pub struct TransformerDebugger {
681    pub config: TransformerDebugConfig,
682    layer_analyses: Vec<LayerAttentionAnalysis>,
683    attention_debugger: AttentionDebugger,
684}
685
686/// Configuration for transformer debugging
687#[derive(Debug, Clone, Serialize, Deserialize)]
688pub struct TransformerDebugConfig {
689    pub attention_config: AttentionDebugConfig,
690    pub enable_layer_analysis: bool,
691    pub enable_cross_layer_analysis: bool,
692    pub max_layers_to_analyze: usize,
693}
694
695impl Default for TransformerDebugConfig {
696    fn default() -> Self {
697        Self {
698            attention_config: AttentionDebugConfig::default(),
699            enable_layer_analysis: true,
700            enable_cross_layer_analysis: true,
701            max_layers_to_analyze: 48, // Support for large models
702        }
703    }
704}
705
706impl TransformerDebugger {
707    /// Create a new transformer debugger
708    pub fn new(config: TransformerDebugConfig) -> Self {
709        let attention_debugger = AttentionDebugger::new(config.attention_config.clone());
710
711        Self {
712            config,
713            layer_analyses: Vec::new(),
714            attention_debugger,
715        }
716    }
717
718    /// Analyze entire transformer model attention patterns
719    pub fn analyze_transformer_attention(
720        &mut self,
721        model_attention_weights: &[Vec<ArrayD<f32>>], // [layer][head] -> attention weights
722    ) -> Result<TransformerAttentionAnalysis> {
723        let mut layer_analyses = Vec::new();
724
725        for (layer_idx, layer_weights) in model_attention_weights.iter().enumerate() {
726            if layer_idx >= self.config.max_layers_to_analyze {
727                break;
728            }
729
730            let layer_analysis =
731                self.attention_debugger.analyze_attention_layer(layer_idx, layer_weights)?;
732            layer_analyses.push(layer_analysis);
733        }
734
735        self.layer_analyses = layer_analyses.clone();
736
737        // Perform cross-layer analysis
738        let cross_layer_analysis = if self.config.enable_cross_layer_analysis {
739            Some(self.perform_cross_layer_analysis(&layer_analyses)?)
740        } else {
741            None
742        };
743
744        Ok(TransformerAttentionAnalysis {
745            num_layers: model_attention_weights.len(),
746            layer_analyses,
747            cross_layer_analysis,
748            model_attention_summary: self.generate_model_attention_summary()?,
749        })
750    }
751
752    /// Perform cross-layer attention analysis
753    fn perform_cross_layer_analysis(
754        &self,
755        layer_analyses: &[LayerAttentionAnalysis],
756    ) -> Result<CrossLayerAnalysis> {
757        let attention_evolution = self.analyze_attention_evolution(layer_analyses)?;
758        let head_consistency = self.analyze_head_consistency(layer_analyses)?;
759        let pattern_progression = self.analyze_pattern_progression(layer_analyses)?;
760
761        Ok(CrossLayerAnalysis {
762            attention_evolution,
763            head_consistency,
764            pattern_progression,
765            layer_diversity_trend: self.compute_layer_diversity_trend(layer_analyses),
766        })
767    }
768
769    /// Analyze how attention patterns evolve across layers
770    fn analyze_attention_evolution(
771        &self,
772        layer_analyses: &[LayerAttentionAnalysis],
773    ) -> Result<AttentionEvolution> {
774        let mut entropy_trend = Vec::new();
775        let mut sparsity_trend = Vec::new();
776
777        for layer in layer_analyses {
778            let layer_entropy: f32 =
779                layer.attention_maps.iter().map(|map| map.attention_entropy).sum::<f32>()
780                    / layer.attention_maps.len() as f32;
781            let layer_sparsity: f32 =
782                layer.attention_maps.iter().map(|map| map.sparsity_ratio).sum::<f32>()
783                    / layer.attention_maps.len() as f32;
784
785            entropy_trend.push(layer_entropy);
786            sparsity_trend.push(layer_sparsity);
787        }
788
789        let evolution_type = self.classify_evolution_type(&entropy_trend);
790
791        Ok(AttentionEvolution {
792            entropy_trend,
793            sparsity_trend,
794            evolution_type,
795        })
796    }
797
798    /// Classify the type of attention evolution
799    fn classify_evolution_type(&self, entropy_trend: &[f32]) -> EvolutionType {
800        if entropy_trend.len() < 3 {
801            return EvolutionType::Stable;
802        }
803
804        let start_entropy = entropy_trend[0];
805        let end_entropy = entropy_trend[entropy_trend.len() - 1];
806        let change_ratio = (end_entropy - start_entropy) / start_entropy.max(1e-8);
807
808        match change_ratio {
809            r if r > 0.2 => EvolutionType::Increasing,
810            r if r < -0.2 => EvolutionType::Decreasing,
811            _ => EvolutionType::Stable,
812        }
813    }
814
815    /// Analyze head consistency across layers
816    fn analyze_head_consistency(
817        &self,
818        layer_analyses: &[LayerAttentionAnalysis],
819    ) -> Result<HeadConsistency> {
820        let mut specialization_consistency = HashMap::new();
821        let pattern_consistency = HashMap::new();
822
823        // Track how specialization types are distributed across layers
824        for layer in layer_analyses {
825            for head in &layer.head_analyses {
826                let spec_type = &head.specialization_type;
827                let layer_counts =
828                    specialization_consistency.entry(spec_type.clone()).or_insert_with(Vec::new);
829                layer_counts.push(layer.layer_index);
830            }
831        }
832
833        Ok(HeadConsistency {
834            specialization_consistency,
835            pattern_consistency,
836            consistency_score: self.compute_consistency_score(layer_analyses),
837        })
838    }
839
840    /// Compute overall consistency score
841    fn compute_consistency_score(&self, layer_analyses: &[LayerAttentionAnalysis]) -> f32 {
842        if layer_analyses.len() < 2 {
843            return 1.0;
844        }
845
846        // Measure how similar the distribution of head types is across layers
847        let mut layer_distributions = Vec::new();
848
849        for layer in layer_analyses {
850            let mut distribution: HashMap<HeadSpecializationType, f32> = HashMap::new();
851            for head in &layer.head_analyses {
852                *distribution.entry(head.specialization_type.clone()).or_insert(0.0) += 1.0;
853            }
854
855            // Normalize
856            let total: f32 = distribution.values().sum();
857            if total > 0.0 {
858                for value in distribution.values_mut() {
859                    *value /= total;
860                }
861            }
862
863            layer_distributions.push(distribution);
864        }
865
866        // Compute average pairwise similarity
867        let mut total_similarity = 0.0;
868        let mut pair_count = 0;
869
870        for i in 0..layer_distributions.len() {
871            for j in (i + 1)..layer_distributions.len() {
872                let similarity = self.compute_distribution_similarity(
873                    &layer_distributions[i],
874                    &layer_distributions[j],
875                );
876                total_similarity += similarity;
877                pair_count += 1;
878            }
879        }
880
881        if pair_count > 0 {
882            total_similarity / pair_count as f32
883        } else {
884            1.0
885        }
886    }
887
888    /// Compute similarity between two distributions
889    fn compute_distribution_similarity(
890        &self,
891        dist1: &HashMap<HeadSpecializationType, f32>,
892        dist2: &HashMap<HeadSpecializationType, f32>,
893    ) -> f32 {
894        let mut all_keys: std::collections::HashSet<_> = dist1.keys().collect();
895        all_keys.extend(dist2.keys());
896
897        let mut similarity = 0.0;
898        for key in all_keys {
899            let val1 = dist1.get(key).unwrap_or(&0.0);
900            let val2 = dist2.get(key).unwrap_or(&0.0);
901            similarity += (val1 - val2).abs();
902        }
903
904        1.0 - (similarity / 2.0) // Normalize to [0, 1]
905    }
906
907    /// Analyze pattern progression across layers
908    fn analyze_pattern_progression(
909        &self,
910        layer_analyses: &[LayerAttentionAnalysis],
911    ) -> Result<PatternProgression> {
912        let mut pattern_evolution = Vec::new();
913
914        for layer in layer_analyses {
915            let mut pattern_counts: HashMap<AttentionPattern, usize> = HashMap::new();
916            for map in &layer.attention_maps {
917                *pattern_counts.entry(map.attention_pattern.clone()).or_insert(0) += 1;
918            }
919            pattern_evolution.push(pattern_counts);
920        }
921
922        let dominant_pattern_sequence = self.extract_dominant_patterns(&pattern_evolution);
923
924        Ok(PatternProgression {
925            pattern_evolution,
926            dominant_pattern_sequence,
927        })
928    }
929
930    /// Extract sequence of dominant patterns across layers
931    fn extract_dominant_patterns(
932        &self,
933        pattern_evolution: &[HashMap<AttentionPattern, usize>],
934    ) -> Vec<AttentionPattern> {
935        pattern_evolution
936            .iter()
937            .map(|patterns| {
938                patterns
939                    .iter()
940                    .max_by_key(|(_, &count)| count)
941                    .map(|(pattern, _)| pattern.clone())
942                    .unwrap_or(AttentionPattern::Random)
943            })
944            .collect()
945    }
946
947    /// Compute layer diversity trend
948    fn compute_layer_diversity_trend(&self, layer_analyses: &[LayerAttentionAnalysis]) -> Vec<f32> {
949        layer_analyses.iter().map(|layer| layer.layer_diversity_score).collect()
950    }
951
952    /// Generate model attention summary
953    fn generate_model_attention_summary(&self) -> Result<ModelAttentionSummary> {
954        if self.layer_analyses.is_empty() {
955            return Ok(ModelAttentionSummary::default());
956        }
957
958        let total_heads: usize = self.layer_analyses.iter().map(|layer| layer.num_heads).sum();
959        let avg_diversity: f32 =
960            self.layer_analyses.iter().map(|layer| layer.layer_diversity_score).sum::<f32>()
961                / self.layer_analyses.len() as f32;
962        let avg_redundancy: f32 = self
963            .layer_analyses
964            .iter()
965            .map(|layer| layer.redundancy_analysis.overall_redundancy_score)
966            .sum::<f32>()
967            / self.layer_analyses.len() as f32;
968
969        // Count specialization types across all layers
970        let mut specialization_distribution: HashMap<HeadSpecializationType, usize> =
971            HashMap::new();
972        for layer in &self.layer_analyses {
973            for head in &layer.head_analyses {
974                *specialization_distribution
975                    .entry(head.specialization_type.clone())
976                    .or_insert(0) += 1;
977            }
978        }
979
980        Ok(ModelAttentionSummary {
981            total_layers: self.layer_analyses.len(),
982            total_heads,
983            average_diversity_score: avg_diversity,
984            average_redundancy_score: avg_redundancy,
985            specialization_distribution,
986            model_attention_health: self
987                .assess_model_attention_health(avg_diversity, avg_redundancy),
988        })
989    }
990
991    /// Assess overall model attention health
992    fn assess_model_attention_health(
993        &self,
994        diversity: f32,
995        redundancy: f32,
996    ) -> AttentionHealthStatus {
997        let health_score = diversity * (1.0 - redundancy); // High diversity, low redundancy is good
998
999        match health_score {
1000            s if s > 0.7 => AttentionHealthStatus::Excellent,
1001            s if s > 0.5 => AttentionHealthStatus::Good,
1002            s if s > 0.3 => AttentionHealthStatus::Fair,
1003            _ => AttentionHealthStatus::Poor,
1004        }
1005    }
1006}
1007
1008/// Complete transformer attention analysis results
1009#[derive(Debug, Serialize, Deserialize)]
1010pub struct TransformerAttentionAnalysis {
1011    pub num_layers: usize,
1012    pub layer_analyses: Vec<LayerAttentionAnalysis>,
1013    pub cross_layer_analysis: Option<CrossLayerAnalysis>,
1014    pub model_attention_summary: ModelAttentionSummary,
1015}
1016
1017/// Cross-layer attention analysis
1018#[derive(Debug, Serialize, Deserialize)]
1019pub struct CrossLayerAnalysis {
1020    pub attention_evolution: AttentionEvolution,
1021    pub head_consistency: HeadConsistency,
1022    pub pattern_progression: PatternProgression,
1023    pub layer_diversity_trend: Vec<f32>,
1024}
1025
1026/// Attention evolution across layers
1027#[derive(Debug, Serialize, Deserialize)]
1028pub struct AttentionEvolution {
1029    pub entropy_trend: Vec<f32>,
1030    pub sparsity_trend: Vec<f32>,
1031    pub evolution_type: EvolutionType,
1032}
1033
1034/// Types of attention evolution patterns
1035#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
1036pub enum EvolutionType {
1037    Increasing, // Attention becomes more diverse
1038    Decreasing, // Attention becomes more focused
1039    Stable,     // Attention patterns remain consistent
1040}
1041
1042/// Head consistency analysis across layers
1043#[derive(Debug, Serialize, Deserialize)]
1044pub struct HeadConsistency {
1045    pub specialization_consistency: HashMap<HeadSpecializationType, Vec<usize>>,
1046    pub pattern_consistency: HashMap<AttentionPattern, Vec<usize>>,
1047    pub consistency_score: f32,
1048}
1049
1050/// Pattern progression analysis
1051#[derive(Debug, Serialize, Deserialize)]
1052pub struct PatternProgression {
1053    pub pattern_evolution: Vec<HashMap<AttentionPattern, usize>>,
1054    pub dominant_pattern_sequence: Vec<AttentionPattern>,
1055}
1056
1057/// Model-level attention summary
1058#[derive(Debug, Serialize, Deserialize)]
1059pub struct ModelAttentionSummary {
1060    pub total_layers: usize,
1061    pub total_heads: usize,
1062    pub average_diversity_score: f32,
1063    pub average_redundancy_score: f32,
1064    pub specialization_distribution: HashMap<HeadSpecializationType, usize>,
1065    pub model_attention_health: AttentionHealthStatus,
1066}
1067
1068impl Default for ModelAttentionSummary {
1069    fn default() -> Self {
1070        Self {
1071            total_layers: 0,
1072            total_heads: 0,
1073            average_diversity_score: 0.0,
1074            average_redundancy_score: 0.0,
1075            specialization_distribution: HashMap::new(),
1076            model_attention_health: AttentionHealthStatus::Poor,
1077        }
1078    }
1079}
1080
1081/// Overall attention health status
1082#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
1083pub enum AttentionHealthStatus {
1084    Excellent,
1085    Good,
1086    Fair,
1087    Poor,
1088}
1089
1090// Convenience macro for creating attention debugger with default config
1091#[macro_export]
1092macro_rules! debug_attention {
1093    ($attention_weights:expr) => {{
1094        let mut debugger = $crate::neural_network_debugging::AttentionDebugger::new(
1095            $crate::neural_network_debugging::AttentionDebugConfig::default(),
1096        );
1097        debugger.analyze_attention_layer(0, $attention_weights)
1098    }};
1099}
1100
1101// Convenience macro for transformer debugging
1102#[macro_export]
1103macro_rules! debug_transformer {
1104    ($model_weights:expr) => {{
1105        let mut debugger = $crate::neural_network_debugging::TransformerDebugger::new(
1106            $crate::neural_network_debugging::TransformerDebugConfig::default(),
1107        );
1108        debugger.analyze_transformer_attention($model_weights)
1109    }};
1110}