1use anyhow::Result;
7use scirs2_core::ndarray::*; use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10
11#[derive(Debug)]
13pub struct AttentionDebugger {
14 pub config: AttentionDebugConfig,
15 attention_maps: Vec<AttentionMap>,
16 head_analysis: HashMap<usize, AttentionHeadAnalysis>,
17}
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct AttentionDebugConfig {
22 pub enable_attention_visualization: bool,
23 pub enable_head_analysis: bool,
24 pub enable_pattern_detection: bool,
25 pub attention_threshold: f32,
26 pub max_heads_to_analyze: usize,
27}
28
29impl Default for AttentionDebugConfig {
30 fn default() -> Self {
31 Self {
32 enable_attention_visualization: true,
33 enable_head_analysis: true,
34 enable_pattern_detection: true,
35 attention_threshold: 0.01,
36 max_heads_to_analyze: 16,
37 }
38 }
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct AttentionMap {
44 pub layer_index: usize,
45 pub head_index: usize,
46 pub sequence_length: usize,
47 pub attention_weights: Vec<Vec<f32>>,
48 pub attention_pattern: AttentionPattern,
49 pub attention_entropy: f32,
50 pub sparsity_ratio: f32,
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct AttentionHeadAnalysis {
56 pub head_id: usize,
57 pub layer_id: usize,
58 pub specialization_type: HeadSpecializationType,
59 pub attention_distribution: AttentionDistribution,
60 pub redundancy_score: f32,
61 pub importance_score: f32,
62 pub patterns_detected: Vec<AttentionPattern>,
63}
64
65#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
67pub enum HeadSpecializationType {
68 LocalSyntax, LongRange, Positional, ContentBased, Copying, Delimiter, Mixed, Redundant, }
77
78#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
80pub enum AttentionPattern {
81 Diagonal, Block, Sparse, Uniform, Concentrated, Strided, Random, }
89
90#[derive(Debug, Clone, Serialize, Deserialize)]
92pub struct AttentionDistribution {
93 pub mean_attention: f32,
94 pub std_attention: f32,
95 pub max_attention: f32,
96 pub min_attention: f32,
97 pub entropy: f32,
98 pub effective_context_length: f32,
99}
100
101impl AttentionDebugger {
102 pub fn new(config: AttentionDebugConfig) -> Self {
104 Self {
105 config,
106 attention_maps: Vec::new(),
107 head_analysis: HashMap::new(),
108 }
109 }
110
111 pub fn analyze_attention_layer(
113 &mut self,
114 layer_index: usize,
115 attention_weights: &[ArrayD<f32>], ) -> Result<LayerAttentionAnalysis> {
117 let mut head_analyses = Vec::new();
118 let mut attention_maps = Vec::new();
119
120 for (head_index, weights) in attention_weights.iter().enumerate() {
121 if head_index >= self.config.max_heads_to_analyze {
122 break;
123 }
124
125 let attention_map = self.create_attention_map(layer_index, head_index, weights)?;
127 attention_maps.push(attention_map.clone());
128 self.attention_maps.push(attention_map);
129
130 let head_analysis = self.analyze_attention_head(layer_index, head_index, weights)?;
132 head_analyses.push(head_analysis.clone());
133 self.head_analysis.insert(head_index, head_analysis);
134 }
135
136 let layer_diversity_score = self.compute_layer_diversity(&head_analyses);
137 let redundancy_analysis = self.analyze_head_redundancy(&head_analyses);
138
139 Ok(LayerAttentionAnalysis {
140 layer_index,
141 num_heads: attention_weights.len(),
142 head_analyses,
143 attention_maps,
144 layer_diversity_score,
145 redundancy_analysis,
146 })
147 }
148
149 fn create_attention_map(
151 &self,
152 layer_index: usize,
153 head_index: usize,
154 weights: &ArrayD<f32>,
155 ) -> Result<AttentionMap> {
156 let shape = weights.shape();
157 if shape.len() != 2 {
158 return Err(anyhow::anyhow!(
159 "Expected 2D attention weights, got {}D",
160 shape.len()
161 ));
162 }
163
164 let seq_len = shape[0];
165 let attention_weights: Vec<Vec<f32>> =
166 (0..seq_len).map(|i| (0..shape[1]).map(|j| weights[[i, j]]).collect()).collect();
167
168 let pattern = self.detect_attention_pattern(&attention_weights);
169 let entropy = self.compute_attention_entropy(&attention_weights);
170 let sparsity = self.compute_sparsity_ratio(&attention_weights);
171
172 Ok(AttentionMap {
173 layer_index,
174 head_index,
175 sequence_length: seq_len,
176 attention_weights,
177 attention_pattern: pattern,
178 attention_entropy: entropy,
179 sparsity_ratio: sparsity,
180 })
181 }
182
183 fn analyze_attention_head(
185 &self,
186 layer_index: usize,
187 head_index: usize,
188 weights: &ArrayD<f32>,
189 ) -> Result<AttentionHeadAnalysis> {
190 let specialization = self.classify_head_specialization(weights)?;
191 let distribution = self.compute_attention_distribution(weights)?;
192 let patterns = vec![self.detect_attention_pattern_from_weights(weights)?];
193
194 Ok(AttentionHeadAnalysis {
195 head_id: head_index,
196 layer_id: layer_index,
197 specialization_type: specialization,
198 attention_distribution: distribution,
199 redundancy_score: 0.0, importance_score: self.compute_head_importance(weights)?,
201 patterns_detected: patterns,
202 })
203 }
204
205 fn detect_attention_pattern(&self, weights: &[Vec<f32>]) -> AttentionPattern {
207 let seq_len = weights.len();
208 if seq_len == 0 {
209 return AttentionPattern::Random;
210 }
211
212 let diagonal_strength = self.measure_diagonal_strength(weights);
214 if diagonal_strength > 0.7 {
215 return AttentionPattern::Diagonal;
216 }
217
218 let sparsity = self.compute_sparsity_ratio(weights);
220 if sparsity > 0.8 {
221 return AttentionPattern::Sparse;
222 }
223
224 let uniformity = self.measure_uniformity(weights);
226 if uniformity > 0.8 {
227 return AttentionPattern::Uniform;
228 }
229
230 if self.has_block_structure(weights) {
232 return AttentionPattern::Block;
233 }
234
235 AttentionPattern::Random
236 }
237
238 fn measure_diagonal_strength(&self, weights: &[Vec<f32>]) -> f32 {
240 let seq_len = weights.len();
241 if seq_len == 0 {
242 return 0.0;
243 }
244
245 let mut diagonal_sum = 0.0;
246 let mut total_sum = 0.0;
247 let window_size = 3; for i in 0..seq_len {
250 for j in 0..weights[i].len() {
251 let weight = weights[i][j];
252 total_sum += weight;
253
254 if (i as i32 - j as i32).abs() <= window_size {
255 diagonal_sum += weight;
256 }
257 }
258 }
259
260 if total_sum > 0.0 {
261 diagonal_sum / total_sum
262 } else {
263 0.0
264 }
265 }
266
267 fn measure_uniformity(&self, weights: &[Vec<f32>]) -> f32 {
269 let seq_len = weights.len();
270 if seq_len == 0 {
271 return 0.0;
272 }
273
274 let expected_weight = 1.0 / seq_len as f32;
275 let mut deviation_sum = 0.0;
276 let mut count = 0;
277
278 for row in weights {
279 for &weight in row {
280 deviation_sum += (weight - expected_weight).abs();
281 count += 1;
282 }
283 }
284
285 if count > 0 {
286 1.0 - (deviation_sum / count as f32)
287 } else {
288 0.0
289 }
290 }
291
292 fn has_block_structure(&self, weights: &[Vec<f32>]) -> bool {
294 let seq_len = weights.len();
296 if seq_len < 4 {
297 return false;
298 }
299
300 let block_size = seq_len / 4;
301 let mut block_concentrations = Vec::new();
302
303 for block_start in (0..seq_len).step_by(block_size) {
304 let block_end = (block_start + block_size).min(seq_len);
305 let mut block_sum = 0.0;
306 let mut block_count = 0;
307
308 for i in block_start..block_end {
309 for j in block_start..(block_end.min(weights[i].len())) {
310 block_sum += weights[i][j];
311 block_count += 1;
312 }
313 }
314
315 if block_count > 0 {
316 block_concentrations.push(block_sum / block_count as f32);
317 }
318 }
319
320 if block_concentrations.len() < 2 {
322 return false;
323 }
324
325 let max_concentration = block_concentrations.iter().cloned().fold(0.0f32, f32::max);
326 let avg_concentration =
327 block_concentrations.iter().sum::<f32>() / block_concentrations.len() as f32;
328
329 max_concentration > avg_concentration * 2.0
330 }
331
332 fn classify_head_specialization(
334 &self,
335 weights: &ArrayD<f32>,
336 ) -> Result<HeadSpecializationType> {
337 let shape = weights.shape();
338 if shape.len() != 2 {
339 return Ok(HeadSpecializationType::Mixed);
340 }
341
342 let seq_len = shape[0];
343
344 let weights_2d: Vec<Vec<f32>> =
346 (0..seq_len).map(|i| (0..shape[1]).map(|j| weights[[i, j]]).collect()).collect();
347
348 let diagonal_strength = self.measure_diagonal_strength(&weights_2d);
350 let long_range_strength = self.measure_long_range_attention(&weights_2d);
351 let positional_bias = self.measure_positional_bias(&weights_2d);
352
353 Ok(if diagonal_strength > 0.7 {
354 HeadSpecializationType::LocalSyntax
355 } else if long_range_strength > 0.6 {
356 HeadSpecializationType::LongRange
357 } else if positional_bias > 0.8 {
358 HeadSpecializationType::Positional
359 } else {
360 HeadSpecializationType::ContentBased
361 })
362 }
363
364 fn measure_long_range_attention(&self, weights: &[Vec<f32>]) -> f32 {
366 let seq_len = weights.len();
367 if seq_len < 4 {
368 return 0.0;
369 }
370
371 let mut long_range_sum = 0.0;
372 let mut total_sum = 0.0;
373 let long_range_threshold = seq_len / 4; for i in 0..seq_len {
376 for j in 0..weights[i].len() {
377 let weight = weights[i][j];
378 total_sum += weight;
379
380 if (i as i32 - j as i32).abs() > long_range_threshold as i32 {
381 long_range_sum += weight;
382 }
383 }
384 }
385
386 if total_sum > 0.0 {
387 long_range_sum / total_sum
388 } else {
389 0.0
390 }
391 }
392
393 fn measure_positional_bias(&self, weights: &[Vec<f32>]) -> f32 {
395 let seq_len = weights.len();
396 if seq_len == 0 {
397 return 0.0;
398 }
399
400 let mut position_correlation = 0.0;
402 let mut count = 0;
403
404 for i in 0..seq_len {
405 for j in 0..weights[i].len().min(seq_len) {
406 let position_similarity = 1.0 - (i as f32 - j as f32).abs() / seq_len as f32;
407 position_correlation += weights[i][j] * position_similarity;
408 count += 1;
409 }
410 }
411
412 if count > 0 {
413 position_correlation / count as f32
414 } else {
415 0.0
416 }
417 }
418
419 fn compute_attention_distribution(
421 &self,
422 weights: &ArrayD<f32>,
423 ) -> Result<AttentionDistribution> {
424 let values: Vec<f32> = weights.iter().cloned().collect();
425
426 if values.is_empty() {
427 return Ok(AttentionDistribution {
428 mean_attention: 0.0,
429 std_attention: 0.0,
430 max_attention: 0.0,
431 min_attention: 0.0,
432 entropy: 0.0,
433 effective_context_length: 0.0,
434 });
435 }
436
437 let mean = values.iter().sum::<f32>() / values.len() as f32;
438 let variance = values.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / values.len() as f32;
439 let std_dev = variance.sqrt();
440 let max_val = values.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
441 let min_val = values.iter().cloned().fold(f32::INFINITY, f32::min);
442
443 let entropy = self.compute_entropy(&values);
445
446 let effective_length = self.compute_effective_context_length(&values);
448
449 Ok(AttentionDistribution {
450 mean_attention: mean,
451 std_attention: std_dev,
452 max_attention: max_val,
453 min_attention: min_val,
454 entropy,
455 effective_context_length: effective_length,
456 })
457 }
458
459 fn compute_entropy(&self, values: &[f32]) -> f32 {
461 if values.is_empty() {
462 return 0.0;
463 }
464
465 let sum: f32 = values.iter().sum();
466 if sum <= 0.0 {
467 return 0.0;
468 }
469
470 let mut entropy = 0.0;
471 for &value in values {
472 if value > 0.0 {
473 let prob = value / sum;
474 entropy -= prob * prob.log2();
475 }
476 }
477
478 entropy
479 }
480
481 fn compute_effective_context_length(&self, values: &[f32]) -> f32 {
483 if values.is_empty() {
484 return 0.0;
485 }
486
487 let sum: f32 = values.iter().sum();
488 if sum <= 0.0 {
489 return 0.0;
490 }
491
492 let mut sorted_values = values.to_vec();
494 sorted_values.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
495
496 let mut cumulative_sum = 0.0;
497 let target_sum = sum * 0.9; for (i, &value) in sorted_values.iter().enumerate() {
500 cumulative_sum += value;
501 if cumulative_sum >= target_sum {
502 return (i + 1) as f32;
503 }
504 }
505
506 values.len() as f32
507 }
508
509 fn detect_attention_pattern_from_weights(
511 &self,
512 weights: &ArrayD<f32>,
513 ) -> Result<AttentionPattern> {
514 let shape = weights.shape();
515 if shape.len() != 2 {
516 return Ok(AttentionPattern::Random);
517 }
518
519 let weights_2d: Vec<Vec<f32>> = (0..shape[0])
520 .map(|i| (0..shape[1]).map(|j| weights[[i, j]]).collect())
521 .collect();
522
523 Ok(self.detect_attention_pattern(&weights_2d))
524 }
525
526 fn compute_head_importance(&self, weights: &ArrayD<f32>) -> Result<f32> {
528 let values: Vec<f32> = weights.iter().cloned().collect();
529
530 if values.is_empty() {
531 return Ok(0.0);
532 }
533
534 let entropy = self.compute_entropy(&values);
536 let max_entropy = (values.len() as f32).log2();
537
538 if max_entropy > 0.0 {
539 Ok(entropy / max_entropy)
540 } else {
541 Ok(0.0)
542 }
543 }
544
545 fn compute_attention_entropy(&self, weights: &[Vec<f32>]) -> f32 {
547 let values: Vec<f32> = weights.iter().flatten().cloned().collect();
548 self.compute_entropy(&values)
549 }
550
551 fn compute_sparsity_ratio(&self, weights: &[Vec<f32>]) -> f32 {
553 let total_count = weights.iter().map(|row| row.len()).sum::<usize>();
554 if total_count == 0 {
555 return 0.0;
556 }
557
558 let sparse_count = weights
559 .iter()
560 .flatten()
561 .filter(|&&w| w < self.config.attention_threshold)
562 .count();
563
564 sparse_count as f32 / total_count as f32
565 }
566
567 fn compute_layer_diversity(&self, head_analyses: &[AttentionHeadAnalysis]) -> f32 {
569 if head_analyses.len() < 2 {
570 return 0.0;
571 }
572
573 let mut specialization_counts: HashMap<HeadSpecializationType, usize> = HashMap::new();
575 for analysis in head_analyses {
576 *specialization_counts.entry(analysis.specialization_type.clone()).or_insert(0) += 1;
577 }
578
579 let num_types = specialization_counts.len() as f32;
580 let max_types = 8.0; num_types / max_types
583 }
584
585 fn analyze_head_redundancy(
587 &self,
588 head_analyses: &[AttentionHeadAnalysis],
589 ) -> RedundancyAnalysis {
590 let mut redundant_heads = Vec::new();
591 let redundancy_groups = Vec::new();
592
593 for i in 0..head_analyses.len() {
595 for j in (i + 1)..head_analyses.len() {
596 let similarity = self.compute_head_similarity(&head_analyses[i], &head_analyses[j]);
597 if similarity > 0.8 {
598 redundant_heads.push((i, j, similarity));
599 }
600 }
601 }
602
603 RedundancyAnalysis {
604 redundant_head_pairs: redundant_heads,
605 redundancy_groups,
606 overall_redundancy_score: self.compute_overall_redundancy(head_analyses),
607 }
608 }
609
610 fn compute_head_similarity(
612 &self,
613 head1: &AttentionHeadAnalysis,
614 head2: &AttentionHeadAnalysis,
615 ) -> f32 {
616 let type_similarity =
618 if head1.specialization_type == head2.specialization_type { 1.0 } else { 0.0 };
619
620 let dist_similarity = {
621 let d1 = &head1.attention_distribution;
622 let d2 = &head2.attention_distribution;
623
624 let mean_diff = (d1.mean_attention - d2.mean_attention).abs();
625 let std_diff = (d1.std_attention - d2.std_attention).abs();
626 let entropy_diff = (d1.entropy - d2.entropy).abs();
627
628 1.0 - (mean_diff + std_diff + entropy_diff) / 3.0
629 };
630
631 (type_similarity + dist_similarity) / 2.0
632 }
633
634 fn compute_overall_redundancy(&self, head_analyses: &[AttentionHeadAnalysis]) -> f32 {
636 if head_analyses.len() < 2 {
637 return 0.0;
638 }
639
640 let mut total_similarity = 0.0;
641 let mut pair_count = 0;
642
643 for i in 0..head_analyses.len() {
644 for j in (i + 1)..head_analyses.len() {
645 total_similarity +=
646 self.compute_head_similarity(&head_analyses[i], &head_analyses[j]);
647 pair_count += 1;
648 }
649 }
650
651 if pair_count > 0 {
652 total_similarity / pair_count as f32
653 } else {
654 0.0
655 }
656 }
657}
658
659#[derive(Debug, Clone, Serialize, Deserialize)]
661pub struct LayerAttentionAnalysis {
662 pub layer_index: usize,
663 pub num_heads: usize,
664 pub head_analyses: Vec<AttentionHeadAnalysis>,
665 pub attention_maps: Vec<AttentionMap>,
666 pub layer_diversity_score: f32,
667 pub redundancy_analysis: RedundancyAnalysis,
668}
669
670#[derive(Debug, Clone, Serialize, Deserialize)]
672pub struct RedundancyAnalysis {
673 pub redundant_head_pairs: Vec<(usize, usize, f32)>, pub redundancy_groups: Vec<Vec<usize>>,
675 pub overall_redundancy_score: f32,
676}
677
678#[derive(Debug)]
680pub struct TransformerDebugger {
681 pub config: TransformerDebugConfig,
682 layer_analyses: Vec<LayerAttentionAnalysis>,
683 attention_debugger: AttentionDebugger,
684}
685
686#[derive(Debug, Clone, Serialize, Deserialize)]
688pub struct TransformerDebugConfig {
689 pub attention_config: AttentionDebugConfig,
690 pub enable_layer_analysis: bool,
691 pub enable_cross_layer_analysis: bool,
692 pub max_layers_to_analyze: usize,
693}
694
695impl Default for TransformerDebugConfig {
696 fn default() -> Self {
697 Self {
698 attention_config: AttentionDebugConfig::default(),
699 enable_layer_analysis: true,
700 enable_cross_layer_analysis: true,
701 max_layers_to_analyze: 48, }
703 }
704}
705
706impl TransformerDebugger {
707 pub fn new(config: TransformerDebugConfig) -> Self {
709 let attention_debugger = AttentionDebugger::new(config.attention_config.clone());
710
711 Self {
712 config,
713 layer_analyses: Vec::new(),
714 attention_debugger,
715 }
716 }
717
718 pub fn analyze_transformer_attention(
720 &mut self,
721 model_attention_weights: &[Vec<ArrayD<f32>>], ) -> Result<TransformerAttentionAnalysis> {
723 let mut layer_analyses = Vec::new();
724
725 for (layer_idx, layer_weights) in model_attention_weights.iter().enumerate() {
726 if layer_idx >= self.config.max_layers_to_analyze {
727 break;
728 }
729
730 let layer_analysis =
731 self.attention_debugger.analyze_attention_layer(layer_idx, layer_weights)?;
732 layer_analyses.push(layer_analysis);
733 }
734
735 self.layer_analyses = layer_analyses.clone();
736
737 let cross_layer_analysis = if self.config.enable_cross_layer_analysis {
739 Some(self.perform_cross_layer_analysis(&layer_analyses)?)
740 } else {
741 None
742 };
743
744 Ok(TransformerAttentionAnalysis {
745 num_layers: model_attention_weights.len(),
746 layer_analyses,
747 cross_layer_analysis,
748 model_attention_summary: self.generate_model_attention_summary()?,
749 })
750 }
751
752 fn perform_cross_layer_analysis(
754 &self,
755 layer_analyses: &[LayerAttentionAnalysis],
756 ) -> Result<CrossLayerAnalysis> {
757 let attention_evolution = self.analyze_attention_evolution(layer_analyses)?;
758 let head_consistency = self.analyze_head_consistency(layer_analyses)?;
759 let pattern_progression = self.analyze_pattern_progression(layer_analyses)?;
760
761 Ok(CrossLayerAnalysis {
762 attention_evolution,
763 head_consistency,
764 pattern_progression,
765 layer_diversity_trend: self.compute_layer_diversity_trend(layer_analyses),
766 })
767 }
768
769 fn analyze_attention_evolution(
771 &self,
772 layer_analyses: &[LayerAttentionAnalysis],
773 ) -> Result<AttentionEvolution> {
774 let mut entropy_trend = Vec::new();
775 let mut sparsity_trend = Vec::new();
776
777 for layer in layer_analyses {
778 let layer_entropy: f32 =
779 layer.attention_maps.iter().map(|map| map.attention_entropy).sum::<f32>()
780 / layer.attention_maps.len() as f32;
781 let layer_sparsity: f32 =
782 layer.attention_maps.iter().map(|map| map.sparsity_ratio).sum::<f32>()
783 / layer.attention_maps.len() as f32;
784
785 entropy_trend.push(layer_entropy);
786 sparsity_trend.push(layer_sparsity);
787 }
788
789 let evolution_type = self.classify_evolution_type(&entropy_trend);
790
791 Ok(AttentionEvolution {
792 entropy_trend,
793 sparsity_trend,
794 evolution_type,
795 })
796 }
797
798 fn classify_evolution_type(&self, entropy_trend: &[f32]) -> EvolutionType {
800 if entropy_trend.len() < 3 {
801 return EvolutionType::Stable;
802 }
803
804 let start_entropy = entropy_trend[0];
805 let end_entropy = entropy_trend[entropy_trend.len() - 1];
806 let change_ratio = (end_entropy - start_entropy) / start_entropy.max(1e-8);
807
808 match change_ratio {
809 r if r > 0.2 => EvolutionType::Increasing,
810 r if r < -0.2 => EvolutionType::Decreasing,
811 _ => EvolutionType::Stable,
812 }
813 }
814
815 fn analyze_head_consistency(
817 &self,
818 layer_analyses: &[LayerAttentionAnalysis],
819 ) -> Result<HeadConsistency> {
820 let mut specialization_consistency = HashMap::new();
821 let pattern_consistency = HashMap::new();
822
823 for layer in layer_analyses {
825 for head in &layer.head_analyses {
826 let spec_type = &head.specialization_type;
827 let layer_counts =
828 specialization_consistency.entry(spec_type.clone()).or_insert_with(Vec::new);
829 layer_counts.push(layer.layer_index);
830 }
831 }
832
833 Ok(HeadConsistency {
834 specialization_consistency,
835 pattern_consistency,
836 consistency_score: self.compute_consistency_score(layer_analyses),
837 })
838 }
839
840 fn compute_consistency_score(&self, layer_analyses: &[LayerAttentionAnalysis]) -> f32 {
842 if layer_analyses.len() < 2 {
843 return 1.0;
844 }
845
846 let mut layer_distributions = Vec::new();
848
849 for layer in layer_analyses {
850 let mut distribution: HashMap<HeadSpecializationType, f32> = HashMap::new();
851 for head in &layer.head_analyses {
852 *distribution.entry(head.specialization_type.clone()).or_insert(0.0) += 1.0;
853 }
854
855 let total: f32 = distribution.values().sum();
857 if total > 0.0 {
858 for value in distribution.values_mut() {
859 *value /= total;
860 }
861 }
862
863 layer_distributions.push(distribution);
864 }
865
866 let mut total_similarity = 0.0;
868 let mut pair_count = 0;
869
870 for i in 0..layer_distributions.len() {
871 for j in (i + 1)..layer_distributions.len() {
872 let similarity = self.compute_distribution_similarity(
873 &layer_distributions[i],
874 &layer_distributions[j],
875 );
876 total_similarity += similarity;
877 pair_count += 1;
878 }
879 }
880
881 if pair_count > 0 {
882 total_similarity / pair_count as f32
883 } else {
884 1.0
885 }
886 }
887
888 fn compute_distribution_similarity(
890 &self,
891 dist1: &HashMap<HeadSpecializationType, f32>,
892 dist2: &HashMap<HeadSpecializationType, f32>,
893 ) -> f32 {
894 let mut all_keys: std::collections::HashSet<_> = dist1.keys().collect();
895 all_keys.extend(dist2.keys());
896
897 let mut similarity = 0.0;
898 for key in all_keys {
899 let val1 = dist1.get(key).unwrap_or(&0.0);
900 let val2 = dist2.get(key).unwrap_or(&0.0);
901 similarity += (val1 - val2).abs();
902 }
903
904 1.0 - (similarity / 2.0) }
906
907 fn analyze_pattern_progression(
909 &self,
910 layer_analyses: &[LayerAttentionAnalysis],
911 ) -> Result<PatternProgression> {
912 let mut pattern_evolution = Vec::new();
913
914 for layer in layer_analyses {
915 let mut pattern_counts: HashMap<AttentionPattern, usize> = HashMap::new();
916 for map in &layer.attention_maps {
917 *pattern_counts.entry(map.attention_pattern.clone()).or_insert(0) += 1;
918 }
919 pattern_evolution.push(pattern_counts);
920 }
921
922 let dominant_pattern_sequence = self.extract_dominant_patterns(&pattern_evolution);
923
924 Ok(PatternProgression {
925 pattern_evolution,
926 dominant_pattern_sequence,
927 })
928 }
929
930 fn extract_dominant_patterns(
932 &self,
933 pattern_evolution: &[HashMap<AttentionPattern, usize>],
934 ) -> Vec<AttentionPattern> {
935 pattern_evolution
936 .iter()
937 .map(|patterns| {
938 patterns
939 .iter()
940 .max_by_key(|(_, &count)| count)
941 .map(|(pattern, _)| pattern.clone())
942 .unwrap_or(AttentionPattern::Random)
943 })
944 .collect()
945 }
946
947 fn compute_layer_diversity_trend(&self, layer_analyses: &[LayerAttentionAnalysis]) -> Vec<f32> {
949 layer_analyses.iter().map(|layer| layer.layer_diversity_score).collect()
950 }
951
952 fn generate_model_attention_summary(&self) -> Result<ModelAttentionSummary> {
954 if self.layer_analyses.is_empty() {
955 return Ok(ModelAttentionSummary::default());
956 }
957
958 let total_heads: usize = self.layer_analyses.iter().map(|layer| layer.num_heads).sum();
959 let avg_diversity: f32 =
960 self.layer_analyses.iter().map(|layer| layer.layer_diversity_score).sum::<f32>()
961 / self.layer_analyses.len() as f32;
962 let avg_redundancy: f32 = self
963 .layer_analyses
964 .iter()
965 .map(|layer| layer.redundancy_analysis.overall_redundancy_score)
966 .sum::<f32>()
967 / self.layer_analyses.len() as f32;
968
969 let mut specialization_distribution: HashMap<HeadSpecializationType, usize> =
971 HashMap::new();
972 for layer in &self.layer_analyses {
973 for head in &layer.head_analyses {
974 *specialization_distribution
975 .entry(head.specialization_type.clone())
976 .or_insert(0) += 1;
977 }
978 }
979
980 Ok(ModelAttentionSummary {
981 total_layers: self.layer_analyses.len(),
982 total_heads,
983 average_diversity_score: avg_diversity,
984 average_redundancy_score: avg_redundancy,
985 specialization_distribution,
986 model_attention_health: self
987 .assess_model_attention_health(avg_diversity, avg_redundancy),
988 })
989 }
990
991 fn assess_model_attention_health(
993 &self,
994 diversity: f32,
995 redundancy: f32,
996 ) -> AttentionHealthStatus {
997 let health_score = diversity * (1.0 - redundancy); match health_score {
1000 s if s > 0.7 => AttentionHealthStatus::Excellent,
1001 s if s > 0.5 => AttentionHealthStatus::Good,
1002 s if s > 0.3 => AttentionHealthStatus::Fair,
1003 _ => AttentionHealthStatus::Poor,
1004 }
1005 }
1006}
1007
1008#[derive(Debug, Serialize, Deserialize)]
1010pub struct TransformerAttentionAnalysis {
1011 pub num_layers: usize,
1012 pub layer_analyses: Vec<LayerAttentionAnalysis>,
1013 pub cross_layer_analysis: Option<CrossLayerAnalysis>,
1014 pub model_attention_summary: ModelAttentionSummary,
1015}
1016
1017#[derive(Debug, Serialize, Deserialize)]
1019pub struct CrossLayerAnalysis {
1020 pub attention_evolution: AttentionEvolution,
1021 pub head_consistency: HeadConsistency,
1022 pub pattern_progression: PatternProgression,
1023 pub layer_diversity_trend: Vec<f32>,
1024}
1025
1026#[derive(Debug, Serialize, Deserialize)]
1028pub struct AttentionEvolution {
1029 pub entropy_trend: Vec<f32>,
1030 pub sparsity_trend: Vec<f32>,
1031 pub evolution_type: EvolutionType,
1032}
1033
1034#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
1036pub enum EvolutionType {
1037 Increasing, Decreasing, Stable, }
1041
1042#[derive(Debug, Serialize, Deserialize)]
1044pub struct HeadConsistency {
1045 pub specialization_consistency: HashMap<HeadSpecializationType, Vec<usize>>,
1046 pub pattern_consistency: HashMap<AttentionPattern, Vec<usize>>,
1047 pub consistency_score: f32,
1048}
1049
1050#[derive(Debug, Serialize, Deserialize)]
1052pub struct PatternProgression {
1053 pub pattern_evolution: Vec<HashMap<AttentionPattern, usize>>,
1054 pub dominant_pattern_sequence: Vec<AttentionPattern>,
1055}
1056
1057#[derive(Debug, Serialize, Deserialize)]
1059pub struct ModelAttentionSummary {
1060 pub total_layers: usize,
1061 pub total_heads: usize,
1062 pub average_diversity_score: f32,
1063 pub average_redundancy_score: f32,
1064 pub specialization_distribution: HashMap<HeadSpecializationType, usize>,
1065 pub model_attention_health: AttentionHealthStatus,
1066}
1067
1068impl Default for ModelAttentionSummary {
1069 fn default() -> Self {
1070 Self {
1071 total_layers: 0,
1072 total_heads: 0,
1073 average_diversity_score: 0.0,
1074 average_redundancy_score: 0.0,
1075 specialization_distribution: HashMap::new(),
1076 model_attention_health: AttentionHealthStatus::Poor,
1077 }
1078 }
1079}
1080
1081#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
1083pub enum AttentionHealthStatus {
1084 Excellent,
1085 Good,
1086 Fair,
1087 Poor,
1088}
1089
1090#[macro_export]
1092macro_rules! debug_attention {
1093 ($attention_weights:expr) => {{
1094 let mut debugger = $crate::neural_network_debugging::AttentionDebugger::new(
1095 $crate::neural_network_debugging::AttentionDebugConfig::default(),
1096 );
1097 debugger.analyze_attention_layer(0, $attention_weights)
1098 }};
1099}
1100
1101#[macro_export]
1103macro_rules! debug_transformer {
1104 ($model_weights:expr) => {{
1105 let mut debugger = $crate::neural_network_debugging::TransformerDebugger::new(
1106 $crate::neural_network_debugging::TransformerDebugConfig::default(),
1107 );
1108 debugger.analyze_transformer_attention($model_weights)
1109 }};
1110}
1111
1112#[cfg(test)]
1113mod tests {
1114 use super::*;
1115 use scirs2_core::ndarray::{ArrayD, IxDyn};
1116
1117 fn make_attention_config() -> AttentionDebugConfig {
1118 AttentionDebugConfig::default()
1119 }
1120
1121 fn make_uniform_weights(seq_len: usize) -> ArrayD<f32> {
1122 let val = 1.0 / seq_len as f32;
1123 ArrayD::from_elem(IxDyn(&[seq_len, seq_len]), val)
1124 }
1125
1126 fn make_diagonal_weights(seq_len: usize) -> ArrayD<f32> {
1127 let mut weights = ArrayD::zeros(IxDyn(&[seq_len, seq_len]));
1128 for i in 0..seq_len {
1129 weights[[i, i]] = 1.0;
1130 }
1131 weights
1132 }
1133
1134 fn make_sparse_weights(seq_len: usize) -> ArrayD<f32> {
1135 let mut weights = ArrayD::zeros(IxDyn(&[seq_len, seq_len]));
1136 for i in 0..seq_len {
1138 weights[[i, 0]] = 1.0;
1139 }
1140 weights
1141 }
1142
1143 #[test]
1144 fn test_attention_debug_config_default() {
1145 let config = make_attention_config();
1146 assert!(config.enable_attention_visualization);
1147 assert!(config.enable_head_analysis);
1148 assert_eq!(config.max_heads_to_analyze, 16);
1149 }
1150
1151 #[test]
1152 fn test_attention_debugger_creation() {
1153 let debugger = AttentionDebugger::new(make_attention_config());
1154 assert!(debugger.attention_maps.is_empty());
1155 assert!(debugger.head_analysis.is_empty());
1156 }
1157
1158 #[test]
1159 fn test_analyze_attention_layer_single_head() {
1160 let mut debugger = AttentionDebugger::new(make_attention_config());
1161 let weights = vec![make_uniform_weights(8)];
1162 let result = debugger.analyze_attention_layer(0, &weights);
1163 assert!(result.is_ok());
1164 let analysis = result.expect("analysis should succeed");
1165 assert_eq!(analysis.layer_index, 0);
1166 assert_eq!(analysis.num_heads, 1);
1167 assert_eq!(analysis.head_analyses.len(), 1);
1168 }
1169
1170 #[test]
1171 fn test_analyze_attention_layer_multiple_heads() {
1172 let mut debugger = AttentionDebugger::new(make_attention_config());
1173 let weights = vec![
1174 make_uniform_weights(8),
1175 make_diagonal_weights(8),
1176 make_sparse_weights(8),
1177 ];
1178 let result = debugger.analyze_attention_layer(0, &weights);
1179 assert!(result.is_ok());
1180 let analysis = result.expect("analysis should succeed");
1181 assert_eq!(analysis.num_heads, 3);
1182 assert_eq!(analysis.head_analyses.len(), 3);
1183 }
1184
1185 #[test]
1186 fn test_detect_attention_pattern_uniform() {
1187 let debugger = AttentionDebugger::new(make_attention_config());
1188 let seq_len = 10;
1189 let val = 1.0 / seq_len as f32;
1190 let weights: Vec<Vec<f32>> = (0..seq_len).map(|_| vec![val; seq_len]).collect();
1191 let pattern = debugger.detect_attention_pattern(&weights);
1192 assert!(matches!(pattern, AttentionPattern::Uniform));
1193 }
1194
1195 #[test]
1196 fn test_detect_attention_pattern_diagonal() {
1197 let debugger = AttentionDebugger::new(make_attention_config());
1198 let seq_len = 10;
1199 let weights: Vec<Vec<f32>> = (0..seq_len)
1200 .map(|i| {
1201 let mut row = vec![0.0; seq_len];
1202 for j in 0..seq_len {
1204 if (i as i32 - j as i32).abs() <= 1 {
1205 row[j] = 1.0;
1206 }
1207 }
1208 row
1209 })
1210 .collect();
1211 let pattern = debugger.detect_attention_pattern(&weights);
1212 assert!(matches!(pattern, AttentionPattern::Diagonal));
1213 }
1214
1215 #[test]
1216 fn test_detect_attention_pattern_empty() {
1217 let debugger = AttentionDebugger::new(make_attention_config());
1218 let weights: Vec<Vec<f32>> = vec![];
1219 let pattern = debugger.detect_attention_pattern(&weights);
1220 assert!(matches!(pattern, AttentionPattern::Random));
1221 }
1222
1223 #[test]
1224 fn test_measure_diagonal_strength() {
1225 let debugger = AttentionDebugger::new(make_attention_config());
1226 let seq_len = 8;
1227 let weights: Vec<Vec<f32>> = (0..seq_len)
1228 .map(|i| {
1229 let mut row = vec![0.01; seq_len];
1230 row[i] = 10.0;
1231 row
1232 })
1233 .collect();
1234 let strength = debugger.measure_diagonal_strength(&weights);
1235 assert!(strength > 0.5);
1236 }
1237
1238 #[test]
1239 fn test_measure_diagonal_strength_empty() {
1240 let debugger = AttentionDebugger::new(make_attention_config());
1241 let weights: Vec<Vec<f32>> = vec![];
1242 assert!((debugger.measure_diagonal_strength(&weights) - 0.0).abs() < f32::EPSILON);
1243 }
1244
1245 #[test]
1246 fn test_measure_uniformity() {
1247 let debugger = AttentionDebugger::new(make_attention_config());
1248 let seq_len = 8;
1249 let val = 1.0 / seq_len as f32;
1250 let weights: Vec<Vec<f32>> = (0..seq_len).map(|_| vec![val; seq_len]).collect();
1251 let uniformity = debugger.measure_uniformity(&weights);
1252 assert!(uniformity > 0.9);
1253 }
1254
1255 #[test]
1256 fn test_measure_uniformity_empty() {
1257 let debugger = AttentionDebugger::new(make_attention_config());
1258 let weights: Vec<Vec<f32>> = vec![];
1259 assert!((debugger.measure_uniformity(&weights) - 0.0).abs() < f32::EPSILON);
1260 }
1261
1262 #[test]
1263 fn test_has_block_structure_false() {
1264 let debugger = AttentionDebugger::new(make_attention_config());
1265 let seq_len = 8;
1266 let val = 1.0 / seq_len as f32;
1267 let weights: Vec<Vec<f32>> = (0..seq_len).map(|_| vec![val; seq_len]).collect();
1268 assert!(!debugger.has_block_structure(&weights));
1269 }
1270
1271 #[test]
1272 fn test_has_block_structure_small() {
1273 let debugger = AttentionDebugger::new(make_attention_config());
1274 let weights: Vec<Vec<f32>> = vec![vec![1.0], vec![1.0]];
1275 assert!(!debugger.has_block_structure(&weights));
1276 }
1277
1278 #[test]
1279 fn test_classify_head_specialization_local() {
1280 let debugger = AttentionDebugger::new(make_attention_config());
1281 let weights = make_diagonal_weights(10);
1282 let result = debugger.classify_head_specialization(&weights);
1283 assert!(result.is_ok());
1284 let spec = result.expect("classification should succeed");
1285 assert!(matches!(spec, HeadSpecializationType::LocalSyntax));
1286 }
1287
1288 #[test]
1289 fn test_compute_sparsity_ratio() {
1290 let debugger = AttentionDebugger::new(make_attention_config());
1291 let seq_len = 10;
1292 let mut weights: Vec<Vec<f32>> = (0..seq_len).map(|_| vec![0.0; seq_len]).collect();
1293 for i in 0..seq_len {
1295 weights[i][0] = 1.0;
1296 }
1297 let sparsity = debugger.compute_sparsity_ratio(&weights);
1298 assert!(sparsity > 0.8);
1299 }
1300
1301 #[test]
1302 fn test_measure_long_range_attention() {
1303 let debugger = AttentionDebugger::new(make_attention_config());
1304 let seq_len = 20;
1305 let weights: Vec<Vec<f32>> = (0..seq_len)
1307 .map(|_| {
1308 let mut row = vec![0.0; seq_len];
1309 row[seq_len - 1] = 1.0;
1310 row
1311 })
1312 .collect();
1313 let long_range = debugger.measure_long_range_attention(&weights);
1314 assert!(long_range > 0.3);
1315 }
1316
1317 #[test]
1318 fn test_measure_long_range_attention_small() {
1319 let debugger = AttentionDebugger::new(make_attention_config());
1320 let weights: Vec<Vec<f32>> = vec![vec![1.0]];
1321 assert!((debugger.measure_long_range_attention(&weights) - 0.0).abs() < f32::EPSILON);
1322 }
1323
1324 #[test]
1325 fn test_model_attention_summary_default() {
1326 let summary = ModelAttentionSummary::default();
1327 assert_eq!(summary.total_layers, 0);
1328 assert_eq!(summary.total_heads, 0);
1329 assert!(matches!(
1330 summary.model_attention_health,
1331 AttentionHealthStatus::Poor
1332 ));
1333 }
1334
1335 #[test]
1336 fn test_analyze_attention_layer_head_limit() {
1337 let mut config = make_attention_config();
1338 config.max_heads_to_analyze = 2;
1339 let mut debugger = AttentionDebugger::new(config);
1340 let weights = vec![
1341 make_uniform_weights(4),
1342 make_uniform_weights(4),
1343 make_uniform_weights(4),
1344 make_uniform_weights(4),
1345 ];
1346 let result = debugger.analyze_attention_layer(0, &weights);
1347 assert!(result.is_ok());
1348 let analysis = result.expect("analysis should succeed");
1349 assert_eq!(analysis.head_analyses.len(), 2);
1350 }
1351
1352 #[test]
1353 fn test_create_attention_map_wrong_dimensions() {
1354 let debugger = AttentionDebugger::new(make_attention_config());
1355 let weights_3d = ArrayD::zeros(IxDyn(&[2, 3, 4]));
1356 let result = debugger.create_attention_map(0, 0, &weights_3d);
1357 assert!(result.is_err());
1358 }
1359
1360 #[test]
1361 fn test_attention_entropy_computation() {
1362 let mut debugger = AttentionDebugger::new(make_attention_config());
1363 let weights = vec![make_uniform_weights(8)];
1364 let analysis =
1365 debugger.analyze_attention_layer(0, &weights).expect("analysis should succeed");
1366 assert!(analysis.attention_maps[0].attention_entropy > 0.0);
1368 }
1369
1370 #[test]
1371 fn test_attention_pattern_variants() {
1372 let patterns = [
1373 AttentionPattern::Diagonal,
1374 AttentionPattern::Block,
1375 AttentionPattern::Sparse,
1376 AttentionPattern::Uniform,
1377 AttentionPattern::Concentrated,
1378 AttentionPattern::Strided,
1379 AttentionPattern::Random,
1380 ];
1381 assert_eq!(patterns.len(), 7);
1382 }
1383
1384 #[test]
1385 fn test_head_specialization_variants() {
1386 let specs = [
1387 HeadSpecializationType::LocalSyntax,
1388 HeadSpecializationType::LongRange,
1389 HeadSpecializationType::Positional,
1390 HeadSpecializationType::ContentBased,
1391 HeadSpecializationType::Copying,
1392 HeadSpecializationType::Delimiter,
1393 HeadSpecializationType::Mixed,
1394 HeadSpecializationType::Redundant,
1395 ];
1396 assert_eq!(specs.len(), 8);
1397 }
1398
1399 #[test]
1400 fn test_attention_health_status_variants() {
1401 let statuses = [
1402 AttentionHealthStatus::Excellent,
1403 AttentionHealthStatus::Good,
1404 AttentionHealthStatus::Fair,
1405 AttentionHealthStatus::Poor,
1406 ];
1407 assert_eq!(statuses.len(), 4);
1408 }
1409
1410 #[test]
1411 fn test_attention_distribution_creation() {
1412 let dist = AttentionDistribution {
1413 mean_attention: 0.125,
1414 std_attention: 0.05,
1415 max_attention: 0.5,
1416 min_attention: 0.01,
1417 entropy: 2.8,
1418 effective_context_length: 6.5,
1419 };
1420 assert!(dist.mean_attention > 0.0);
1421 assert!(dist.max_attention > dist.mean_attention);
1422 assert!(dist.entropy > 0.0);
1423 }
1424
1425 #[test]
1426 fn test_attention_map_creation() {
1427 let map = AttentionMap {
1428 layer_index: 0,
1429 head_index: 3,
1430 sequence_length: 8,
1431 attention_weights: vec![vec![0.125; 8]; 8],
1432 attention_pattern: AttentionPattern::Uniform,
1433 attention_entropy: 3.0,
1434 sparsity_ratio: 0.0,
1435 };
1436 assert_eq!(map.layer_index, 0);
1437 assert_eq!(map.head_index, 3);
1438 assert_eq!(map.sequence_length, 8);
1439 }
1440
1441 #[test]
1442 fn test_attention_head_analysis_creation() {
1443 let analysis = AttentionHeadAnalysis {
1444 head_id: 0,
1445 layer_id: 0,
1446 specialization_type: HeadSpecializationType::ContentBased,
1447 attention_distribution: AttentionDistribution {
1448 mean_attention: 0.1,
1449 std_attention: 0.02,
1450 max_attention: 0.3,
1451 min_attention: 0.01,
1452 entropy: 2.5,
1453 effective_context_length: 5.0,
1454 },
1455 redundancy_score: 0.1,
1456 importance_score: 0.8,
1457 patterns_detected: vec![AttentionPattern::Random],
1458 };
1459 assert_eq!(analysis.head_id, 0);
1460 assert!(analysis.importance_score > analysis.redundancy_score);
1461 }
1462
1463 #[test]
1464 fn test_pattern_progression_creation() {
1465 let progression = PatternProgression {
1466 pattern_evolution: vec![HashMap::new()],
1467 dominant_pattern_sequence: vec![AttentionPattern::Diagonal, AttentionPattern::Sparse],
1468 };
1469 assert_eq!(progression.dominant_pattern_sequence.len(), 2);
1470 }
1471}