1use anyhow::Result;
7use scirs2_core::ndarray::*; use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10
11#[derive(Debug)]
13pub struct AttentionDebugger {
14 pub config: AttentionDebugConfig,
15 attention_maps: Vec<AttentionMap>,
16 head_analysis: HashMap<usize, AttentionHeadAnalysis>,
17}
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct AttentionDebugConfig {
22 pub enable_attention_visualization: bool,
23 pub enable_head_analysis: bool,
24 pub enable_pattern_detection: bool,
25 pub attention_threshold: f32,
26 pub max_heads_to_analyze: usize,
27}
28
29impl Default for AttentionDebugConfig {
30 fn default() -> Self {
31 Self {
32 enable_attention_visualization: true,
33 enable_head_analysis: true,
34 enable_pattern_detection: true,
35 attention_threshold: 0.01,
36 max_heads_to_analyze: 16,
37 }
38 }
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct AttentionMap {
44 pub layer_index: usize,
45 pub head_index: usize,
46 pub sequence_length: usize,
47 pub attention_weights: Vec<Vec<f32>>,
48 pub attention_pattern: AttentionPattern,
49 pub attention_entropy: f32,
50 pub sparsity_ratio: f32,
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct AttentionHeadAnalysis {
56 pub head_id: usize,
57 pub layer_id: usize,
58 pub specialization_type: HeadSpecializationType,
59 pub attention_distribution: AttentionDistribution,
60 pub redundancy_score: f32,
61 pub importance_score: f32,
62 pub patterns_detected: Vec<AttentionPattern>,
63}
64
65#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
67pub enum HeadSpecializationType {
68 LocalSyntax, LongRange, Positional, ContentBased, Copying, Delimiter, Mixed, Redundant, }
77
78#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
80pub enum AttentionPattern {
81 Diagonal, Block, Sparse, Uniform, Concentrated, Strided, Random, }
89
90#[derive(Debug, Clone, Serialize, Deserialize)]
92pub struct AttentionDistribution {
93 pub mean_attention: f32,
94 pub std_attention: f32,
95 pub max_attention: f32,
96 pub min_attention: f32,
97 pub entropy: f32,
98 pub effective_context_length: f32,
99}
100
101impl AttentionDebugger {
102 pub fn new(config: AttentionDebugConfig) -> Self {
104 Self {
105 config,
106 attention_maps: Vec::new(),
107 head_analysis: HashMap::new(),
108 }
109 }
110
111 pub fn analyze_attention_layer(
113 &mut self,
114 layer_index: usize,
115 attention_weights: &[ArrayD<f32>], ) -> Result<LayerAttentionAnalysis> {
117 let mut head_analyses = Vec::new();
118 let mut attention_maps = Vec::new();
119
120 for (head_index, weights) in attention_weights.iter().enumerate() {
121 if head_index >= self.config.max_heads_to_analyze {
122 break;
123 }
124
125 let attention_map = self.create_attention_map(layer_index, head_index, weights)?;
127 attention_maps.push(attention_map.clone());
128 self.attention_maps.push(attention_map);
129
130 let head_analysis = self.analyze_attention_head(layer_index, head_index, weights)?;
132 head_analyses.push(head_analysis.clone());
133 self.head_analysis.insert(head_index, head_analysis);
134 }
135
136 let layer_diversity_score = self.compute_layer_diversity(&head_analyses);
137 let redundancy_analysis = self.analyze_head_redundancy(&head_analyses);
138
139 Ok(LayerAttentionAnalysis {
140 layer_index,
141 num_heads: attention_weights.len(),
142 head_analyses,
143 attention_maps,
144 layer_diversity_score,
145 redundancy_analysis,
146 })
147 }
148
149 fn create_attention_map(
151 &self,
152 layer_index: usize,
153 head_index: usize,
154 weights: &ArrayD<f32>,
155 ) -> Result<AttentionMap> {
156 let shape = weights.shape();
157 if shape.len() != 2 {
158 return Err(anyhow::anyhow!(
159 "Expected 2D attention weights, got {}D",
160 shape.len()
161 ));
162 }
163
164 let seq_len = shape[0];
165 let attention_weights: Vec<Vec<f32>> =
166 (0..seq_len).map(|i| (0..shape[1]).map(|j| weights[[i, j]]).collect()).collect();
167
168 let pattern = self.detect_attention_pattern(&attention_weights);
169 let entropy = self.compute_attention_entropy(&attention_weights);
170 let sparsity = self.compute_sparsity_ratio(&attention_weights);
171
172 Ok(AttentionMap {
173 layer_index,
174 head_index,
175 sequence_length: seq_len,
176 attention_weights,
177 attention_pattern: pattern,
178 attention_entropy: entropy,
179 sparsity_ratio: sparsity,
180 })
181 }
182
183 fn analyze_attention_head(
185 &self,
186 layer_index: usize,
187 head_index: usize,
188 weights: &ArrayD<f32>,
189 ) -> Result<AttentionHeadAnalysis> {
190 let specialization = self.classify_head_specialization(weights)?;
191 let distribution = self.compute_attention_distribution(weights)?;
192 let patterns = vec![self.detect_attention_pattern_from_weights(weights)?];
193
194 Ok(AttentionHeadAnalysis {
195 head_id: head_index,
196 layer_id: layer_index,
197 specialization_type: specialization,
198 attention_distribution: distribution,
199 redundancy_score: 0.0, importance_score: self.compute_head_importance(weights)?,
201 patterns_detected: patterns,
202 })
203 }
204
205 fn detect_attention_pattern(&self, weights: &[Vec<f32>]) -> AttentionPattern {
207 let seq_len = weights.len();
208 if seq_len == 0 {
209 return AttentionPattern::Random;
210 }
211
212 let diagonal_strength = self.measure_diagonal_strength(weights);
214 if diagonal_strength > 0.7 {
215 return AttentionPattern::Diagonal;
216 }
217
218 let sparsity = self.compute_sparsity_ratio(weights);
220 if sparsity > 0.8 {
221 return AttentionPattern::Sparse;
222 }
223
224 let uniformity = self.measure_uniformity(weights);
226 if uniformity > 0.8 {
227 return AttentionPattern::Uniform;
228 }
229
230 if self.has_block_structure(weights) {
232 return AttentionPattern::Block;
233 }
234
235 AttentionPattern::Random
236 }
237
238 fn measure_diagonal_strength(&self, weights: &[Vec<f32>]) -> f32 {
240 let seq_len = weights.len();
241 if seq_len == 0 {
242 return 0.0;
243 }
244
245 let mut diagonal_sum = 0.0;
246 let mut total_sum = 0.0;
247 let window_size = 3; for i in 0..seq_len {
250 for j in 0..weights[i].len() {
251 let weight = weights[i][j];
252 total_sum += weight;
253
254 if (i as i32 - j as i32).abs() <= window_size {
255 diagonal_sum += weight;
256 }
257 }
258 }
259
260 if total_sum > 0.0 {
261 diagonal_sum / total_sum
262 } else {
263 0.0
264 }
265 }
266
267 fn measure_uniformity(&self, weights: &[Vec<f32>]) -> f32 {
269 let seq_len = weights.len();
270 if seq_len == 0 {
271 return 0.0;
272 }
273
274 let expected_weight = 1.0 / seq_len as f32;
275 let mut deviation_sum = 0.0;
276 let mut count = 0;
277
278 for row in weights {
279 for &weight in row {
280 deviation_sum += (weight - expected_weight).abs();
281 count += 1;
282 }
283 }
284
285 if count > 0 {
286 1.0 - (deviation_sum / count as f32)
287 } else {
288 0.0
289 }
290 }
291
292 fn has_block_structure(&self, weights: &[Vec<f32>]) -> bool {
294 let seq_len = weights.len();
296 if seq_len < 4 {
297 return false;
298 }
299
300 let block_size = seq_len / 4;
301 let mut block_concentrations = Vec::new();
302
303 for block_start in (0..seq_len).step_by(block_size) {
304 let block_end = (block_start + block_size).min(seq_len);
305 let mut block_sum = 0.0;
306 let mut block_count = 0;
307
308 for i in block_start..block_end {
309 for j in block_start..(block_end.min(weights[i].len())) {
310 block_sum += weights[i][j];
311 block_count += 1;
312 }
313 }
314
315 if block_count > 0 {
316 block_concentrations.push(block_sum / block_count as f32);
317 }
318 }
319
320 if block_concentrations.len() < 2 {
322 return false;
323 }
324
325 let max_concentration = block_concentrations.iter().cloned().fold(0.0f32, f32::max);
326 let avg_concentration =
327 block_concentrations.iter().sum::<f32>() / block_concentrations.len() as f32;
328
329 max_concentration > avg_concentration * 2.0
330 }
331
332 fn classify_head_specialization(
334 &self,
335 weights: &ArrayD<f32>,
336 ) -> Result<HeadSpecializationType> {
337 let shape = weights.shape();
338 if shape.len() != 2 {
339 return Ok(HeadSpecializationType::Mixed);
340 }
341
342 let seq_len = shape[0];
343
344 let weights_2d: Vec<Vec<f32>> =
346 (0..seq_len).map(|i| (0..shape[1]).map(|j| weights[[i, j]]).collect()).collect();
347
348 let diagonal_strength = self.measure_diagonal_strength(&weights_2d);
350 let long_range_strength = self.measure_long_range_attention(&weights_2d);
351 let positional_bias = self.measure_positional_bias(&weights_2d);
352
353 Ok(if diagonal_strength > 0.7 {
354 HeadSpecializationType::LocalSyntax
355 } else if long_range_strength > 0.6 {
356 HeadSpecializationType::LongRange
357 } else if positional_bias > 0.8 {
358 HeadSpecializationType::Positional
359 } else {
360 HeadSpecializationType::ContentBased
361 })
362 }
363
364 fn measure_long_range_attention(&self, weights: &[Vec<f32>]) -> f32 {
366 let seq_len = weights.len();
367 if seq_len < 4 {
368 return 0.0;
369 }
370
371 let mut long_range_sum = 0.0;
372 let mut total_sum = 0.0;
373 let long_range_threshold = seq_len / 4; for i in 0..seq_len {
376 for j in 0..weights[i].len() {
377 let weight = weights[i][j];
378 total_sum += weight;
379
380 if (i as i32 - j as i32).abs() > long_range_threshold as i32 {
381 long_range_sum += weight;
382 }
383 }
384 }
385
386 if total_sum > 0.0 {
387 long_range_sum / total_sum
388 } else {
389 0.0
390 }
391 }
392
393 fn measure_positional_bias(&self, weights: &[Vec<f32>]) -> f32 {
395 let seq_len = weights.len();
396 if seq_len == 0 {
397 return 0.0;
398 }
399
400 let mut position_correlation = 0.0;
402 let mut count = 0;
403
404 for i in 0..seq_len {
405 for j in 0..weights[i].len().min(seq_len) {
406 let position_similarity = 1.0 - (i as f32 - j as f32).abs() / seq_len as f32;
407 position_correlation += weights[i][j] * position_similarity;
408 count += 1;
409 }
410 }
411
412 if count > 0 {
413 position_correlation / count as f32
414 } else {
415 0.0
416 }
417 }
418
419 fn compute_attention_distribution(
421 &self,
422 weights: &ArrayD<f32>,
423 ) -> Result<AttentionDistribution> {
424 let values: Vec<f32> = weights.iter().cloned().collect();
425
426 if values.is_empty() {
427 return Ok(AttentionDistribution {
428 mean_attention: 0.0,
429 std_attention: 0.0,
430 max_attention: 0.0,
431 min_attention: 0.0,
432 entropy: 0.0,
433 effective_context_length: 0.0,
434 });
435 }
436
437 let mean = values.iter().sum::<f32>() / values.len() as f32;
438 let variance = values.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / values.len() as f32;
439 let std_dev = variance.sqrt();
440 let max_val = values.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
441 let min_val = values.iter().cloned().fold(f32::INFINITY, f32::min);
442
443 let entropy = self.compute_entropy(&values);
445
446 let effective_length = self.compute_effective_context_length(&values);
448
449 Ok(AttentionDistribution {
450 mean_attention: mean,
451 std_attention: std_dev,
452 max_attention: max_val,
453 min_attention: min_val,
454 entropy,
455 effective_context_length: effective_length,
456 })
457 }
458
459 fn compute_entropy(&self, values: &[f32]) -> f32 {
461 if values.is_empty() {
462 return 0.0;
463 }
464
465 let sum: f32 = values.iter().sum();
466 if sum <= 0.0 {
467 return 0.0;
468 }
469
470 let mut entropy = 0.0;
471 for &value in values {
472 if value > 0.0 {
473 let prob = value / sum;
474 entropy -= prob * prob.log2();
475 }
476 }
477
478 entropy
479 }
480
481 fn compute_effective_context_length(&self, values: &[f32]) -> f32 {
483 if values.is_empty() {
484 return 0.0;
485 }
486
487 let sum: f32 = values.iter().sum();
488 if sum <= 0.0 {
489 return 0.0;
490 }
491
492 let mut sorted_values = values.to_vec();
494 sorted_values.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
495
496 let mut cumulative_sum = 0.0;
497 let target_sum = sum * 0.9; for (i, &value) in sorted_values.iter().enumerate() {
500 cumulative_sum += value;
501 if cumulative_sum >= target_sum {
502 return (i + 1) as f32;
503 }
504 }
505
506 values.len() as f32
507 }
508
509 fn detect_attention_pattern_from_weights(
511 &self,
512 weights: &ArrayD<f32>,
513 ) -> Result<AttentionPattern> {
514 let shape = weights.shape();
515 if shape.len() != 2 {
516 return Ok(AttentionPattern::Random);
517 }
518
519 let weights_2d: Vec<Vec<f32>> = (0..shape[0])
520 .map(|i| (0..shape[1]).map(|j| weights[[i, j]]).collect())
521 .collect();
522
523 Ok(self.detect_attention_pattern(&weights_2d))
524 }
525
526 fn compute_head_importance(&self, weights: &ArrayD<f32>) -> Result<f32> {
528 let values: Vec<f32> = weights.iter().cloned().collect();
529
530 if values.is_empty() {
531 return Ok(0.0);
532 }
533
534 let entropy = self.compute_entropy(&values);
536 let max_entropy = (values.len() as f32).log2();
537
538 if max_entropy > 0.0 {
539 Ok(entropy / max_entropy)
540 } else {
541 Ok(0.0)
542 }
543 }
544
545 fn compute_attention_entropy(&self, weights: &[Vec<f32>]) -> f32 {
547 let values: Vec<f32> = weights.iter().flatten().cloned().collect();
548 self.compute_entropy(&values)
549 }
550
551 fn compute_sparsity_ratio(&self, weights: &[Vec<f32>]) -> f32 {
553 let total_count = weights.iter().map(|row| row.len()).sum::<usize>();
554 if total_count == 0 {
555 return 0.0;
556 }
557
558 let sparse_count = weights
559 .iter()
560 .flatten()
561 .filter(|&&w| w < self.config.attention_threshold)
562 .count();
563
564 sparse_count as f32 / total_count as f32
565 }
566
567 fn compute_layer_diversity(&self, head_analyses: &[AttentionHeadAnalysis]) -> f32 {
569 if head_analyses.len() < 2 {
570 return 0.0;
571 }
572
573 let mut specialization_counts: HashMap<HeadSpecializationType, usize> = HashMap::new();
575 for analysis in head_analyses {
576 *specialization_counts.entry(analysis.specialization_type.clone()).or_insert(0) += 1;
577 }
578
579 let num_types = specialization_counts.len() as f32;
580 let max_types = 8.0; num_types / max_types
583 }
584
585 fn analyze_head_redundancy(
587 &self,
588 head_analyses: &[AttentionHeadAnalysis],
589 ) -> RedundancyAnalysis {
590 let mut redundant_heads = Vec::new();
591 let redundancy_groups = Vec::new();
592
593 for i in 0..head_analyses.len() {
595 for j in (i + 1)..head_analyses.len() {
596 let similarity = self.compute_head_similarity(&head_analyses[i], &head_analyses[j]);
597 if similarity > 0.8 {
598 redundant_heads.push((i, j, similarity));
599 }
600 }
601 }
602
603 RedundancyAnalysis {
604 redundant_head_pairs: redundant_heads,
605 redundancy_groups,
606 overall_redundancy_score: self.compute_overall_redundancy(head_analyses),
607 }
608 }
609
610 fn compute_head_similarity(
612 &self,
613 head1: &AttentionHeadAnalysis,
614 head2: &AttentionHeadAnalysis,
615 ) -> f32 {
616 let type_similarity =
618 if head1.specialization_type == head2.specialization_type { 1.0 } else { 0.0 };
619
620 let dist_similarity = {
621 let d1 = &head1.attention_distribution;
622 let d2 = &head2.attention_distribution;
623
624 let mean_diff = (d1.mean_attention - d2.mean_attention).abs();
625 let std_diff = (d1.std_attention - d2.std_attention).abs();
626 let entropy_diff = (d1.entropy - d2.entropy).abs();
627
628 1.0 - (mean_diff + std_diff + entropy_diff) / 3.0
629 };
630
631 (type_similarity + dist_similarity) / 2.0
632 }
633
634 fn compute_overall_redundancy(&self, head_analyses: &[AttentionHeadAnalysis]) -> f32 {
636 if head_analyses.len() < 2 {
637 return 0.0;
638 }
639
640 let mut total_similarity = 0.0;
641 let mut pair_count = 0;
642
643 for i in 0..head_analyses.len() {
644 for j in (i + 1)..head_analyses.len() {
645 total_similarity +=
646 self.compute_head_similarity(&head_analyses[i], &head_analyses[j]);
647 pair_count += 1;
648 }
649 }
650
651 if pair_count > 0 {
652 total_similarity / pair_count as f32
653 } else {
654 0.0
655 }
656 }
657}
658
659#[derive(Debug, Clone, Serialize, Deserialize)]
661pub struct LayerAttentionAnalysis {
662 pub layer_index: usize,
663 pub num_heads: usize,
664 pub head_analyses: Vec<AttentionHeadAnalysis>,
665 pub attention_maps: Vec<AttentionMap>,
666 pub layer_diversity_score: f32,
667 pub redundancy_analysis: RedundancyAnalysis,
668}
669
670#[derive(Debug, Clone, Serialize, Deserialize)]
672pub struct RedundancyAnalysis {
673 pub redundant_head_pairs: Vec<(usize, usize, f32)>, pub redundancy_groups: Vec<Vec<usize>>,
675 pub overall_redundancy_score: f32,
676}
677
678#[derive(Debug)]
680pub struct TransformerDebugger {
681 pub config: TransformerDebugConfig,
682 layer_analyses: Vec<LayerAttentionAnalysis>,
683 attention_debugger: AttentionDebugger,
684}
685
686#[derive(Debug, Clone, Serialize, Deserialize)]
688pub struct TransformerDebugConfig {
689 pub attention_config: AttentionDebugConfig,
690 pub enable_layer_analysis: bool,
691 pub enable_cross_layer_analysis: bool,
692 pub max_layers_to_analyze: usize,
693}
694
695impl Default for TransformerDebugConfig {
696 fn default() -> Self {
697 Self {
698 attention_config: AttentionDebugConfig::default(),
699 enable_layer_analysis: true,
700 enable_cross_layer_analysis: true,
701 max_layers_to_analyze: 48, }
703 }
704}
705
706impl TransformerDebugger {
707 pub fn new(config: TransformerDebugConfig) -> Self {
709 let attention_debugger = AttentionDebugger::new(config.attention_config.clone());
710
711 Self {
712 config,
713 layer_analyses: Vec::new(),
714 attention_debugger,
715 }
716 }
717
718 pub fn analyze_transformer_attention(
720 &mut self,
721 model_attention_weights: &[Vec<ArrayD<f32>>], ) -> Result<TransformerAttentionAnalysis> {
723 let mut layer_analyses = Vec::new();
724
725 for (layer_idx, layer_weights) in model_attention_weights.iter().enumerate() {
726 if layer_idx >= self.config.max_layers_to_analyze {
727 break;
728 }
729
730 let layer_analysis =
731 self.attention_debugger.analyze_attention_layer(layer_idx, layer_weights)?;
732 layer_analyses.push(layer_analysis);
733 }
734
735 self.layer_analyses = layer_analyses.clone();
736
737 let cross_layer_analysis = if self.config.enable_cross_layer_analysis {
739 Some(self.perform_cross_layer_analysis(&layer_analyses)?)
740 } else {
741 None
742 };
743
744 Ok(TransformerAttentionAnalysis {
745 num_layers: model_attention_weights.len(),
746 layer_analyses,
747 cross_layer_analysis,
748 model_attention_summary: self.generate_model_attention_summary()?,
749 })
750 }
751
752 fn perform_cross_layer_analysis(
754 &self,
755 layer_analyses: &[LayerAttentionAnalysis],
756 ) -> Result<CrossLayerAnalysis> {
757 let attention_evolution = self.analyze_attention_evolution(layer_analyses)?;
758 let head_consistency = self.analyze_head_consistency(layer_analyses)?;
759 let pattern_progression = self.analyze_pattern_progression(layer_analyses)?;
760
761 Ok(CrossLayerAnalysis {
762 attention_evolution,
763 head_consistency,
764 pattern_progression,
765 layer_diversity_trend: self.compute_layer_diversity_trend(layer_analyses),
766 })
767 }
768
769 fn analyze_attention_evolution(
771 &self,
772 layer_analyses: &[LayerAttentionAnalysis],
773 ) -> Result<AttentionEvolution> {
774 let mut entropy_trend = Vec::new();
775 let mut sparsity_trend = Vec::new();
776
777 for layer in layer_analyses {
778 let layer_entropy: f32 =
779 layer.attention_maps.iter().map(|map| map.attention_entropy).sum::<f32>()
780 / layer.attention_maps.len() as f32;
781 let layer_sparsity: f32 =
782 layer.attention_maps.iter().map(|map| map.sparsity_ratio).sum::<f32>()
783 / layer.attention_maps.len() as f32;
784
785 entropy_trend.push(layer_entropy);
786 sparsity_trend.push(layer_sparsity);
787 }
788
789 let evolution_type = self.classify_evolution_type(&entropy_trend);
790
791 Ok(AttentionEvolution {
792 entropy_trend,
793 sparsity_trend,
794 evolution_type,
795 })
796 }
797
798 fn classify_evolution_type(&self, entropy_trend: &[f32]) -> EvolutionType {
800 if entropy_trend.len() < 3 {
801 return EvolutionType::Stable;
802 }
803
804 let start_entropy = entropy_trend[0];
805 let end_entropy = entropy_trend[entropy_trend.len() - 1];
806 let change_ratio = (end_entropy - start_entropy) / start_entropy.max(1e-8);
807
808 match change_ratio {
809 r if r > 0.2 => EvolutionType::Increasing,
810 r if r < -0.2 => EvolutionType::Decreasing,
811 _ => EvolutionType::Stable,
812 }
813 }
814
815 fn analyze_head_consistency(
817 &self,
818 layer_analyses: &[LayerAttentionAnalysis],
819 ) -> Result<HeadConsistency> {
820 let mut specialization_consistency = HashMap::new();
821 let pattern_consistency = HashMap::new();
822
823 for layer in layer_analyses {
825 for head in &layer.head_analyses {
826 let spec_type = &head.specialization_type;
827 let layer_counts =
828 specialization_consistency.entry(spec_type.clone()).or_insert_with(Vec::new);
829 layer_counts.push(layer.layer_index);
830 }
831 }
832
833 Ok(HeadConsistency {
834 specialization_consistency,
835 pattern_consistency,
836 consistency_score: self.compute_consistency_score(layer_analyses),
837 })
838 }
839
840 fn compute_consistency_score(&self, layer_analyses: &[LayerAttentionAnalysis]) -> f32 {
842 if layer_analyses.len() < 2 {
843 return 1.0;
844 }
845
846 let mut layer_distributions = Vec::new();
848
849 for layer in layer_analyses {
850 let mut distribution: HashMap<HeadSpecializationType, f32> = HashMap::new();
851 for head in &layer.head_analyses {
852 *distribution.entry(head.specialization_type.clone()).or_insert(0.0) += 1.0;
853 }
854
855 let total: f32 = distribution.values().sum();
857 if total > 0.0 {
858 for value in distribution.values_mut() {
859 *value /= total;
860 }
861 }
862
863 layer_distributions.push(distribution);
864 }
865
866 let mut total_similarity = 0.0;
868 let mut pair_count = 0;
869
870 for i in 0..layer_distributions.len() {
871 for j in (i + 1)..layer_distributions.len() {
872 let similarity = self.compute_distribution_similarity(
873 &layer_distributions[i],
874 &layer_distributions[j],
875 );
876 total_similarity += similarity;
877 pair_count += 1;
878 }
879 }
880
881 if pair_count > 0 {
882 total_similarity / pair_count as f32
883 } else {
884 1.0
885 }
886 }
887
888 fn compute_distribution_similarity(
890 &self,
891 dist1: &HashMap<HeadSpecializationType, f32>,
892 dist2: &HashMap<HeadSpecializationType, f32>,
893 ) -> f32 {
894 let mut all_keys: std::collections::HashSet<_> = dist1.keys().collect();
895 all_keys.extend(dist2.keys());
896
897 let mut similarity = 0.0;
898 for key in all_keys {
899 let val1 = dist1.get(key).unwrap_or(&0.0);
900 let val2 = dist2.get(key).unwrap_or(&0.0);
901 similarity += (val1 - val2).abs();
902 }
903
904 1.0 - (similarity / 2.0) }
906
907 fn analyze_pattern_progression(
909 &self,
910 layer_analyses: &[LayerAttentionAnalysis],
911 ) -> Result<PatternProgression> {
912 let mut pattern_evolution = Vec::new();
913
914 for layer in layer_analyses {
915 let mut pattern_counts: HashMap<AttentionPattern, usize> = HashMap::new();
916 for map in &layer.attention_maps {
917 *pattern_counts.entry(map.attention_pattern.clone()).or_insert(0) += 1;
918 }
919 pattern_evolution.push(pattern_counts);
920 }
921
922 let dominant_pattern_sequence = self.extract_dominant_patterns(&pattern_evolution);
923
924 Ok(PatternProgression {
925 pattern_evolution,
926 dominant_pattern_sequence,
927 })
928 }
929
930 fn extract_dominant_patterns(
932 &self,
933 pattern_evolution: &[HashMap<AttentionPattern, usize>],
934 ) -> Vec<AttentionPattern> {
935 pattern_evolution
936 .iter()
937 .map(|patterns| {
938 patterns
939 .iter()
940 .max_by_key(|(_, &count)| count)
941 .map(|(pattern, _)| pattern.clone())
942 .unwrap_or(AttentionPattern::Random)
943 })
944 .collect()
945 }
946
947 fn compute_layer_diversity_trend(&self, layer_analyses: &[LayerAttentionAnalysis]) -> Vec<f32> {
949 layer_analyses.iter().map(|layer| layer.layer_diversity_score).collect()
950 }
951
952 fn generate_model_attention_summary(&self) -> Result<ModelAttentionSummary> {
954 if self.layer_analyses.is_empty() {
955 return Ok(ModelAttentionSummary::default());
956 }
957
958 let total_heads: usize = self.layer_analyses.iter().map(|layer| layer.num_heads).sum();
959 let avg_diversity: f32 =
960 self.layer_analyses.iter().map(|layer| layer.layer_diversity_score).sum::<f32>()
961 / self.layer_analyses.len() as f32;
962 let avg_redundancy: f32 = self
963 .layer_analyses
964 .iter()
965 .map(|layer| layer.redundancy_analysis.overall_redundancy_score)
966 .sum::<f32>()
967 / self.layer_analyses.len() as f32;
968
969 let mut specialization_distribution: HashMap<HeadSpecializationType, usize> =
971 HashMap::new();
972 for layer in &self.layer_analyses {
973 for head in &layer.head_analyses {
974 *specialization_distribution
975 .entry(head.specialization_type.clone())
976 .or_insert(0) += 1;
977 }
978 }
979
980 Ok(ModelAttentionSummary {
981 total_layers: self.layer_analyses.len(),
982 total_heads,
983 average_diversity_score: avg_diversity,
984 average_redundancy_score: avg_redundancy,
985 specialization_distribution,
986 model_attention_health: self
987 .assess_model_attention_health(avg_diversity, avg_redundancy),
988 })
989 }
990
991 fn assess_model_attention_health(
993 &self,
994 diversity: f32,
995 redundancy: f32,
996 ) -> AttentionHealthStatus {
997 let health_score = diversity * (1.0 - redundancy); match health_score {
1000 s if s > 0.7 => AttentionHealthStatus::Excellent,
1001 s if s > 0.5 => AttentionHealthStatus::Good,
1002 s if s > 0.3 => AttentionHealthStatus::Fair,
1003 _ => AttentionHealthStatus::Poor,
1004 }
1005 }
1006}
1007
1008#[derive(Debug, Serialize, Deserialize)]
1010pub struct TransformerAttentionAnalysis {
1011 pub num_layers: usize,
1012 pub layer_analyses: Vec<LayerAttentionAnalysis>,
1013 pub cross_layer_analysis: Option<CrossLayerAnalysis>,
1014 pub model_attention_summary: ModelAttentionSummary,
1015}
1016
1017#[derive(Debug, Serialize, Deserialize)]
1019pub struct CrossLayerAnalysis {
1020 pub attention_evolution: AttentionEvolution,
1021 pub head_consistency: HeadConsistency,
1022 pub pattern_progression: PatternProgression,
1023 pub layer_diversity_trend: Vec<f32>,
1024}
1025
1026#[derive(Debug, Serialize, Deserialize)]
1028pub struct AttentionEvolution {
1029 pub entropy_trend: Vec<f32>,
1030 pub sparsity_trend: Vec<f32>,
1031 pub evolution_type: EvolutionType,
1032}
1033
1034#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
1036pub enum EvolutionType {
1037 Increasing, Decreasing, Stable, }
1041
1042#[derive(Debug, Serialize, Deserialize)]
1044pub struct HeadConsistency {
1045 pub specialization_consistency: HashMap<HeadSpecializationType, Vec<usize>>,
1046 pub pattern_consistency: HashMap<AttentionPattern, Vec<usize>>,
1047 pub consistency_score: f32,
1048}
1049
1050#[derive(Debug, Serialize, Deserialize)]
1052pub struct PatternProgression {
1053 pub pattern_evolution: Vec<HashMap<AttentionPattern, usize>>,
1054 pub dominant_pattern_sequence: Vec<AttentionPattern>,
1055}
1056
1057#[derive(Debug, Serialize, Deserialize)]
1059pub struct ModelAttentionSummary {
1060 pub total_layers: usize,
1061 pub total_heads: usize,
1062 pub average_diversity_score: f32,
1063 pub average_redundancy_score: f32,
1064 pub specialization_distribution: HashMap<HeadSpecializationType, usize>,
1065 pub model_attention_health: AttentionHealthStatus,
1066}
1067
1068impl Default for ModelAttentionSummary {
1069 fn default() -> Self {
1070 Self {
1071 total_layers: 0,
1072 total_heads: 0,
1073 average_diversity_score: 0.0,
1074 average_redundancy_score: 0.0,
1075 specialization_distribution: HashMap::new(),
1076 model_attention_health: AttentionHealthStatus::Poor,
1077 }
1078 }
1079}
1080
1081#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
1083pub enum AttentionHealthStatus {
1084 Excellent,
1085 Good,
1086 Fair,
1087 Poor,
1088}
1089
1090#[macro_export]
1092macro_rules! debug_attention {
1093 ($attention_weights:expr) => {{
1094 let mut debugger = $crate::neural_network_debugging::AttentionDebugger::new(
1095 $crate::neural_network_debugging::AttentionDebugConfig::default(),
1096 );
1097 debugger.analyze_attention_layer(0, $attention_weights)
1098 }};
1099}
1100
1101#[macro_export]
1103macro_rules! debug_transformer {
1104 ($model_weights:expr) => {{
1105 let mut debugger = $crate::neural_network_debugging::TransformerDebugger::new(
1106 $crate::neural_network_debugging::TransformerDebugConfig::default(),
1107 );
1108 debugger.analyze_transformer_attention($model_weights)
1109 }};
1110}